Skip to content
Snippets Groups Projects
Commit bf1c3ee8 authored by Florian Maury's avatar Florian Maury
Browse files

Add SSLStreamSocket

parent 620f195c
No related branches found
No related tags found
No related merge requests found
...@@ -12,6 +12,7 @@ import socket,time ...@@ -12,6 +12,7 @@ import socket,time
from scapy.config import conf from scapy.config import conf
from scapy.data import * from scapy.data import *
from scapy.error import warning, log_runtime from scapy.error import warning, log_runtime
import scapy.packet
class _SuperSocket_metaclass(type): class _SuperSocket_metaclass(type):
def __repr__(self): def __repr__(self):
...@@ -137,8 +138,41 @@ class StreamSocket(SimpleSocket): ...@@ -137,8 +138,41 @@ class StreamSocket(SimpleSocket):
pad = pad.payload pad = pad.payload
self.ins.recv(x) self.ins.recv(x)
return pkt return pkt
class SSLStreamSocket(StreamSocket):
desc = "similar usage than StreamSocket but specialized for handling SSL-wrapped sockets"
def __init__(self, sock, basecls=None):
self._buf = ''
super(SSLStreamSocket, self).__init__(sock, basecls)
#65535, the default value of x is the maximum length of a TLS record
def recv(self, x=65535):
pkt = None
if self._buf != '':
try:
pkt = self.basecls(self._buf)
except:
# We assume that the exception is generated by a buffer underflow
pass
if not pkt:
buf = self.ins.recv(x)
if len(buf) == 0:
raise socket.error((100,"Underlying stream socket tore down"))
self._buf += buf
x = len(self._buf)
pkt = self.basecls(self._buf)
pad = pkt.getlayer(conf.padding_layer)
if pad is not None and pad.underlayer is not None:
del(pad.underlayer.payload)
while pad is not None and not isinstance(pad, scapy.packet.NoPayload):
x -= len(pad.load)
pad = pad.payload
self._buf = self._buf[x:]
return pkt
if conf.L3socket is None: if conf.L3socket is None:
conf.L3socket = L3RawSocket conf.L3socket = L3RawSocket
...@@ -6163,3 +6163,139 @@ p = Ether(str(p)) ...@@ -6163,3 +6163,139 @@ p = Ether(str(p))
assert(p[VXLAN].gpid == 42) assert(p[VXLAN].gpid == 42)
assert(p[VXLAN].reserved1 is None) assert(p[VXLAN].reserved1 is None)
assert(p[Ether:2].type == 0x800) assert(p[Ether:2].type == 0x800)
+ Tests of SSLStreamContext
= Test with recv() calls that return exact packet-length strings
~ sslstreamsocket
import socket
class MockSocket(object):
def __init__(self):
self.l = [ '\x00\x00\x00\x01', '\x00\x00\x00\x02', '\x00\x00\x00\x03' ]
def recv(self, x):
if len(self.l) == 0:
raise socket.error(100, 'EOF')
return self.l.pop(0)
class TestPacket(Packet):
name = 'TestPacket'
fields_desc = [
IntField('data', 0)
]
def guess_payload_class(self, p):
return conf.padding_layer
s = MockSocket()
ss = SSLStreamSocket(s, basecls=TestPacket)
p = ss.recv()
assert(p.data == 1)
p = ss.recv()
assert(p.data == 2)
p = ss.recv()
assert(p.data == 3)
try:
ss.recv()
ret = False
except socket.error:
ret = True
assert(ret)
= Test with recv() calls that return twice as much data as the exact packet-length
~ sslstreamsocket
import socket
class MockSocket(object):
def __init__(self):
self.l = [ '\x00\x00\x00\x01\x00\x00\x00\x02', '\x00\x00\x00\x03\x00\x00\x00\x04' ]
def recv(self, x):
if len(self.l) == 0:
raise socket.error(100, 'EOF')
return self.l.pop(0)
class TestPacket(Packet):
name = 'TestPacket'
fields_desc = [
IntField('data', 0)
]
def guess_payload_class(self, p):
return conf.padding_layer
s = MockSocket()
ss = SSLStreamSocket(s, basecls=TestPacket)
p = ss.recv()
assert(p.data == 1)
p = ss.recv()
assert(p.data == 2)
p = ss.recv()
assert(p.data == 3)
p = ss.recv()
assert(p.data == 4)
try:
ss.recv()
ret = False
except socket.error:
ret = True
assert(ret)
= Test with recv() calls that return not enough data
~ sslstreamsocket
import socket
class MockSocket(object):
def __init__(self):
self.l = [ '\x00\x00', '\x00\x01', '\x00\x00\x00', '\x02', '\x00\x00', '\x00', '\x03' ]
def recv(self, x):
if len(self.l) == 0:
raise socket.error(100, 'EOF')
return self.l.pop(0)
class TestPacket(Packet):
name = 'TestPacket'
fields_desc = [
IntField('data', 0)
]
def guess_payload_class(self, p):
return conf.padding_layer
s = MockSocket()
ss = SSLStreamSocket(s, basecls=TestPacket)
try:
p = ss.recv()
ret = False
except:
ret = True
assert(ret)
p = ss.recv()
assert(p.data == 1)
try:
p = ss.recv()
ret = False
except:
ret = True
assert(ret)
p = ss.recv()
assert(p.data == 2)
try:
p = ss.recv()
ret = False
except:
ret = True
assert(ret)
try:
p = ss.recv()
ret = False
except:
ret = True
assert(ret)
p = ss.recv()
assert(p.data == 3)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment