From d351405798bd7c644bcf8dfbaa14334a710ca3b5 Mon Sep 17 00:00:00 2001
From: Phil <phil@secdev.org>
Date: Sun, 10 Aug 2008 02:40:18 +0200
Subject: [PATCH] Reworked main.py, added ~/scapy_prestart.py loading, added
 load_module/layer()

---
 scapy/main.py | 141 ++++++++++++++++++++++----------------------------
 1 file changed, 63 insertions(+), 78 deletions(-)

diff --git a/scapy/main.py b/scapy/main.py
index bf2b504e..cd83cf31 100644
--- a/scapy/main.py
+++ b/scapy/main.py
@@ -7,17 +7,35 @@
 from __future__ import generators
 import os,sys
 from error import *
+import __builtin__
+    
 
-DEFAULT_CONFIG_FILE = os.path.join(os.environ["HOME"], ".scapy_startup.py")
+def _probe_config_file(cf):
+    cf_path = os.path.join(os.environ["HOME"], cf)
+    try:
+        os.stat(cf_path)
+    except OSError:
+        return None
+    else:
+        return cf_path
 
-try:
-    os.stat(DEFAULT_CONFIG_FILE)
-except OSError:
-    DEFAULT_CONFIG_FILE = None
+def _read_config_file(cf):
+    log_loading.debug("Loading config file [%s]" % cf)
+    try:
+        execfile(cf)
+    except IOError,e:
+        log_loading.warning("Cannot read config file [%s] [%s]" % (cf,e))
+    except Exception,e:
+        log_loading.exception("Error during evaluation of config file [%s]" % cf)
+        
+
+DEFAULT_PRESTART_FILE = _probe_config_file(".scapy_prestart.py")
+DEFAULT_STARTUP_FILE = _probe_config_file(".scapy_startup.py")
 
-def usage():
-    print """Usage: scapy.py [-s sessionfile] [-c new_startup_file] [-C]
-    -C: do not read startup file"""
+def _usage():
+    print """Usage: scapy.py [-s sessionfile] [-c new_startup_file] [-p new_prestart_file] [-C] [-P]
+    -C: do not read startup file
+    -P: do not read pre-startup file"""
     sys.exit(0)
 
 
@@ -30,38 +48,19 @@ from themes import ColorPrompt
 ######################
 
 
-def load_extension(filename):
-    import imp
-    paths = conf.extensions_paths
-    if type(paths) is not list:
-        paths = [paths]
+def _load(module):
+    try:
+        mod = __import__(module,globals(),locals(),".")
+        __builtin__.__dict__.update(mod.__dict__)
+    except Exception,e:
+        log_interactive.error(e)
+        
+def load_module(name):
+    _load("scapy.modules."+name)
 
-    name = os.path.realpath(os.path.expanduser(filename))
-    thepath = os.path.dirname(name)
-    thename = os.path.basename(name)
-    if thename.endswith(".py"):
-        thename = thename[:-3]
+def load_layer(name):
+    _load("scapy.layers."+name)
 
-    paths.insert(0, thepath)
-    cwd=syspath=None
-    try:
-        cwd = os.getcwd()
-        os.chdir(thepath)
-        syspath = sys.path[:]
-        sys.path += paths
-        try:
-            extf = imp.find_module(thename, paths)
-        except ImportError:
-            log_runtime.error("Module [%s] not found. Check conf.extensions_paths ?" % filename)
-        else:
-            ext = imp.load_module(thename, *extf)
-            import __builtin__
-            __builtin__.__dict__.update(ext.__dict__)
-    finally:
-        if syspath:
-            sys.path=syspath
-        if cwd:
-            os.chdir(cwd)
     
 
 ################
@@ -96,27 +95,6 @@ def interact(mydict=None,argv=None,mybanner=None,loglevel=1):
     if argv is None:
         argv = sys.argv
 
-#    scapy_module = argv[0][argv[0].rfind("/")+1:]
-#    if not scapy_module:
-#        scapy_module = "scapy"
-#    else:
-#        if scapy_module.endswith(".py"):
-#            scapy_module = scapy_module[:-3]
-#
-#    scapy=imp.load_module("scapy",*imp.find_module(scapy_module))
-    
-    
-#    __builtin__.__dict__.update(scapy.__dict__)
-    import __builtin__
-    scapy_builtins = __import__("all",globals(),locals(),".").__dict__
-    __builtin__.__dict__.update(scapy_builtins)
-    globkeys = scapy_builtins.keys()
-    globkeys.append("scapy_session")
-    scapy_builtins=None # XXX replace with "with" statement
-    if mydict is not None:
-        __builtin__.__dict__.update(mydict)
-        globkeys += mydict.keys()
-    
     import atexit
     try:
         import rlcompleter,readline
@@ -166,20 +144,25 @@ def interact(mydict=None,argv=None,mybanner=None,loglevel=1):
     
     session=None
     session_name=""
-    CONFIG_FILE = DEFAULT_CONFIG_FILE
+    STARTUP_FILE = DEFAULT_STARTUP_FILE
+    PRESTART_FILE = DEFAULT_PRESTART_FILE
 
     iface = None
     try:
-        opts=getopt.getopt(argv[1:], "hs:Cc:")
+        opts=getopt.getopt(argv[1:], "hs:Cc:Pp:")
         for opt, parm in opts[0]:
             if opt == "-h":
-                usage()
+                _usage()
             elif opt == "-s":
                 session_name = parm
             elif opt == "-c":
-                CONFIG_FILE = parm
+                STARTUP_FILE = parm
             elif opt == "-C":
-                CONFIG_FILE = None
+                STARTUP_FILE = None
+            elif opt == "-p":
+                PRESTART_FILE = parm
+            elif opt == "-P":
+                PRESTART_FILE = None
         
         if len(opts[1]) > 0:
             raise getopt.GetoptError("Too many parameters : [%s]" % " ".join(opts[1]))
@@ -190,8 +173,22 @@ def interact(mydict=None,argv=None,mybanner=None,loglevel=1):
         sys.exit(1)
 
 
-    if CONFIG_FILE:
-        read_config_file(CONFIG_FILE)
+    from config import conf
+    if PRESTART_FILE:
+        _read_config_file(PRESTART_FILE)
+
+    scapy_builtins = __import__("all",globals(),locals(),".").__dict__
+    __builtin__.__dict__.update(scapy_builtins)
+    globkeys = scapy_builtins.keys()
+    globkeys.append("scapy_session")
+    scapy_builtins=None # XXX replace with "with" statement
+    if mydict is not None:
+        __builtin__.__dict__.update(mydict)
+        globkeys += mydict.keys()
+    
+
+    if STARTUP_FILE:
+        _read_config_file(STARTUP_FILE)
         
     if session_name:
         try:
@@ -245,17 +242,5 @@ def interact(mydict=None,argv=None,mybanner=None,loglevel=1):
         except:
             pass
 
-def read_config_file(configfile):
-    try:
-        execfile(configfile)
-    except IOError,e:
-        log_loading.warning("Cannot read config file [%s] [%s]" % (configfile,e))
-    except Exception,e:
-        log_loading.exception("Error during evaluation of config file [%s]" % configfile)
-        
-
 if __name__ == "__main__":
     interact()
-else:
-    if DEFAULT_CONFIG_FILE:
-        read_config_file(DEFAULT_CONFIG_FILE)
-- 
GitLab