From 0c892a5b12e34ebc84b15a7a77d30443c852db90 Mon Sep 17 00:00:00 2001
From: Florian Maury <florian.maury@ssi.gouv.fr>
Date: Fri, 29 Jul 2016 14:27:04 +0200
Subject: [PATCH] Add MultiFlagsField

---
 scapy/fields.py     |  91 +++++++++++++++++++++-
 test/regression.uts | 185 ++++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 274 insertions(+), 2 deletions(-)

diff --git a/scapy/fields.py b/scapy/fields.py
index 4299827b..7f48092c 100644
--- a/scapy/fields.py
+++ b/scapy/fields.py
@@ -7,7 +7,7 @@
 Fields: basic data structures that make up parts of packets.
 """
 
-import struct,copy,socket
+import struct,copy,socket,collections
 from scapy.config import conf
 from scapy.volatile import *
 from scapy.data import *
@@ -990,7 +990,94 @@ class FlagsField(BitField):
             r = "+".join(r)
         return r
 
-            
+
+MultiFlagsEntry = collections.namedtuple('MultiFlagEntry', ['short', 'long'])
+
+
+class MultiFlagsField(BitField):
+    __slots__ = FlagsField.__slots__ + ["depends_on"]
+
+    def __init__(self, name, default, size, names, depends_on):
+        self.names = names
+        self.depends_on = depends_on
+        super(MultiFlagsField, self).__init__(name, default, size)
+
+    def any2i(self, pkt, x):
+        assert isinstance(x, (int, long, set)), 'set expected'
+
+        if pkt is not None:
+            if isinstance(x, (int, long)):
+                x = self.m2i(pkt, x)
+            else:
+                v = self.depends_on(pkt)
+                if v is not None:
+                    assert self.names.has_key(v), 'invalid dependency'
+                    these_names = self.names[v]
+                    s = set()
+                    for i in x:
+                        for j in these_names.keys():
+                            if these_names[j].short == i:
+                                s.add(i)
+                                break
+                        else:
+                            assert False, 'Unknown flag "{}" with this dependency'.format(i)
+                            continue
+                    x = s
+        return x
+
+    def i2m(self, pkt, x):
+        v = self.depends_on(pkt)
+        if v in self.names:
+            these_names = self.names[v]
+        else:
+            these_names = {}
+
+        r = 0
+        for flag_set in x:
+            for i in these_names.keys():
+                if these_names[i].short == flag_set:
+                    r |= 1 << i
+                    break
+            else:
+                r |= 1 << int(flag_set[len('bit '):])
+        return r
+
+    def m2i(self, pkt, x):
+        v = self.depends_on(pkt)
+        if v in self.names:
+            these_names = self.names[v]
+        else:
+            these_names = {}
+
+        r = set()
+        i = 0
+
+        while x:
+            if x & 1:
+                if i in these_names:
+                    r.add(these_names[i].short)
+                else:
+                    r.add('bit {}'.format(i))
+            x >>= 1
+            i += 1
+        return r
+
+    def i2repr(self, pkt, x):
+        v = self.depends_on(pkt)
+        if self.names.has_key(v):
+            these_names = self.names[v]
+        else:
+            these_names = {}
+
+        r = set()
+        for flag_set in x:
+            for i in these_names.itervalues():
+                if i.short == flag_set:
+                    r.add("{} ({})".format(i.long, i.short))
+                    break
+            else:
+                r.add(flag_set)
+        return repr(r)
 
 
 class FixedPointField(BitField):
diff --git a/test/regression.uts b/test/regression.uts
index eda698d9..c512a5cb 100644
--- a/test/regression.uts
+++ b/test/regression.uts
@@ -6304,3 +6304,188 @@ except:
 assert(ret)
 p = ss.recv()
 assert(p.data == 3)
+
++ Tests on MultiFlagsField
+
+= Test calls on MultiFlagsField.any2i
+~ multiflagsfield
+
+import collections
+MockPacket = collections.namedtuple('MockPacket', ['type'])
+
+f = MultiFlagsField('flags', set(), 3, {
+        0: {
+            0: MultiFlagsEntry('A', 'OptionA'),
+            1: MultiFlagsEntry('B', 'OptionB'),
+        },
+        1: {
+            0: MultiFlagsEntry('+', 'Plus'),
+            1: MultiFlagsEntry('*', 'Star'),
+        },
+    },
+    depends_on=lambda x: x.type
+)
+
+mp = MockPacket(0)
+x = f.any2i(mp, set())
+assert(isinstance(x, set))
+assert(len(x) == 0)
+x = f.any2i(mp, {'A'})
+assert(isinstance(x, set))
+assert(len(x) == 1)
+assert('A' in x)
+assert('B' not in x)
+assert('+' not in x)
+x = f.any2i(mp, {'A', 'B'})
+assert(isinstance(x, set))
+assert(len(x) == 2)
+assert('A' in x)
+assert('B' in x)
+assert('+' not in x)
+assert('*' not in x)
+x = f.any2i(mp, 3)
+assert(isinstance(x, set))
+assert(len(x) == 2)
+assert('A' in x)
+assert('B' in x)
+assert('+' not in x)
+assert('*' not in x)
+x = f.any2i(mp, 7)
+assert(isinstance(x, set))
+assert(len(x) == 3)
+assert('A' in x)
+assert('B' in x)
+assert('bit 2' in x)
+assert('+' not in x)
+assert('*' not in x)
+mp = MockPacket(1)
+x = f.any2i(mp, {'+', '*'})
+assert(isinstance(x, set))
+assert(len(x) == 2)
+assert('+' in x)
+assert('*' in x)
+assert('A' not in x)
+assert('B' not in x)
+try:
+    x = f.any2i(mp, {'A'})
+    ret = False
+except AssertionError:
+    ret = True
+
+assert(ret)
+#Following test demonstrate a non-sensical yet acceptable usage :(
+x = f.any2i(None, {'Toto'})
+assert('Toto' in x)
+
+= Test calls on MultiFlagsField.i2m
+~ multiflagsfield
+
+import collections
+MockPacket = collections.namedtuple('MockPacket', ['type'])
+
+f = MultiFlagsField('flags', set(), 3, {
+        0: {
+            0: MultiFlagsEntry('A', 'OptionA'),
+            1: MultiFlagsEntry('B', 'OptionB'),
+        },
+        1: {
+            0: MultiFlagsEntry('+', 'Plus'),
+            1: MultiFlagsEntry('*', 'Star'),
+        },
+    },
+    depends_on=lambda x: x.type
+)
+
+mp = MockPacket(0)
+x = f.i2m(mp, set())
+assert(isinstance(x, (int, long)))
+assert(x == 0)
+x = f.i2m(mp, {'A'})
+assert(isinstance(x, (int, long)))
+assert(x == 1)
+x = f.i2m(mp, {'A', 'B'})
+assert(isinstance(x, (int, long)))
+assert(x == 3)
+x = f.i2m(mp, {'A', 'B', 'bit 2'})
+assert(isinstance(x, (int, long)))
+assert(x == 7)
+try:
+    x = f.i2m(mp, {'+'})
+    ret = False
+except:
+    ret = True
+
+assert(ret)
+
+= Test calls on MultiFlagsField.m2i
+~ multiflagsfield
+
+import collections
+MockPacket = collections.namedtuple('MockPacket', ['type'])
+
+f = MultiFlagsField('flags', set(), 3, {
+        0: {
+            0: MultiFlagsEntry('A', 'OptionA'),
+            1: MultiFlagsEntry('B', 'OptionB'),
+        },
+        1: {
+            0: MultiFlagsEntry('+', 'Plus'),
+            1: MultiFlagsEntry('*', 'Star'),
+        },
+    },
+    depends_on=lambda x: x.type
+)
+
+mp = MockPacket(0)
+x = f.m2i(mp, 2)
+assert(isinstance(x, set))
+assert(len(x) == 1)
+assert('B' in x)
+assert('A' not in x)
+assert('*' not in x)
+
+x = f.m2i(mp, 7)
+assert(isinstance(x, set))
+assert('B' in x)
+assert('A' in x)
+assert('bit 2' in x)
+assert('*' not in x)
+assert('+' not in x)
+x = f.m2i(mp, 0)
+assert(len(x) == 0)
+mp = MockPacket(1)
+x = f.m2i(mp, 2)
+assert(isinstance(x, set))
+assert(len(x) == 1)
+assert('*' in x)
+assert('+' not in x)
+assert('B' not in x)
+
+= Test calls on MultiFlagsField.i2repr
+~ multiflagsfield
+
+import collections, re
+MockPacket = collections.namedtuple('MockPacket', ['type'])
+
+f = MultiFlagsField('flags', set(), 3, {
+        0: {
+            0: MultiFlagsEntry('A', 'OptionA'),
+            1: MultiFlagsEntry('B', 'OptionB'),
+        },
+        1: {
+            0: MultiFlagsEntry('+', 'Plus'),
+            1: MultiFlagsEntry('*', 'Star'),
+        },
+    },
+    depends_on=lambda x: x.type
+)
+
+mp = MockPacket(0)
+x = f.i2repr(mp, {'A', 'B'})
+assert(re.match(r'^.*OptionA \(A\).*$', x) is not None)
+assert(re.match(r'^.*OptionB \(B\).*$', x) is not None)
+mp = MockPacket(1)
+x = f.i2repr(mp, {'*', '+', 'bit 2'})
+assert(re.match(r'^.*Star \(\*\).*$', x) is not None)
+assert(re.match(r'^.*Plus \(\+\).*$', x) is not None)
+assert(re.match(r'^.*bit 2.*$', x) is not None)
-- 
GitLab