Skip to content
Snippets Groups Projects
Commit 6aec0858 authored by Phil's avatar Phil
Browse files

Better Automaton exception hierarchy, and some cleanups

parent 135d5a3e
No related branches found
No related tags found
No related merge requests found
...@@ -155,6 +155,7 @@ class ATMT: ...@@ -155,6 +155,7 @@ class ATMT:
class _ATMT_Command: class _ATMT_Command:
RUN = "RUN" RUN = "RUN"
NEXT = "NEXT" NEXT = "NEXT"
FREEZE = "FREEZE"
STOP = "STOP" STOP = "STOP"
END = "END" END = "END"
EXCEPTION = "EXCEPTION" EXCEPTION = "EXCEPTION"
...@@ -327,19 +328,30 @@ class Automaton: ...@@ -327,19 +328,30 @@ class Automaton:
def write(self, msg): def write(self, msg):
return self.wr.send(msg) return self.wr.send(msg)
class ErrorState(Exception):
class AutomatonException(Exception):
pass
class ErrorState(AutomatonException):
def __init__(self, msg, result=None): def __init__(self, msg, result=None):
Exception.__init__(self, msg) Exception.__init__(self, msg)
self.result = result self.result = result
class Stuck(ErrorState): class Stuck(ErrorState):
pass pass
class Breakpoint(Exception): class AutomatonStopped(AutomatonException):
def __init__(self, msg, breakpoint): def __init__(self, msg, state=None):
Exception.__init__(self, msg) Exception.__init__(self, msg)
self.breakpoint = breakpoint self.state = state
class Breakpoint(AutomatonStopped):
pass
class InterceptionPoint(AutomatonStopped):
def __init__(self, msg, state, packet):
Automaton.AutomatonStopped.__init__(self, msg, state)
self.packet = packet
class CommandMessage(Exception): class CommandMessage(AutomatonException):
pass pass
...@@ -351,7 +363,7 @@ class Automaton: ...@@ -351,7 +363,7 @@ class Automaton:
def send(self, pkt): def send(self, pkt):
if self.state.state in self.interception_points: if self.state.state in self.interception_points:
self.debug(3,"INTERCEPT: packet intercepted: %s" % pkt.summary()) self.debug(3,"INTERCEPT: packet intercepted: %s" % pkt.summary())
cmd = Message(type = _ATMT_Command.INTERCEPT, pkt = pkt) cmd = Message(type = _ATMT_Command.INTERCEPT, state=self.state, pkt=pkt)
self.cmdout.send(cmd) self.cmdout.send(cmd)
cmd = self.cmdin.recv() cmd = self.cmdin.recv()
if cmd.type == _ATMT_Command.REJECT: if cmd.type == _ATMT_Command.REJECT:
...@@ -431,13 +443,18 @@ class Automaton: ...@@ -431,13 +443,18 @@ class Automaton:
else: else:
self.debug(2, "%s [%s] not taken" % (cond.atmt_type, cond.atmt_condname)) self.debug(2, "%s [%s] not taken" % (cond.atmt_type, cond.atmt_condname))
def _do_start(self, *args, **kargs):
thread.start_new_thread(self._do_control, args, kargs)
def _do_control(self, *args, **kargs): def _do_control(self, *args, **kargs):
with self.running: with self.running:
self.threadid = thread.get_ident() self.threadid = thread.get_ident()
# Update default parameters # Update default parameters
a = args+self.init_args[len(args):] a = args+self.init_args[len(args):]
k = self.init_kargs k = self.init_kargs.copy()
k.update(kargs) k.update(kargs)
self.parse_args(*a,**k) self.parse_args(*a,**k)
...@@ -457,6 +474,8 @@ class Automaton: ...@@ -457,6 +474,8 @@ class Automaton:
singlestep = False singlestep = False
elif c.type == _ATMT_Command.NEXT: elif c.type == _ATMT_Command.NEXT:
singlestep = True singlestep = True
elif c.type == _ATMT_Command.FREEZE:
continue
elif c.type == _ATMT_Command.STOP: elif c.type == _ATMT_Command.STOP:
break break
while True: while True:
...@@ -486,7 +505,7 @@ class Automaton: ...@@ -486,7 +505,7 @@ class Automaton:
if self.state.state in self.breakpoints and self.state.state != self.breakpointed: if self.state.state in self.breakpoints and self.state.state != self.breakpointed:
self.breakpointed = self.state.state self.breakpointed = self.state.state
raise self.Breakpoint("breakpoint triggered on state %s" % self.state.state, raise self.Breakpoint("breakpoint triggered on state %s" % self.state.state,
breakpoint = self.state.state) state = self.state.state)
self.breakpointed = None self.breakpointed = None
state_output = self.state.run() state_output = self.state.run()
if self.state.error: if self.state.error:
...@@ -567,8 +586,7 @@ class Automaton: ...@@ -567,8 +586,7 @@ class Automaton:
for ipt in ipts: for ipt in ipts:
if hasattr(ipt,"atmt_state"): if hasattr(ipt,"atmt_state"):
ipt = ipt.atmt_state ipt = ipt.atmt_state
if ipt in self.interception_points: self.interception_points.discard(ipt)
self.interception_points.remove(ipt)
def add_breakpoints(self, *bps): def add_breakpoints(self, *bps):
for bp in bps: for bp in bps:
...@@ -580,11 +598,11 @@ class Automaton: ...@@ -580,11 +598,11 @@ class Automaton:
for bp in bps: for bp in bps:
if hasattr(bp,"atmt_state"): if hasattr(bp,"atmt_state"):
bp = bp.atmt_state bp = bp.atmt_state
if bp in self.breakpoints: self.breakpoints.discard(pb)
self.breakpoints.remove(pb)
def start(self, *args, **kargs): def start(self, *args, **kargs):
thread.start_new_thread(self._do_control, args, kargs) if not self.running.locked():
self._do_start(*args, **kargs)
def run(self, resume=None, wait=True): def run(self, resume=None, wait=True):
if resume is None: if resume is None:
...@@ -594,14 +612,14 @@ class Automaton: ...@@ -594,14 +612,14 @@ class Automaton:
try: try:
c = self.cmdout.recv() c = self.cmdout.recv()
except KeyboardInterrupt: except KeyboardInterrupt:
self.cmdin.send(Message(type = _ATMT_Command.FREEZE))
return return
if c.type == _ATMT_Command.END: if c.type == _ATMT_Command.END:
return c.result return c.result
elif c.type == _ATMT_Command.INTERCEPT: elif c.type == _ATMT_Command.INTERCEPT:
print "Packet intercepted" raise self.InterceptionPoint("packet intercepted", state=c.state.state, packet=c.pkt)
return c.pkt
elif c.type == _ATMT_Command.SINGLESTEP: elif c.type == _ATMT_Command.SINGLESTEP:
return c.state raise self.Breakpoint("singlestep", state=c.state.state)
elif c.type == _ATMT_Command.EXCEPTION: elif c.type == _ATMT_Command.EXCEPTION:
raise c.exception raise c.exception
...@@ -612,24 +630,25 @@ class Automaton: ...@@ -612,24 +630,25 @@ class Automaton:
return self.run(resume = Message(type=_ATMT_Command.NEXT)) return self.run(resume = Message(type=_ATMT_Command.NEXT))
def stop(self): def stop(self):
self.cmdin.send(Message(type=_ATMT_Command.STOP)) if self.running.locked():
self.cmdin.send(Message(type=_ATMT_Command.STOP))
def restart(self): def restart(self, *args, **kargs):
self.stop() self.stop()
self.start() self.start(*args, **kargs)
def accept_packet(self, pkt=None): def accept_packet(self, pkt=None, wait=True):
rsm = Message() rsm = Message()
if pkt is None: if pkt is None:
rsm.type = _ATMT_Command.ACCEPT rsm.type = _ATMT_Command.ACCEPT
else: else:
rsm.type = _ATMT_Command.REPLACE rsm.type = _ATMT_Command.REPLACE
rsm.pkt = pkt rsm.pkt = pkt
return self.run(resume=rsm) return self.run(resume=rsm, wait=wait)
def reject_packet(self): def reject_packet(self, wait=True):
rsm = Message(type = _ATMT_Command.REJECT) rsm = Message(type = _ATMT_Command.REJECT)
return self.run(resume=rsm) return self.run(resume=rsm, wait=wait)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment