From f0bfaf7501a1c235109f38e30fe9af666d58a855 Mon Sep 17 00:00:00 2001
From: Pierre LALET <pierre.lalet@cea.fr>
Date: Sat, 21 Jan 2017 15:03:56 +0100
Subject: [PATCH] Cleanup & test logical operations on flags

---
 scapy/fields.py     | 19 ++++++++-----------
 test/regression.uts | 33 +++++++++++++++++++++++++++++++++
 2 files changed, 41 insertions(+), 11 deletions(-)

diff --git a/scapy/fields.py b/scapy/fields.py
index d4e5f4ca..581c28c3 100644
--- a/scapy/fields.py
+++ b/scapy/fields.py
@@ -962,20 +962,17 @@ class FlagValue(object):
             return cmp(self.value, other.value)
         return cmp(self.value, other)
     def __and__(self, other):
-        if isinstance(other, self.__class__):
-            return self.value & other.value
-        return self.value & other
+        return self.__class__(self.value & int(other), self.names)
     __rand__ = __and__
     def __or__(self, other):
-        if isinstance(other, self.__class__):
-            return self.value | other.value
-        return self.value | other
+        return self.__class__(self.value | int(other), self.names)
     __ror__ = __or__
-    def __add__(self, other):
-        if isinstance(other, self.__class__):
-            return self.value + other.value
-        return self.value + other
-    __radd__ = __add__
+    def __lshift__(self, other):
+        return self.value << int(other)
+    def __rshift__(self, other):
+        return self.value >> int(other)
+    def __nonzero__(self):
+        return bool(self.value)
     def flagrepr(self):
         i = 0
         r = []
diff --git a/test/regression.uts b/test/regression.uts
index b87cfb8a..200b389f 100644
--- a/test/regression.uts
+++ b/test/regression.uts
@@ -7497,20 +7497,24 @@ pkt = IP(flags="MF")
 assert pkt.flags.MF
 assert not pkt.flags.DF
 assert not pkt.flags.evil
+assert repr(pkt.flags) == '<Flag 1 (MF)>'
 pkt.flags.MF = 0
 pkt.flags.DF = 1
 assert not pkt.flags.MF
 assert pkt.flags.DF
 assert not pkt.flags.evil
+assert repr(pkt.flags) == '<Flag 2 (DF)>'
 
 pkt = IP(flags=3)
 assert pkt.flags.MF
 assert pkt.flags.DF
 assert not pkt.flags.evil
+assert repr(pkt.flags) == '<Flag 3 (MF+DF)>'
 pkt.flags = 6
 assert not pkt.flags.MF
 assert pkt.flags.DF
 assert pkt.flags.evil
+assert repr(pkt.flags) == '<Flag 6 (DF+evil)>'
 
 = TCP flags
 ~ TCP
@@ -7520,18 +7524,22 @@ assert pkt.flags == 18
 assert pkt.flags.S
 assert pkt.flags.A
 assert not any(getattr(pkt.flags, f) for f in 'FRPUECN')
+assert repr(pkt.flags) == '<Flag 18 (SA)>'
 pkt.flags.U = True
 pkt.flags.S = False
 assert pkt.flags.A
 assert pkt.flags.U
 assert not any(getattr(pkt.flags, f) for f in 'FSRPECN')
+assert repr(pkt.flags) == '<Flag 48 (AU)>'
 
 pkt = TCP(flags=56)
 assert all(getattr(pkt.flags, f) for f in 'PAU')
 assert not any(getattr(pkt.flags, f) for f in 'FSRECN')
+assert repr(pkt.flags) == '<Flag 56 (PAU)>'
 pkt.flags = 50
 assert all(getattr(pkt.flags, f) for f in 'SAU')
 assert not any(getattr(pkt.flags, f) for f in 'FRPECN')
+assert repr(pkt.flags) == '<Flag 50 (SAU)>'
 
 = Flag values mutation with .raw_packet_cache
 ~ IP TCP
@@ -7542,9 +7550,11 @@ assert pkt[TCP].raw_packet_cache is not None
 assert pkt.flags.MF
 assert not pkt.flags.DF
 assert not pkt.flags.evil
+assert repr(pkt.flags) == '<Flag 1 (MF)>'
 assert pkt[TCP].flags.S
 assert pkt[TCP].flags.A
 assert not any(getattr(pkt[TCP].flags, f) for f in 'FRPUECN')
+assert repr(pkt[TCP].flags) == '<Flag 18 (SA)>'
 pkt.flags.MF = 0
 pkt.flags.DF = 1
 pkt[TCP].flags.U = True
@@ -7553,6 +7563,29 @@ pkt = IP(str(pkt))
 assert not pkt.flags.MF
 assert pkt.flags.DF
 assert not pkt.flags.evil
+assert repr(pkt.flags) == '<Flag 2 (DF)>'
 assert pkt[TCP].flags.A
 assert pkt[TCP].flags.U
 assert not any(getattr(pkt[TCP].flags, f) for f in 'FSRPECN')
+assert repr(pkt[TCP].flags) == '<Flag 48 (AU)>'
+
+= Operations on flag values
+~ TCP
+
+p1, p2 = TCP(flags="SU"), TCP(flags="AU")
+assert (p1.flags & p2.flags).U
+assert not any(getattr(p1.flags & p2.flags, f) for f in 'FSRPAECN')
+assert all(getattr(p1.flags | p2.flags, f) for f in 'SAU')
+assert not any(getattr(p1.flags | p2.flags, f) for f in 'FRPECN')
+
+assert TCP(flags="SA").flags & TCP(flags="S").flags == TCP(flags="S").flags
+assert TCP(flags="SA").flags | TCP(flags="S").flags == TCP(flags="SA").flags
+
+= Using tuples and lists as flag values
+~ IP TCP
+
+plist = PacketList(list(IP()/TCP(flags=(0, 2**9 - 1))))
+assert [p[TCP].flags for p in plist] == range(512)
+
+plist = PacketList(list(IP()/TCP(flags=["S", "SA", "A"])))
+assert [p[TCP].flags for p in plist] == [2, 18, 16]
-- 
GitLab