From fc6a4caaf1f7da48cb63e3906d5a07de6d49c977 Mon Sep 17 00:00:00 2001
From: Pierre LALET <pierre.lalet@cea.fr>
Date: Tue, 17 Jan 2017 02:04:07 +0100
Subject: [PATCH] Introduce FlagValue(int) objects to represent FlagsField()
 values

This makes it possible to write `if pkt[TCP].flags.S: [...]`.
---
 scapy/base_classes.py |  4 +--
 scapy/fields.py       | 82 ++++++++++++++++++++++++++++---------------
 test/regression.uts   | 40 +++++++++++++++++----
 3 files changed, 90 insertions(+), 36 deletions(-)

diff --git a/scapy/base_classes.py b/scapy/base_classes.py
index 636f28d1..9e4ee778 100644
--- a/scapy/base_classes.py
+++ b/scapy/base_classes.py
@@ -24,8 +24,8 @@ class SetGen(Gen):
         self._iterpacket=_iterpacket
         if isinstance(values, (list, BasePacketList)):
             self.values = list(values)
-        elif (type(values) is tuple) and (2 <= len(values) <= 3) and \
-             all(type(i) is int for i in values):
+        elif (isinstance(values, tuple) and (2 <= len(values) <= 3) and \
+             all(isinstance(i, int) for i in values)):
             # We use values[1] + 1 as stop value for xrange to maintain
             # the behavior of using tuples as field `values`
             self.values = [xrange(*((values[0], values[1] + 1) + values[2:]))]
diff --git a/scapy/fields.py b/scapy/fields.py
index cbe15534..ca0ebb59 100644
--- a/scapy/fields.py
+++ b/scapy/fields.py
@@ -934,6 +934,49 @@ class LEFieldLenField(FieldLenField):
         FieldLenField.__init__(self, name, default, length_of=length_of, fmt=fmt, count_of=count_of, fld=fld, adjust=adjust)
 
 
+class FlagValue(int):
+    __slots__ = ["names", "multi"]
+    @staticmethod
+    def __fixvalue(value, names):
+        if isinstance(value, basestring):
+            if isinstance(names, list):
+                value = value.split('+')
+            else:
+                value = list(value)
+        if isinstance(value, list):
+            y = 0
+            for i in value:
+                y |= 1 << names.index(i)
+            value = y
+        return value
+    def __new__(cls, value, names):
+        return super(FlagValue, cls).__new__(cls, cls.__fixvalue(value, names))
+    def __init__(self, value, names):
+        super(FlagValue, self).__init__(value)
+        self.multi = isinstance(names, list)
+        self.names = names
+    def flagrepr(self):
+        i = 0
+        r = []
+        x = int(self)
+        while x:
+            if x & 1:
+                r.append(self.names[i])
+            i += 1
+            x >>= 1
+        return ("+" if self.multi else "").join(r)
+    def __repr__(self):
+        return "<Flag %r (%s)>" % (int(self),
+                                   self.flagrepr())
+    def __deepcopy__(self, memo):
+        return self.__class__(int(self), self.names)
+    def __getattr__(self, attr):
+        try:
+            return bool((2 ** self.names.index(attr)) & int(self))
+        except ValueError:
+            return super(FlagValue, self).__getattr__(attr)
+
+
 class FlagsField(BitField):
     """ Handle Flag type field
 
@@ -957,37 +1000,20 @@ class FlagsField(BitField):
    """
     __slots__ = ["multi", "names"]
     def __init__(self, name, default, size, names):
-        self.multi = type(names) is list
-        if self.multi:
-            self.names = map(lambda x:[x], names)
-        else:
-            self.names = names
+        self.multi = isinstance(names, list)
+        self.names = names
         BitField.__init__(self, name, default, size)
     def any2i(self, pkt, x):
-        if type(x) is str:
-            if self.multi:
-                x = map(lambda y:[y], x.split("+"))
-            y = 0
-            for i in x:
-                y |= 1 << self.names.index(i)
-            x = y
-        return x
+        if isinstance(x, (list, tuple)):
+            return type(x)(None if v is None else FlagValue(v, self.names)
+                           for v in x)
+        return None if x is None else FlagValue(x, self.names)
     def i2repr(self, pkt, x):
-        if type(x) is list or type(x) is tuple:
-            return repr(x)
-        if self.multi:
-            r = []
-        else:
-            r = ""
-        i=0
-        while x:
-            if x & 1:
-                r += self.names[i]
-            i += 1
-            x >>= 1
-        if self.multi:
-            r = "+".join(r)
-        return r
+        if isinstance(x, (list, tuple)):
+            return repr(type(x)(
+                None if v is None else FlagValue(v, self.names).flagrepr()
+                for v in x))
+        return None if x is None else FlagValue(x, self.names).flagrepr()
 
 
 MultiFlagsEntry = collections.namedtuple('MultiFlagEntry', ['short', 'long'])
diff --git a/test/regression.uts b/test/regression.uts
index 11052b84..e08eabc3 100644
--- a/test/regression.uts
+++ b/test/regression.uts
@@ -7143,7 +7143,9 @@ assert(re.match(r'^.*Star \(\*\).*$', x) is not None)
 assert(re.match(r'^.*Plus \(\+\).*$', x) is not None)
 assert(re.match(r'^.*bit 2.*$', x) is not None)
 
-###########################################################################################################
+
+############
+############
 + Test correct conversion from binary to string of IPv6 addresses
 
 = IPv6 bin to string conversion - all zero bytes
@@ -7222,7 +7224,6 @@ assert(compressed1 == compressed2 == '1000:200:30:4:5:60:700:8000')
 
 ############
 ############
-
 + VRRP tests
 
 = VRRP - build
@@ -7236,7 +7237,6 @@ VRRP in p and p[VRRP].chksum == 0x7afd
 
 ############
 ############
-
 + L2TP tests
 
 = L2TP - build
@@ -7250,7 +7250,6 @@ L2TP in p and p[L2TP].len == 14 and p.tunnel_id == 0 and p[UDP].chksum == 0xf465
 
 ############
 ############
-
 + HSRP tests
 
 = HSRP - build & dissection
@@ -7263,7 +7262,6 @@ assert pkt[HSRPmd5].type == 4 and pkt[HSRPmd5].sourceip == defaddr
 
 ############
 ############
-
 + RIP tests
 
 = RIP - build
@@ -7277,7 +7275,6 @@ RIPEntry in p and RIPAuth in p and p[RIPAuth].password.startswith("scapy")
 
 ############
 ############
-
 + Radius tests
 
 = Radius - build
@@ -7487,3 +7484,34 @@ rek == 'b'
 random.seed(0x2807)
 rts = RandTermString(4, "scapy")
 sane(str(rts)) == "...[scapy"
+
+
+############
+############
++ Flags
+
+= IP flags
+~ IP
+
+pkt = IP(flags="MF")
+assert pkt.flags.MF
+assert not pkt.flags.DF
+assert not pkt.flags.evil
+
+pkt = IP(flags=3)
+assert pkt.flags.MF
+assert pkt.flags.DF
+assert not pkt.flags.evil
+
+= TCP flags
+~ TCP
+
+pkt = TCP(flags="SA")
+assert pkt.flags == 18
+assert pkt.flags.S
+assert pkt.flags.A
+assert not any(getattr(pkt.flags, f) for f in 'FRPUECN')
+
+pkt = TCP(flags=56)
+assert all(getattr(pkt.flags, f) for f in 'PAU')
+assert not any(getattr(pkt.flags, f) for f in 'FSRECN')
-- 
GitLab