From 6e173e9f6e570d6c3c8c76040b44e270c47308db Mon Sep 17 00:00:00 2001
From: gpotter2 <gabriel@potter.fr>
Date: Fri, 29 Sep 2017 14:55:35 +0200
Subject: [PATCH] [coverage] Add more scapypipes tests + tests fixes (#808)

* Add more scapypipes tests
* Daemonize threads, better UTscapy closing
* Fix TCPConnectPipe
---
 scapy/automaton.py         |  38 +++--
 scapy/pipetool.py          |   3 +-
 scapy/scapypipes.py        |  22 ++-
 scapy/tools/UTscapy.py     |   2 +-
 test/configs/windows.utsc  |   3 +-
 test/configs/windows2.utsc |   3 +-
 test/pipetool.uts          | 288 +++++++++++++++++++++++++++++++++++--
 7 files changed, 323 insertions(+), 36 deletions(-)

diff --git a/scapy/automaton.py b/scapy/automaton.py
index 49eaec06..b06d05c8 100644
--- a/scapy/automaton.py
+++ b/scapy/automaton.py
@@ -74,16 +74,20 @@ else:
 class SelectableObject:
     """DEV: to implement one of those, you need to add 2 things to your object:
     - add "check_recv" function
-    - call "self.call_release" once you are ready to be read"""
-    trigger = threading.Lock()
-    was_ended = False
+    - call "self.call_release" once you are ready to be read
+
+    You can set the __selectable_force_select__ to True in the class, if you want to
+    force the handler to use fileno(). This may only be useable on sockets created using
+    the builtin socket API."""
+    __selectable_force_select__ = False
     def check_recv(self):
         """DEV: will be called only once (at beginning) to check if the object is ready."""
         raise OSError("This method must be overwriten.")
 
     def _wait_non_ressources(self, callback):
         """This get started as a thread, and waits for the data lock to be freed then advertise itself to the SelectableSelector using the callback"""
-        self.call_release()
+        self.trigger = threading.Lock()
+        self.was_ended = False
         self.trigger.acquire()
         self.trigger.acquire()
         if not self.was_ended:
@@ -93,7 +97,9 @@ class SelectableObject:
         """Entry point of SelectableObject: register the callback"""
         if self.check_recv():
             return callback(self)
-        threading.Thread(target=self._wait_non_ressources, args=(callback,)).start()
+        _t = threading.Thread(target=self._wait_non_ressources, args=(callback,))
+        _t.setDaemon(True)
+        _t.start()
         
     def call_release(self, arborted=False):
         """DEV: Must be call when the object becomes ready to read.
@@ -101,7 +107,7 @@ class SelectableObject:
         self.was_ended = arborted
         try:
             self.trigger.release()
-        except THREAD_EXCEPTION as e:
+        except (THREAD_EXCEPTION, AttributeError):
             pass
 
 class SelectableSelector(object):
@@ -112,10 +118,6 @@ class SelectableSelector(object):
     remain: timeout. If 0, return [].
     customTypes: types of the objects that have the check_recv function.
     """
-    results = None
-    inputs = None
-    available_lock = None
-    _ended = False
     def _release_all(self):
         """Releases all locks to kill all threads"""
         for i in self.inputs:
@@ -129,7 +131,7 @@ class SelectableSelector(object):
             self._ended = True
             self._release_all()
 
-    def _exit_door(self,_input):
+    def _exit_door(self, _input):
         """This function is passed to each SelectableObject as a callback
         The SelectableObjects have to call it once there are ready"""
         self.results.append(_input)
@@ -149,13 +151,20 @@ class SelectableSelector(object):
     def process(self):
         """Entry point of SelectableSelector"""
         if WINDOWS:
+            select_inputs = []
             for i in self.inputs:
                 if not isinstance(i, SelectableObject):
-                    warning("Unknown ignored object type: " + type(i))
+                    warning("Unknown ignored object type: %s", type(i))
+                elif i.__selectable_force_select__:
+                    # Then use select.select
+                    select_inputs.append(i)
                 elif not self.remain and i.check_recv():
                     self.results.append(i)
                 else:
                     i.wait_return(self._exit_door)
+            if select_inputs:
+                # Use default select function
+                self.results.extend(select(select_inputs, [], [], self.remain)[0])
             if not self.remain:
                 return self.results
 
@@ -175,7 +184,6 @@ def select_objects(inputs, remain):
     
     inputs: objects to process
     remain: timeout. If 0, return [].
-    customTypes: types of the objects that have the check_recv function.
     """
     handler = SelectableSelector(inputs, remain)
     return handler.process()
@@ -694,7 +702,9 @@ class Automaton(six.with_metaclass(Automaton_metaclass)):
 
     def _do_start(self, *args, **kargs):
         ready = threading.Event()
-        threading.Thread(target=self._do_control, args=(ready,) + (args), kwargs=kargs).start()
+        _t = threading.Thread(target=self._do_control, args=(ready,) + (args), kwargs=kargs)
+        _t.setDaemon(True)
+        _t.start()
         ready.wait()
 
     def _do_control(self, ready, *args, **kargs):
diff --git a/scapy/pipetool.py b/scapy/pipetool.py
index 889a2f8a..a44c57a4 100644
--- a/scapy/pipetool.py
+++ b/scapy/pipetool.py
@@ -156,6 +156,7 @@ class PipeEngine(SelectableObject):
     def start(self):
         if self.thread_lock.acquire(0):
             _t = Thread(target=self.run)
+            _t.setDaemon(True)
             _t.start()
             self.thread = _t
         else:
@@ -375,6 +376,7 @@ class AutoSource(Source, SelectableObject):
         self._wake_up()
     def _wake_up(self):
         os.write(self.__fdw,"X")
+        self.call_release()
     def deliver(self):
         os.read(self.__fdr,1)
         try:
@@ -382,7 +384,6 @@ class AutoSource(Source, SelectableObject):
         except IndexError: #empty queue. Exhausted source
             pass
         else:
-            self.call_release()
             if high:
                 self._high_send(msg)
             else:
diff --git a/scapy/scapypipes.py b/scapy/scapypipes.py
index 231d10c2..07b84c59 100644
--- a/scapy/scapypipes.py
+++ b/scapy/scapypipes.py
@@ -165,6 +165,7 @@ class TCPConnectPipe(Source):
    >-|-[addr:port]-|->
      +-------------+
 """
+    __selectable_force_select__ = True
     def __init__(self, addr="", port=0, name=None):
         Source.__init__(self, name=name)
         self.addr = addr
@@ -181,7 +182,13 @@ class TCPConnectPipe(Source):
     def fileno(self):
         return self.fd.fileno()
     def deliver(self):
-        self._send(self.fd.recv(65536))
+        try:
+            msg = self.fd.recv(65536)
+        except socket.error:
+            self.stop()
+            raise
+        if msg:
+            self._send(msg)
 
 class TCPListenPipe(TCPConnectPipe):
     """TCP listen on [addr:]port and use first connection as source and sink ; send peer address to high output
@@ -191,6 +198,7 @@ class TCPListenPipe(TCPConnectPipe):
    >-|-[addr:port]-|->
      +-------------+
 """
+    __selectable_force_select__ = True
     def __init__(self, addr="", port=0, name=None):
         TCPConnectPipe.__init__(self, addr, port, name)
         self.connected = False
@@ -208,7 +216,13 @@ class TCPListenPipe(TCPConnectPipe):
             self.q.put(msg)
     def deliver(self):
         if self.connected:
-            self._send(self.fd.recv(65536))
+            try:
+                msg = self.fd.recv(65536)
+            except socket.error:
+                self.stop()
+                raise
+            if msg:
+                self._send(msg)
         else:
             fd,frm = self.fd.accept()
             self._high_send(frm)
@@ -277,7 +291,7 @@ class TriggeredValve(Drain):
             self._send(msg)
     def high_push(self, msg):
         if self.opened:
-            self._send(msg)
+            self._high_send(msg)
     def on_trigger(self, msg):
         self.opened ^= True
         self._trigger(msg)
@@ -305,7 +319,7 @@ class TriggeredQueueingValve(Drain):
         if self.opened:
             self._send(msg)
         else:
-            self.hq.put((False,msg))
+            self.q.put((False,msg))
     def on_trigger(self, msg):
         self.opened ^= True
         self._trigger(msg)
diff --git a/scapy/tools/UTscapy.py b/scapy/tools/UTscapy.py
index 660eb961..6460c04a 100755
--- a/scapy/tools/UTscapy.py
+++ b/scapy/tools/UTscapy.py
@@ -851,4 +851,4 @@ def main(argv):
     return glob_result
 
 if __name__ == "__main__":
-    exit(main(sys.argv[1:]))
+    sys.exit(main(sys.argv[1:]))
diff --git a/test/configs/windows.utsc b/test/configs/windows.utsc
index fb854b06..daa30fac 100644
--- a/test/configs/windows.utsc
+++ b/test/configs/windows.utsc
@@ -14,6 +14,7 @@
   "kw_ko": [
     "crypto_advanced",
     "ipv6",
-    "osx"
+    "osx",
+    "linux"
   ]
 }
diff --git a/test/configs/windows2.utsc b/test/configs/windows2.utsc
index 0f708f8f..d8c8c0e1 100644
--- a/test/configs/windows2.utsc
+++ b/test/configs/windows2.utsc
@@ -14,6 +14,7 @@
   "kw_ko": [
     "crypto_advanced",
     "mock_read_routes6_bsd",
-    "appveyor_only"
+    "appveyor_only",
+    "linux"
   ]
 }
diff --git a/test/pipetool.uts b/test/pipetool.uts
index df106625..bc7644a7 100644
--- a/test/pipetool.uts
+++ b/test/pipetool.uts
@@ -10,7 +10,7 @@ s = PeriodicSource("hello", 1, name="src")
 d1 = Drain(name="d1")
 c = ConsoleSink(name="c")
 tf = TransformDrain(lambda x:"Got %r" % x)
-t = TermSink(name="t", keepterm=False)
+t = TermSink(name="PipeToolsPeriodicTest", keepterm=False)
 s > d1 > c
 d1 > tf > t
 
@@ -18,6 +18,7 @@ p = PipeEngine(s)
 p.graph(type="png",target="> /tmp/pipe.png")
 p.start()
 time.sleep(3)
+s.msg = []
 p.stop()
 
 = Test add_pipe
@@ -46,28 +47,24 @@ p.wait_and_stop()
 
 = Test add_pipe on running instance
 
-test_val = None
-
-class TestSink(Sink):
-    def push(self, msg):
-        global test_val
-        test_val = msg
-
 p = PipeEngine()
 p.start()
 
-s = AutoSource()
-s._gen_data("hello")
-s.is_exhausted = True
+s = CLIFeeder()
 
 d1 = Drain(name="d1")
-c = TestSink(name="c")
+c = QueueSink(name="c")
 s > d1 > c
 
 p.add(s)
 
-p.wait_and_stop()
-assert test_val == "hello"
+s.send("hello")
+s.send("hi")
+
+assert c.q.get(timeout=5) == "hello"
+assert c.q.get(timeout=5) == "hi"
+
+p.stop()
 
 = Test Operators
 
@@ -295,3 +292,266 @@ _inject_sink(False) # InjectSink
 _inject_sink(True) # Inject3Sink
 
 assert msgs == [a,a]
+
+= TriggerDrain and TriggeredValve with CLIFeeder
+
+s = CLIFeeder()
+d1 = TriggerDrain(lambda x:x=="trigger")
+d2 = TriggeredValve()
+c = QueueSink()
+
+s > d1 > d2 > c
+d1 ^ d2
+
+p = PipeEngine(s)
+p.start()
+
+s.send("hello")
+s.send("trigger")
+s.send("hello2")
+s.send("trigger")
+s.send("hello3")
+
+assert c.q.get(timeout=5) == "hello"
+assert c.q.get(timeout=5) == "trigger"
+assert c.q.get(timeout=5) == "hello3"
+
+p.stop()
+
+= TriggerDrain and TriggeredValve with CLIHighFeeder
+
+s = CLIHighFeeder()
+d1 = TriggerDrain(lambda x:x=="trigger")
+d2 = TriggeredValve()
+c = QueueSink()
+
+s >> d1
+d1 >> d2
+d2 >> c
+d1 ^ d2
+
+p = PipeEngine(s)
+p.start()
+
+s.send("hello")
+s.send("trigger")
+s.send("hello2")
+s.send("trigger")
+s.send("hello3")
+
+assert c.q.get(timeout=5) == "hello"
+assert c.q.get(timeout=5) == "trigger"
+assert c.q.get(timeout=5) == "hello3"
+
+p.stop()
+
+= TriggerDrain and TriggeredQueueingValve with CLIFeeder
+
+s = CLIFeeder()
+d1 = TriggerDrain(lambda x:x=="trigger")
+d2 = TriggeredValve()
+c = QueueSink()
+
+s > d1 > d2 > c
+d1 ^ d2
+
+p = PipeEngine(s)
+p.start()
+
+s.send("hello")
+s.send("trigger")
+s.send("hello2")
+s.send("trigger")
+s.send("hello3")
+
+assert c.q.get(timeout=5) == "hello"
+assert c.q.get(timeout=5) == "trigger"
+assert c.q.get(timeout=5) == "hello3"
+
+p.stop()
+
+= TriggerDrain and TriggeredSwitch with CLIFeeder on high channel
+
+s = CLIFeeder()
+d1 = TriggerDrain(lambda x:x=="trigger")
+d2 = TriggeredSwitch()
+c = QueueSink()
+
+s > d1 > d2
+d2 >> c
+d1 ^ d2
+
+p = PipeEngine(s)
+p.start()
+
+s.send("hello")
+s.send("trigger")
+s.send("hello2")
+s.send("trigger")
+s.send("hello3")
+
+assert c.q.get(timeout=5) == "trigger"
+assert c.q.get(timeout=5) == "hello2"
+
+p.stop()
+
+= TriggerDrain and TriggeredSwitch with CLIHighFeeder on low channel
+
+s = CLIHighFeeder()
+d1 = TriggerDrain(lambda x:x=="trigger")
+d2 = TriggeredSwitch()
+c = QueueSink()
+
+s >> d1
+d1 >> d2
+d2 > c
+d1 ^ d2
+
+p = PipeEngine(s)
+p.start()
+
+s.send("hello")
+s.send("trigger")
+s.send("hello2")
+s.send("trigger")
+s.send("hello3")
+
+assert c.q.get(timeout=5) == "hello"
+assert c.q.get(timeout=5) == "trigger"
+assert c.q.get(timeout=5) == "hello3"
+
+p.stop()
+
+= TriggerDrain and TriggeredMessage
+
+s = CLIFeeder()
+d1 = TriggerDrain(lambda x:x=="trigger")
+d2 = TriggeredMessage("hello")
+c = QueueSink()
+
+s > d1 > d2 > c
+d1 ^ d2
+
+p = PipeEngine(s)
+p.start()
+
+s.send("trigger")
+
+r = [c.q.get(timeout=5), c.q.get(timeout=5)]
+assert "hello" in r
+assert "trigger" in r
+
+p.stop()
+
+= TriggerDrain and TriggeredQueueingValve on low channel
+
+p = PipeEngine()
+
+s = CLIFeeder()
+r, w = os.pipe()
+
+d1 = TriggerDrain(lambda x:x=="trigger")
+d2 = TriggeredQueueingValve()
+c = QueueSink(name="c")
+s > d1 > d2 > c
+d1 ^ d2
+
+p.add(s)
+p.start()
+
+s.send("trigger")
+s.send("hello")
+s.send("trigger")
+assert c.q.get(timeout=3) == "trigger"
+assert d2.q.qsize() == 0
+assert 'hello' in c.q.queue and 'trigger' in c.q.queue
+
+p.stop()
+
+= TriggerDrain and TriggeredQueueingValve on high channel
+
+p = PipeEngine()
+
+s = CLIHighFeeder()
+r, w = os.pipe()
+
+d1 = TriggerDrain(lambda x:x=="trigger")
+d2 = TriggeredQueueingValve()
+c = QueueSink(name="c")
+s >> d1 >> d2 >> c
+d1 ^ d2
+
+p.add(s)
+p.start()
+
+s.send("trigger")
+s.send("hello")
+s.send("trigger")
+assert c.q.get(timeout=3) == "trigger"
+assert d2.q.qsize() == 0
+assert c.q.queue == deque(['hello'])
+
+p.stop()
+
+= UDPDrain
+
+p = PipeEngine()
+
+s = CLIFeeder()
+s2 = CLIHighFeeder()
+d1 = UDPDrain()
+c = QueueSink()
+
+s > d1 > c
+s2 >> d1 >> c
+
+p.add(s)
+p.add(s2)
+p.start()
+
+s.send(IP(src="127.0.0.1")/UDP()/DNS())
+s2.send(DNS())
+
+res = [c.q.get(timeout=2), c.q.get(timeout=2)]
+assert b'\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00' in res
+res.remove(b'\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00')
+assert DNS in res[0] and res[0][UDP].sport == 1234
+
+p.stop()
+
+= FDSourceSink on a Bunch object
+
+class Bunch:
+    __init__ = lambda self, **kw: setattr(self, '__dict__', kw)
+
+fd = Bunch(write=lambda x: None, read=lambda: "hello", fileno=lambda: None)
+
+s = FDSourceSink(fd)
+d = Drain()
+c = QueueSink()
+s > d > c
+
+assert s.fileno() == None
+s.push("data")
+s.deliver()
+assert c.q.get(timeout=1) == "hello"
+
+= TCPConnectPipe networking test
+~ networking needs_root
+
+p = PipeEngine()
+
+s = CLIFeeder()
+d1 = TCPConnectPipe(addr="www.google.fr", port=80)
+c = QueueSink()
+
+s > d1 > c
+
+p.add(s)
+p.start()
+
+s.send("GET http://www.google.fr/search?q=scapy&start=1&num=1\n")
+result = c.q.get(timeout=10)
+p.stop()
+
+assert result.startswith("HTTP/1.0 200 OK")
-- 
GitLab