From f0bfaf7501a1c235109f38e30fe9af666d58a855 Mon Sep 17 00:00:00 2001 From: Pierre LALET <pierre.lalet@cea.fr> Date: Sat, 21 Jan 2017 15:03:56 +0100 Subject: [PATCH] Cleanup & test logical operations on flags --- scapy/fields.py | 19 ++++++++----------- test/regression.uts | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 11 deletions(-) diff --git a/scapy/fields.py b/scapy/fields.py index d4e5f4ca..581c28c3 100644 --- a/scapy/fields.py +++ b/scapy/fields.py @@ -962,20 +962,17 @@ class FlagValue(object): return cmp(self.value, other.value) return cmp(self.value, other) def __and__(self, other): - if isinstance(other, self.__class__): - return self.value & other.value - return self.value & other + return self.__class__(self.value & int(other), self.names) __rand__ = __and__ def __or__(self, other): - if isinstance(other, self.__class__): - return self.value | other.value - return self.value | other + return self.__class__(self.value | int(other), self.names) __ror__ = __or__ - def __add__(self, other): - if isinstance(other, self.__class__): - return self.value + other.value - return self.value + other - __radd__ = __add__ + def __lshift__(self, other): + return self.value << int(other) + def __rshift__(self, other): + return self.value >> int(other) + def __nonzero__(self): + return bool(self.value) def flagrepr(self): i = 0 r = [] diff --git a/test/regression.uts b/test/regression.uts index b87cfb8a..200b389f 100644 --- a/test/regression.uts +++ b/test/regression.uts @@ -7497,20 +7497,24 @@ pkt = IP(flags="MF") assert pkt.flags.MF assert not pkt.flags.DF assert not pkt.flags.evil +assert repr(pkt.flags) == '<Flag 1 (MF)>' pkt.flags.MF = 0 pkt.flags.DF = 1 assert not pkt.flags.MF assert pkt.flags.DF assert not pkt.flags.evil +assert repr(pkt.flags) == '<Flag 2 (DF)>' pkt = IP(flags=3) assert pkt.flags.MF assert pkt.flags.DF assert not pkt.flags.evil +assert repr(pkt.flags) == '<Flag 3 (MF+DF)>' pkt.flags = 6 assert not pkt.flags.MF assert pkt.flags.DF assert pkt.flags.evil +assert repr(pkt.flags) == '<Flag 6 (DF+evil)>' = TCP flags ~ TCP @@ -7520,18 +7524,22 @@ assert pkt.flags == 18 assert pkt.flags.S assert pkt.flags.A assert not any(getattr(pkt.flags, f) for f in 'FRPUECN') +assert repr(pkt.flags) == '<Flag 18 (SA)>' pkt.flags.U = True pkt.flags.S = False assert pkt.flags.A assert pkt.flags.U assert not any(getattr(pkt.flags, f) for f in 'FSRPECN') +assert repr(pkt.flags) == '<Flag 48 (AU)>' pkt = TCP(flags=56) assert all(getattr(pkt.flags, f) for f in 'PAU') assert not any(getattr(pkt.flags, f) for f in 'FSRECN') +assert repr(pkt.flags) == '<Flag 56 (PAU)>' pkt.flags = 50 assert all(getattr(pkt.flags, f) for f in 'SAU') assert not any(getattr(pkt.flags, f) for f in 'FRPECN') +assert repr(pkt.flags) == '<Flag 50 (SAU)>' = Flag values mutation with .raw_packet_cache ~ IP TCP @@ -7542,9 +7550,11 @@ assert pkt[TCP].raw_packet_cache is not None assert pkt.flags.MF assert not pkt.flags.DF assert not pkt.flags.evil +assert repr(pkt.flags) == '<Flag 1 (MF)>' assert pkt[TCP].flags.S assert pkt[TCP].flags.A assert not any(getattr(pkt[TCP].flags, f) for f in 'FRPUECN') +assert repr(pkt[TCP].flags) == '<Flag 18 (SA)>' pkt.flags.MF = 0 pkt.flags.DF = 1 pkt[TCP].flags.U = True @@ -7553,6 +7563,29 @@ pkt = IP(str(pkt)) assert not pkt.flags.MF assert pkt.flags.DF assert not pkt.flags.evil +assert repr(pkt.flags) == '<Flag 2 (DF)>' assert pkt[TCP].flags.A assert pkt[TCP].flags.U assert not any(getattr(pkt[TCP].flags, f) for f in 'FSRPECN') +assert repr(pkt[TCP].flags) == '<Flag 48 (AU)>' + += Operations on flag values +~ TCP + +p1, p2 = TCP(flags="SU"), TCP(flags="AU") +assert (p1.flags & p2.flags).U +assert not any(getattr(p1.flags & p2.flags, f) for f in 'FSRPAECN') +assert all(getattr(p1.flags | p2.flags, f) for f in 'SAU') +assert not any(getattr(p1.flags | p2.flags, f) for f in 'FRPECN') + +assert TCP(flags="SA").flags & TCP(flags="S").flags == TCP(flags="S").flags +assert TCP(flags="SA").flags | TCP(flags="S").flags == TCP(flags="SA").flags + += Using tuples and lists as flag values +~ IP TCP + +plist = PacketList(list(IP()/TCP(flags=(0, 2**9 - 1)))) +assert [p[TCP].flags for p in plist] == range(512) + +plist = PacketList(list(IP()/TCP(flags=["S", "SA", "A"]))) +assert [p[TCP].flags for p in plist] == [2, 18, 16] -- GitLab