diff --git a/scapy/layers/ipsec.py b/scapy/layers/ipsec.py
index f45fa49c15996270c4887d2c69f7939c0aac2fad..b609b9473c81424bba202a26a473982ca6bb5d09 100644
--- a/scapy/layers/ipsec.py
+++ b/scapy/layers/ipsec.py
@@ -40,6 +40,7 @@ True
"""
import socket
+import struct
from Crypto.Util.number import GCD as gcd
from scapy.data import IP_PROTOS
@@ -169,7 +170,7 @@ class CryptAlgo(object):
IPSec encryption algorithm
"""
- def __init__(self, name, cipher, mode, block_size=None, iv_size=None, key_size=None):
+ def __init__(self, name, cipher, mode, block_size=None, iv_size=None, key_size=None, icv_size=None):
"""
@param name: the name of this encryption algorithm
@param cipher: a Cipher module
@@ -185,6 +186,11 @@ class CryptAlgo(object):
self.name = name
self.cipher = cipher
self.mode = mode
+ self.icv_size = icv_size
+ self.is_aead = (hasattr(self.cipher, 'MODE_GCM') and
+ self.mode == self.cipher.MODE_GCM) or \
+ (hasattr(self.cipher, 'MODE_CCM') and
+ self.mode == self.cipher.MODE_CCM)
if block_size is not None:
self.block_size = block_size
@@ -232,8 +238,7 @@ class CryptAlgo(object):
@return: an initialized cipher object for this algo
"""
if (hasattr(self.cipher, 'MODE_CTR') and self.mode == self.cipher.MODE_CTR
- or hasattr(self.cipher, 'MODE_GCM') and self.mode == self.cipher.MODE_GCM
- or hasattr(self.cipher, 'MODE_CCM') and self.mode == self.cipher.MODE_CCM):
+ or self.is_aead):
# in counter mode, the "iv" must be incremented for each block
# it is calculated like this:
# +---------+------------------+---------+
@@ -252,8 +257,7 @@ class CryptAlgo(object):
# <--------->
# nonce_size
cipher_key, nonce = key[:-nonce_size], key[-nonce_size:]
- if (hasattr(self.cipher, 'MODE_GCM') and self.mode == self.cipher.MODE_GCM
- or hasattr(self.cipher, 'MODE_CCM') and self.mode == self.cipher.MODE_CCM):
+ if self.is_aead:
return self.cipher.new(cipher_key, self.mode, nonce + iv,
counter=Counter.new(4 * 8, prefix=nonce + iv))
@@ -307,7 +311,13 @@ class CryptAlgo(object):
if self.cipher:
self.check_key(key)
cipher = self.new_cipher(key, esp.iv)
- data = cipher.encrypt(data)
+
+ if self.is_aead:
+ cipher.update(struct.pack('!LL', esp.spi, esp.seq))
+ data = cipher.encrypt(data)
+ data += cipher.digest()[:self.icv_size]
+ else:
+ data = cipher.encrypt(data)
return ESP(spi=esp.spi, seq=esp.seq, data=esp.iv + data)
@@ -323,12 +333,19 @@ class CryptAlgo(object):
"""
self.check_key(key)
+ if self.cipher and self.is_aead:
+ icv_size = self.icv_size
+
iv = esp.data[:self.iv_size]
data = esp.data[self.iv_size:len(esp.data) - icv_size]
icv = esp.data[len(esp.data) - icv_size:]
if self.cipher:
cipher = self.new_cipher(key, iv)
+
+ if self.is_aead:
+ cipher.update(struct.pack('!LL', esp.spi, esp.seq))
+
data = cipher.decrypt(data)
# extract padlen and nh
@@ -373,12 +390,14 @@ if AES:
cipher=AES,
mode=AES.MODE_GCM,
iv_size=8,
+ icv_size=8,
key_size=(16 + 4, 24 + 4, 32 + 4))
if hasattr(AES, "MODE_CCM"):
CRYPT_ALGOS['AES-CCM'] = CryptAlgo('AES-CCM',
cipher=AES,
mode=AES.MODE_CCM,
iv_size=8,
+ icv_size=8,
key_size=(16 + 4, 24 + 4, 32 + 4))
if DES:
CRYPT_ALGOS['DES'] = CryptAlgo('DES',