]> scripts.mit.edu Git - wizard.git/commitdiff
Refactor out common mass operations to command module.
authorEdward Z. Yang <ezyang@mit.edu>
Sun, 23 Aug 2009 06:42:31 +0000 (02:42 -0400)
committerEdward Z. Yang <ezyang@mit.edu>
Sun, 23 Aug 2009 06:42:31 +0000 (02:42 -0400)
Signed-off-by: Edward Z. Yang <ezyang@mit.edu>
bin/wizard
wizard/command/__init__.py
wizard/command/mass_migrate.py
wizard/shell.py
wizard/sset.py

index 85ac65d00f928a13374269d19986b498cc86c2a8..aad6c9258a5050ecbf67f747c97331076a74ab7f 100755 (executable)
@@ -39,14 +39,14 @@ See '%prog help COMMAND' for more information on a specific command."""
     baton.add("--srv-path", dest="srv_path",
         default=getenvpath("WIZARD_SRV_PATH") or "/afs/athena.mit.edu/contrib/scripts/git/autoinstalls",
         help="Location of autoinstall Git repositories, such that $REPO_PATH/$APP.git is a repository (for development work).  Environment variable is WIZARD_SRV_PATH.")
-    baton.add("--log-dir", dest="log_dir",
-        default=getenvpath("WIZARD_LOG_DIR") or None,
-        help="Log files for Wizard children processes are placed here.")
     try:
         command_name = args[0]
     except IndexError:
         parser.print_help()
         raise SystemExit(1)
+    baton.add("--log-dir", dest="log_dir",
+        default=getenvpath("WIZARD_LOG_DIR") or "/tmp/wizard-%s" % command_name,
+        help="Log files for Wizard children processes are placed here.")
     if command_name == "help":
         try:
             help_module = get_command(rest_argv[0])
index eeb5336a236dcaffa999bf3f8867b2a08cfab99b..79c957d1dc54e3c1d1fd3ccb18d242156a66c62b 100644 (file)
@@ -73,6 +73,59 @@ def makeBaseArgs(options, **grab):
             args.append(str(value))
     return args
 
+def security_check_homedir(location):
+    """
+    Performs a check against a directory to determine if current
+    directory's owner has a home directory that is a parent directory.
+    This protects against malicious mountpoints, and is roughly equivalent
+    to the suexec checks.
+    """
+    uid = util.get_dir_uid(location)
+    real = os.path.realpath(location)
+    try:
+        if not real.startswith(pwd.getpwuid(uid).pw_dir + "/"):
+            logging.error("Security check failed, owner of deployment and"
+                    "owner of home directory mismatch for %s" % d.location)
+            return False
+    except KeyError:
+        logging.error("Security check failed, could not look up"
+                "owner of %s (uid %d)" % (location, uid))
+        return False
+    return True
+
+def calculate_log_name(log_dir, i, dir):
+    """
+    Calculates a log entry given a log directory, numeric identifier, and
+    directory under operation.
+    """
+    return os.path.join(log_dir, "%04d" % i + dir.replace('/', '-') + ".log")
+
+def open_logs(log_dir, log_names=('warnings', 'errors')):
+    """
+    Opens a number of log files for auxiliary reporting.  You can override what
+    log files to generate using ``log_names``, which corresponds to the tuple
+    of log files you will receive, i.e. the default returns a tuple
+    ``(warnings.log file object, errors.log file object)``.
+
+    .. note::
+
+        The log directory is chmod'ed 777 after creation, to enable
+        de-priviledged processes to create files.
+    """
+    # must not be on AFS, since subprocesses won't be
+    # able to write to the logfiles do the to the AFS patch.
+    try:
+        os.mkdir(log_dir)
+    except OSError as e:
+        if e.errno != errno.EEXIST:
+            raise
+        #if create_subdirs:
+        #    log_dir = os.path.join(log_dir, str(int(time.time())))
+        #    os.mkdir(log_dir) # if fails, be fatal
+        #    # XXX: update last symlink
+    os.chmod(log_dir, 0o777)
+    return (open(os.path.join(os.path.join(log_dir, "%s.log" % x)), "a") for x in log_names)
+
 class NullLogHandler(logging.Handler):
     """Log handler that doesn't do anything"""
     def emit(self, record):
index 15fde37ef6ca0b85e4cffde9174a3cb4dfe18669..00dd006f81cd1db21b289bbb2d9629392ce57d07 100644 (file)
@@ -16,25 +16,26 @@ def main(argv, baton):
     app = args[0]
     base_args = calculate_base_args(options)
     sh = shell.ParallelShell.make(options.no_parallelize, options.max)
-    seen = sset.make(options)
+    seen = sset.make(options.seen)
     is_root = not os.getuid()
-    warnings_log, errors_log = open_aggregate_logs(options)
+    warnings_log, errors_log = command.open_logs(options.log_dir)
     # loop stuff
     errors = {}
     i = 0
     # [] needed to workaround subtle behavior of frozenset("")
     deploys = deploy.parse_install_lines([app], options.versions_path)
     requested_deploys = itertools.islice(deploys, options.limit)
-    for i, d in enumerate(requested_deploys, 1)
+    for i, d in enumerate(requested_deploys, 1):
         # check if we want to punt due to --limit
         if d.location in seen:
             continue
-        if is_root and not security_check_homedir(d):
+        if is_root and not command.security_check_homedir(d):
             continue
+        logging.info("Processing %s" % d.location)
         child_args = list(base_args)
         # calculate the log file, if a log dir was specified
         if options.log_dir:
-            log_file = os.path.join(options.log_dir, calculate_log_name(i, d.location))
+            log_file = command.calculate_log_name(options.log_dir, i, d.location)
             child_args.append("--log-file=" + log_file)
         # actual meat
         def make_on_pair(d, i):
@@ -93,7 +94,7 @@ untrusted repositories."""
     parser.add_option("--force", dest="force", action="store_true",
             default=False, help="Force migrations to occur even if .scripts or .git exists.")
     parser.add_option("--limit", dest="limit", type="int",
-            default=0, help="Limit the number of autoinstalls to look at.")
+            default=None, help="Limit the number of autoinstalls to look at.")
     baton.push(parser, "versions_path")
     baton.push(parser, "srv_path")
     options, args, = parser.parse_all(argv)
@@ -105,45 +106,7 @@ untrusted repositories."""
         options.no_parallelize = True
     return options, args
 
-def open_aggregate_logs(options):
-    warnings_logname = "/tmp/wizard-migrate-warnings.log"
-    errors_logname = "/tmp/wizard-migrate-errors.log"
-    if options.log_dir:
-        # must not be on AFS, since subprocesses won't be
-        # able to write to the logfiles do the to the AFS patch.
-        try:
-            os.mkdir(options.log_dir)
-        except OSError as e:
-            if e.errno != errno.EEXIST:
-                raise
-            if options.force:
-                options.log_dir = os.path.join(options.log_dir, str(int(time.time())))
-                os.mkdir(options.log_dir) # if fails, be fatal
-        os.chmod(options.log_dir, 0o777)
-        warnings_logname = os.path.join(options.log_dir, "warnings.log")
-        errors_logname = os.path.join(options.log_dir, "errors.log")
-    warnings_log = open(warnings_logname, "a")
-    errors_log = open(errors_logname, "a")
-    return warnings_log, errors_log
-
-def security_check_homedir(d):
-    uid = util.get_dir_uid(d.location)
-    real = os.path.realpath(d.location)
-    try:
-        if not real.startswith(pwd.getpwuid(uid).pw_dir + "/"):
-            logging.error("Security check failed, owner of deployment and"
-                    "owner of home directory mismatch for %s" % d.location)
-            return False
-    except KeyError:
-        logging.error("Security check failed, could not look up"
-                "owner of %s (uid %d)" % (d.location, uid))
-        return False
-    return True
-
 def calculate_base_args(options):
     return command.makeBaseArgs(options, dry_run="--dry-run", srv_path="--srv-path",
             force="--force")
 
-def calculate_log_name(i, dir):
-    return "%04d" % i + dir.replace('/', '-') + ".log"
-
index 79b2a71fa3e93b49def7d7c8d703fc07dc91005c..ad0c9f622b14382e6f7410200b2616103a67eff6 100644 (file)
@@ -245,7 +245,7 @@ class ParallelShell(Shell):
         self.running = {}
         self.max = max # maximum of commands to run in parallel
     @staticmethod
-    def make(self, no_parallelize, max):
+    def make(no_parallelize, max):
         """Convenience method oriented towards command modules."""
         if no_parallelize:
             return DummyParallelShell()
index 22cc971858c54124c38d60804db0749d4dd2b1cd..e6a4527dbfacbbe1378f4e17f6af4fdd5f43294d 100644 (file)
@@ -4,7 +4,7 @@ def make(seen_file):
     if seen_file:
         return SerializedSet(seen_file)
     else:
-        return DummySerializedSet(seen_file)
+        return DummySerializedSet()
 
 class ISerializedSet(object):
     def put(self, name):