From a2e649debf169a8221abb7e17935f885c7f23b54 Mon Sep 17 00:00:00 2001
From: Phil <phil@secdev.org>
Date: Mon, 5 Jan 2009 19:30:59 +0100
Subject: [PATCH] Fixed defragmentation issues in defrag() and defragment()
 (ticket #152)

- do not defragment if last packets are missing
  (problem was in both defrag() and defragment())
- defragment() now puts a defragmented packet at the place of
  the latest fragment in the original capture instead of
  putting it in place of the last fragment of the datagram.
---
 scapy/layers/inet.py | 67 ++++++++++++++++++--------------------------
 1 file changed, 27 insertions(+), 40 deletions(-)

diff --git a/scapy/layers/inet.py b/scapy/layers/inet.py
index 94b75d7d..d9032dab 100644
--- a/scapy/layers/inet.py
+++ b/scapy/layers/inet.py
@@ -5,6 +5,7 @@
 
 import os,time,struct,re,socket,new
 from select import select
+from collections import defaultdict
 from scapy.utils import checksum
 from scapy.layers.l2 import *
 from scapy.config import conf
@@ -561,7 +562,7 @@ def overlap_frag(p, overlap, fragsize=8, overlap_fragsize=None):
 def defrag(plist):
     """defrag(plist) -> ([not fragmented], [defragmented],
                   [ [bad fragments], [bad fragments], ... ])"""
-    frags = {}
+    frags = defaultdict(PacketList)
     nofrag = PacketList()
     for p in plist:
         ip = p[IP]
@@ -572,16 +573,14 @@ def defrag(plist):
             nofrag.append(p)
             continue
         uniq = (ip.id,ip.src,ip.dst,ip.proto)
-        if uniq in frags:
-            frags[uniq].append(p)
-        else:
-            frags[uniq] = PacketList([p])
+        frags[uniq].append(p)
     defrag = []
     missfrag = []
     for lst in frags.itervalues():
-        lst.sort(lambda x,y:cmp(x.frag, y.frag))
+        lst.sort(key=lambda x: x.frag)
         p = lst[0]
-        if p.frag > 0:
+        lastp = lst[-1]
+        if p.frag > 0 or lastp.flags & 1 != 0: # first or last fragment missing
             missfrag.append(lst)
             continue
         p = p.copy()
@@ -594,11 +593,10 @@ def defrag(plist):
             clen = ip.len - (ip.ihl<<2)
         txt = Raw()
         for q in lst[1:]:
-            if clen != q.frag<<3:
+            if clen != q.frag<<3: # Wrong fragmentation offset
                 if clen > q.frag<<3:
                     warning("Fragment overlap (%i > %i) %r || %r ||  %r" % (clen, q.frag<<3, p,txt,q))
                 missfrag.append(lst)
-                txt = None
                 break
             if q[IP].len is None or q[IP].ihl is None:
                 clen += len(q[IP].payload)
@@ -607,15 +605,12 @@ def defrag(plist):
             if Padding in q:
                 del(q[Padding].underlayer.payload)
             txt.add_payload(q[IP].payload.copy())
-            
-        if txt is None:
-            continue
-
-        ip.flags &= ~1 # !MF
-        del(ip.chksum)
-        del(ip.len)
-        p = p/txt
-        defrag.append(p)
+        else:
+            ip.flags &= ~1 # !MF
+            del(ip.chksum)
+            del(ip.len)
+            p = p/txt
+            defrag.append(p)
     defrag2=PacketList()
     for p in defrag:
         defrag2.append(p.__class__(str(p)))
@@ -624,7 +619,7 @@ def defrag(plist):
 @conf.commands.register
 def defragment(plist):
     """defrag(plist) -> plist defragmented as much as possible """
-    frags = {}
+    frags = defaultdict(lambda:[])
     final = []
 
     pos = 0
@@ -636,19 +631,17 @@ def defragment(plist):
             if ip.frag != 0 or ip.flags & 1:
                 ip = p[IP]
                 uniq = (ip.id,ip.src,ip.dst,ip.proto)
-                if uniq in frags:
-                    frags[uniq].append(p)
-                else:
-                    frags[uniq] = [p]
+                frags[uniq].append(p)
                 continue
         final.append(p)
 
     defrag = []
     missfrag = []
     for lst in frags.itervalues():
-        lst.sort(lambda x,y:cmp(x.frag, y.frag))
+        lst.sort(key=lambda x: x.frag)
         p = lst[0]
-        if p.frag > 0:
+        lastp = lst[-1]
+        if p.frag > 0 or lastp.flags & 1 != 0: # first or last fragment missing
             missfrag += lst
             continue
         p = p.copy()
@@ -661,11 +654,10 @@ def defragment(plist):
             clen = ip.len - (ip.ihl<<2)
         txt = Raw()
         for q in lst[1:]:
-            if clen != q.frag<<3:
+            if clen != q.frag<<3: # Wrong fragmentation offset
                 if clen > q.frag<<3:
                     warning("Fragment overlap (%i > %i) %r || %r ||  %r" % (clen, q.frag<<3, p,txt,q))
                 missfrag += lst
-                txt = None
                 break
             if q[IP].len is None or q[IP].ihl is None:
                 clen += len(q[IP].payload)
@@ -674,16 +666,13 @@ def defragment(plist):
             if Padding in q:
                 del(q[Padding].underlayer.payload)
             txt.add_payload(q[IP].payload.copy())
-            
-        if txt is None:
-            continue
-
-        ip.flags &= ~1 # !MF
-        del(ip.chksum)
-        del(ip.len)
-        p = p/txt
-        p._defrag_pos = lst[-1]._defrag_pos
-        defrag.append(p)
+        else:
+            ip.flags &= ~1 # !MF
+            del(ip.chksum)
+            del(ip.len)
+            p = p/txt
+            p._defrag_pos = max(x._defrag_pos for x in lst)
+            defrag.append(p)
     defrag2=[]
     for p in defrag:
         q = p.__class__(str(p))
@@ -691,7 +680,7 @@ def defragment(plist):
         defrag2.append(q)
     final += defrag2
     final += missfrag
-    final.sort(lambda x,y: cmp(x._defrag_pos, y._defrag_pos))
+    final.sort(key=lambda x: x._defrag_pos)
     for p in final:
         del(p._defrag_pos)
 
@@ -699,11 +688,9 @@ def defragment(plist):
         name = "Defragmented %s" % plist.listname
     else:
         name = "Defragmented"
-        
     
     return PacketList(final, name=name)
             
-            
         
 
 ### Add timeskew_graph() method to PacketList
-- 
GitLab