diff --git a/scapy/automaton.py b/scapy/automaton.py index ce1a2b0e4eae812fa391c1a0c14c8f8819fdd15c..7d0aae48ab6a6f943a07c5d524d762b615727c22 100644 --- a/scapy/automaton.py +++ b/scapy/automaton.py @@ -469,6 +469,7 @@ class Automaton: self.packets = PacketList(name="session[%s]"%self.__class__.__name__) singlestep = True + iterator = self._do_iter() self.debug(3, "Starting control thread [tid=%i]" % self.threadid) try: while True: @@ -483,9 +484,8 @@ class Automaton: elif c.type == _ATMT_Command.STOP: break while True: - try: - state = self._do_next() - except self.CommandMessage: + state = iterator.next() + if isinstance(state, self.CommandMessage): break if singlestep: c = Message(type=_ATMT_Command.SINGLESTEP,state=state) @@ -501,85 +501,85 @@ class Automaton: self.debug(3, "Stopping control thread (tid=%i)"%self.threadid) self.threadid = None - def _do_next(self): - try: - self.debug(1, "## state=[%s]" % self.state.state) - - # Entering a new state. First, call new state function - if self.state.state in self.breakpoints and self.state.state != self.breakpointed: - self.breakpointed = self.state.state - raise self.Breakpoint("breakpoint triggered on state %s" % self.state.state, - state = self.state.state) - self.breakpointed = None - state_output = self.state.run() - if self.state.error: - raise self.ErrorState("Reached %s: [%r]" % (self.state.state, state_output), result=state_output) - if self.state.final: - raise StopIteration(state_output) - - if state_output is None: - state_output = () - elif type(state_output) is not list: - state_output = state_output, - - # Then check immediate conditions - for cond in self.conditions[self.state.state]: - self._run_condition(cond, *state_output) - - # If still there and no conditions left, we are stuck! - if ( len(self.recv_conditions[self.state.state]) == 0 and - len(self.ioevents[self.state.state]) == 0 and - len(self.timeout[self.state.state]) == 1 ): - raise self.Stuck("stuck in [%s]" % self.state.state,result=state_output) - - # Finally listen and pay attention to timeouts - expirations = iter(self.timeout[self.state.state]) - next_timeout,timeout_func = expirations.next() - t0 = time.time() - - fds = [self.cmdin] - if len(self.recv_conditions[self.state.state]) > 0: - fds.append(self.listen_sock) - for ioev in self.ioevents[self.state.state]: - fds.append(self.ioin[ioev.atmt_ioname]) - while 1: - t = time.time()-t0 - if next_timeout is not None: - if next_timeout <= t: - self._run_condition(timeout_func, *state_output) - next_timeout,timeout_func = expirations.next() - if next_timeout is None: - remain = None - else: - remain = next_timeout-t - - self.debug(5, "Select on %r" % fds) - r,_,_ = select(fds,[],[],remain) - self.debug(5, "Selected %r" % r) - for fd in r: - self.debug(5, "Looking at %r" % fd) - if fd == self.cmdin: - raise self.CommandMessage() - if fd == self.listen_sock: - pkt = self.listen_sock.recv(MTU) - if pkt is not None: - if self.master_filter(pkt): - self.debug(3, "RECVD: %s" % pkt.summary()) - for rcvcond in self.recv_conditions[self.state.state]: - self._run_condition(rcvcond, pkt, *state_output) - else: - self.debug(4, "FILTR: %s" % pkt.summary()) + def _do_iter(self): + while True: + try: + self.debug(1, "## state=[%s]" % self.state.state) + + # Entering a new state. First, call new state function + if self.state.state in self.breakpoints and self.state.state != self.breakpointed: + self.breakpointed = self.state.state + raise self.Breakpoint("breakpoint triggered on state %s" % self.state.state, + state = self.state.state) + self.breakpointed = None + state_output = self.state.run() + if self.state.error: + raise self.ErrorState("Reached %s: [%r]" % (self.state.state, state_output), result=state_output) + if self.state.final: + raise StopIteration(state_output) + + if state_output is None: + state_output = () + elif type(state_output) is not list: + state_output = state_output, + + # Then check immediate conditions + for cond in self.conditions[self.state.state]: + self._run_condition(cond, *state_output) + + # If still there and no conditions left, we are stuck! + if ( len(self.recv_conditions[self.state.state]) == 0 and + len(self.ioevents[self.state.state]) == 0 and + len(self.timeout[self.state.state]) == 1 ): + raise self.Stuck("stuck in [%s]" % self.state.state,result=state_output) + + # Finally listen and pay attention to timeouts + expirations = iter(self.timeout[self.state.state]) + next_timeout,timeout_func = expirations.next() + t0 = time.time() + + fds = [self.cmdin] + if len(self.recv_conditions[self.state.state]) > 0: + fds.append(self.listen_sock) + for ioev in self.ioevents[self.state.state]: + fds.append(self.ioin[ioev.atmt_ioname]) + while 1: + t = time.time()-t0 + if next_timeout is not None: + if next_timeout <= t: + self._run_condition(timeout_func, *state_output) + next_timeout,timeout_func = expirations.next() + if next_timeout is None: + remain = None else: - self.debug(3, "IOEVENT on %s" % fd.ioname) - for ioevt in self.ioevents[self.state.state]: - if ioevt.atmt_ioname == fd.ioname: - self._run_condition(ioevt, fd, *state_output) - - except ATMT.NewStateRequested,state_req: - self.debug(2, "switching from [%s] to [%s]" % (self.state.state,state_req.state)) - self.state = state_req - return state_req - + remain = next_timeout-t + + self.debug(5, "Select on %r" % fds) + r,_,_ = select(fds,[],[],remain) + self.debug(5, "Selected %r" % r) + for fd in r: + self.debug(5, "Looking at %r" % fd) + if fd == self.cmdin: + yield self.CommandMessage() + elif fd == self.listen_sock: + pkt = self.listen_sock.recv(MTU) + if pkt is not None: + if self.master_filter(pkt): + self.debug(3, "RECVD: %s" % pkt.summary()) + for rcvcond in self.recv_conditions[self.state.state]: + self._run_condition(rcvcond, pkt, *state_output) + else: + self.debug(4, "FILTR: %s" % pkt.summary()) + else: + self.debug(3, "IOEVENT on %s" % fd.ioname) + for ioevt in self.ioevents[self.state.state]: + if ioevt.atmt_ioname == fd.ioname: + self._run_condition(ioevt, fd, *state_output) + + except ATMT.NewStateRequested,state_req: + self.debug(2, "switching from [%s] to [%s]" % (self.state.state,state_req.state)) + self.state = state_req + yield state_req ## Public API def add_interception_points(self, *ipts): @@ -636,9 +636,16 @@ class Automaton: return self.run(resume = Message(type=_ATMT_Command.NEXT)) def stop(self): - if self.running.locked(): - self.cmdin.send(Message(type=_ATMT_Command.STOP)) - + self.cmdin.send(Message(type=_ATMT_Command.STOP)) + with self.running: + # Flush command pipes + while True: + r,_,_ = select([self.cmdin, self.cmdout],[],[],0) + if not r: + break + for fd in r: + fd.recv() + def restart(self, *args, **kargs): self.stop() self.start(*args, **kargs)