diff --git a/scapy/utils.py b/scapy/utils.py index bf890ec1744227a1457e0f1ef1d1e2b9380c9325..b41dd7c2a14cddb6a4bbf80af22121bdc599da30 100644 --- a/scapy/utils.py +++ b/scapy/utils.py @@ -428,6 +428,36 @@ class ContextManagerSubprocess(object): log_scapy.error(msg, self.name, conf.prog.wireshark, exc_info=1) return True # Suppress the exception +class ContextManagerCaptureOutput(object): + """ + Context manager that intercept the console's output. + + Example: + >>> with ContextManagerCaptureOutput() as cmco: + ... print("hey") + ... assert cmco.get_output() == "hey" + """ + def __init__(self): + self.result_export_object = "" + try: + import mock + except: + raise ImportError("The mock module needs to be installed !") + def __enter__(self): + import mock + def write(s, decorator=self): + decorator.result_export_object += s + mock_stdout = mock.Mock() + mock_stdout.write = write + self.bck_stdout = sys.stdout + sys.stdout = mock_stdout + return self + def __exit__(self, *exc): + sys.stdout = self.bck_stdout + return False + def get_output(self): + return self.result_export_object + def do_graph(graph,prog=None,format=None,target=None,type=None,string=None,options=None): """do_graph(graph, prog=conf.prog.dot, format="svg", target="| conf.prog.display", options=None, [string=1]): diff --git a/test/regression.uts b/test/regression.uts index c34f21b4c9719cd32c4c0fe7a1d61a63b6ec70a5..c0ed1b311cd5d04a4d532e7bb7793d6935cbafc4 100644 --- a/test/regression.uts +++ b/test/regression.uts @@ -22,18 +22,10 @@ ls() lsc() = List contribs -import mock -result_list_contrib = "" def test_list_contrib(): - def write(s): - global result_list_contrib - result_list_contrib += s - mock_stdout = mock.Mock() - mock_stdout.write = write - bck_stdout = sys.stdout - sys.stdout = mock_stdout - list_contrib() - sys.stdout = bck_stdout + with ContextManagerCaptureOutput() as cmco: + list_contrib() + result_list_contrib = cmco.get_output() assert("http2 : HTTP/2 (RFC 7540, RFC 7541) status=loads" in result_list_contrib) assert(result_list_contrib.split('\n') > 40) @@ -250,20 +242,12 @@ assert(fletcher16_checkbytes(b"\x28\x07", 1) == "\xaf(") = Test hexdiff function ~ not_pypy -import mock -result_hexdiff = "" def test_hexdiff(): - def write(s): - global result_hexdiff - result_hexdiff += s conf_color_theme = conf.color_theme conf.color_theme = BlackAndWhite() - mock_stdout = mock.Mock() - mock_stdout.write = write - bck_stdout = sys.stdout - sys.stdout = mock_stdout - hexdiff("abcde", "abCde") - sys.stdout = bck_stdout + with ContextManagerCaptureOutput() as cmco: + hexdiff("abcde", "abCde") + result_hexdiff = cmco.get_output() conf.interactive = True conf.color_theme = conf_color_theme expected = "0000 61 62 63 64 65 abcde\n" @@ -287,17 +271,10 @@ zerofree_randstring(4) == "\xd2\x12\xe4\x5b" = Test export_object and import_object functions import mock -result_export_object = "" def test_export_import_object(): - def write(s): - global result_export_object - result_export_object += s - mock_stdout = mock.Mock() - mock_stdout.write = write - bck_stdout = sys.stdout - sys.stdout = mock_stdout - export_object(2807) - sys.stdout = bck_stdout + with ContextManagerCaptureOutput() as cmco: + export_object(2807) + result_export_object = cmco.get_output() assert(result_export_object.endswith("eNprYPL9zqUHAAdrAf8=\n\n")) assert(import_object(result_export_object) == 2807) @@ -3158,18 +3135,11 @@ tr6 = TracerouteResult6(tr6_packets) tr6.get_trace() == {'2001:db8::1': {1: ('2001:db8::1', False), 2: ('2001:db8::2', False), 3: ('2001:db8::3', False), 4: ('2001:db8::4', False), 5: ('2001:db8::5', False), 6: ('2001:db8::6', False), 7: ('2001:db8::7', False), 8: ('2001:db8::8', False), 9: ('2001:db8::9', False)}} = show() -result = "" def test_show(): - def write(s): - global result - result += s - mock_stdout = mock.Mock() - mock_stdout.write = write - bck_stdout = sys.stdout - sys.stdout = mock_stdout - tr6 = TracerouteResult6(tr6_packets) - tr6.show() - sys.stdout = bck_stdout + with ContextManagerCaptureOutput() as cmco: + tr6 = TracerouteResult6(tr6_packets) + tr6.show() + result = cmco.get_output() expected = " 2001:db8::1 :udpdomain \n" expected += "1 2001:db8::1 3 \n" expected += "2 2001:db8::2 3 \n" @@ -8696,6 +8666,34 @@ assert(len(conf.mib._find("MIB", "keyUsage"))) assert(len(conf.mib._recurs_find_all((), "MIB", "keyUsage"))) += DADict tests + +a = DADict("test") +a.test_value = "scapy" +with ContextManagerCaptureOutput() as cmco: + a._show() + assert(cmco.get_output() == "test_value = 'scapy'\n") + +b = DADict("test2") +b.test_value_2 = "hello_world" + +a._branch(b, 1) +try: + a._branch(b, 1) + assert False +except DADict_Exception: + pass + +assert(len(a._find("test2"))) + +assert(len(a._find(test_value_2="hello_world"))) + +assert(len(a._find_all("test2"))) + +assert(not a._recurs_find((a,))) + +assert(not a._recurs_find_all((a,))) + = BER tests BER_id_enc(42) == '*' @@ -8761,18 +8759,11 @@ tr_packets = [ (IP(dst="192.168.0.1", src="192.168.0.254", ttl=ttl)/TCP(options= tr = TracerouteResult(tr_packets) assert(tr.get_trace() == {'192.168.0.1': {1: ('192.168.0.1', False), 2: ('192.168.0.2', False), 3: ('192.168.0.3', False), 4: ('192.168.0.4', False), 5: ('192.168.0.5', False), 6: ('192.168.0.6', False), 7: ('192.168.0.7', False), 8: ('192.168.0.8', False), 9: ('192.168.0.9', False)}}) -result_show = "" def test_show(): - def write(s): - global result_show - result_show += s - mock_stdout = mock.Mock() - mock_stdout.write = write - saved_stdout = sys.stdout - sys.stdout = mock_stdout - tr = TracerouteResult(tr_packets) - tr.show() - sys.stdout = saved_stdout + with ContextManagerCaptureOutput() as cmco: + tr = TracerouteResult(tr_packets) + tr.show() + result_show = cmco.get_output() expected = " 192.168.0.1:tcp80 \n" expected += "1 192.168.0.1 11 \n" expected += "2 192.168.0.2 11 \n" @@ -8789,19 +8780,11 @@ def test_show(): test_show() -import mock -result_summary = "" def test_summary(): - def write_summary(s): - global result_summary - result_summary += s - mock_stdout = mock.Mock() - mock_stdout.write = write_summary - bck_stdout = sys.stdout - sys.stdout = mock_stdout - tr = TracerouteResult(tr_packets) - tr.summary() - sys.stdout = bck_stdout + with ContextManagerCaptureOutput() as cmco: + tr = TracerouteResult(tr_packets) + tr.summary() + result_summary = cmco.get_output() assert(len(result_summary.split('\n')) == 10) assert("IP / TCP 192.168.0.254:ftp_data > 192.168.0.1:http S / Raw ==> IP / ICMP 192.168.0.9 > 192.168.0.254 time-exceeded ttl-zero-during-transit / IPerror / TCPerror / Raw" in result_summary) @@ -8854,18 +8837,11 @@ def test_report_ports(mock_sr): test_report_ports() -result_IPID_count = "" def test_IPID_count(): - def write(s): - global result_IPID_count - result_IPID_count += s - mock_stdout = mock.Mock() - mock_stdout.write = write - saved_stdout = sys.stdout - sys.stdout = mock_stdout - random.seed(0x2807) - IPID_count([(IP()/UDP(), IP(id=random.randint(0, 65535))/UDP()) for i in range(3)]) - sys.stdout = saved_stdout + with ContextManagerCaptureOutput() as cmco: + random.seed(0x2807) + IPID_count([(IP()/UDP(), IP(id=random.randint(0, 65535))/UDP()) for i in range(3)]) + result_IPID_count = cmco.get_output() lines = result_IPID_count.split("\n") assert(len(lines) == 5) assert(lines[0].endswith("Probably 3 classes: [4613, 53881, 58437]"))