Skip to content
Snippets Groups Projects
Commit bbcd63e5 authored by gpotter2's avatar gpotter2
Browse files

New sys.stdout contextmanager + dadict tests

parent da00861c
No related branches found
No related tags found
No related merge requests found
......@@ -428,6 +428,36 @@ class ContextManagerSubprocess(object):
log_scapy.error(msg, self.name, conf.prog.wireshark, exc_info=1)
return True # Suppress the exception
class ContextManagerCaptureOutput(object):
"""
Context manager that intercept the console's output.
Example:
>>> with ContextManagerCaptureOutput() as cmco:
... print("hey")
... assert cmco.get_output() == "hey"
"""
def __init__(self):
self.result_export_object = ""
try:
import mock
except:
raise ImportError("The mock module needs to be installed !")
def __enter__(self):
import mock
def write(s, decorator=self):
decorator.result_export_object += s
mock_stdout = mock.Mock()
mock_stdout.write = write
self.bck_stdout = sys.stdout
sys.stdout = mock_stdout
return self
def __exit__(self, *exc):
sys.stdout = self.bck_stdout
return False
def get_output(self):
return self.result_export_object
def do_graph(graph,prog=None,format=None,target=None,type=None,string=None,options=None):
"""do_graph(graph, prog=conf.prog.dot, format="svg",
target="| conf.prog.display", options=None, [string=1]):
......
......@@ -22,18 +22,10 @@ ls()
lsc()
= List contribs
import mock
result_list_contrib = ""
def test_list_contrib():
def write(s):
global result_list_contrib
result_list_contrib += s
mock_stdout = mock.Mock()
mock_stdout.write = write
bck_stdout = sys.stdout
sys.stdout = mock_stdout
list_contrib()
sys.stdout = bck_stdout
with ContextManagerCaptureOutput() as cmco:
list_contrib()
result_list_contrib = cmco.get_output()
assert("http2 : HTTP/2 (RFC 7540, RFC 7541) status=loads" in result_list_contrib)
assert(result_list_contrib.split('\n') > 40)
......@@ -250,20 +242,12 @@ assert(fletcher16_checkbytes(b"\x28\x07", 1) == "\xaf(")
= Test hexdiff function
~ not_pypy
import mock
result_hexdiff = ""
def test_hexdiff():
def write(s):
global result_hexdiff
result_hexdiff += s
conf_color_theme = conf.color_theme
conf.color_theme = BlackAndWhite()
mock_stdout = mock.Mock()
mock_stdout.write = write
bck_stdout = sys.stdout
sys.stdout = mock_stdout
hexdiff("abcde", "abCde")
sys.stdout = bck_stdout
with ContextManagerCaptureOutput() as cmco:
hexdiff("abcde", "abCde")
result_hexdiff = cmco.get_output()
conf.interactive = True
conf.color_theme = conf_color_theme
expected = "0000 61 62 63 64 65 abcde\n"
......@@ -287,17 +271,10 @@ zerofree_randstring(4) == "\xd2\x12\xe4\x5b"
= Test export_object and import_object functions
import mock
result_export_object = ""
def test_export_import_object():
def write(s):
global result_export_object
result_export_object += s
mock_stdout = mock.Mock()
mock_stdout.write = write
bck_stdout = sys.stdout
sys.stdout = mock_stdout
export_object(2807)
sys.stdout = bck_stdout
with ContextManagerCaptureOutput() as cmco:
export_object(2807)
result_export_object = cmco.get_output()
assert(result_export_object.endswith("eNprYPL9zqUHAAdrAf8=\n\n"))
assert(import_object(result_export_object) == 2807)
......@@ -3158,18 +3135,11 @@ tr6 = TracerouteResult6(tr6_packets)
tr6.get_trace() == {'2001:db8::1': {1: ('2001:db8::1', False), 2: ('2001:db8::2', False), 3: ('2001:db8::3', False), 4: ('2001:db8::4', False), 5: ('2001:db8::5', False), 6: ('2001:db8::6', False), 7: ('2001:db8::7', False), 8: ('2001:db8::8', False), 9: ('2001:db8::9', False)}}
= show()
result = ""
def test_show():
def write(s):
global result
result += s
mock_stdout = mock.Mock()
mock_stdout.write = write
bck_stdout = sys.stdout
sys.stdout = mock_stdout
tr6 = TracerouteResult6(tr6_packets)
tr6.show()
sys.stdout = bck_stdout
with ContextManagerCaptureOutput() as cmco:
tr6 = TracerouteResult6(tr6_packets)
tr6.show()
result = cmco.get_output()
expected = " 2001:db8::1 :udpdomain \n"
expected += "1 2001:db8::1 3 \n"
expected += "2 2001:db8::2 3 \n"
......@@ -8696,6 +8666,34 @@ assert(len(conf.mib._find("MIB", "keyUsage")))
assert(len(conf.mib._recurs_find_all((), "MIB", "keyUsage")))
= DADict tests
a = DADict("test")
a.test_value = "scapy"
with ContextManagerCaptureOutput() as cmco:
a._show()
assert(cmco.get_output() == "test_value = 'scapy'\n")
b = DADict("test2")
b.test_value_2 = "hello_world"
a._branch(b, 1)
try:
a._branch(b, 1)
assert False
except DADict_Exception:
pass
assert(len(a._find("test2")))
assert(len(a._find(test_value_2="hello_world")))
assert(len(a._find_all("test2")))
assert(not a._recurs_find((a,)))
assert(not a._recurs_find_all((a,)))
= BER tests
BER_id_enc(42) == '*'
......@@ -8761,18 +8759,11 @@ tr_packets = [ (IP(dst="192.168.0.1", src="192.168.0.254", ttl=ttl)/TCP(options=
tr = TracerouteResult(tr_packets)
assert(tr.get_trace() == {'192.168.0.1': {1: ('192.168.0.1', False), 2: ('192.168.0.2', False), 3: ('192.168.0.3', False), 4: ('192.168.0.4', False), 5: ('192.168.0.5', False), 6: ('192.168.0.6', False), 7: ('192.168.0.7', False), 8: ('192.168.0.8', False), 9: ('192.168.0.9', False)}})
result_show = ""
def test_show():
def write(s):
global result_show
result_show += s
mock_stdout = mock.Mock()
mock_stdout.write = write
saved_stdout = sys.stdout
sys.stdout = mock_stdout
tr = TracerouteResult(tr_packets)
tr.show()
sys.stdout = saved_stdout
with ContextManagerCaptureOutput() as cmco:
tr = TracerouteResult(tr_packets)
tr.show()
result_show = cmco.get_output()
expected = " 192.168.0.1:tcp80 \n"
expected += "1 192.168.0.1 11 \n"
expected += "2 192.168.0.2 11 \n"
......@@ -8789,19 +8780,11 @@ def test_show():
test_show()
import mock
result_summary = ""
def test_summary():
def write_summary(s):
global result_summary
result_summary += s
mock_stdout = mock.Mock()
mock_stdout.write = write_summary
bck_stdout = sys.stdout
sys.stdout = mock_stdout
tr = TracerouteResult(tr_packets)
tr.summary()
sys.stdout = bck_stdout
with ContextManagerCaptureOutput() as cmco:
tr = TracerouteResult(tr_packets)
tr.summary()
result_summary = cmco.get_output()
assert(len(result_summary.split('\n')) == 10)
assert("IP / TCP 192.168.0.254:ftp_data > 192.168.0.1:http S / Raw ==> IP / ICMP 192.168.0.9 > 192.168.0.254 time-exceeded ttl-zero-during-transit / IPerror / TCPerror / Raw" in result_summary)
......@@ -8854,18 +8837,11 @@ def test_report_ports(mock_sr):
test_report_ports()
result_IPID_count = ""
def test_IPID_count():
def write(s):
global result_IPID_count
result_IPID_count += s
mock_stdout = mock.Mock()
mock_stdout.write = write
saved_stdout = sys.stdout
sys.stdout = mock_stdout
random.seed(0x2807)
IPID_count([(IP()/UDP(), IP(id=random.randint(0, 65535))/UDP()) for i in range(3)])
sys.stdout = saved_stdout
with ContextManagerCaptureOutput() as cmco:
random.seed(0x2807)
IPID_count([(IP()/UDP(), IP(id=random.randint(0, 65535))/UDP()) for i in range(3)])
result_IPID_count = cmco.get_output()
lines = result_IPID_count.split("\n")
assert(len(lines) == 5)
assert(lines[0].endswith("Probably 3 classes: [4613, 53881, 58437]"))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment