diff --git a/scapy/automaton.py b/scapy/automaton.py index b5be2b1430919ee2c71dd7714e124cc31ba530e6..95f638fe5cb013cdaa37ebc38c435972ebfd1b55 100644 --- a/scapy/automaton.py +++ b/scapy/automaton.py @@ -204,6 +204,7 @@ class Automaton: __metaclass__ = Automaton_metaclass def __init__(self, *args, **kargs): + self.running = False self.debug_level=0 self.init_args=args self.init_kargs=kargs @@ -247,7 +248,10 @@ class Automaton: self.debug(2, "%s [%s] not taken" % (cond.atmt_type, cond.atmt_condname)) - def run(self, *args, **kargs): + + def start(self, *args, **kargs): + self.running = True + # Update default parameters a = args+self.init_args[len(args):] k = self.init_kargs @@ -257,66 +261,91 @@ class Automaton: # Start the automaton self.state=self.initial_states[0](self) self.send_sock = conf.L3socket() - l = conf.L2listen(**self.socket_kargs) + self.listen_sock = conf.L2listen(**self.socket_kargs) self.packets = PacketList(name="session[%s]"%self.__class__.__name__) + + def next(self): + if not self.running: + self.start() + try: + self.debug(1, "## state=[%s]" % self.state.state) + + # Entering a new state. First, call new state function + state_output = self.state.run() + if self.state.error: + self.running = False + raise self.ErrorState("Reached %s: [%r]" % (self.state.state, state_output), result=state_output) + if self.state.final: + self.running = False + raise StopIteration(state_output) + + if state_output is None: + state_output = () + elif type(state_output) is not list: + state_output = state_output, + + # Then check immediate conditions + for cond in self.conditions[self.state.state]: + self.run_condition(cond, *state_output) + + # If still there and no conditions left, we are stuck! + if ( len(self.recv_conditions[self.state.state]) == 0 + and len(self.timeout[self.state.state]) == 1 ): + raise self.Stuck("stuck in [%s]" % self.state.state,result=state_output) + + # Finally listen and pay attention to timeouts + expirations = iter(self.timeout[self.state.state]) + next_timeout,timeout_func = expirations.next() + t0 = time.time() + + while 1: + t = time.time()-t0 + if next_timeout is not None: + if next_timeout <= t: + self.run_condition(timeout_func, *state_output) + next_timeout,timeout_func = expirations.next() + if next_timeout is None: + remain = None + else: + remain = next_timeout-t + + r,_,_ = select([self.listen_sock],[],[],remain) + if self.listen_sock in r: + pkt = self.listen_sock.recv(MTU) + if pkt is not None: + if self.master_filter(pkt): + self.debug(3, "RECVD: %s" % pkt.summary()) + for rcvcond in self.recv_conditions[self.state.state]: + self.run_condition(rcvcond, pkt, *state_output) + else: + self.debug(4, "FILTR: %s" % pkt.summary()) + + except ATMT.NewStateRequested,state_req: + self.debug(2, "switching from [%s] to [%s]" % (self.state.state,state_req.state)) + self.state = state_req + return state_req + + + + def run(self, *args, **kargs): + if not self.running: + self.start(*args, **kargs) + while 1: try: - self.debug(1, "## state=[%s]" % self.state.state) - - # Entering a new state. First, call new state function - state_output = self.state.run() - if self.state.error: - raise self.ErrorState("Reached %s: [%r]" % (self.state.state, state_output), result=state_output) - if self.state.final: - return state_output - - if state_output is None: - state_output = () - elif type(state_output) is not list: - state_output = state_output, - - # Then check immediate conditions - for cond in self.conditions[self.state.state]: - self.run_condition(cond, *state_output) - - # If still there and no conditions left, we are stuck! - if ( len(self.recv_conditions[self.state.state]) == 0 - and len(self.timeout[self.state.state]) == 1 ): - raise self.Stuck("stuck in [%s]" % self.state.state,result=state_output) - - # Finally listen and pay attention to timeouts - expirations = iter(self.timeout[self.state.state]) - next_timeout,timeout_func = expirations.next() - t0 = time.time() - - while 1: - t = time.time()-t0 - if next_timeout is not None: - if next_timeout <= t: - self.run_condition(timeout_func, *state_output) - next_timeout,timeout_func = expirations.next() - if next_timeout is None: - remain = None - else: - remain = next_timeout-t - - r,_,_ = select([l],[],[],remain) - if l in r: - pkt = l.recv(MTU) - if pkt is not None: - if self.master_filter(pkt): - self.debug(3, "RECVD: %s" % pkt.summary()) - for rcvcond in self.recv_conditions[self.state.state]: - self.run_condition(rcvcond, pkt, *state_output) - else: - self.debug(4, "FILTR: %s" % pkt.summary()) - - except ATMT.NewStateRequested,state_req: - self.debug(2, "switching from [%s] to [%s]" % (self.state.state,state_req.state)) - self.state = state_req + self.next() except KeyboardInterrupt: self.debug(1,"Interrupted by user") break + except StopIteration,e: + return e.args[0] + + cont = run + + def __iter__(self): + if not self.running: + self.start() + return self def my_send(self, pkt): self.send_sock.send(pkt) @@ -326,5 +355,3 @@ class Automaton: self.debug(3,"SENT : %s" % pkt.summary()) self.packets.append(pkt.copy()) - -