]> scripts.mit.edu Git - wizard.git/blobdiff - lib/wizard/shell.py
Rewrite ParallelShell to use waitpid() instead of threads.
[wizard.git] / lib / wizard / shell.py
index 6c138ba5d76056f84e6eac591032d6db6464c863..aee8995bfbfde822c663087205d2d7a494745601 100644 (file)
@@ -1,5 +1,6 @@
 import subprocess
 import sys
 import subprocess
 import sys
+import os
 import Queue
 import threading
 
 import Queue
 import threading
 
@@ -35,97 +36,77 @@ class Shell(object):
         self.logger = logger
         self.dry = dry
     def call(self, *args, **kwargs):
         self.logger = logger
         self.dry = dry
     def call(self, *args, **kwargs):
-        (python,) = ("python" in kwargs) and [kwargs["python"]] or [None]
+        kwargs.setdefault("python", None)
         if self.dry or self.logger:
             self.logger.info("$ " + ' '.join(args))
         if self.dry or self.logger:
             self.logger.info("$ " + ' '.join(args))
-        if self.dry: return
-        if python is None and is_python(args):
-            python = True
+        if self.dry:
+            return
+        if kwargs["python"] is None and is_python(args):
+            kwargs["python"] = True
         proc = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
         proc = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+        if hasattr(self, "async"):
+            self.async(proc, args, **kwargs)
+            return proc
         stdout, stderr = proc.communicate()
         stdout, stderr = proc.communicate()
-        if self.logger and stdout: self.logger.info(stdout)
+        self.log(stdout, stderr)
         if proc.returncode:
         if proc.returncode:
-            if python: eclass = PythonCallError
+            if kwargs["python"]: eclass = PythonCallError
             else: eclass = CallError
             raise eclass(proc.returncode, args, stdout, stderr)
         return (stdout, stderr)
             else: eclass = CallError
             raise eclass(proc.returncode, args, stdout, stderr)
         return (stdout, stderr)
+    def log(self, stdout, stderr):
+        if self.logger and stdout: self.logger.info(stdout)
+        if self.logger and stderr: self.logger.info("STDERR: " + stderr)
     def callAsUser(self, *args, **kwargs):
     def callAsUser(self, *args, **kwargs):
-        user = ("user" in kwargs) and kwargs["user"] or None
-        if not user: return self.call(*args)
-        return self.call("sudo", "-u", user, *args, python=is_python(args))
-
-class ShellThread(threading.Thread):
-    """Little thread that does the dispatching"""
-    def __init__(self, queue, logger=False, dry=False):
-        self.queue = queue
-        self.logger = logger
-        self.shell = Shell(logger=logger, dry=dry)
-        super(ShellThread, self).__init__()
-    def run(self):
-        while True:
-            call = self.queue.get()
-            # thread termination mechanism
-            if not call: return
-            name, args, kwargs, on_success, on_error = call
-            try:
-                result = getattr(self.shell, name)(*args, **kwargs)
-                on_success(result)
-            except CallError as e:
-                if isinstance(e, PythonCallError):
-                    # check if the subprocess got a KeyboardInterrupt
-                    if e.name == "KeyboardInterrupt":
-                        raise KeyboardInterrupt
-                on_error(e)
-            except:
-                # XXX: This is really scary
-                self.logger.error("Uncaught exception in thread")
-                raise
+        user = kwargs.pop("user", None)
+        kwargs.setdefault("python", is_python(args))
+        if not user: return self.call(*args, **kwargs)
+        return self.call("sudo", "-u", user, *args, **kwargs)
 
 
-class ParallelShell(object):
+class ParallelShell(Shell):
     """Commands are queued here, and executed in parallel (with
     threading) in accordance with the maximum number of allowed
     subprocesses, and result in callback execution when they finish."""
     """Commands are queued here, and executed in parallel (with
     threading) in accordance with the maximum number of allowed
     subprocesses, and result in callback execution when they finish."""
-    def __init__(self, logger = False, dry = False, max_threads = 10):
-        self.logger = logger
-        self.dry = dry
-        self.threads = []
-        # queue of call tuples (method name, args and kwargs) to
-        # be threaded
-        self.queue = Queue.Queue()
-        # build our threads
-        for n in range(max_threads):
-            self.threads.append(ShellThread(self.queue, logger=logger, dry=dry))
-    def __getattr__(self, name):
-        # override call and callAsUser (and possibly others)
-        def thunk(*args, **kwargs):
-            on_success = kwargs.pop("on_success")
-            on_error   = kwargs.pop("on_error")
-            self.queue.put((name, args, kwargs, on_success, on_error))
-        return thunk
-    def start(self):
-        for thread in self.threads:
-            thread.start()
+    def __init__(self, logger = False, dry = False, max = 10):
+        super(ParallelShell, self).__init__(logger=logger,dry=dry)
+        self.running = {}
+        self.max = max # maximum of commands to run in parallel
+    def async(self, proc, args, python, on_success, on_error):
+        """Gets handed a subprocess.Proc object from our deferred
+        execution"""
+        self.running[proc.pid] = (proc, args, python, on_success, on_error)
+    def wait(self):
+        # bail out immediately on initial ramp up
+        if len(self.running) < self.max: return
+        # now, wait for open pids.
+        try:
+            pid, status = os.waitpid(-1, 0)
+        except OSError as e:
+            return
+        # ooh, zombie process. reap it
+        proc, args, python, on_success, on_error = self.running.pop(pid)
+        # XXX: this is slightly dangerous; should actually use
+        # temporary files
+        stdout = proc.stdout.read()
+        stderr = proc.stderr.read()
+        self.log(stdout, stderr)
+        if status:
+            if python: eclass = PythonCallError
+            else: eclass = CallError
+            on_error(eclass(proc.returncode, args, stdout, stderr))
+            return
+        on_success(stdout, stderr)
     def join(self):
     def join(self):
-        """Waits for all of our subthreads to terminate."""
-        for thread in self.threads:
-            # generate as many nops as we have threads to
-            # terminate them
-            self.queue.put(False)
-        # wait for the queue to empty
-        self.queue.join()
-        # defense in depth: make sure all the threads
-        # terminate too
-        for thread in self.threads:
-            thread.join()
-    def terminate(self):
-        # empty the queue
-        while not self.queue.empty():
-            self.queue.get_nowait()
-        self.join()
+        """Waits for all of our subprocesses to terminate."""
+        try:
+            while os.waitpid(-1, 0):
+                pass
+        except OSError as e:
+            return
 
 class DummyParallelShell(ParallelShell):
     """Same API as ParallelShell, but doesn't actually parallelize (by
     using only one thread)"""
     def __init__(self, logger = False, dry = False):
 
 class DummyParallelShell(ParallelShell):
     """Same API as ParallelShell, but doesn't actually parallelize (by
     using only one thread)"""
     def __init__(self, logger = False, dry = False):
-        super(DummyParallelShell, self).__init__(logger, dry, max_threads=1)
+        super(DummyParallelShell, self).__init__(logger, dry, max=1)