diff --git a/scapy/automaton.py b/scapy/automaton.py
index ce1a2b0e4eae812fa391c1a0c14c8f8819fdd15c..7d0aae48ab6a6f943a07c5d524d762b615727c22 100644
--- a/scapy/automaton.py
+++ b/scapy/automaton.py
@@ -469,6 +469,7 @@ class Automaton:
             self.packets = PacketList(name="session[%s]"%self.__class__.__name__)
 
             singlestep = True
+            iterator = self._do_iter()
             self.debug(3, "Starting control thread [tid=%i]" % self.threadid)
             try:
                 while True:
@@ -483,9 +484,8 @@ class Automaton:
                     elif c.type == _ATMT_Command.STOP:
                         break
                     while True:
-                        try:
-                            state = self._do_next()
-                        except self.CommandMessage:
+                        state = iterator.next()
+                        if isinstance(state, self.CommandMessage):
                             break
                         if singlestep:
                             c = Message(type=_ATMT_Command.SINGLESTEP,state=state)
@@ -501,85 +501,85 @@ class Automaton:
             self.debug(3, "Stopping control thread (tid=%i)"%self.threadid)
             self.threadid = None
     
-    def _do_next(self):
-        try:
-            self.debug(1, "## state=[%s]" % self.state.state)
-
-            # Entering a new state. First, call new state function
-            if self.state.state in self.breakpoints and self.state.state != self.breakpointed: 
-                self.breakpointed = self.state.state
-                raise self.Breakpoint("breakpoint triggered on state %s" % self.state.state,
-                                      state = self.state.state)
-            self.breakpointed = None
-            state_output = self.state.run()
-            if self.state.error:
-                raise self.ErrorState("Reached %s: [%r]" % (self.state.state, state_output), result=state_output)
-            if self.state.final:
-                raise StopIteration(state_output)
-
-            if state_output is None:
-                state_output = ()
-            elif type(state_output) is not list:
-                state_output = state_output,
-            
-            # Then check immediate conditions
-            for cond in self.conditions[self.state.state]:
-                self._run_condition(cond, *state_output)
-
-            # If still there and no conditions left, we are stuck!
-            if ( len(self.recv_conditions[self.state.state]) == 0 and
-                 len(self.ioevents[self.state.state]) == 0 and
-                 len(self.timeout[self.state.state]) == 1 ):
-                raise self.Stuck("stuck in [%s]" % self.state.state,result=state_output)
-
-            # Finally listen and pay attention to timeouts
-            expirations = iter(self.timeout[self.state.state])
-            next_timeout,timeout_func = expirations.next()
-            t0 = time.time()
-            
-            fds = [self.cmdin]
-            if len(self.recv_conditions[self.state.state]) > 0:
-                fds.append(self.listen_sock)
-            for ioev in self.ioevents[self.state.state]:
-                fds.append(self.ioin[ioev.atmt_ioname])
-            while 1:
-                t = time.time()-t0
-                if next_timeout is not None:
-                    if next_timeout <= t:
-                        self._run_condition(timeout_func, *state_output)
-                        next_timeout,timeout_func = expirations.next()
-                if next_timeout is None:
-                    remain = None
-                else:
-                    remain = next_timeout-t
-
-                self.debug(5, "Select on %r" % fds)
-                r,_,_ = select(fds,[],[],remain)
-                self.debug(5, "Selected %r" % r)
-                for fd in r:
-                    self.debug(5, "Looking at %r" % fd)
-                    if fd == self.cmdin:
-                        raise self.CommandMessage()
-                    if fd == self.listen_sock:
-                        pkt = self.listen_sock.recv(MTU)
-                        if pkt is not None:
-                            if self.master_filter(pkt):
-                                self.debug(3, "RECVD: %s" % pkt.summary())
-                                for rcvcond in self.recv_conditions[self.state.state]:
-                                    self._run_condition(rcvcond, pkt, *state_output)
-                            else:
-                                self.debug(4, "FILTR: %s" % pkt.summary())
+    def _do_iter(self):
+        while True:
+            try:
+                self.debug(1, "## state=[%s]" % self.state.state)
+    
+                # Entering a new state. First, call new state function
+                if self.state.state in self.breakpoints and self.state.state != self.breakpointed: 
+                    self.breakpointed = self.state.state
+                    raise self.Breakpoint("breakpoint triggered on state %s" % self.state.state,
+                                          state = self.state.state)
+                self.breakpointed = None
+                state_output = self.state.run()
+                if self.state.error:
+                    raise self.ErrorState("Reached %s: [%r]" % (self.state.state, state_output), result=state_output)
+                if self.state.final:
+                    raise StopIteration(state_output)
+    
+                if state_output is None:
+                    state_output = ()
+                elif type(state_output) is not list:
+                    state_output = state_output,
+                
+                # Then check immediate conditions
+                for cond in self.conditions[self.state.state]:
+                    self._run_condition(cond, *state_output)
+    
+                # If still there and no conditions left, we are stuck!
+                if ( len(self.recv_conditions[self.state.state]) == 0 and
+                     len(self.ioevents[self.state.state]) == 0 and
+                     len(self.timeout[self.state.state]) == 1 ):
+                    raise self.Stuck("stuck in [%s]" % self.state.state,result=state_output)
+    
+                # Finally listen and pay attention to timeouts
+                expirations = iter(self.timeout[self.state.state])
+                next_timeout,timeout_func = expirations.next()
+                t0 = time.time()
+                
+                fds = [self.cmdin]
+                if len(self.recv_conditions[self.state.state]) > 0:
+                    fds.append(self.listen_sock)
+                for ioev in self.ioevents[self.state.state]:
+                    fds.append(self.ioin[ioev.atmt_ioname])
+                while 1:
+                    t = time.time()-t0
+                    if next_timeout is not None:
+                        if next_timeout <= t:
+                            self._run_condition(timeout_func, *state_output)
+                            next_timeout,timeout_func = expirations.next()
+                    if next_timeout is None:
+                        remain = None
                     else:
-                        self.debug(3, "IOEVENT on %s" % fd.ioname)
-                        for ioevt in self.ioevents[self.state.state]:
-                            if ioevt.atmt_ioname == fd.ioname:
-                                self._run_condition(ioevt, fd, *state_output)
-
-        except ATMT.NewStateRequested,state_req:
-            self.debug(2, "switching from [%s] to [%s]" % (self.state.state,state_req.state))
-            self.state = state_req
-            return state_req
-        
+                        remain = next_timeout-t
+    
+                    self.debug(5, "Select on %r" % fds)
+                    r,_,_ = select(fds,[],[],remain)
+                    self.debug(5, "Selected %r" % r)
+                    for fd in r:
+                        self.debug(5, "Looking at %r" % fd)
+                        if fd == self.cmdin:
+                            yield self.CommandMessage()
+                        elif fd == self.listen_sock:
+                            pkt = self.listen_sock.recv(MTU)
+                            if pkt is not None:
+                                if self.master_filter(pkt):
+                                    self.debug(3, "RECVD: %s" % pkt.summary())
+                                    for rcvcond in self.recv_conditions[self.state.state]:
+                                        self._run_condition(rcvcond, pkt, *state_output)
+                                else:
+                                    self.debug(4, "FILTR: %s" % pkt.summary())
+                        else:
+                            self.debug(3, "IOEVENT on %s" % fd.ioname)
+                            for ioevt in self.ioevents[self.state.state]:
+                                if ioevt.atmt_ioname == fd.ioname:
+                                    self._run_condition(ioevt, fd, *state_output)
+    
+            except ATMT.NewStateRequested,state_req:
+                self.debug(2, "switching from [%s] to [%s]" % (self.state.state,state_req.state))
+                self.state = state_req
+                yield state_req
 
     ## Public API
     def add_interception_points(self, *ipts):
@@ -636,9 +636,16 @@ class Automaton:
         return self.run(resume = Message(type=_ATMT_Command.NEXT))
 
     def stop(self):
-        if self.running.locked():
-            self.cmdin.send(Message(type=_ATMT_Command.STOP))
-
+        self.cmdin.send(Message(type=_ATMT_Command.STOP))
+        with self.running:
+            # Flush command pipes
+            while True:
+                r,_,_ = select([self.cmdin, self.cmdout],[],[],0)
+                if not r:
+                    break
+                for fd in r:
+                    fd.recv()
+                
     def restart(self, *args, **kargs):
         self.stop()
         self.start(*args, **kargs)