From 05841bb38338dc3414bbe1de1c23d7c7a31d268e Mon Sep 17 00:00:00 2001
From: Pierre LALET <pierre.lalet@cea.fr>
Date: Tue, 17 Jan 2017 19:33:35 +0100
Subject: [PATCH] FlagValue objects are now mutable.

It is now possible to write `pkt[TCP].flags.S = True`.
---
 scapy/base_classes.py |  5 +++--
 scapy/fields.py       | 47 ++++++++++++++++++++++++++++++++++++-------
 test/regression.uts   | 17 ++++++++++++++++
 3 files changed, 60 insertions(+), 9 deletions(-)

diff --git a/scapy/base_classes.py b/scapy/base_classes.py
index 9e4ee778..62997889 100644
--- a/scapy/base_classes.py
+++ b/scapy/base_classes.py
@@ -25,10 +25,11 @@ class SetGen(Gen):
         if isinstance(values, (list, BasePacketList)):
             self.values = list(values)
         elif (isinstance(values, tuple) and (2 <= len(values) <= 3) and \
-             all(isinstance(i, int) for i in values)):
+             all(hasattr(i, "__int__") for i in values)):
             # We use values[1] + 1 as stop value for xrange to maintain
             # the behavior of using tuples as field `values`
-            self.values = [xrange(*((values[0], values[1] + 1) + values[2:]))]
+            self.values = [xrange(*((int(values[0]), int(values[1]) + 1)
+                                    + tuple(int(v) for v in values[2:])))]
         else:
             self.values = [values]
     def transf(self, element):
diff --git a/scapy/fields.py b/scapy/fields.py
index ca0ebb59..cecd8761 100644
--- a/scapy/fields.py
+++ b/scapy/fields.py
@@ -934,8 +934,8 @@ class LEFieldLenField(FieldLenField):
         FieldLenField.__init__(self, name, default, length_of=length_of, fmt=fmt, count_of=count_of, fld=fld, adjust=adjust)
 
 
-class FlagValue(int):
-    __slots__ = ["names", "multi"]
+class FlagValue(object):
+    __slots__ = ["value", "names", "multi"]
     @staticmethod
     def __fixvalue(value, names):
         if isinstance(value, basestring):
@@ -949,12 +949,32 @@ class FlagValue(int):
                 y |= 1 << names.index(i)
             value = y
         return value
-    def __new__(cls, value, names):
-        return super(FlagValue, cls).__new__(cls, cls.__fixvalue(value, names))
     def __init__(self, value, names):
-        super(FlagValue, self).__init__(value)
+        self.value = (value.value if isinstance(value, self.__class__)
+                      else self.__fixvalue(value, names))
         self.multi = isinstance(names, list)
         self.names = names
+    def __int__(self):
+        return self.value
+    def __cmp__(self, other):
+        if isinstance(other, self.__class__):
+            return cmp(self.value, other.value)
+        return cmp(self.value, other)
+    def __and__(self, other):
+        if isinstance(other, self.__class__):
+            return self.value & other.value
+        return self.value & other
+    __rand__ = __and__
+    def __or__(self, other):
+        if isinstance(other, self.__class__):
+            return self.value | other.value
+        return self.value | other
+    __ror__ = __or__
+    def __add__(self, other):
+        if isinstance(other, self.__class__):
+            return self.value + other.value
+        return self.value + other
+    __radd__ = __add__
     def flagrepr(self):
         i = 0
         r = []
@@ -966,15 +986,28 @@ class FlagValue(int):
             x >>= 1
         return ("+" if self.multi else "").join(r)
     def __repr__(self):
-        return "<Flag %r (%s)>" % (int(self),
-                                   self.flagrepr())
+        return "<Flag %d (%s)>" % (self, self.flagrepr())
     def __deepcopy__(self, memo):
         return self.__class__(int(self), self.names)
     def __getattr__(self, attr):
+        if attr in self.__slots__:
+            return super(FlagValue, self).__getattr__(attr)
         try:
             return bool((2 ** self.names.index(attr)) & int(self))
         except ValueError:
             return super(FlagValue, self).__getattr__(attr)
+    def __setattr__(self, attr, value):
+        if attr == "value" and not isinstance(value, (int, long)):
+            raise ValueError(value)
+        if attr in self.__slots__:
+            return super(FlagValue, self).__setattr__(attr, value)
+        if attr in self.names:
+            if value:
+                self.value |= (2 ** self.names.index(attr))
+            else:
+                self.value &= ~(2 ** self.names.index(attr))
+        else:
+            return super(FlagValue, self).__setattr__(attr, value)
 
 
 class FlagsField(BitField):
diff --git a/test/regression.uts b/test/regression.uts
index e08eabc3..393df949 100644
--- a/test/regression.uts
+++ b/test/regression.uts
@@ -7497,11 +7497,20 @@ pkt = IP(flags="MF")
 assert pkt.flags.MF
 assert not pkt.flags.DF
 assert not pkt.flags.evil
+pkt.flags.MF = 0
+pkt.flags.DF = 1
+assert not pkt.flags.MF
+assert pkt.flags.DF
+assert not pkt.flags.evil
 
 pkt = IP(flags=3)
 assert pkt.flags.MF
 assert pkt.flags.DF
 assert not pkt.flags.evil
+pkt.flags = 6
+assert not pkt.flags.MF
+assert pkt.flags.DF
+assert pkt.flags.evil
 
 = TCP flags
 ~ TCP
@@ -7511,7 +7520,15 @@ assert pkt.flags == 18
 assert pkt.flags.S
 assert pkt.flags.A
 assert not any(getattr(pkt.flags, f) for f in 'FRPUECN')
+pkt.flags.U = True
+pkt.flags.S = False
+assert pkt.flags.A
+assert pkt.flags.U
+assert not any(getattr(pkt.flags, f) for f in 'FSRPECN')
 
 pkt = TCP(flags=56)
 assert all(getattr(pkt.flags, f) for f in 'PAU')
 assert not any(getattr(pkt.flags, f) for f in 'FSRECN')
+pkt.flags = 50
+assert all(getattr(pkt.flags, f) for f in 'SAU')
+assert not any(getattr(pkt.flags, f) for f in 'FRPECN')
-- 
GitLab