]> scripts.mit.edu Git - wizard.git/blobdiff - wizard/util.py
Rewrite parametrize to use new parametrizeWithVars
[wizard.git] / wizard / util.py
index 2588d173bba79bd185ee82f18ef62fbeb95b2a1d..6c9764d317d88d4dd8f6f05b4546d91778665ec4 100644 (file)
+"""
+Miscellaneous utility functions and classes.
+
+.. testsetup:: *
+
+    from wizard.util import *
+"""
+
 import os.path
 import os
 import subprocess
 import pwd
 import sys
+import socket
+import errno
+import itertools
+import signal
+import httplib
+import urllib
+import time
+import logging
+import random
+import string
 
 import wizard
 
 class ChangeDirectory(object):
-    """Context for temporarily changing directory"""
+    """
+    Context for temporarily changing the working directory.
+
+        >>> with ChangeDirectory("/tmp"):
+        ...    print os.getcwd()
+        /tmp
+    """
     def __init__(self, dir):
         self.dir = dir
         self.olddir = None
     def __enter__(self):
         self.olddir = os.getcwd()
-        os.chdir(self.dir)
+        chdir(self.dir)
     def __exit__(self, *args):
-        os.chdir(self.olddir)
+        chdir(self.olddir)
 
 class Counter(object):
+    """
+    Object for counting different values when you don't know what
+    they are a priori.  Supports index access and iteration.
+
+        >>> counter = Counter()
+        >>> counter.count("foo")
+        >>> print counter["foo"]
+        1
+    """
     def __init__(self):
         self.dict = {}
     def count(self, value):
+        """Increments count for ``value``."""
         self.dict.setdefault(value, 0)
         self.dict[value] += 1
     def __getitem__(self, key):
         return self.dict[key]
     def __iter__(self):
         return self.dict.__iter__()
+    def max(self):
+        """Returns the max counter value seen."""
+        return max(self.dict.values())
+    def sum(self):
+        """Returns the sum of all counter values."""
+        return sum(self.dict.values())
+    def keys(self):
+        """Returns the keys of counters."""
+        return self.dict.keys()
+
+class PipeToLess(object):
+    """
+    Context for printing output to a pager.  Use this if output
+    is expected to be long.
+    """
+    def __enter__(self):
+        self.proc = subprocess.Popen("less", stdin=subprocess.PIPE)
+        self.old_stdout = sys.stdout
+        sys.stdout = self.proc.stdin
+    def __exit__(self, *args):
+        if self.proc:
+            self.proc.stdin.close()
+            self.proc.wait()
+            sys.stdout = self.old_stdout
+
+class IgnoreKeyboardInterrupts(object):
+    """
+    Context for temporarily ignoring keyboard interrupts.  Use this
+    if aborting would cause more harm than finishing the job.
+    """
+    def __enter__(self):
+        signal.signal(signal.SIGINT,signal.SIG_IGN)
+    def __exit__(self, *args):
+        signal.signal(signal.SIGINT, signal.default_int_handler)
+
+class LockDirectory(object):
+    """
+    Context for locking a directory.
+    """
+    def __init__(self, lockfile, expiry = 3600):
+        self.lockfile = lockfile
+        self.expiry = expiry # by default an hour
+    def __enter__(self):
+        # It's A WAVY
+        for i in range(0, 3):
+            try:
+                os.open(self.lockfile, os.O_CREAT | os.O_EXCL)
+                open(self.lockfile, "w").write("%d" % os.getpid())
+            except OSError as e:
+                if e.errno == errno.EEXIST:
+                    # There is a possibility of infinite recursion, but we
+                    # expect it to be unlikely, and not harmful if it does happen
+                    with LockDirectory(self.lockfile + "_"):
+                        # See if we can break the lock
+                        try:
+                            pid = open(self.lockfile, "r").read().strip()
+                            if not os.path.exists("/proc/%s" % pid):
+                                # break the lock, try again
+                                logging.warning("Breaking orphaned lock at %s", self.lockfile)
+                                os.unlink(self.lockfile)
+                                continue
+                            try:
+                                # check if the file is expiry old, if so, break the lock, try again
+                                if time.time() - os.stat(self.lockfile).st_mtime > self.expiry:
+                                    logging.warning("Breaking stale lock at %s", self.lockfile)
+                                    os.unlink(self.lockfile)
+                                    continue
+                            except OSError as e:
+                                if e.errno == errno.ENOENT:
+                                    continue
+                                raise
+                        except IOError:
+                            # oh hey, it went away; try again
+                            continue
+                    raise DirectoryLockedError(os.getcwd())
+                elif e.errno == errno.EACCES:
+                    raise PermissionsError(os.getcwd())
+                raise
+            return
+        raise DirectoryLockedError(os.getcwd())
+    def __exit__(self, *args):
+        try:
+            os.unlink(self.lockfile)
+        except OSError:
+            pass
+
+def chdir(dir):
+    """
+    Changes a directory, but has special exceptions for certain
+    classes of errors.
+    """
+    try:
+        os.chdir(dir)
+    except OSError as e:
+        if e.errno == errno.EACCES:
+            raise PermissionsError()
+        elif e.errno == errno.ENOENT:
+            raise NoSuchDirectoryError()
+        else: raise e
 
 def dictmap(f, d):
-    """A map function for dictionaries.  Does not allow changing keys, only
-    values"""
+    """
+    A map function for dictionaries.  Only changes values.
+
+        >>> dictmap(lambda x: x + 2, {'a': 1, 'b': 2})
+        {'a': 3, 'b': 4}
+    """
     return dict((k,f(v)) for k,v in d.items())
 
+def dictkmap(f, d):
+    """
+    A map function for dictionaries that passes key and value.
+
+        >>> dictkmap(lambda x, y: x + y, {1: 4, 3: 4})
+        {1: 5, 3: 7}
+    """
+    return dict((k,f(k,v)) for k,v in d.items())
+
 def get_exception_name(output):
-    """Reads the stderr output of another Python command and grabs the
-    fully qualified exception name"""
+    """
+    Reads the traceback from a Python program and grabs the
+    fully qualified exception name.
+    """
     lines = output.split("\n")
-    for line in lines[1:]: # skip the "traceback" line
+    cue = False
+    result = "(unknown)"
+    for line in lines[1:]:
         line = line.rstrip()
-        if line[0] == ' ': continue
-        if line[-1] == ":":
-            return line[:-1]
-        else:
-            return line
+        if not line: continue
+        if line[0] == ' ':
+            cue = True
+            continue
+        if cue:
+            cue = False
+            return line.partition(':')[0]
+    return result
 
 def get_dir_uid(dir):
     """Finds the uid of the person who owns this directory."""
     return os.stat(dir).st_uid
 
 def get_dir_owner(dir = "."):
-    """Finds the name of the locker this directory is in."""
-    return pwd.getpwuid(get_dir_uid(dir)).pw_name
+    """
+    Finds the name of the locker this directory is in.
+
+    .. note::
+
+        This function uses the passwd database and thus
+        only works on scripts servers when querying directories
+        that live on AFS.
+    """
+    uid = get_dir_uid(dir)
+    try:
+        pwentry = pwd.getpwuid(uid)
+        return pwentry.pw_name
+    except KeyError:
+        # do an pts query to get the name
+        return subprocess.Popen(['pts', 'examine', str(uid)], stdout=subprocess.PIPE).communicate()[0].partition(",")[0].partition(": ")[2]
 
 def get_revision():
     """Returns the commit ID of the current Wizard install."""
+    # If you decide to convert this to use wizard.shell, be warned
+    # that there is a circular dependency, so this function would
+    # probably have to live somewhere else, probably wizard.git
     wizard_git = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), ".git")
     return subprocess.Popen(["git", "--git-dir=" + wizard_git, "rev-parse", "HEAD"], stdout=subprocess.PIPE).communicate()[0].rstrip()
 
 def get_operator_info():
-    """Returns tuple of (realname, email) of person who is operating
-    this script, as told to use by the Kerberos principal name.
-    Useful for commit messages."""
-    username = get_operator_name()
-    hesinfo = subprocess.Popen(["hesinfo", username, "passwd"],stdout=subprocess.PIPE).communicate()[0]
-    fields = hesinfo.partition(",")[0]
-    realname = fields.rpartition(":")[2]
-    return realname, username + "@mit.edu"
+    """
+    Returns tuple of ``(realname, email)`` about the person running
+    the script.  If run from a scripts server, get info from Hesiod.
+    Otherwise, use the passwd database (email generated probably won't
+    actually accept mail).  Useful when generating commit messages.
+    """
+    username = get_operator_name_from_gssapi()
+    if username:
+        # scripts approach
+        hesinfo = subprocess.Popen(["hesinfo", username, "passwd"],stdout=subprocess.PIPE).communicate()[0]
+        fields = hesinfo.partition(",")[0]
+        realname = fields.rpartition(":")[2]
+        return realname, username + "@mit.edu"
+    else:
+        # more traditional approach, but the email probably doesn't work
+        uid = os.getuid()
+        if not uid:
+            # since root isn't actually a useful designation, but maybe
+            # SUDO_USER contains something helpful
+            sudo_user = os.getenv("SUDO_USER")
+            if not sudo_user:
+                raise NoOperatorInfo
+            pwdentry = pwd.getpwnam(sudo_user)
+        else:
+            pwdentry = pwd.getpwuid(uid)
+        # XXX: error checking might be nice
+        # We follow the Ubuntu convention of gecos being a comma split field
+        # with the person's realname being the first entry.
+        return pwdentry.pw_gecos.split(",")[0], pwdentry.pw_name + "@" + socket.gethostname()
 
 def get_operator_git():
-    """Returns Real Name <username@mit.edu> suitable for use in
-    Git Something-by: string."""
+    """
+    Returns ``Real Name <username@mit.edu>`` suitable for use in
+    Git ``Something-by:`` string.
+    """
     return "%s <%s>" % get_operator_info()
 
-def get_operator_name():
-    """Returns username of the person operating this script."""
+def get_operator_name_from_gssapi():
+    """
+    Returns username of the person operating this script based
+    off of the :envvar:`SSH_GSSAPI_NAME` environment variable.
+
+    .. note::
+
+        :envvar:`SSH_GSSAPI_NAME` is not set by a vanilla OpenSSH
+        distributions.  Scripts servers are patched to support this
+        environment variable.
+    """
     principal = os.getenv("SSH_GSSAPI_NAME")
-    if not principal: raise NoOperatorInfo()
+    if not principal:
+        return None
     instance, _, _ = principal.partition("@")
     if instance.endswith("/root"):
         username, _, _ = principal.partition("/")
@@ -85,6 +288,11 @@ def get_operator_name():
     return username
 
 def set_operator_env():
+    """
+    Sets :envvar:`GIT_COMMITTER_NAME` and :envvar:`GIT_COMMITTER_EMAIL`
+    environment variables if applicable.  Does nothing if
+    :func:`get_operator_info` throws :exc:`NoOperatorInfo`.
+    """
     try:
         op_realname, op_email = get_operator_info()
         os.putenv("GIT_COMMITTER_NAME", op_realname)
@@ -93,25 +301,111 @@ def set_operator_env():
         pass
 
 def set_author_env():
+    """
+    Sets :envvar:`GIT_AUTHOR_NAME` and :envvar:`GIT_AUTHOR_EMAIL` environment
+    variables if applicable. Does nothing if :func:`get_dir_owner` fails.
+    """
     try:
+        # XXX: should check if the directory is in AFS, and if not, use
+        # a more traditional metric
         lockername = get_dir_owner()
         os.putenv("GIT_AUTHOR_NAME", "%s locker" % lockername)
         os.putenv("GIT_AUTHOR_EMAIL", "%s@scripts.mit.edu" % lockername)
-    except KeyError:
+    except KeyError: # XXX: This doesn't actually make sense
         pass
 
 def set_git_env():
+    """Sets all appropriate environment variables for Git commits."""
     set_operator_env()
     set_author_env()
 
 def get_git_footer():
+    """Returns strings for placing in Git log info about Wizard."""
     return "\n".join(["Wizard-revision: %s" % get_revision()
         ,"Wizard-args: %s" % " ".join(sys.argv)
         ])
 
-class Error(wizard.Error):
-    pass
+def safe_unlink(file):
+    """Moves a file/dir to a backup location."""
+    if not os.path.exists(file):
+        return None
+    prefix = "%s.bak" % file
+    name = None
+    for i in itertools.count():
+        name = "%s.%d" % (prefix, i)
+        if not os.path.exists(name):
+            break
+    os.rename(file, name)
+    return name
+
+def soft_unlink(file):
+    """Unlink a file, but don't complain if it doesn't exist."""
+    try:
+        os.unlink(file)
+    except OSError:
+        pass
 
-class NoOperatorInfo(Error):
+def fetch(host, path, subpath, post=None):
+    try:
+        # XXX: Special case if it's https; not sure why this data isn't
+        # passed
+        h = httplib.HTTPConnection(host)
+        fullpath = path.rstrip("/") + "/" + subpath.lstrip("/") # to be lenient about input we accept
+        if post:
+            headers = {"Content-type": "application/x-www-form-urlencoded"}
+            h.request("POST", fullpath, urllib.urlencode(post), headers)
+        else:
+            h.request("GET", fullpath)
+        r = h.getresponse()
+        data = r.read()
+        h.close()
+        return data
+    except socket.gaierror as e:
+        if e.errno == socket.EAI_NONAME:
+            raise DNSError(host)
+        else:
+            raise
+
+def mixed_newlines(filename):
+    """Returns ``True`` if ``filename`` has mixed newlines."""
+    f = open(filename, "U") # requires universal newline support
+    f.read()
+    ret = isinstance(f.newlines, tuple)
+    f.close() # just to be safe
+    return ret
+
+def random_key(length=30):
+    """Generates a random alphanumeric key of ``length`` size."""
+    return ''.join(random.choice(string.letters + string.digits) for i in xrange(length))
+
+class NoOperatorInfo(wizard.Error):
+    """No information could be found about the operator from Kerberos."""
     pass
 
+class PermissionsError(IOError):
+    errno = errno.EACCES
+
+class NoSuchDirectoryError(IOError):
+    errno = errno.ENOENT
+
+class DirectoryLockedError(wizard.Error):
+    def __init__(self, dir):
+        self.dir = dir
+    def __str__(self):
+        return """
+
+ERROR: Could not acquire lock on directory.  Maybe there is
+another migration process running?
+"""
+
+class DNSError(socket.gaierror):
+    errno = socket.EAI_NONAME
+    #: Hostname that could not resolve name
+    host = None
+    def __init__(self, host):
+        self.host = host
+    def __str__(self):
+        return """
+
+ERROR: Could not resolve hostname %s.
+""" % self.host