From bf1c3ee866407ec8e4e99d2c48e5f2afa9efc91d Mon Sep 17 00:00:00 2001 From: Florian Maury <florian.maury@ssi.gouv.fr> Date: Fri, 29 Jul 2016 14:28:13 +0200 Subject: [PATCH] Add SSLStreamSocket --- scapy/supersocket.py | 36 +++++++++++- test/regression.uts | 136 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 171 insertions(+), 1 deletion(-) diff --git a/scapy/supersocket.py b/scapy/supersocket.py index 3fd0fef5..56f9c12d 100644 --- a/scapy/supersocket.py +++ b/scapy/supersocket.py @@ -12,6 +12,7 @@ import socket,time from scapy.config import conf from scapy.data import * from scapy.error import warning, log_runtime +import scapy.packet class _SuperSocket_metaclass(type): def __repr__(self): @@ -137,8 +138,41 @@ class StreamSocket(SimpleSocket): pad = pad.payload self.ins.recv(x) return pkt - +class SSLStreamSocket(StreamSocket): + desc = "similar usage than StreamSocket but specialized for handling SSL-wrapped sockets" + + def __init__(self, sock, basecls=None): + self._buf = '' + super(SSLStreamSocket, self).__init__(sock, basecls) + + #65535, the default value of x is the maximum length of a TLS record + def recv(self, x=65535): + pkt = None + if self._buf != '': + try: + pkt = self.basecls(self._buf) + except: + # We assume that the exception is generated by a buffer underflow + pass + + if not pkt: + buf = self.ins.recv(x) + if len(buf) == 0: + raise socket.error((100,"Underlying stream socket tore down")) + self._buf += buf + + x = len(self._buf) + pkt = self.basecls(self._buf) + pad = pkt.getlayer(conf.padding_layer) + + if pad is not None and pad.underlayer is not None: + del(pad.underlayer.payload) + while pad is not None and not isinstance(pad, scapy.packet.NoPayload): + x -= len(pad.load) + pad = pad.payload + self._buf = self._buf[x:] + return pkt if conf.L3socket is None: conf.L3socket = L3RawSocket diff --git a/test/regression.uts b/test/regression.uts index f5996ea0..9bbaae07 100644 --- a/test/regression.uts +++ b/test/regression.uts @@ -6163,3 +6163,139 @@ p = Ether(str(p)) assert(p[VXLAN].gpid == 42) assert(p[VXLAN].reserved1 is None) assert(p[Ether:2].type == 0x800) + ++ Tests of SSLStreamContext + += Test with recv() calls that return exact packet-length strings +~ sslstreamsocket + +import socket +class MockSocket(object): + def __init__(self): + self.l = [ '\x00\x00\x00\x01', '\x00\x00\x00\x02', '\x00\x00\x00\x03' ] + def recv(self, x): + if len(self.l) == 0: + raise socket.error(100, 'EOF') + return self.l.pop(0) + +class TestPacket(Packet): + name = 'TestPacket' + fields_desc = [ + IntField('data', 0) + ] + def guess_payload_class(self, p): + return conf.padding_layer + +s = MockSocket() +ss = SSLStreamSocket(s, basecls=TestPacket) + +p = ss.recv() +assert(p.data == 1) +p = ss.recv() +assert(p.data == 2) +p = ss.recv() +assert(p.data == 3) +try: + ss.recv() + ret = False +except socket.error: + ret = True + +assert(ret) + += Test with recv() calls that return twice as much data as the exact packet-length +~ sslstreamsocket + +import socket +class MockSocket(object): + def __init__(self): + self.l = [ '\x00\x00\x00\x01\x00\x00\x00\x02', '\x00\x00\x00\x03\x00\x00\x00\x04' ] + def recv(self, x): + if len(self.l) == 0: + raise socket.error(100, 'EOF') + return self.l.pop(0) + +class TestPacket(Packet): + name = 'TestPacket' + fields_desc = [ + IntField('data', 0) + ] + def guess_payload_class(self, p): + return conf.padding_layer + +s = MockSocket() +ss = SSLStreamSocket(s, basecls=TestPacket) + +p = ss.recv() +assert(p.data == 1) +p = ss.recv() +assert(p.data == 2) +p = ss.recv() +assert(p.data == 3) +p = ss.recv() +assert(p.data == 4) +try: + ss.recv() + ret = False +except socket.error: + ret = True + +assert(ret) + += Test with recv() calls that return not enough data +~ sslstreamsocket + +import socket +class MockSocket(object): + def __init__(self): + self.l = [ '\x00\x00', '\x00\x01', '\x00\x00\x00', '\x02', '\x00\x00', '\x00', '\x03' ] + def recv(self, x): + if len(self.l) == 0: + raise socket.error(100, 'EOF') + return self.l.pop(0) + +class TestPacket(Packet): + name = 'TestPacket' + fields_desc = [ + IntField('data', 0) + ] + def guess_payload_class(self, p): + return conf.padding_layer + +s = MockSocket() +ss = SSLStreamSocket(s, basecls=TestPacket) + +try: + p = ss.recv() + ret = False +except: + ret = True + +assert(ret) +p = ss.recv() +assert(p.data == 1) +try: + p = ss.recv() + ret = False +except: + ret = True + +assert(ret) +p = ss.recv() +assert(p.data == 2) +try: + p = ss.recv() + ret = False +except: + ret = True + +assert(ret) +try: + p = ss.recv() + ret = False +except: + ret = True + +assert(ret) +p = ss.recv() +assert(p.data == 3) -- GitLab