From 730c47bda3db127330ae0926643cf87bf9f5a883 Mon Sep 17 00:00:00 2001
From: gpotter2 <gabriel@potter.fr>
Date: Wed, 27 Sep 2017 15:24:33 +0200
Subject: [PATCH] Fix sndrcvloop functions + share code

---
 scapy/sendrecv.py | 249 +++++++++++++++++++++++++---------------------
 1 file changed, 134 insertions(+), 115 deletions(-)

diff --git a/scapy/sendrecv.py b/scapy/sendrecv.py
index 12cb6654..6a5dd0c0 100644
--- a/scapy/sendrecv.py
+++ b/scapy/sendrecv.py
@@ -70,7 +70,6 @@ def _sndrcv_snd(pks, timeout, inter, verbose, tobesent, stopevent):
         stopevent.wait(timeout)
         stopevent.set()
 
-
 class _BreakException(Exception):
     """A dummy exception used in _get_pkt() to get out of the infinite
 loop
@@ -78,33 +77,15 @@ loop
     """
     pass
 
-
-def sndrcv(pks, pkt, timeout=None, inter=0, verbose=None, chainCC=False,
-           retry=0, multi=False):
-    if not isinstance(pkt, Gen):
-        pkt = SetGen(pkt)
-    if verbose is None:
-        verbose = conf.verb
-    debug.recv = plist.PacketList([],"Unanswered")
-    debug.sent = plist.PacketList([],"Sent")
-    debug.match = plist.SndRcvList([])
-    nbrecv=0
+def _sndrcv_rcv(pks, tobesent, stopevent, nbrecv, notans, verbose, chainCC,
+                multi):
+    """Function used to recieve packets and check their hashret"""
     ans = []
-    # do it here to fix random fields, so that parent and child have the same
-    tobesent = [p for p in pkt]
-    notans = len(tobesent)
-
-    hsent={}
+    hsent = {}
     for i in tobesent:
         h = i.hashret()
         hsent.setdefault(i.hashret(), []).append(i)
 
-    if retry < 0:
-        retry = -retry
-        autostop = retry
-    else:
-        autostop = 0
-
     if WINDOWS:
         def _get_pkt():
             return pks.recv(MTU)
@@ -134,6 +115,69 @@ def sndrcv(pks, pkt, timeout=None, inter=0, verbose=None, chainCC=False,
             if stopevent.is_set():
                 raise _BreakException()
 
+    try:
+        try:
+            while True:
+                r = _get_pkt()
+                if r is None:
+                    if stopevent.is_set():
+                        break
+                    continue
+                ok = False
+                h = r.hashret()
+                if h in hsent:
+                    hlst = hsent[h]
+                    for i, sentpkt in enumerate(hlst):
+                        if r.answers(sentpkt):
+                            ans.append((sentpkt, r))
+                            if verbose > 1:
+                                os.write(1, b"*")
+                            ok = True
+                            if not multi:
+                                del hlst[i]
+                                notans -= 1
+                            else:
+                                if not hasattr(sentpkt, '_answered'):
+                                    notans -= 1
+                                sentpkt._answered = 1
+                            break
+                if notans == 0 and not multi:
+                    break
+                if not ok:
+                    if verbose > 1:
+                        os.write(1, b".")
+                    nbrecv += 1
+                    if conf.debug_match:
+                        debug.recv.append(r)
+        except KeyboardInterrupt:
+            if chainCC:
+                raise
+        except _BreakException:
+            pass
+    finally:
+        stopevent.set()
+    return (hsent, ans, nbrecv, notans)
+
+def sndrcv(pks, pkt, timeout=None, inter=0, verbose=None, chainCC=False,
+           retry=0, multi=False):
+    if not isinstance(pkt, Gen):
+        pkt = SetGen(pkt)
+    if verbose is None:
+        verbose = conf.verb
+    debug.recv = plist.PacketList([],"Unanswered")
+    debug.sent = plist.PacketList([],"Sent")
+    debug.match = plist.SndRcvList([])
+    nbrecv=0
+    # do it here to fix random fields, so that parent and child have the same
+    tobesent = [p for p in pkt]
+    notans = len(tobesent)
+
+    if retry < 0:
+        retry = -retry
+        autostop = retry
+    else:
+        autostop = 0
+
     while retry >= 0:
         if timeout < 0:
             timeout = None
@@ -145,49 +189,8 @@ def sndrcv(pks, pkt, timeout=None, inter=0, verbose=None, chainCC=False,
         )
         thread.start()
 
-        try:
-            try:
-                while True:
-                    r = _get_pkt()
-                    if r is None:
-                        if stopevent.is_set():
-                            break
-                        continue
-                    ok = False
-                    h = r.hashret()
-                    if h in hsent:
-                        hlst = hsent[h]
-                        for i, sentpkt in enumerate(hlst):
-                            if r.answers(sentpkt):
-                                ans.append((sentpkt, r))
-                                if verbose > 1:
-                                    os.write(1, b"*")
-                                ok = True
-                                if not multi:
-                                    del hlst[i]
-                                    notans -= 1
-                                else:
-                                    if not hasattr(sentpkt, '_answered'):
-                                        notans -= 1
-                                    sentpkt._answered = 1
-                                break
-                    if notans == 0 and not multi:
-                        break
-                    if not ok:
-                        if verbose > 1:
-                            os.write(1, b".")
-                        nbrecv += 1
-                        if conf.debug_match:
-                            debug.recv.append(r)
-            except KeyboardInterrupt:
-                if chainCC:
-                    raise
-            except _BreakException:
-                pass
-        finally:
-            stopevent.set()
-            thread.join()
-            pks.close()
+        hsent, ans, nbrecv, notans = _sndrcv_rcv(pks, tobesent, stopevent, nbrecv, notans, verbose, chainCC, multi)
+        thread.join()
 
         remain = list(itertools.chain(*six.itervalues(hsent)))
         if multi:
@@ -403,6 +406,8 @@ iface:    work only on the given interface"""
     else:
         return None
 
+# SEND/RECV LOOP METHODS
+
 def __sr_loop(srfunc, pkts, prn=lambda x:x[1].summary(), prnfail=lambda x:x.summary(), inter=1, timeout=None, count=None, verbose=None, store=1, *args, **kargs):
     n = 0
     r = 0
@@ -467,71 +472,55 @@ def srploop(pkts, *args, **kargs):
 srloop(pkts, [prn], [inter], [count], ...) --> None"""
     return __sr_loop(srp, pkts, *args, **kargs)
 
+# SEND/RECV FLOOD METHODS
 
-def sndrcvflood(pks, pkt, prn=lambda s_r:s_r[1].summary(), chainCC=0, store=1, unique=0):
+def sndrcvflood(pks, pkt, inter=0, verbose=None, chainCC=False, prn=lambda x: x):
+    if not verbose:
+        verbose = conf.verb
     if not isinstance(pkt, Gen):
         pkt = SetGen(pkt)
     tobesent = [p for p in pkt]
     received = plist.SndRcvList()
     seen = {}
 
-    hsent={}
-    for i in tobesent:
-        h = i.hashret()
-        if h in hsent:
-            hsent[h].append(i)
-        else:
-            hsent[h] = [i]
+    stopevent = threading.Event()
+    count_packets = six.moves.queue.Queue()
 
-    def send_in_loop(tobesent):
+    def send_in_loop(tobesent, stopevent, count_packets=count_packets):
+        """Infinite generator that produces the same packet until stopevent is triggered."""
         while True:
             for p in tobesent:
+                if stopevent.is_set():
+                    raise StopIteration()
+                count_packets.put(0)
                 yield p
 
-    packets_to_send = send_in_loop(tobesent)
+    infinite_gen = send_in_loop(tobesent, stopevent)
 
-    ssock = rsock = pks.fileno()
+    # We don't use _sndrcv_snd verbose (it messes the logs up as in a thread that ends after recieving)
+    thread = threading.Thread(
+        target=_sndrcv_snd,
+        args=(pks, None, inter, False, infinite_gen, stopevent),
+    )
+    thread.start()
 
-    try:
-        while True:
-            if conf.use_bpf:
-                from scapy.arch.bpf.supersocket import bpf_select
-                readyr = bpf_select([rsock])
-                _, readys, _ = select([], [ssock], [])
-            else:
-                readyr, readys, _ = select([rsock], [ssock], [])
+    hsent, ans, nbrecv, notans = _sndrcv_rcv(pks, tobesent, stopevent, 0, len(tobesent), verbose, chainCC, False)
+    thread.join()
+    remain = list(itertools.chain(*six.itervalues(hsent)))
+    # Apply prn
+    ans = [(x, prn(y)) for (x, y) in ans]
 
-            if ssock in readys:
-                pks.send(packets_to_send.next())
-                
-            if rsock in readyr:
-                p = pks.recv(MTU)
-                if p is None:
-                    continue
-                h = p.hashret()
-                if h in hsent:
-                    hlst = hsent[h]
-                    for i in hlst:
-                        if p.answers(i):
-                            res = prn((i,p))
-                            if unique:
-                                if res in seen:
-                                    continue
-                                seen[res] = None
-                            if res is not None:
-                                print(res)
-                            if store:
-                                received.append((i,p))
-    except KeyboardInterrupt:
-        if chainCC:
-            raise
-    return received
+    if verbose:
+        print("\nReceived %i packets, got %i answers, remaining %i packets. Sent a total of %i packets." % (nbrecv+len(ans), len(ans), notans, count_packets.qsize()))
+    count_packets.empty()
+    del count_packets
+
+    return plist.SndRcvList(ans), plist.PacketList(remain, "Unanswered")
 
 @conf.commands.register
 def srflood(x, promisc=None, filter=None, iface=None, nofilter=None, *args,**kargs):
     """Flood and receive packets at layer 3
-prn:      function applied to packets received. Ret val is printed if not None
-store:    if 1 (default), store answers and return them
+prn:      function applied to packets received
 unique:   only consider packets whose print 
 nofilter: put 1 to avoid use of BPF filters
 filter:   provide a BPF filter
@@ -541,11 +530,26 @@ iface:    listen answers only on the given interface"""
     s.close()
     return r
 
+@conf.commands.register
+def sr1flood(x, promisc=None, filter=None, iface=None, nofilter=0, *args,**kargs):
+    """Flood and receive packets at layer 3 and return only the first answer
+prn:      function applied to packets received
+verbose:  set verbosity level
+nofilter: put 1 to avoid use of BPF filters
+filter:   provide a BPF filter
+iface:    listen answers only on the given interface"""
+    s=conf.L3socket(promisc=promisc, filter=filter, nofilter=nofilter, iface=iface)
+    ans, _ = sndrcvflood(s, x, *args, **kargs)
+    s.close()
+    if len(ans) > 0:
+        return ans[0][1]
+    else:
+        return None
+
 @conf.commands.register
 def srpflood(x, promisc=None, filter=None, iface=None, iface_hint=None, nofilter=None, *args,**kargs):
     """Flood and receive packets at layer 2
-prn:      function applied to packets received. Ret val is printed if not None
-store:    if 1 (default), store answers and return them
+prn:      function applied to packets received
 unique:   only consider packets whose print 
 nofilter: put 1 to avoid use of BPF filters
 filter:   provide a BPF filter
@@ -557,8 +561,23 @@ iface:    listen answers only on the given interface"""
     s.close()
     return r
 
-           
+@conf.commands.register
+def srp1flood(x, promisc=None, filter=None, iface=None, nofilter=0, *args,**kargs):
+    """Flood and receive packets at layer 2 and return only the first answer
+prn:      function applied to packets received
+verbose:  set verbosity level
+nofilter: put 1 to avoid use of BPF filters
+filter:   provide a BPF filter
+iface:    listen answers only on the given interface"""
+    s=conf.L2socket(promisc=promisc, filter=filter, nofilter=nofilter, iface=iface)
+    ans, _ = sndrcvflood(s, x, *args, **kargs)
+    s.close()
+    if len(ans) > 0:
+        return ans[0][1]
+    else:
+        return None
 
+# SNIFF METHODS
 
 @conf.commands.register
 def sniff(count=0, store=True, offline=None, prn=None, lfilter=None,
-- 
GitLab