From 17ec20bdcaac18ed21801c172c75e8026bf38e1b Mon Sep 17 00:00:00 2001
From: Pierre LALET <pierre.lalet@cea.fr>
Date: Tue, 24 Jan 2017 08:11:08 +0100
Subject: [PATCH] FlagValue: support multiple flags check at once

This PR makes it possible to check if several flags are all set, by
writting, for example, `if pkt[TCP].flags.SA:`.

It only works for single letter flags (e.g., will work for TCP().flags
but not for IP().flags), and only when getting a value.

Thanks @guedou for suggesting that.
---
 scapy/fields.py     | 5 ++++-
 test/regression.uts | 7 +++++++
 2 files changed, 11 insertions(+), 1 deletion(-)

diff --git a/scapy/fields.py b/scapy/fields.py
index 581c28c3..0a92c3f5 100644
--- a/scapy/fields.py
+++ b/scapy/fields.py
@@ -991,7 +991,10 @@ class FlagValue(object):
         if attr in self.__slots__:
             return super(FlagValue, self).__getattr__(attr)
         try:
-            return bool((2 ** self.names.index(attr)) & int(self))
+            if self.multi:
+                return bool((2 ** self.names.index(attr)) & int(self))
+            return all(bool((2 ** self.names.index(flag)) & int(self))
+                       for flag in attr)
         except ValueError:
             return super(FlagValue, self).__getattr__(attr)
     def __setattr__(self, attr, value):
diff --git a/test/regression.uts b/test/regression.uts
index 200b389f..6e0ba09c 100644
--- a/test/regression.uts
+++ b/test/regression.uts
@@ -7523,21 +7523,25 @@ pkt = TCP(flags="SA")
 assert pkt.flags == 18
 assert pkt.flags.S
 assert pkt.flags.A
+assert pkt.flags.SA
 assert not any(getattr(pkt.flags, f) for f in 'FRPUECN')
 assert repr(pkt.flags) == '<Flag 18 (SA)>'
 pkt.flags.U = True
 pkt.flags.S = False
 assert pkt.flags.A
 assert pkt.flags.U
+assert pkt.flags.AU
 assert not any(getattr(pkt.flags, f) for f in 'FSRPECN')
 assert repr(pkt.flags) == '<Flag 48 (AU)>'
 
 pkt = TCP(flags=56)
 assert all(getattr(pkt.flags, f) for f in 'PAU')
+assert pkt.flags.PAU
 assert not any(getattr(pkt.flags, f) for f in 'FSRECN')
 assert repr(pkt.flags) == '<Flag 56 (PAU)>'
 pkt.flags = 50
 assert all(getattr(pkt.flags, f) for f in 'SAU')
+assert pkt.flags.SAU
 assert not any(getattr(pkt.flags, f) for f in 'FRPECN')
 assert repr(pkt.flags) == '<Flag 50 (SAU)>'
 
@@ -7553,6 +7557,7 @@ assert not pkt.flags.evil
 assert repr(pkt.flags) == '<Flag 1 (MF)>'
 assert pkt[TCP].flags.S
 assert pkt[TCP].flags.A
+assert pkt[TCP].flags.SA
 assert not any(getattr(pkt[TCP].flags, f) for f in 'FRPUECN')
 assert repr(pkt[TCP].flags) == '<Flag 18 (SA)>'
 pkt.flags.MF = 0
@@ -7566,6 +7571,7 @@ assert not pkt.flags.evil
 assert repr(pkt.flags) == '<Flag 2 (DF)>'
 assert pkt[TCP].flags.A
 assert pkt[TCP].flags.U
+assert pkt[TCP].flags.AU
 assert not any(getattr(pkt[TCP].flags, f) for f in 'FSRPECN')
 assert repr(pkt[TCP].flags) == '<Flag 48 (AU)>'
 
@@ -7576,6 +7582,7 @@ p1, p2 = TCP(flags="SU"), TCP(flags="AU")
 assert (p1.flags & p2.flags).U
 assert not any(getattr(p1.flags & p2.flags, f) for f in 'FSRPAECN')
 assert all(getattr(p1.flags | p2.flags, f) for f in 'SAU')
+assert (p1.flags | p2.flags).SAU
 assert not any(getattr(p1.flags | p2.flags, f) for f in 'FRPECN')
 
 assert TCP(flags="SA").flags & TCP(flags="S").flags == TCP(flags="S").flags
-- 
GitLab