Skip to content
Snippets Groups Projects
Commit 65aded24 authored by Guillaume Valadon's avatar Guillaume Valadon Committed by GitHub
Browse files

Merge pull request #471 from p-l-/enh-flags

Introduce FlagValue() objects to represent FlagsField() values
parents 0c9b9089 17ec20bd
No related branches found
No related tags found
No related merge requests found
...@@ -24,11 +24,12 @@ class SetGen(Gen): ...@@ -24,11 +24,12 @@ class SetGen(Gen):
self._iterpacket=_iterpacket self._iterpacket=_iterpacket
if isinstance(values, (list, BasePacketList)): if isinstance(values, (list, BasePacketList)):
self.values = list(values) self.values = list(values)
elif (type(values) is tuple) and (2 <= len(values) <= 3) and \ elif (isinstance(values, tuple) and (2 <= len(values) <= 3) and \
all(type(i) is int for i in values): all(hasattr(i, "__int__") for i in values)):
# We use values[1] + 1 as stop value for xrange to maintain # We use values[1] + 1 as stop value for xrange to maintain
# the behavior of using tuples as field `values` # the behavior of using tuples as field `values`
self.values = [xrange(*((values[0], values[1] + 1) + values[2:]))] self.values = [xrange(*((int(values[0]), int(values[1]) + 1)
+ tuple(int(v) for v in values[2:])))]
else: else:
self.values = [values] self.values = [values]
def transf(self, element): def transf(self, element):
......
...@@ -26,6 +26,7 @@ class Field(object): ...@@ -26,6 +26,7 @@ class Field(object):
__slots__ = ["name", "fmt", "default", "sz", "owners"] __slots__ = ["name", "fmt", "default", "sz", "owners"]
__metaclass__ = Field_metaclass __metaclass__ = Field_metaclass
islist = 0 islist = 0
ismutable = False
holds_packets = 0 holds_packets = 0
def __init__(self, name, default, fmt="H"): def __init__(self, name, default, fmt="H"):
self.name = name self.name = name
...@@ -934,6 +935,84 @@ class LEFieldLenField(FieldLenField): ...@@ -934,6 +935,84 @@ class LEFieldLenField(FieldLenField):
FieldLenField.__init__(self, name, default, length_of=length_of, fmt=fmt, count_of=count_of, fld=fld, adjust=adjust) FieldLenField.__init__(self, name, default, length_of=length_of, fmt=fmt, count_of=count_of, fld=fld, adjust=adjust)
class FlagValue(object):
__slots__ = ["value", "names", "multi"]
@staticmethod
def __fixvalue(value, names):
if isinstance(value, basestring):
if isinstance(names, list):
value = value.split('+')
else:
value = list(value)
if isinstance(value, list):
y = 0
for i in value:
y |= 1 << names.index(i)
value = y
return value
def __init__(self, value, names):
self.value = (value.value if isinstance(value, self.__class__)
else self.__fixvalue(value, names))
self.multi = isinstance(names, list)
self.names = names
def __int__(self):
return self.value
def __cmp__(self, other):
if isinstance(other, self.__class__):
return cmp(self.value, other.value)
return cmp(self.value, other)
def __and__(self, other):
return self.__class__(self.value & int(other), self.names)
__rand__ = __and__
def __or__(self, other):
return self.__class__(self.value | int(other), self.names)
__ror__ = __or__
def __lshift__(self, other):
return self.value << int(other)
def __rshift__(self, other):
return self.value >> int(other)
def __nonzero__(self):
return bool(self.value)
def flagrepr(self):
i = 0
r = []
x = int(self)
while x:
if x & 1:
r.append(self.names[i])
i += 1
x >>= 1
return ("+" if self.multi else "").join(r)
def __repr__(self):
return "<Flag %d (%s)>" % (self, self.flagrepr())
def __deepcopy__(self, memo):
return self.__class__(int(self), self.names)
def __getattr__(self, attr):
if attr in self.__slots__:
return super(FlagValue, self).__getattr__(attr)
try:
if self.multi:
return bool((2 ** self.names.index(attr)) & int(self))
return all(bool((2 ** self.names.index(flag)) & int(self))
for flag in attr)
except ValueError:
return super(FlagValue, self).__getattr__(attr)
def __setattr__(self, attr, value):
if attr == "value" and not isinstance(value, (int, long)):
raise ValueError(value)
if attr in self.__slots__:
return super(FlagValue, self).__setattr__(attr, value)
if attr in self.names:
if value:
self.value |= (2 ** self.names.index(attr))
else:
self.value &= ~(2 ** self.names.index(attr))
else:
return super(FlagValue, self).__setattr__(attr, value)
def copy(self):
return self.__class__(self.value, self.names)
class FlagsField(BitField): class FlagsField(BitField):
""" Handle Flag type field """ Handle Flag type field
...@@ -955,39 +1034,28 @@ class FlagsField(BitField): ...@@ -955,39 +1034,28 @@ class FlagsField(BitField):
:param size: number of bits in the field :param size: number of bits in the field
:param names: (list or dict) label for each flag, Least Significant Bit tag's name is written first :param names: (list or dict) label for each flag, Least Significant Bit tag's name is written first
""" """
ismutable = True
__slots__ = ["multi", "names"] __slots__ = ["multi", "names"]
def __init__(self, name, default, size, names): def __init__(self, name, default, size, names):
self.multi = type(names) is list self.multi = isinstance(names, list)
if self.multi: self.names = names
self.names = map(lambda x:[x], names)
else:
self.names = names
BitField.__init__(self, name, default, size) BitField.__init__(self, name, default, size)
def any2i(self, pkt, x): def any2i(self, pkt, x):
if type(x) is str: if isinstance(x, (list, tuple)):
if self.multi: return type(x)(None if v is None else FlagValue(v, self.names)
x = map(lambda y:[y], x.split("+")) for v in x)
y = 0 return None if x is None else FlagValue(x, self.names)
for i in x: def m2i(self, pkt, x):
y |= 1 << self.names.index(i) if isinstance(x, (list, tuple)):
x = y return type(x)(None if v is None else FlagValue(v, self.names)
return x for v in x)
return None if x is None else FlagValue(x, self.names)
def i2repr(self, pkt, x): def i2repr(self, pkt, x):
if type(x) is list or type(x) is tuple: if isinstance(x, (list, tuple)):
return repr(x) return repr(type(x)(
if self.multi: None if v is None else FlagValue(v, self.names).flagrepr()
r = [] for v in x))
else: return None if x is None else FlagValue(x, self.names).flagrepr()
r = ""
i=0
while x:
if x & 1:
r += self.names[i]
i += 1
x >>= 1
if self.multi:
r = "+".join(r)
return r
MultiFlagsEntry = collections.namedtuple('MultiFlagEntry', ['short', 'long']) MultiFlagsEntry = collections.namedtuple('MultiFlagEntry', ['short', 'long'])
......
...@@ -589,7 +589,7 @@ Creates an EPS file describing a packet. If filename is not provided a temporary ...@@ -589,7 +589,7 @@ Creates an EPS file describing a packet. If filename is not provided a temporary
s, fval = f.getfield(self, s) s, fval = f.getfield(self, s)
# We need to track fields with mutable values to discard # We need to track fields with mutable values to discard
# .raw_packet_cache when needed. # .raw_packet_cache when needed.
if f.islist or f.holds_packets: if f.islist or f.holds_packets or f.ismutable:
self.raw_packet_cache_fields[f.name] = f.do_copy(fval) self.raw_packet_cache_fields[f.name] = f.do_copy(fval)
self.fields[f.name] = fval self.fields[f.name] = fval
assert(raw.endswith(s)) assert(raw.endswith(s))
......
...@@ -7143,7 +7143,9 @@ assert(re.match(r'^.*Star \(\*\).*$', x) is not None) ...@@ -7143,7 +7143,9 @@ assert(re.match(r'^.*Star \(\*\).*$', x) is not None)
assert(re.match(r'^.*Plus \(\+\).*$', x) is not None) assert(re.match(r'^.*Plus \(\+\).*$', x) is not None)
assert(re.match(r'^.*bit 2.*$', x) is not None) assert(re.match(r'^.*bit 2.*$', x) is not None)
###########################################################################################################
############
############
+ Test correct conversion from binary to string of IPv6 addresses + Test correct conversion from binary to string of IPv6 addresses
= IPv6 bin to string conversion - all zero bytes = IPv6 bin to string conversion - all zero bytes
...@@ -7222,7 +7224,6 @@ assert(compressed1 == compressed2 == '1000:200:30:4:5:60:700:8000') ...@@ -7222,7 +7224,6 @@ assert(compressed1 == compressed2 == '1000:200:30:4:5:60:700:8000')
############ ############
############ ############
+ VRRP tests + VRRP tests
= VRRP - build = VRRP - build
...@@ -7236,7 +7237,6 @@ VRRP in p and p[VRRP].chksum == 0x7afd ...@@ -7236,7 +7237,6 @@ VRRP in p and p[VRRP].chksum == 0x7afd
############ ############
############ ############
+ L2TP tests + L2TP tests
= L2TP - build = L2TP - build
...@@ -7250,7 +7250,6 @@ L2TP in p and p[L2TP].len == 14 and p.tunnel_id == 0 and p[UDP].chksum == 0xf465 ...@@ -7250,7 +7250,6 @@ L2TP in p and p[L2TP].len == 14 and p.tunnel_id == 0 and p[UDP].chksum == 0xf465
############ ############
############ ############
+ HSRP tests + HSRP tests
= HSRP - build & dissection = HSRP - build & dissection
...@@ -7263,7 +7262,6 @@ assert pkt[HSRPmd5].type == 4 and pkt[HSRPmd5].sourceip == defaddr ...@@ -7263,7 +7262,6 @@ assert pkt[HSRPmd5].type == 4 and pkt[HSRPmd5].sourceip == defaddr
############ ############
############ ############
+ RIP tests + RIP tests
= RIP - build = RIP - build
...@@ -7277,7 +7275,6 @@ RIPEntry in p and RIPAuth in p and p[RIPAuth].password.startswith("scapy") ...@@ -7277,7 +7275,6 @@ RIPEntry in p and RIPAuth in p and p[RIPAuth].password.startswith("scapy")
############ ############
############ ############
+ Radius tests + Radius tests
= Radius - build = Radius - build
...@@ -7487,3 +7484,115 @@ rek == 'b' ...@@ -7487,3 +7484,115 @@ rek == 'b'
random.seed(0x2807) random.seed(0x2807)
rts = RandTermString(4, "scapy") rts = RandTermString(4, "scapy")
sane(str(rts)) == "...[scapy" sane(str(rts)) == "...[scapy"
############
############
+ Flags
= IP flags
~ IP
pkt = IP(flags="MF")
assert pkt.flags.MF
assert not pkt.flags.DF
assert not pkt.flags.evil
assert repr(pkt.flags) == '<Flag 1 (MF)>'
pkt.flags.MF = 0
pkt.flags.DF = 1
assert not pkt.flags.MF
assert pkt.flags.DF
assert not pkt.flags.evil
assert repr(pkt.flags) == '<Flag 2 (DF)>'
pkt = IP(flags=3)
assert pkt.flags.MF
assert pkt.flags.DF
assert not pkt.flags.evil
assert repr(pkt.flags) == '<Flag 3 (MF+DF)>'
pkt.flags = 6
assert not pkt.flags.MF
assert pkt.flags.DF
assert pkt.flags.evil
assert repr(pkt.flags) == '<Flag 6 (DF+evil)>'
= TCP flags
~ TCP
pkt = TCP(flags="SA")
assert pkt.flags == 18
assert pkt.flags.S
assert pkt.flags.A
assert pkt.flags.SA
assert not any(getattr(pkt.flags, f) for f in 'FRPUECN')
assert repr(pkt.flags) == '<Flag 18 (SA)>'
pkt.flags.U = True
pkt.flags.S = False
assert pkt.flags.A
assert pkt.flags.U
assert pkt.flags.AU
assert not any(getattr(pkt.flags, f) for f in 'FSRPECN')
assert repr(pkt.flags) == '<Flag 48 (AU)>'
pkt = TCP(flags=56)
assert all(getattr(pkt.flags, f) for f in 'PAU')
assert pkt.flags.PAU
assert not any(getattr(pkt.flags, f) for f in 'FSRECN')
assert repr(pkt.flags) == '<Flag 56 (PAU)>'
pkt.flags = 50
assert all(getattr(pkt.flags, f) for f in 'SAU')
assert pkt.flags.SAU
assert not any(getattr(pkt.flags, f) for f in 'FRPECN')
assert repr(pkt.flags) == '<Flag 50 (SAU)>'
= Flag values mutation with .raw_packet_cache
~ IP TCP
pkt = IP(str(IP(flags="MF")/TCP(flags="SA")))
assert pkt.raw_packet_cache is not None
assert pkt[TCP].raw_packet_cache is not None
assert pkt.flags.MF
assert not pkt.flags.DF
assert not pkt.flags.evil
assert repr(pkt.flags) == '<Flag 1 (MF)>'
assert pkt[TCP].flags.S
assert pkt[TCP].flags.A
assert pkt[TCP].flags.SA
assert not any(getattr(pkt[TCP].flags, f) for f in 'FRPUECN')
assert repr(pkt[TCP].flags) == '<Flag 18 (SA)>'
pkt.flags.MF = 0
pkt.flags.DF = 1
pkt[TCP].flags.U = True
pkt[TCP].flags.S = False
pkt = IP(str(pkt))
assert not pkt.flags.MF
assert pkt.flags.DF
assert not pkt.flags.evil
assert repr(pkt.flags) == '<Flag 2 (DF)>'
assert pkt[TCP].flags.A
assert pkt[TCP].flags.U
assert pkt[TCP].flags.AU
assert not any(getattr(pkt[TCP].flags, f) for f in 'FSRPECN')
assert repr(pkt[TCP].flags) == '<Flag 48 (AU)>'
= Operations on flag values
~ TCP
p1, p2 = TCP(flags="SU"), TCP(flags="AU")
assert (p1.flags & p2.flags).U
assert not any(getattr(p1.flags & p2.flags, f) for f in 'FSRPAECN')
assert all(getattr(p1.flags | p2.flags, f) for f in 'SAU')
assert (p1.flags | p2.flags).SAU
assert not any(getattr(p1.flags | p2.flags, f) for f in 'FRPECN')
assert TCP(flags="SA").flags & TCP(flags="S").flags == TCP(flags="S").flags
assert TCP(flags="SA").flags | TCP(flags="S").flags == TCP(flags="SA").flags
= Using tuples and lists as flag values
~ IP TCP
plist = PacketList(list(IP()/TCP(flags=(0, 2**9 - 1))))
assert [p[TCP].flags for p in plist] == range(512)
plist = PacketList(list(IP()/TCP(flags=["S", "SA", "A"])))
assert [p[TCP].flags for p in plist] == [2, 18, 16]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment