diff --git a/scapy/layers/ntp.py b/scapy/layers/ntp.py index 5a72a9de90efa910c5055eb2074673da2603eb9a..f8dbe16b45c95e9f6fe0c42152bfc2b25489b95e 100644 --- a/scapy/layers/ntp.py +++ b/scapy/layers/ntp.py @@ -95,14 +95,6 @@ class TimeStampField(FixedPointField): return FixedPointField.i2m(self, pkt, val) -def get_cls(name, fallback_cls=conf.raw_layer): - """ - Returns class named "name" if it exists, fallback_cls otherwise. - """ - - return globals().get(name, fallback_cls) - - ############################################################################# ##### NTP ############################################################################# @@ -172,42 +164,22 @@ _kiss_codes = { # Used by _ntp_dispatcher to instantiate the appropriate class -_ntp_cls_by_mode = { - 0: "NTPHeader", - 1: "NTPHeader", - 2: "NTPHeader", - 3: "NTPHeader", - 4: "NTPHeader", - 5: "NTPHeader", - 6: "NTPControl", - 7: "NTPPrivate" -} - - def _ntp_dispatcher(payload): """ Returns the right class for a given NTP packet. """ - - cls = conf.raw_layer - # By default, calling NTP() will build a NTP packet as defined in RFC 5905 # (see the code of NTPHeader). Use NTPHeader for extension fields and MAC. if payload is None: - cls = get_cls("NTPHeader") - + return NTPHeader else: length = len(payload) if length >= _NTP_PACKET_MIN_SIZE: first_byte = orb(payload[0]) - # Extract NTP mode - mode_mask = 0x07 - mode = first_byte & mode_mask - - cls = get_cls(_ntp_cls_by_mode.get(mode)) - - return cls + mode = first_byte & 7 + return {6: NTPControl, 7: NTPPrivate}.get(mode, NTPHeader) + return conf.raw_layer class NTP(Packet): @@ -238,37 +210,18 @@ class NTP(Packet): # NTPHeader, NTPControl and NTPPrivate are NTP packets. # This might help, for example when reading a pcap file. def haslayer(self, cls): - ntp_classes = [ - get_cls("NTPHeader"), - get_cls("NTPControl"), - get_cls("NTPPrivate") - ] - ret = 0 - if cls == NTP: - # If cls is NTP (the parent class), check that the object is an - # instance of a NTP packet - for ntp_class in ntp_classes: - if isinstance(self, ntp_class): - ret = 1 - break - elif cls in ntp_classes and isinstance(self, cls): - ret = 1 - return ret - - def getlayer(self, cls, nb=1, _track=None, **flt): - ntp_classes = [ - get_cls("NTPHeader"), - get_cls("NTPControl"), - get_cls("NTPPrivate") - ] - if cls == NTP: - for ntp_class in ntp_classes: - if isinstance(self, ntp_class) and \ - all(self.getfieldval(fldname) == fldvalue - for fldname, fldvalue in flt.iteritems()): - return self - else: - return super(NTP, self).getlayer(cls, nb=nb, _track=_track, **flt) + """Specific: NTPHeader().haslayer(NTP) should return True.""" + if cls == "NTP": + if isinstance(self, NTP): + return True + elif issubclass(cls, NTP): + if isinstance(self, cls): + return True + return super(NTP, self).haslayer(cls) + + def getlayer(self, cls, nb=1, _track=None, _subclass=True, **flt): + return super(NTP, self).getlayer(cls, nb=nb, _track=_track, + _subclass=True, **flt) def mysummary(self): return self.sprintf("NTP v%ir,NTP.version%, %NTP.mode%")