From d530da66b23f854376dcb457e2bb85c67cf9712e Mon Sep 17 00:00:00 2001
From: Thomas Faivre <thomas.faivre@6wind.com>
Date: Mon, 27 Feb 2017 13:17:39 +0100
Subject: [PATCH] layers/ipsec: fix iv handling in special modes

There is a distinction to be made between the IV generated using
CryptAlgo.generate_iv and the IV given as argument to the cipher mode in
CryptAlgo.new_cipher.

The first one is random string which is sent with the ESP packet (first
bytes of the data field). The cipher mode only affects the size of the
string in our implementation (some modes like GCM may implement a
counter instead of pure random).

And the second is a combination of the salt, the ESP iv and possibly other
things. This can vary a lot depending on the mode.

Add an attribute to CryptAlgo to give a function computing this "second"
IV based on SA information.

Signed-off-by: Thomas Faivre <thomas.faivre@6wind.com>
---
 scapy/layers/ipsec.py | 63 ++++++++++++++++++++++++++++---------------
 1 file changed, 41 insertions(+), 22 deletions(-)

diff --git a/scapy/layers/ipsec.py b/scapy/layers/ipsec.py
index 26eea06c..1e4c7644 100644
--- a/scapy/layers/ipsec.py
+++ b/scapy/layers/ipsec.py
@@ -183,7 +183,7 @@ class CryptAlgo(object):
     """
 
     def __init__(self, name, cipher, mode, block_size=None, iv_size=None,
-                 key_size=None, icv_size=None, salt_size=None):
+                 key_size=None, icv_size=None, salt_size=None, format_mode_iv=None):
         """
         @param name: the name of this encryption algorithm
         @param cipher: a Cipher module
@@ -199,6 +199,9 @@ class CryptAlgo(object):
                          Used by Combined Mode Algorithms e.g. GCM
         @param salt_size: the length of the salt to use as the IV prefix.
                           Usually used by Counter modes e.g. CTR
+        @param format_mode_iv: function to format the Initialization Vector
+                               e.g. handle the salt value
+                               Default is the random buffer from `generate_iv`
         """
         self.name = name
         self.cipher = cipher
@@ -235,6 +238,11 @@ class CryptAlgo(object):
         else:
             self.salt_size = salt_size
 
+        if format_mode_iv is None:
+            self._format_mode_iv = lambda iv, **kw: iv
+        else:
+            self._format_mode_iv = format_mode_iv
+
     def check_key(self, key):
         """
         Check that the key length is valid.
@@ -251,17 +259,17 @@ class CryptAlgo(object):
         """
         # XXX: Handle counter modes with real counters? RFCs allow the use of
         # XXX: random bytes for counters, so it is not wrong to do it that way
-        return os.urandom(self.iv_size - self.salt_size)
+        return os.urandom(self.iv_size)
 
     @crypto_validator
-    def new_cipher(self, key, iv, digest=None):
+    def new_cipher(self, key, mode_iv, digest=None):
         """
-        @param key:    the secret key, a byte string
-        @param iv:     the initialization vector, a byte string. Used as the
-                       initial nonce in counter mode
-        @param digest: also known as tag or icv. A byte string containing the
-                       digest of the encrypted data. Only use this during
-                       decryption!
+        @param key:     the secret key, a byte string
+        @param mode_iv: the initialization vector or nonce, a byte string.
+                        Formatted by `format_mode_iv`.
+        @param digest:  also known as tag or icv. A byte string containing the
+                        digest of the encrypted data. Only use this during
+                        decryption!
 
         @return:    an initialized cipher object for this algo
         """
@@ -269,13 +277,13 @@ class CryptAlgo(object):
             # With AEAD, the mode needs the digest during decryption.
             return Cipher(
                 self.cipher(key),
-                self.mode(iv, digest, len(digest)),
+                self.mode(mode_iv, digest, len(digest)),
                 default_backend(),
             )
         else:
             return Cipher(
                 self.cipher(key),
-                self.mode(iv),
+                self.mode(mode_iv),
                 default_backend(),
             )
 
@@ -314,10 +322,11 @@ class CryptAlgo(object):
 
         return esp
 
-    def encrypt(self, esp, key):
+    def encrypt(self, sa, esp, key):
         """
         Encrypt an ESP packet
 
+        @param sa:   the SecurityAssociation associated with the ESP packet.
         @param esp:  an unencrypted _ESPPlain packet with valid padding
         @param key:  the secret key used for encryption
 
@@ -326,7 +335,8 @@ class CryptAlgo(object):
         data = esp.data_for_encryption()
 
         if self.cipher:
-            cipher = self.new_cipher(key, esp.iv)
+            mode_iv = self._format_mode_iv(algo=self, sa=sa, iv=esp.iv)
+            cipher = self.new_cipher(key, mode_iv)
             encryptor = cipher.encryptor()
 
             if self.is_aead:
@@ -339,10 +349,11 @@ class CryptAlgo(object):
 
         return ESP(spi=esp.spi, seq=esp.seq, data=esp.iv + data)
 
-    def decrypt(self, esp, key, icv_size=None):
+    def decrypt(self, sa, esp, key, icv_size=None):
         """
         Decrypt an ESP packet
 
+        @param sa:         the SecurityAssociation associated with the ESP packet.
         @param esp:        an encrypted ESP packet
         @param key:        the secret key used for encryption
         @param icv_size:   the length of the icv used for integrity check
@@ -359,7 +370,8 @@ class CryptAlgo(object):
         icv = esp.data[len(esp.data) - icv_size:]
 
         if self.cipher:
-            cipher = self.new_cipher(key, iv, icv)
+            mode_iv = self._format_mode_iv(sa=sa, iv=iv)
+            cipher = self.new_cipher(key, mode_iv, icv)
             decryptor = cipher.decryptor()
 
             if self.is_aead:
@@ -402,20 +414,29 @@ if algorithms:
     CRYPT_ALGOS['AES-CBC'] = CryptAlgo('AES-CBC',
                                        cipher=algorithms.AES,
                                        mode=modes.CBC)
+    _aes_ctr_format_mode_iv = lambda sa, iv, **kw: sa.crypt_salt + iv + b'\x00\x00\x00\x01'
     CRYPT_ALGOS['AES-CTR'] = CryptAlgo('AES-CTR',
                                        cipher=algorithms.AES,
                                        mode=modes.CTR,
-                                       salt_size=4)
+                                       iv_size=8,
+                                       salt_size=4,
+                                       format_mode_iv=_aes_ctr_format_mode_iv)
+    _salt_format_mode_iv = lambda sa, iv, **kw: sa.crypt_salt + iv
     CRYPT_ALGOS['AES-GCM'] = CryptAlgo('AES-GCM',
                                        cipher=algorithms.AES,
                                        mode=modes.GCM,
                                        salt_size=4,
-                                       icv_size=16)
+                                       iv_size=8,
+                                       icv_size=16,
+                                       format_mode_iv=_salt_format_mode_iv)
     if hasattr(modes, 'CCM'):
         CRYPT_ALGOS['AES-CCM'] = CryptAlgo('AES-CCM',
                                            cipher=algorithms.AES,
                                            mode=modes.CCM,
-                                           icv_size=16)
+                                           iv_size=8,
+                                           salt_size=3,
+                                           icv_size=16,
+                                           format_mode_iv=_salt_format_mode_iv)
     # XXX: Flagged as weak by 'cryptography'. Kept for backward compatibility
     CRYPT_ALGOS['Blowfish'] = CryptAlgo('Blowfish',
                                         cipher=algorithms.Blowfish,
@@ -806,8 +827,6 @@ class SecurityAssociation(object):
 
         if iv is None:
             iv = self.crypt_algo.generate_iv()
-            if self.crypt_salt:
-                iv = self.crypt_salt + iv
         else:
             if len(iv) != self.crypt_algo.iv_size:
                 raise TypeError('iv length must be %s' % self.crypt_algo.iv_size)
@@ -832,7 +851,7 @@ class SecurityAssociation(object):
         esp.nh = nh
 
         esp = self.crypt_algo.pad(esp)
-        esp = self.crypt_algo.encrypt(esp, self.crypt_key)
+        esp = self.crypt_algo.encrypt(self, esp, self.crypt_key)
 
         self.auth_algo.sign(esp, self.auth_key)
 
@@ -938,7 +957,7 @@ class SecurityAssociation(object):
             self.check_spi(pkt)
             self.auth_algo.verify(encrypted, self.auth_key)
 
-        esp = self.crypt_algo.decrypt(encrypted, self.crypt_key,
+        esp = self.crypt_algo.decrypt(self, encrypted, self.crypt_key,
                                       self.crypt_algo.icv_size or
                                       self.auth_algo.icv_size)
 
-- 
GitLab