diff --git a/scapy/base_classes.py b/scapy/base_classes.py index 9e4ee778b899a60d151aae6cc4ad9e1623230f8e..62997889affff2fcd4358032b5c9e583445cc1e2 100644 --- a/scapy/base_classes.py +++ b/scapy/base_classes.py @@ -25,10 +25,11 @@ class SetGen(Gen): if isinstance(values, (list, BasePacketList)): self.values = list(values) elif (isinstance(values, tuple) and (2 <= len(values) <= 3) and \ - all(isinstance(i, int) for i in values)): + all(hasattr(i, "__int__") for i in values)): # We use values[1] + 1 as stop value for xrange to maintain # the behavior of using tuples as field `values` - self.values = [xrange(*((values[0], values[1] + 1) + values[2:]))] + self.values = [xrange(*((int(values[0]), int(values[1]) + 1) + + tuple(int(v) for v in values[2:])))] else: self.values = [values] def transf(self, element): diff --git a/scapy/fields.py b/scapy/fields.py index ca0ebb5947468be432507fd6e5563458ed1968cb..cecd8761ae56e84fd1bb59c23f96386794969cc8 100644 --- a/scapy/fields.py +++ b/scapy/fields.py @@ -934,8 +934,8 @@ class LEFieldLenField(FieldLenField): FieldLenField.__init__(self, name, default, length_of=length_of, fmt=fmt, count_of=count_of, fld=fld, adjust=adjust) -class FlagValue(int): - __slots__ = ["names", "multi"] +class FlagValue(object): + __slots__ = ["value", "names", "multi"] @staticmethod def __fixvalue(value, names): if isinstance(value, basestring): @@ -949,12 +949,32 @@ class FlagValue(int): y |= 1 << names.index(i) value = y return value - def __new__(cls, value, names): - return super(FlagValue, cls).__new__(cls, cls.__fixvalue(value, names)) def __init__(self, value, names): - super(FlagValue, self).__init__(value) + self.value = (value.value if isinstance(value, self.__class__) + else self.__fixvalue(value, names)) self.multi = isinstance(names, list) self.names = names + def __int__(self): + return self.value + def __cmp__(self, other): + if isinstance(other, self.__class__): + return cmp(self.value, other.value) + return cmp(self.value, other) + def __and__(self, other): + if isinstance(other, self.__class__): + return self.value & other.value + return self.value & other + __rand__ = __and__ + def __or__(self, other): + if isinstance(other, self.__class__): + return self.value | other.value + return self.value | other + __ror__ = __or__ + def __add__(self, other): + if isinstance(other, self.__class__): + return self.value + other.value + return self.value + other + __radd__ = __add__ def flagrepr(self): i = 0 r = [] @@ -966,15 +986,28 @@ class FlagValue(int): x >>= 1 return ("+" if self.multi else "").join(r) def __repr__(self): - return "<Flag %r (%s)>" % (int(self), - self.flagrepr()) + return "<Flag %d (%s)>" % (self, self.flagrepr()) def __deepcopy__(self, memo): return self.__class__(int(self), self.names) def __getattr__(self, attr): + if attr in self.__slots__: + return super(FlagValue, self).__getattr__(attr) try: return bool((2 ** self.names.index(attr)) & int(self)) except ValueError: return super(FlagValue, self).__getattr__(attr) + def __setattr__(self, attr, value): + if attr == "value" and not isinstance(value, (int, long)): + raise ValueError(value) + if attr in self.__slots__: + return super(FlagValue, self).__setattr__(attr, value) + if attr in self.names: + if value: + self.value |= (2 ** self.names.index(attr)) + else: + self.value &= ~(2 ** self.names.index(attr)) + else: + return super(FlagValue, self).__setattr__(attr, value) class FlagsField(BitField): diff --git a/test/regression.uts b/test/regression.uts index e08eabc362c61af4c5d862e01ca6dede23e71487..393df949434fcc052efd761c21e9f63aaa34fc82 100644 --- a/test/regression.uts +++ b/test/regression.uts @@ -7497,11 +7497,20 @@ pkt = IP(flags="MF") assert pkt.flags.MF assert not pkt.flags.DF assert not pkt.flags.evil +pkt.flags.MF = 0 +pkt.flags.DF = 1 +assert not pkt.flags.MF +assert pkt.flags.DF +assert not pkt.flags.evil pkt = IP(flags=3) assert pkt.flags.MF assert pkt.flags.DF assert not pkt.flags.evil +pkt.flags = 6 +assert not pkt.flags.MF +assert pkt.flags.DF +assert pkt.flags.evil = TCP flags ~ TCP @@ -7511,7 +7520,15 @@ assert pkt.flags == 18 assert pkt.flags.S assert pkt.flags.A assert not any(getattr(pkt.flags, f) for f in 'FRPUECN') +pkt.flags.U = True +pkt.flags.S = False +assert pkt.flags.A +assert pkt.flags.U +assert not any(getattr(pkt.flags, f) for f in 'FSRPECN') pkt = TCP(flags=56) assert all(getattr(pkt.flags, f) for f in 'PAU') assert not any(getattr(pkt.flags, f) for f in 'FSRECN') +pkt.flags = 50 +assert all(getattr(pkt.flags, f) for f in 'SAU') +assert not any(getattr(pkt.flags, f) for f in 'FRPECN')