diff --git a/scapy.py b/scapy.py index 7379deef762089d52b2b66f0a1bf26c6fcb0dc37..3b243ddb7c7e3046b6669585d660d14809765373 100755 --- a/scapy.py +++ b/scapy.py @@ -11435,74 +11435,121 @@ class ATMT: return deco -class Automaton: - - def __init__(self, *args, **kargs): - self.init_states() - self.debug_level=0 - self.init_args=args - self.init_kargs=kargs - self.parse_args(*args, **kargs) +class Automaton_metaclass(type): + def __new__(cls, name, bases, dct): + cls = super(Automaton_metaclass, cls).__new__(cls, name, bases, dct) + cls.states={} + cls.state = None + cls.recv_conditions={} + cls.conditions={} + cls.timeout={} + cls.actions={} + cls.initial_states=[] - def debug(self, lvl, msg): - if self.debug_level >= lvl: - log_interactive.debug(msg) - - def init_states(self): - self.states={} - self.state = None - self.recv_conditions={} - self.conditions={} - self.timeout={} - self.actions={} - self.initial_states=[] - decorated = dict((k,v) for (k,v) in self.get_members().iteritems() - if type(v) is types.FunctionType and hasattr(v, "atmt_type")) - - for m in decorated.itervalues(): + members = {} + classes = [cls] + while classes: + c = classes.pop(0) # order is important to avoid breaking method overloading + classes += list(c.__bases__) + for k,v in c.__dict__.iteritems(): + if k not in members: + members[k] = v + + decorated = [v for v in members.itervalues() + if type(v) is types.FunctionType and hasattr(v, "atmt_type")] + + for m in decorated: if m.atmt_type == ATMT.STATE: s = m.atmt_state - self.states[s] = m - self.recv_conditions[s]=[] - self.conditions[s]=[] - self.timeout[s]=[] + cls.states[s] = m + cls.recv_conditions[s]=[] + cls.conditions[s]=[] + cls.timeout[s]=[] if m.atmt_initial: - self.initial_states.append(m) + cls.initial_states.append(m) elif m.atmt_type in [ATMT.CONDITION, ATMT.RECV, ATMT.TIMEOUT]: - self.actions[m.atmt_condname] = [] + cls.actions[m.atmt_condname] = [] - for m in decorated.itervalues(): + for m in decorated: if m.atmt_type == ATMT.CONDITION: - self.conditions[m.atmt_state].append(m) + cls.conditions[m.atmt_state].append(m) elif m.atmt_type == ATMT.RECV: - self.recv_conditions[m.atmt_state].append(m) + cls.recv_conditions[m.atmt_state].append(m) elif m.atmt_type == ATMT.TIMEOUT: - self.timeout[m.atmt_state].append((m.atmt_timeout, m)) + cls.timeout[m.atmt_state].append((m.atmt_timeout, m)) elif m.atmt_type == ATMT.ACTION: for c in m.atmt_cond: - self.actions[c].append(m) + cls.actions[c].append(m) - for v in self.timeout.itervalues(): + for v in cls.timeout.itervalues(): v.sort(lambda (t1,f1),(t2,f2): cmp(t1,t2)) v.append((None, None)) - for v in itertools.chain(self.conditions.itervalues(), - self.recv_conditions.itervalues()): + for v in itertools.chain(cls.conditions.itervalues(), + cls.recv_conditions.itervalues()): v.sort(lambda c1,c2: cmp(c1.atmt_prio,c2.atmt_prio)) - for condname,actlst in self.actions.iteritems(): + for condname,actlst in cls.actions.iteritems(): actlst.sort(lambda c1,c2: cmp(c1.atmt_cond[condname], c2.atmt_cond[condname])) - def get_members(self): - members = {} - classes = [self.__class__] - while classes: - c = classes.pop(0) # order is important to avoid breaking method overloading - classes += list(c.__bases__) - for k,v in c.__dict__.iteritems(): - if k not in members: - members[k] = v - return members + return cls + + def graph(self, **kargs): + s = 'digraph "%s" {\n' % self.__class__.__name__ + + se = "" # Keep initial nodes at the begining for better rendering + for st in self.states.itervalues(): + if st.atmt_initial: + se = ('\t"%s" [ style=filled, fillcolor=blue, shape=box, root=true];\n' % st.atmt_state)+se + elif st.atmt_final: + se += '\t"%s" [ style=filled, fillcolor=green, shape=octagon ];\n' % st.atmt_state + elif st.atmt_error: + se += '\t"%s" [ style=filled, fillcolor=red, shape=octagon ];\n' % st.atmt_state + s += se + + for st in self.states.values(): + for n in st.atmt_origfunc.func_code.co_names+st.atmt_origfunc.func_code.co_consts: + if n in self.states: + s += '\t"%s" -> "%s" [ color=green ];\n' % (st.atmt_state,n) + + + for c,k,v in [("purple",k,v) for k,v in self.conditions.items()]+[("red",k,v) for k,v in self.recv_conditions.items()]: + for f in v: + for n in f.func_code.co_names+f.func_code.co_consts: + if n in self.states: + l = f.atmt_condname + for x in self.actions[f.atmt_condname]: + l += "\\l>[%s]" % x.func_name + s += '\t"%s" -> "%s" [label="%s", color=%s];\n' % (k,n,l,c) + for k,v in self.timeout.iteritems(): + for t,f in v: + if f is None: + continue + for n in f.func_code.co_names+f.func_code.co_consts: + if n in self.states: + l = "%s/%.1fs" % (f.atmt_condname,t) + for x in self.actions[f.atmt_condname]: + l += "\\l>[%s]" % x.func_name + s += '\t"%s" -> "%s" [label="%s",color=blue];\n' % (k,n,l) + s += "}\n" + return do_graph(s, **kargs) + + + +class Automaton: + __metaclass__ = Automaton_metaclass + + def __init__(self, *args, **kargs): + self.debug_level=0 + self.init_args=args + self.init_kargs=kargs + self.parse_args(*args, **kargs) + + def debug(self, lvl, msg): + if self.debug_level >= lvl: + log_interactive.debug(msg) + + class ErrorState(Exception): @@ -11615,46 +11662,6 @@ class Automaton: self.packets.append(pkt.copy()) - def graph(self, **kargs): - s = 'digraph "%s" {\n' % self.__class__.__name__ - - se = "" # Keep initial nodes at the begining for better rendering - for st in self.states.itervalues(): - if st.atmt_initial: - se = ('\t"%s" [ style=filled, fillcolor=blue, shape=box, root=true];\n' % st.atmt_state)+se - elif st.atmt_final: - se += '\t"%s" [ style=filled, fillcolor=green, shape=octagon ];\n' % st.atmt_state - elif st.atmt_error: - se += '\t"%s" [ style=filled, fillcolor=red, shape=octagon ];\n' % st.atmt_state - s += se - - for st in self.states.values(): - for n in st.atmt_origfunc.func_code.co_names+st.atmt_origfunc.func_code.co_consts: - if n in self.states: - s += '\t"%s" -> "%s" [ color=green ];\n' % (st.atmt_state,n) - - - for c,k,v in [("purple",k,v) for k,v in self.conditions.items()]+[("red",k,v) for k,v in self.recv_conditions.items()]: - for f in v: - for n in f.func_code.co_names+f.func_code.co_consts: - if n in self.states: - l = f.atmt_condname - for x in self.actions[f.atmt_condname]: - l += "\\l>[%s]" % x.func_name - s += '\t"%s" -> "%s" [label="%s", color=%s];\n' % (k,n,l,c) - for k,v in self.timeout.iteritems(): - for t,f in v: - if f is None: - continue - for n in f.func_code.co_names+f.func_code.co_consts: - if n in self.states: - l = "%s/%.1fs" % (f.atmt_condname,t) - for x in self.actions[f.atmt_condname]: - l += "\\l>[%s]" % x.func_name - s += '\t"%s" -> "%s" [label="%s",color=blue];\n' % (k,n,l) - s += "}\n" - return do_graph(s, **kargs) -