]> scripts.mit.edu Git - wizard.git/blobdiff - lib/wizard/shell.py
Rewrite ParallelShell to use waitpid() instead of threads.
[wizard.git] / lib / wizard / shell.py
index 6c138ba5d76056f84e6eac591032d6db6464c863..aee8995bfbfde822c663087205d2d7a494745601 100644 (file)
@@ -1,5 +1,6 @@
 import subprocess
 import sys
+import os
 import Queue
 import threading
 
@@ -35,97 +36,77 @@ class Shell(object):
         self.logger = logger
         self.dry = dry
     def call(self, *args, **kwargs):
-        (python,) = ("python" in kwargs) and [kwargs["python"]] or [None]
+        kwargs.setdefault("python", None)
         if self.dry or self.logger:
             self.logger.info("$ " + ' '.join(args))
-        if self.dry: return
-        if python is None and is_python(args):
-            python = True
+        if self.dry:
+            return
+        if kwargs["python"] is None and is_python(args):
+            kwargs["python"] = True
         proc = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+        if hasattr(self, "async"):
+            self.async(proc, args, **kwargs)
+            return proc
         stdout, stderr = proc.communicate()
-        if self.logger and stdout: self.logger.info(stdout)
+        self.log(stdout, stderr)
         if proc.returncode:
-            if python: eclass = PythonCallError
+            if kwargs["python"]: eclass = PythonCallError
             else: eclass = CallError
             raise eclass(proc.returncode, args, stdout, stderr)
         return (stdout, stderr)
+    def log(self, stdout, stderr):
+        if self.logger and stdout: self.logger.info(stdout)
+        if self.logger and stderr: self.logger.info("STDERR: " + stderr)
     def callAsUser(self, *args, **kwargs):
-        user = ("user" in kwargs) and kwargs["user"] or None
-        if not user: return self.call(*args)
-        return self.call("sudo", "-u", user, *args, python=is_python(args))
-
-class ShellThread(threading.Thread):
-    """Little thread that does the dispatching"""
-    def __init__(self, queue, logger=False, dry=False):
-        self.queue = queue
-        self.logger = logger
-        self.shell = Shell(logger=logger, dry=dry)
-        super(ShellThread, self).__init__()
-    def run(self):
-        while True:
-            call = self.queue.get()
-            # thread termination mechanism
-            if not call: return
-            name, args, kwargs, on_success, on_error = call
-            try:
-                result = getattr(self.shell, name)(*args, **kwargs)
-                on_success(result)
-            except CallError as e:
-                if isinstance(e, PythonCallError):
-                    # check if the subprocess got a KeyboardInterrupt
-                    if e.name == "KeyboardInterrupt":
-                        raise KeyboardInterrupt
-                on_error(e)
-            except:
-                # XXX: This is really scary
-                self.logger.error("Uncaught exception in thread")
-                raise
+        user = kwargs.pop("user", None)
+        kwargs.setdefault("python", is_python(args))
+        if not user: return self.call(*args, **kwargs)
+        return self.call("sudo", "-u", user, *args, **kwargs)
 
-class ParallelShell(object):
+class ParallelShell(Shell):
     """Commands are queued here, and executed in parallel (with
     threading) in accordance with the maximum number of allowed
     subprocesses, and result in callback execution when they finish."""
-    def __init__(self, logger = False, dry = False, max_threads = 10):
-        self.logger = logger
-        self.dry = dry
-        self.threads = []
-        # queue of call tuples (method name, args and kwargs) to
-        # be threaded
-        self.queue = Queue.Queue()
-        # build our threads
-        for n in range(max_threads):
-            self.threads.append(ShellThread(self.queue, logger=logger, dry=dry))
-    def __getattr__(self, name):
-        # override call and callAsUser (and possibly others)
-        def thunk(*args, **kwargs):
-            on_success = kwargs.pop("on_success")
-            on_error   = kwargs.pop("on_error")
-            self.queue.put((name, args, kwargs, on_success, on_error))
-        return thunk
-    def start(self):
-        for thread in self.threads:
-            thread.start()
+    def __init__(self, logger = False, dry = False, max = 10):
+        super(ParallelShell, self).__init__(logger=logger,dry=dry)
+        self.running = {}
+        self.max = max # maximum of commands to run in parallel
+    def async(self, proc, args, python, on_success, on_error):
+        """Gets handed a subprocess.Proc object from our deferred
+        execution"""
+        self.running[proc.pid] = (proc, args, python, on_success, on_error)
+    def wait(self):
+        # bail out immediately on initial ramp up
+        if len(self.running) < self.max: return
+        # now, wait for open pids.
+        try:
+            pid, status = os.waitpid(-1, 0)
+        except OSError as e:
+            return
+        # ooh, zombie process. reap it
+        proc, args, python, on_success, on_error = self.running.pop(pid)
+        # XXX: this is slightly dangerous; should actually use
+        # temporary files
+        stdout = proc.stdout.read()
+        stderr = proc.stderr.read()
+        self.log(stdout, stderr)
+        if status:
+            if python: eclass = PythonCallError
+            else: eclass = CallError
+            on_error(eclass(proc.returncode, args, stdout, stderr))
+            return
+        on_success(stdout, stderr)
     def join(self):
-        """Waits for all of our subthreads to terminate."""
-        for thread in self.threads:
-            # generate as many nops as we have threads to
-            # terminate them
-            self.queue.put(False)
-        # wait for the queue to empty
-        self.queue.join()
-        # defense in depth: make sure all the threads
-        # terminate too
-        for thread in self.threads:
-            thread.join()
-    def terminate(self):
-        # empty the queue
-        while not self.queue.empty():
-            self.queue.get_nowait()
-        self.join()
+        """Waits for all of our subprocesses to terminate."""
+        try:
+            while os.waitpid(-1, 0):
+                pass
+        except OSError as e:
+            return
 
 class DummyParallelShell(ParallelShell):
     """Same API as ParallelShell, but doesn't actually parallelize (by
     using only one thread)"""
     def __init__(self, logger = False, dry = False):
-        super(DummyParallelShell, self).__init__(logger, dry, max_threads=1)
+        super(DummyParallelShell, self).__init__(logger, dry, max=1)