diff --git a/scapy/contrib/bgp.py b/scapy/contrib/bgp.py index 1f8a9f2164d7309b7c47ab3fd7fc4ce9112473cf..87ea38f71e252d8da63cd914e92257cfe00ab293 100644 --- a/scapy/contrib/bgp.py +++ b/scapy/contrib/bgp.py @@ -607,29 +607,18 @@ class BGPCapability(six.with_metaclass(_BGPCapability_metaclass, Packet)): # Every BGP capability object inherits from BGPCapability. def haslayer(self, cls): - ret = 0 - if cls == BGPCapability: - # If cls is BGPCap (the parent class), check that the object is an - # instance of an existing BGP capability class. - for cap_class in _capabilities_registry: - if isinstance(self, _capabilities_registry[cap_class]): - ret = 1 - break - elif cls in _capabilities_registry and isinstance(self, cls): - ret = 1 - return ret - - def getlayer(self, cls, nb=1, _track=None, **flt): - if cls == BGPCapability: - for cap_class in _capabilities_registry: - if isinstance(self, _capabilities_registry[cap_class]) and \ - all(self.getfieldval(fldname) == fldvalue - for fldname, fldvalue in flt.iteritems()): - return self - else: - return super(BGPCapability, self).getlayer( - cls, nb=nb, _track=_track, **flt - ) + if cls == "BGPCapability": + if isinstance(self, BGPCapability): + return True + if issubclass(cls, BGPCapability): + if isinstance(self, cls): + return True + return super(BGPCapability, self).haslayer(cls) + + def getlayer(self, cls, nb=1, _track=None, _subclass=True, **flt): + return super(BGPCapability, self).getlayer( + cls, nb=nb, _track=_track, _subclass=True, **flt + ) def post_build(self, p, pay): length = 0 diff --git a/scapy/contrib/bgp.uts b/scapy/contrib/bgp.uts index 521abe07c42f71309d5fe2a4dcead1f7e3ddc905..55eb506ba5e442a44fc82032c2eb980a822f6c8e 100644 --- a/scapy/contrib/bgp.uts +++ b/scapy/contrib/bgp.uts @@ -119,10 +119,12 @@ except _BGPInvalidDataException: True = BGPCapability - Test haslayer() -BGPCapFourBytesASN().haslayer(BGPCapability) == True +assert BGPCapFourBytesASN().haslayer(BGPCapability) +assert BGPCapability in BGPCapFourBytesASN() = BGPCapability - Test getlayer() -isinstance(BGPCapFourBytesASN().getlayer(BGPCapability), BGPCapFourBytesASN) +assert isinstance(BGPCapFourBytesASN().getlayer(BGPCapability), BGPCapFourBytesASN) +assert isinstance(BGPCapFourBytesASN()[BGPCapability], BGPCapFourBytesASN) ############################ BGPCapMultiprotocol ############################## diff --git a/scapy/layers/eap.py b/scapy/layers/eap.py index 0213d2ffc903f5d8a7ddeb1e856a6cf0897d0622..f5b49764910528915b4f77a381470f78d517b003 100644 --- a/scapy/layers/eap.py +++ b/scapy/layers/eap.py @@ -238,25 +238,17 @@ class EAP(Packet): return cls def haslayer(self, cls): - ret = 0 - if cls == EAP: - for eap_class in EAP.registered_methods.values(): - if isinstance(self, eap_class): - ret = 1 - break - elif cls in list(EAP.registered_methods.values()) and isinstance(self, cls): - ret = 1 - return ret - - def getlayer(self, cls, nb=1, _track=None, **flt): - if cls == EAP: - for eap_class in EAP.registered_methods.values(): - if isinstance(self, eap_class) and \ - all(self.getfieldval(fldname) == fldvalue - for fldname, fldvalue in flt.iteritems()): - return self - else: - return super(EAP, self).getlayer(cls, nb=nb, _track=_track, **flt) + if cls == "EAP": + if isinstance(self, EAP): + return True + elif issubclass(cls, EAP): + if isinstance(self, cls): + return True + return super(EAP, self).haslayer(cls) + + def getlayer(self, cls, nb=1, _track=None, _subclass=True, **flt): + return super(EAP, self).getlayer(cls, nb=nb, _track=_track, + _subclass=True, **flt) def answers(self, other): if isinstance(other, EAP): diff --git a/scapy/layers/radius.py b/scapy/layers/radius.py index 7fa569d551b9e32a90a7f38eefbc3c2ff8012e1b..ada1cb4243857b5f0d5857e96ac2c3e134beac27 100644 --- a/scapy/layers/radius.py +++ b/scapy/layers/radius.py @@ -270,25 +270,17 @@ class RadiusAttribute(Packet): return cls def haslayer(self, cls): - if cls == RadiusAttribute: - for attr_class in RadiusAttribute.registered_attributes.values(): - if isinstance(self, attr_class): - return True - elif cls in RadiusAttribute.registered_attributes.values() and isinstance(self, cls): - return True - return False - - def getlayer(self, cls, nb=1, _track=None, **flt): - if cls == RadiusAttribute: - for attr_class in RadiusAttribute.registered_attributes.values(): - if isinstance(self, attr_class) and \ - all(self.getfieldval(fldname) == fldvalue - for fldname, fldvalue in flt.iteritems()): - return self - else: - return super(RadiusAttribute, self).getlayer( - cls, nb=nb, _track=_track, **flt - ) + if cls == "RadiusAttribute": + if isinstance(self, RadiusAttribute): + return True + elif issubclass(cls, RadiusAttribute): + if isinstance(self, cls): + return True + return super(RadiusAttribute, self).haslayer(cls) + + def getlayer(self, cls, nb=1, _track=None, _subclass=True, **flt): + return super(RadiusAttribute, self).getlayer(cls, nb=nb, _track=_track, + _subclass=True, **flt) def post_build(self, p, pay): length = self.len diff --git a/scapy/packet.py b/scapy/packet.py index 50fe67a44b069c839761bba0d8f8500d8311a03e..89184d1724565d819f243f21a0b67b8a89e2f585 100644 --- a/scapy/packet.py +++ b/scapy/packet.py @@ -889,11 +889,15 @@ class Packet(six.with_metaclass(Packet_metaclass, BasePacket)): return ret return self.payload.haslayer(cls) - def getlayer(self, cls, nb=1, _track=None, **flt): + def getlayer(self, cls, nb=1, _track=None, _subclass=False, **flt): """Return the nb^th layer that is an instance of cls, matching flt values. """ + if _subclass: + match = lambda cls1, cls2: issubclass(cls1, cls2) + else: + match = lambda cls1, cls2: cls1 == cls2 if isinstance(cls, int): nb = cls+1 cls = None @@ -901,7 +905,7 @@ values. ccls,fld = cls.split(".",1) else: ccls,fld = cls,None - if cls is None or self.__class__ == cls or self.__class__.__name__ == ccls: + if cls is None or match(self.__class__, cls) or self.__class__.__name__ == ccls: if all(self.getfieldval(fldname) == fldvalue for fldname, fldvalue in flt.iteritems()): if nb == 1: @@ -920,11 +924,13 @@ values. for fvalue in fvalue_gen: if isinstance(fvalue, Packet): track=[] - ret = fvalue.getlayer(cls, nb, _track=track) + ret = fvalue.getlayer(cls, nb=nb, _track=track, + _subclass=_subclass) if ret is not None: return ret nb = track[0] - return self.payload.getlayer(cls, nb=nb, _track=_track, **flt) + return self.payload.getlayer(cls, nb=nb, _track=_track, + _subclass=_subclass, **flt) def firstlayer(self): q = self