From 3a0bc3a94204838aeaea7940810bf5db2f9edc49 Mon Sep 17 00:00:00 2001
From: Pierre LALET <pierre.lalet@cea.fr>
Date: Wed, 23 Aug 2017 13:36:28 +0200
Subject: [PATCH] sndrcv(): replace multiprocessing with threading

---
 scapy/sendrecv.py | 114 +++++++++++++++-------------------------------
 1 file changed, 37 insertions(+), 77 deletions(-)

diff --git a/scapy/sendrecv.py b/scapy/sendrecv.py
index f8ae3135..60421936 100644
--- a/scapy/sendrecv.py
+++ b/scapy/sendrecv.py
@@ -10,7 +10,7 @@ Functions to send and receive packets.
 from __future__ import absolute_import, print_function
 import errno
 import itertools
-import multiprocessing
+import threading
 import os
 from select import select, error as select_error
 import subprocess
@@ -47,46 +47,27 @@ class debug:
 ####################
 
 
-def _sndrcv_snd(pks, inter, verbose, tobesent, all_stimuli, rdpipe, wrpipe):
+def _sndrcv_snd(pks, timeout, inter, verbose, tobesent, all_stimuli, stopevent):
     try:
-        if not WINDOWS:
-            sys.stdin.close()
-            rdpipe.close()
-        try:
-            i = 0
-            if verbose:
-                print("Begin emission:")
-            for p in tobesent:
-                pks.send(p)
-                i += 1
-                time.sleep(inter)
-            if verbose:
-                print("Finished to send %i packets." % i)
-        except SystemExit:
-            pass
-        except KeyboardInterrupt:
-            pass
-        except:
-            if WINDOWS:
-                log_runtime.exception("--- Error sending packets")
-                log_runtime.info("--- Error sending packets")
-            else:
-                log_runtime.exception("--- Error in child "
-                                      "%i" % os.getpid())
-                log_runtime.info("--- Error in child "
-                                 "%i" % os.getpid())
-    finally:
-        try:
-            sent_times = [p.sent_time for p in all_stimuli if p.sent_time]
-            if not WINDOWS:
-                # Change process group to avoid ctrl-C
-                os.setpgrp()
-            netcache = {cache.name: (cache.items(), cache._timetable.items())
-                        for cache in conf.netcache._caches_list}
-            wrpipe.send((netcache, sent_times))
-            wrpipe.close()
-        except:
-            pass
+        i = 0
+        if verbose:
+            print("Begin emission:")
+        for p in tobesent:
+            pks.send(p)
+            i += 1
+            time.sleep(inter)
+        if verbose:
+            print("Finished to send %i packets." % i)
+    except SystemExit:
+        pass
+    except KeyboardInterrupt:
+        pass
+    except:
+        log_runtime.exception("--- Error sending packets")
+        log_runtime.info("--- Error sending packets")
+    if timeout is not None:
+        stopevent.wait(timeout)
+        pks.close()
 
 
 def sndrcv(pks, pkt, timeout=None, inter=0, verbose=None, chainCC=False,
@@ -119,17 +100,14 @@ def sndrcv(pks, pkt, timeout=None, inter=0, verbose=None, chainCC=False,
     while retry >= 0:
         if timeout < 0:
             timeout = None
+        stopevent = threading.Event()
 
-        rdpipe, wrpipe = multiprocessing.Pipe(False)
-
-        proc = multiprocessing.Process(
+        thread = threading.Thread(
             target=_sndrcv_snd,
-            args=(pks, inter, verbose, tobesent, all_stimuli, rdpipe,
-                  wrpipe),
+            args=(pks, timeout, inter, verbose, tobesent, all_stimuli,
+                  stopevent),
         )
-        proc.start()
-        wrpipe.close()
-        inmask = [rdpipe, pks]
+        thread.start()
         stoptime = 0
         remaintime = None
         try:
@@ -144,19 +122,19 @@ def sndrcv(pks, pkt, timeout=None, inter=0, verbose=None, chainCC=False,
                         r = pks.recv(MTU)
                     elif conf.use_bpf:
                         from scapy.arch.bpf.supersocket import bpf_select
-                        inp = bpf_select(inmask)
+                        inp = bpf_select([pks])
                         if pks in inp:
                             r = pks.recv()
                     elif not isinstance(pks, StreamSocket) and (
                             FREEBSD or DARWIN or OPENBSD
                     ):
-                        inp, _, _ = select(inmask, [], [], 0.05)
+                        inp, _, _ = select([pks], [], [], 0.05)
                         if len(inp) == 0 or pks in inp:
                             r = pks.nonblock_recv()
                     else:
                         inp = []
                         try:
-                            inp, _, _ = select(inmask, [], [], remaintime)
+                            inp, _, _ = select([pks], [], [], remaintime)
                         except (IOError, select_error) as exc:
                             # select.error has no .errno attribute
                             if exc.args[0] != errno.EINTR:
@@ -165,11 +143,9 @@ def sndrcv(pks, pkt, timeout=None, inter=0, verbose=None, chainCC=False,
                             break
                         if pks in inp:
                             r = pks.recv(MTU)
-                    if rdpipe in inp:
-                        if timeout:
-                            stoptime = time.time() + timeout
-                        del(inmask[inmask.index(rdpipe)])
                     if r is None:
+                        if pks.closed:
+                            break
                         continue
                     ok = 0
                     h = r.hashret()
@@ -201,24 +177,8 @@ def sndrcv(pks, pkt, timeout=None, inter=0, verbose=None, chainCC=False,
                 if chainCC:
                     raise
         finally:
-            try:
-                netcache, sent_times = rdpipe.recv()
-            except EOFError:
-                warning("Child died unexpectedly. "
-                        "Packets may have not been sent.")
-            else:
-                for cache_name, (cache, cache_times) in netcache.iteritems():
-                    cache_obj = CacheInstance(name=cache_name)
-                    for key, value in cache:
-                        cache_obj[key] = value
-                    cache_obj._timetable.update(cache_times)
-                    try:
-                        getattr(conf.netcache, cache_name).update(cache_obj)
-                    except AttributeError:
-                        conf.netcache.add_cache(cache_obj)
-                for pkt, tstamp in zip(all_stimuli, sent_times):
-                    pkt.sent_time = tstamp
-            proc.join()
+            stopevent.set()
+            thread.join()
 
         remain = list(itertools.chain(*six.itervalues(hsent)))
         if multi:
@@ -236,11 +196,11 @@ def sndrcv(pks, pkt, timeout=None, inter=0, verbose=None, chainCC=False,
         debug.sent=plist.PacketList(remain[:],"Sent")
         debug.match=plist.SndRcvList(ans[:])
 
-    #clean the ans list to delete the field _answered
+    # Clean the ans list to delete the field _answered
     if multi:
-        for s,r in ans:
-            if hasattr(s, '_answered'):
-                del(s._answered)
+        for snd, _ in ans:
+            if hasattr(snd, '_answered'):
+                del snd._answered
 
     if verbose:
         print("\nReceived %i packets, got %i answers, remaining %i packets" % (nbrecv+len(ans), len(ans), notans))
-- 
GitLab