]> scripts.mit.edu Git - wizard.git/blobdiff - lib/wizard/shell.py
Parallelize mass migration.
[wizard.git] / lib / wizard / shell.py
index 9df5583275d01b46c9d49ac6f8f310bc11780c01..6c138ba5d76056f84e6eac591032d6db6464c863 100644 (file)
@@ -1,5 +1,7 @@
 import subprocess
 import sys
+import Queue
+import threading
 
 import wizard as _wizard
 from wizard import util
@@ -52,3 +54,78 @@ class Shell(object):
         if not user: return self.call(*args)
         return self.call("sudo", "-u", user, *args, python=is_python(args))
 
+class ShellThread(threading.Thread):
+    """Little thread that does the dispatching"""
+    def __init__(self, queue, logger=False, dry=False):
+        self.queue = queue
+        self.logger = logger
+        self.shell = Shell(logger=logger, dry=dry)
+        super(ShellThread, self).__init__()
+    def run(self):
+        while True:
+            call = self.queue.get()
+            # thread termination mechanism
+            if not call: return
+            name, args, kwargs, on_success, on_error = call
+            try:
+                result = getattr(self.shell, name)(*args, **kwargs)
+                on_success(result)
+            except CallError as e:
+                if isinstance(e, PythonCallError):
+                    # check if the subprocess got a KeyboardInterrupt
+                    if e.name == "KeyboardInterrupt":
+                        raise KeyboardInterrupt
+                on_error(e)
+            except:
+                # XXX: This is really scary
+                self.logger.error("Uncaught exception in thread")
+                raise
+
+class ParallelShell(object):
+    """Commands are queued here, and executed in parallel (with
+    threading) in accordance with the maximum number of allowed
+    subprocesses, and result in callback execution when they finish."""
+    def __init__(self, logger = False, dry = False, max_threads = 10):
+        self.logger = logger
+        self.dry = dry
+        self.threads = []
+        # queue of call tuples (method name, args and kwargs) to
+        # be threaded
+        self.queue = Queue.Queue()
+        # build our threads
+        for n in range(max_threads):
+            self.threads.append(ShellThread(self.queue, logger=logger, dry=dry))
+    def __getattr__(self, name):
+        # override call and callAsUser (and possibly others)
+        def thunk(*args, **kwargs):
+            on_success = kwargs.pop("on_success")
+            on_error   = kwargs.pop("on_error")
+            self.queue.put((name, args, kwargs, on_success, on_error))
+        return thunk
+    def start(self):
+        for thread in self.threads:
+            thread.start()
+    def join(self):
+        """Waits for all of our subthreads to terminate."""
+        for thread in self.threads:
+            # generate as many nops as we have threads to
+            # terminate them
+            self.queue.put(False)
+        # wait for the queue to empty
+        self.queue.join()
+        # defense in depth: make sure all the threads
+        # terminate too
+        for thread in self.threads:
+            thread.join()
+    def terminate(self):
+        # empty the queue
+        while not self.queue.empty():
+            self.queue.get_nowait()
+        self.join()
+
+class DummyParallelShell(ParallelShell):
+    """Same API as ParallelShell, but doesn't actually parallelize (by
+    using only one thread)"""
+    def __init__(self, logger = False, dry = False):
+        super(DummyParallelShell, self).__init__(logger, dry, max_threads=1)
+