Skip to content
Snippets Groups Projects
Commit 7cf15249 authored by Phil's avatar Phil
Browse files

Rewrote Automaton._do_next() as a coroutine to fix bugs with command interactions

parent a0e92141
No related branches found
No related tags found
No related merge requests found
......@@ -469,6 +469,7 @@ class Automaton:
self.packets = PacketList(name="session[%s]"%self.__class__.__name__)
singlestep = True
iterator = self._do_iter()
self.debug(3, "Starting control thread [tid=%i]" % self.threadid)
try:
while True:
......@@ -483,9 +484,8 @@ class Automaton:
elif c.type == _ATMT_Command.STOP:
break
while True:
try:
state = self._do_next()
except self.CommandMessage:
state = iterator.next()
if isinstance(state, self.CommandMessage):
break
if singlestep:
c = Message(type=_ATMT_Command.SINGLESTEP,state=state)
......@@ -501,85 +501,85 @@ class Automaton:
self.debug(3, "Stopping control thread (tid=%i)"%self.threadid)
self.threadid = None
def _do_next(self):
try:
self.debug(1, "## state=[%s]" % self.state.state)
# Entering a new state. First, call new state function
if self.state.state in self.breakpoints and self.state.state != self.breakpointed:
self.breakpointed = self.state.state
raise self.Breakpoint("breakpoint triggered on state %s" % self.state.state,
state = self.state.state)
self.breakpointed = None
state_output = self.state.run()
if self.state.error:
raise self.ErrorState("Reached %s: [%r]" % (self.state.state, state_output), result=state_output)
if self.state.final:
raise StopIteration(state_output)
if state_output is None:
state_output = ()
elif type(state_output) is not list:
state_output = state_output,
# Then check immediate conditions
for cond in self.conditions[self.state.state]:
self._run_condition(cond, *state_output)
# If still there and no conditions left, we are stuck!
if ( len(self.recv_conditions[self.state.state]) == 0 and
len(self.ioevents[self.state.state]) == 0 and
len(self.timeout[self.state.state]) == 1 ):
raise self.Stuck("stuck in [%s]" % self.state.state,result=state_output)
# Finally listen and pay attention to timeouts
expirations = iter(self.timeout[self.state.state])
next_timeout,timeout_func = expirations.next()
t0 = time.time()
fds = [self.cmdin]
if len(self.recv_conditions[self.state.state]) > 0:
fds.append(self.listen_sock)
for ioev in self.ioevents[self.state.state]:
fds.append(self.ioin[ioev.atmt_ioname])
while 1:
t = time.time()-t0
if next_timeout is not None:
if next_timeout <= t:
self._run_condition(timeout_func, *state_output)
next_timeout,timeout_func = expirations.next()
if next_timeout is None:
remain = None
else:
remain = next_timeout-t
self.debug(5, "Select on %r" % fds)
r,_,_ = select(fds,[],[],remain)
self.debug(5, "Selected %r" % r)
for fd in r:
self.debug(5, "Looking at %r" % fd)
if fd == self.cmdin:
raise self.CommandMessage()
if fd == self.listen_sock:
pkt = self.listen_sock.recv(MTU)
if pkt is not None:
if self.master_filter(pkt):
self.debug(3, "RECVD: %s" % pkt.summary())
for rcvcond in self.recv_conditions[self.state.state]:
self._run_condition(rcvcond, pkt, *state_output)
else:
self.debug(4, "FILTR: %s" % pkt.summary())
def _do_iter(self):
while True:
try:
self.debug(1, "## state=[%s]" % self.state.state)
# Entering a new state. First, call new state function
if self.state.state in self.breakpoints and self.state.state != self.breakpointed:
self.breakpointed = self.state.state
raise self.Breakpoint("breakpoint triggered on state %s" % self.state.state,
state = self.state.state)
self.breakpointed = None
state_output = self.state.run()
if self.state.error:
raise self.ErrorState("Reached %s: [%r]" % (self.state.state, state_output), result=state_output)
if self.state.final:
raise StopIteration(state_output)
if state_output is None:
state_output = ()
elif type(state_output) is not list:
state_output = state_output,
# Then check immediate conditions
for cond in self.conditions[self.state.state]:
self._run_condition(cond, *state_output)
# If still there and no conditions left, we are stuck!
if ( len(self.recv_conditions[self.state.state]) == 0 and
len(self.ioevents[self.state.state]) == 0 and
len(self.timeout[self.state.state]) == 1 ):
raise self.Stuck("stuck in [%s]" % self.state.state,result=state_output)
# Finally listen and pay attention to timeouts
expirations = iter(self.timeout[self.state.state])
next_timeout,timeout_func = expirations.next()
t0 = time.time()
fds = [self.cmdin]
if len(self.recv_conditions[self.state.state]) > 0:
fds.append(self.listen_sock)
for ioev in self.ioevents[self.state.state]:
fds.append(self.ioin[ioev.atmt_ioname])
while 1:
t = time.time()-t0
if next_timeout is not None:
if next_timeout <= t:
self._run_condition(timeout_func, *state_output)
next_timeout,timeout_func = expirations.next()
if next_timeout is None:
remain = None
else:
self.debug(3, "IOEVENT on %s" % fd.ioname)
for ioevt in self.ioevents[self.state.state]:
if ioevt.atmt_ioname == fd.ioname:
self._run_condition(ioevt, fd, *state_output)
except ATMT.NewStateRequested,state_req:
self.debug(2, "switching from [%s] to [%s]" % (self.state.state,state_req.state))
self.state = state_req
return state_req
remain = next_timeout-t
self.debug(5, "Select on %r" % fds)
r,_,_ = select(fds,[],[],remain)
self.debug(5, "Selected %r" % r)
for fd in r:
self.debug(5, "Looking at %r" % fd)
if fd == self.cmdin:
yield self.CommandMessage()
elif fd == self.listen_sock:
pkt = self.listen_sock.recv(MTU)
if pkt is not None:
if self.master_filter(pkt):
self.debug(3, "RECVD: %s" % pkt.summary())
for rcvcond in self.recv_conditions[self.state.state]:
self._run_condition(rcvcond, pkt, *state_output)
else:
self.debug(4, "FILTR: %s" % pkt.summary())
else:
self.debug(3, "IOEVENT on %s" % fd.ioname)
for ioevt in self.ioevents[self.state.state]:
if ioevt.atmt_ioname == fd.ioname:
self._run_condition(ioevt, fd, *state_output)
except ATMT.NewStateRequested,state_req:
self.debug(2, "switching from [%s] to [%s]" % (self.state.state,state_req.state))
self.state = state_req
yield state_req
## Public API
def add_interception_points(self, *ipts):
......@@ -636,9 +636,16 @@ class Automaton:
return self.run(resume = Message(type=_ATMT_Command.NEXT))
def stop(self):
if self.running.locked():
self.cmdin.send(Message(type=_ATMT_Command.STOP))
self.cmdin.send(Message(type=_ATMT_Command.STOP))
with self.running:
# Flush command pipes
while True:
r,_,_ = select([self.cmdin, self.cmdout],[],[],0)
if not r:
break
for fd in r:
fd.recv()
def restart(self, *args, **kargs):
self.stop()
self.start(*args, **kargs)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment