]> scripts.mit.edu Git - wizard.git/blobdiff - wizard/shell.py
Fix kernel buffer overflow by avoiding passing --debug to subprocesses.
[wizard.git] / wizard / shell.py
index 6c08dc35c5082681812a477544256a0f98fdd020..2197cdecc3cf219beebe5d5afe10744917682869 100644 (file)
+"""
+Wrappers around subprocess functionality that simulate an actual shell.
+
+.. testsetup:: *
+
+    from wizard.shell import *
+"""
+
 import subprocess
+import logging
 import sys
 import os
-import Queue
-import threading
+import errno
 
-import wizard as _wizard
+import wizard
 from wizard import util
 
-wizard = sys.argv[0]
-
-class CallError(_wizard.Error):
-    def __init__(self, code, args, stdout, stderr):
-        self.code = code
-        self.args = args
-        self.stdout = stdout
-        self.stderr = stderr
-    def __str__(self):
-        return "CallError [%d]" % self.code
-
-class PythonCallError(CallError):
-    def __init__(self, code, args, stdout, stderr):
-        self.name = util.get_exception_name(stderr)
-        CallError.__init__(self, code, args, stdout, stderr)
-    def __str__(self):
-        return "PythonCallError [%s]" % self.name
+wizard_bin = sys.argv[0]
+"""
+This is the path to the wizard executable as specified
+by the caller; it lets us recursively invoke wizard.
+"""
 
 def is_python(args):
-    return args[0] == "python" or args[0] == wizard
+    """Detects whether or not an argument list invokes a Python program."""
+    return args[0] == "python" or args[0] == "wizard"
+
+def drop_priviledges(dir, log_file):
+    """
+    Checks if we are running as root.  If we are, attempt to drop
+    priviledges to the user who owns ``dir``, by re-calling
+    itself using sudo with exec, such that the new process subsumes our
+    current one.  If ``log_file`` is passed, the file is chown'ed
+    to the user we are dropping priviledges to, so the subprocess
+    can write to it.
+    """
+    if os.getuid():
+        return
+    uid = util.get_dir_uid(dir)
+    if not uid:
+        return
+    args = []
+    for k,v in os.environ.items():
+        if k.startswith('WIZARD_') or k == "SSH_GSSAPI_NAME":
+            args.append("%s=%s" % (k,v))
+    args += sys.argv
+    logging.debug("Dropping priviledges")
+    if log_file: os.chown(log_file, uid, -1)
+    os.execlp('sudo', 'sudo', '-u', '#' + str(uid), *args)
 
 class Shell(object):
-    """An advanced shell, with the ability to do dry-run and log commands"""
-    def __init__(self, logger = False, dry = False):
-        """ `logger`    The logger
-            `dry`       Don't run any commands, just print them"""
-        self.logger = logger
+    """
+    An advanced shell that performs logging.  If ``dry`` is ``True``,
+    no commands are actually run.
+    """
+    def __init__(self, dry = False):
         self.dry = dry
+        self.cwd = None
     def call(self, *args, **kwargs):
+        """
+        Performs a system call.  The actual executable and options should
+        be passed as arguments to this function.  It will magically
+        ensure that 'wizard' as a command works. Several keyword arguments
+        are also supported:
+
+        :param python: explicitly marks the subprocess as Python or not Python
+            for improved error reporting.  By default, we use
+            :func:`is_python` to autodetect this.
+        :param input: input to feed the subprocess on standard input.
+        :param interactive: whether or not directly hook up all pipes
+            to the controlling terminal, to allow interaction with subprocess.
+        :param strip: if ``True``, instead of returning a tuple,
+            return the string stdout output of the command with trailing newlines
+            removed.  This emulates the behavior of backticks and ``$()`` in Bash.
+            Prefer to use :meth:`eval` instead (you should only need to explicitly
+            specify this if you are using another wrapper around this function).
+        :param log: if True, we log the call as INFO, if False, we log the call
+            as DEBUG, otherwise, we detect based on ``strip``.
+        :param stdout:
+        :param stderr:
+        :param stdin: a file-type object that will be written to or read from as a pipe.
+        :returns: a tuple of strings ``(stdout, stderr)``, or a string ``stdout``
+            if ``strip`` is specified.
+
+        >>> sh = Shell()
+        >>> sh.call("echo", "Foobar")
+        ('Foobar\\n', '')
+        >>> sh.call("cat", input='Foobar')
+        ('Foobar', '')
+        """
+        self._wait()
+        kwargs.setdefault("interactive", False)
+        kwargs.setdefault("strip", False)
         kwargs.setdefault("python", None)
-        if self.dry or self.logger:
-            self.logger.info("$ " + ' '.join(args))
+        kwargs.setdefault("log", None)
+        kwargs.setdefault("stdout", subprocess.PIPE)
+        kwargs.setdefault("stdin", subprocess.PIPE)
+        kwargs.setdefault("stderr", subprocess.PIPE)
+        msg = "Running `" + ' '.join(args) + "`"
+        if kwargs["strip"] and not kwargs["log"] is True or kwargs["log"] is False:
+            logging.debug(msg)
+        else:
+            logging.info(msg)
         if self.dry:
-            return
+            if kwargs["strip"]:
+                return ''
+            return None, None
         if kwargs["python"] is None and is_python(args):
             kwargs["python"] = True
-        proc = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
-        if hasattr(self, "async"):
-            self.async(proc, args, **kwargs)
+        if args[0] == "wizard":
+            args = list(args)
+            args[0] = wizard_bin
+        kwargs.setdefault("input", None)
+        if kwargs["interactive"]:
+            stdout=sys.stdout
+            stdin=sys.stdin
+            stderr=sys.stderr
+        else:
+            stdout=kwargs["stdout"]
+            stdin=kwargs["stdin"]
+            stderr=kwargs["stderr"]
+        # XXX: There is a possible problem here where we can fill up
+        # the kernel buffer if we have 64KB of data.  This shouldn't
+        # normally be a problem, and the fix for such case would be to write to
+        # temporary files instead of a pipe.
+        #
+        # However, it *is* a problem when you do something silly, like
+        # pass --debug to mass-upgrade.
+        #
+        # Another possible way of fixing this is converting from a
+        # waitpid() pump to a select() pump, creating a pipe to
+        # ourself, and then setting up a SIGCHILD handler to write a single
+        # byte to the pipe to get us out of select() when a subprocess exits.
+        proc = subprocess.Popen(args, stdout=stdout, stderr=stderr, stdin=stdin, cwd=self.cwd, )
+        if self._async(proc, args, **kwargs):
             return proc
-        stdout, stderr = proc.communicate()
-        self.log(stdout, stderr)
+        stdout, stderr = proc.communicate(kwargs["input"])
+        # can occur if we were doing interactive communication; i.e.
+        # we didn't pass in PIPE.
+        if stdout is None:
+            stdout = ""
+        if stderr is None:
+            stderr = ""
+        if not kwargs["interactive"]:
+            if kwargs["strip"]:
+                self._log(None, stderr)
+            else:
+                self._log(stdout, stderr)
         if proc.returncode:
             if kwargs["python"]: eclass = PythonCallError
             else: eclass = CallError
             raise eclass(proc.returncode, args, stdout, stderr)
+        if kwargs["strip"]:
+            return str(stdout).rstrip("\n")
         return (stdout, stderr)
-    def log(self, stdout, stderr):
-        if self.logger and stdout: self.logger.info(stdout)
-        if self.logger and stderr: self.logger.info("STDERR: " + stderr)
+    def _log(self, stdout, stderr):
+        """Logs the standard output and standard input from a command."""
+        if stdout:
+            logging.debug("STDOUT:\n" + stdout)
+        if stderr:
+            logging.debug("STDERR:\n" + stderr)
+    def _wait(self):
+        pass
+    def _async(self, *args, **kwargs):
+        return False
     def callAsUser(self, *args, **kwargs):
+        """
+        Performs a system call as a different user.  This is only possible
+        if you are running as root.  Keyword arguments
+        are the same as :meth:`call` with the following additions:
+
+        :param user: name of the user to run command as.
+        :param uid: uid of the user to run command as.
+
+        .. note::
+
+            The resulting system call internally uses :command:`sudo`,
+            and as such environment variables will get scrubbed.  We
+            manually preserve :envvar:`SSH_GSSAPI_NAME`.
+        """
         user = kwargs.pop("user", None)
+        uid = kwargs.pop("uid", None)
         kwargs.setdefault("python", is_python(args))
-        if not user: return self.call(*args, **kwargs)
-        return self.call("sudo", "-u", user, *args, **kwargs)
+        if not user and not uid: return self.call(*args, **kwargs)
+        if os.getenv("SSH_GSSAPI_NAME"):
+            # This might be generalized as "preserve some environment"
+            args = list(args)
+            args.insert(0, "SSH_GSSAPI_NAME=" + os.getenv("SSH_GSSAPI_NAME"))
+        if uid: return self.call("sudo", "-u", "#" + str(uid), *args, **kwargs)
+        if user: return self.call("sudo", "-u", user, *args, **kwargs)
+    def safeCall(self, *args, **kwargs):
+        """
+        Checks if the owner of the current working directory is the same
+        as the current user, and if it isn't, attempts to sudo to be
+        that user.  The intended use case is for calling Git commands
+        when running as root, but this method should be used when
+        interfacing with any moderately complex program that depends
+        on working directory context.  Keyword arguments are the
+        same as :meth:`call`.
+        """
+        if os.getuid():
+            return self.call(*args, **kwargs)
+        uid = os.stat(os.getcwd()).st_uid
+        # consider also checking ruid?
+        if uid != os.geteuid():
+            kwargs['uid'] = uid
+            return self.callAsUser(*args, **kwargs)
+        else:
+            return self.call(*args, **kwargs)
+    def eval(self, *args, **kwargs):
+        """
+        Evaluates a command and returns its output, with trailing newlines
+        stripped (like backticks in Bash).  This is a convenience method for
+        calling :meth:`call` with ``strip``.
+
+            >>> sh = Shell()
+            >>> sh.eval("echo", "Foobar") 
+            'Foobar'
+        """
+        kwargs["strip"] = True
+        return self.call(*args, **kwargs)
+    def setcwd(self, cwd):
+        """
+        Sets the directory processes are executed in. This sets a value
+        to be passed as the ``cwd`` argument to ``subprocess.Popen``.
+        """
+        self.cwd = cwd
 
 class ParallelShell(Shell):
-    """Commands are queued here, and executed in parallel (with
-    threading) in accordance with the maximum number of allowed
-    subprocesses, and result in callback execution when they finish."""
-    def __init__(self, logger = False, dry = False, max = 10):
-        super(ParallelShell, self).__init__(logger=logger,dry=dry)
+    """
+    Modifies the semantics of :class:`Shell` so that
+    commands are queued here, and executed in parallel using waitpid
+    with ``max`` subprocesses, and result in callback execution
+    when they finish.
+
+    .. method:: call(*args, **kwargs)
+
+        Enqueues a system call for parallel processing.  If there are
+        no openings in the queue, this will block.  Keyword arguments
+        are the same as :meth:`Shell.call` with the following additions:
+
+        :param on_success: Callback function for success (zero exit status).
+            The callback function should accept two arguments,
+            ``stdout`` and ``stderr``.
+        :param on_error: Callback function for failure (nonzero exit status).
+            The callback function should accept one argument, the
+            exception that would have been thrown by the synchronous
+            version.
+        :return: The :class:`subprocess.Proc` object that was opened.
+
+    .. method:: callAsUser(*args, **kwargs)
+
+        Enqueues a system call under a different user for parallel
+        processing.  Keyword arguments are the same as
+        :meth:`Shell.callAsUser` with the additions of keyword
+        arguments from :meth:`call`.
+
+    .. method:: safeCall(*args, **kwargs)
+
+        Enqueues a "safe" call for parallel processing.  Keyword
+        arguments are the same as :meth:`Shell.safeCall` with the
+        additions of keyword arguments from :meth:`call`.
+
+    .. method:: eval(*args, **kwargs)
+
+        No difference from :meth:`call`.  Consider having a
+        non-parallel shell if the program you are shelling out
+        to is fast.
+
+    """
+    def __init__(self, dry = False, max = 10):
+        super(ParallelShell, self).__init__(dry=dry)
         self.running = {}
         self.max = max # maximum of commands to run in parallel
-    def async(self, proc, args, python, on_success, on_error):
-        """Gets handed a subprocess.Proc object from our deferred
-        execution"""
+    @staticmethod
+    def make(no_parallelize, max):
+        """Convenience method oriented towards command modules."""
+        if no_parallelize:
+            return DummyParallelShell()
+        else:
+            return ParallelShell(max=max)
+    def _async(self, proc, args, python, on_success, on_error, **kwargs):
+        """
+        Gets handed a :class:`subprocess.Proc` object from our deferred
+        execution.  See :meth:`Shell.call` source code for details.
+        """
         self.running[proc.pid] = (proc, args, python, on_success, on_error)
-    def wait(self):
+        return True # so that the parent function returns
+    def _wait(self):
+        """
+        Blocking call that waits for an open subprocess slot.  This is
+        automatically called by :meth:`Shell.call`.
+        """
+        # XXX: This API sucks; the actual call/callAsUser call should
+        # probably block automatically (unless I have a good reason not to)
         # bail out immediately on initial ramp up
         if len(self.running) < self.max: return
         # now, wait for open pids.
         try:
-            pid, status = os.waitpid(-1, 0)
+            self.reap(*os.waitpid(-1, 0))
+        except OSError as e:
+            if e.errno == errno.ECHILD: return
+            raise
+    def join(self):
+        """Waits for all of our subprocesses to terminate."""
+        try:
+            while True:
+                self.reap(*os.waitpid(-1, 0))
         except OSError as e:
             if e.errno == errno.ECHILD: return
-            raise e
+            raise
+    def reap(self, pid, status):
+        """Reaps a process."""
         # ooh, zombie process. reap it
         proc, args, python, on_success, on_error = self.running.pop(pid)
         # XXX: this is slightly dangerous; should actually use
         # temporary files
         stdout = proc.stdout.read()
         stderr = proc.stderr.read()
-        self.log(stdout, stderr)
+        self._log(stdout, stderr)
         if status:
             if python: eclass = PythonCallError
             else: eclass = CallError
             on_error(eclass(proc.returncode, args, stdout, stderr))
             return
         on_success(stdout, stderr)
-    def join(self):
-        """Waits for all of our subprocesses to terminate."""
-        try:
-            while os.waitpid(-1, 0):
-                pass
-        except OSError as e:
-            if e.errno == errno.ECHILD: return
-            raise e
+
+# Setup a convenience global instance
+shell = Shell()
+call = shell.call
+callAsUser = shell.callAsUser
+safeCall = shell.safeCall
+eval = shell.eval
 
 class DummyParallelShell(ParallelShell):
-    """Same API as ParallelShell, but doesn't actually parallelize (by
-    using only one thread)"""
-    def __init__(self, logger = False, dry = False):
-        super(DummyParallelShell, self).__init__(logger, dry, max=1)
+    """Same API as :class:`ParallelShell`, but doesn't actually
+    parallelize (i.e. all calls to :meth:`wait` block.)"""
+    def __init__(self, dry = False):
+        super(DummyParallelShell, self).__init__(dry=dry, max=1)
+
+class Error(wizard.Error):
+    """Base exception for this module"""
+    pass
+
+class CallError(Error):
+    """Indicates that a subprocess call returned a nonzero exit status."""
+    #: The exit code of the failed subprocess.
+    code = None
+    #: List of the program and arguments that failed.
+    args = None
+    #: The stdout of the program.
+    stdout = None
+    #: The stderr of the program.
+    stderr = None
+    def __init__(self, code, args, stdout, stderr):
+        self.code = code
+        self.args = args
+        self.stdout = stdout
+        self.stderr = stderr
+    def __str__(self):
+        compact = self.stderr.rstrip().split("\n")[-1]
+        return "%s (exited with %d)\n%s" % (compact, self.code, self.stderr)
+
+class PythonCallError(CallError):
+    """
+    Indicates that a Python subprocess call had an uncaught exception.
+    This exception also contains the attributes of :class:`CallError`.
+    """
+    #: Name of the uncaught exception.
+    name = None
+    def __init__(self, code, args, stdout, stderr):
+        if stderr: self.name = util.get_exception_name(stderr)
+        CallError.__init__(self, code, args, stdout, stderr)
+    def __str__(self):
+        if self.name:
+            return "PythonCallError [%s]\n%s" % (self.name, self.stderr)
+        else:
+            return "PythonCallError\n%s" % self.stderr