X-Git-Url: https://scripts.mit.edu/gitweb/wizard.git/blobdiff_plain/a19b4fb59bafd03043b0cee1014dd454ac342e60..f3096b4a2a5498c38c7451397eebdc4ac5ed0788:/wizard/util.py diff --git a/wizard/util.py b/wizard/util.py index 2588d17..5f09ca3 100644 --- a/wizard/util.py +++ b/wizard/util.py @@ -1,82 +1,285 @@ +""" +Miscellaneous utility functions and classes. + +.. testsetup:: * + + from wizard.util import * +""" + import os.path import os import subprocess import pwd import sys +import socket +import errno +import itertools +import signal +import httplib +import urllib +import time +import logging +import random +import string import wizard class ChangeDirectory(object): - """Context for temporarily changing directory""" + """ + Context for temporarily changing the working directory. + + >>> with ChangeDirectory("/tmp"): + ... print os.getcwd() + /tmp + """ def __init__(self, dir): self.dir = dir self.olddir = None def __enter__(self): self.olddir = os.getcwd() - os.chdir(self.dir) + chdir(self.dir) def __exit__(self, *args): - os.chdir(self.olddir) + chdir(self.olddir) class Counter(object): + """ + Object for counting different values when you don't know what + they are a priori. Supports index access and iteration. + + >>> counter = Counter() + >>> counter.count("foo") + >>> print counter["foo"] + 1 + """ def __init__(self): self.dict = {} def count(self, value): + """Increments count for ``value``.""" self.dict.setdefault(value, 0) self.dict[value] += 1 def __getitem__(self, key): return self.dict[key] def __iter__(self): return self.dict.__iter__() + def max(self): + """Returns the max counter value seen.""" + return max(self.dict.values()) + def sum(self): + """Returns the sum of all counter values.""" + return sum(self.dict.values()) + def keys(self): + """Returns the keys of counters.""" + return self.dict.keys() + +class PipeToLess(object): + """ + Context for printing output to a pager. Use this if output + is expected to be long. + """ + def __enter__(self): + self.proc = subprocess.Popen("less", stdin=subprocess.PIPE) + self.old_stdout = sys.stdout + sys.stdout = self.proc.stdin + def __exit__(self, *args): + if self.proc: + self.proc.stdin.close() + self.proc.wait() + sys.stdout = self.old_stdout + +class IgnoreKeyboardInterrupts(object): + """ + Context for temporarily ignoring keyboard interrupts. Use this + if aborting would cause more harm than finishing the job. + """ + def __enter__(self): + signal.signal(signal.SIGINT,signal.SIG_IGN) + def __exit__(self, *args): + signal.signal(signal.SIGINT, signal.default_int_handler) + +class LockDirectory(object): + """ + Context for locking a directory. + """ + def __init__(self, lockfile, expiry = 3600): + self.lockfile = lockfile + self.expiry = expiry # by default an hour + def __enter__(self): + # It's A WAVY + for i in range(0, 3): + try: + os.open(self.lockfile, os.O_CREAT | os.O_EXCL) + open(self.lockfile, "w").write("%d" % os.getpid()) + except OSError as e: + if e.errno == errno.EEXIST: + # There is a possibility of infinite recursion, but we + # expect it to be unlikely, and not harmful if it does happen + with LockDirectory(self.lockfile + "_"): + # See if we can break the lock + try: + pid = open(self.lockfile, "r").read().strip() + if not os.path.exists("/proc/%s" % pid): + # break the lock, try again + logging.warning("Breaking orphaned lock at %s", self.lockfile) + os.unlink(self.lockfile) + continue + try: + # check if the file is expiry old, if so, break the lock, try again + if time.time() - os.stat(self.lockfile).st_mtime > self.expiry: + logging.warning("Breaking stale lock at %s", self.lockfile) + os.unlink(self.lockfile) + continue + except OSError as e: + if e.errno == errno.ENOENT: + continue + raise + except IOError: + # oh hey, it went away; try again + continue + raise DirectoryLockedError(os.getcwd()) + elif e.errno == errno.EACCES: + raise PermissionsError(os.getcwd()) + raise + return + raise DirectoryLockedError(os.getcwd()) + def __exit__(self, *args): + try: + os.unlink(self.lockfile) + except OSError: + pass + +def chdir(dir): + """ + Changes a directory, but has special exceptions for certain + classes of errors. + """ + try: + os.chdir(dir) + except OSError as e: + if e.errno == errno.EACCES: + raise PermissionsError() + elif e.errno == errno.ENOENT: + raise NoSuchDirectoryError() + else: raise e def dictmap(f, d): - """A map function for dictionaries. Does not allow changing keys, only - values""" + """ + A map function for dictionaries. Only changes values. + + >>> dictmap(lambda x: x + 2, {'a': 1, 'b': 2}) + {'a': 3, 'b': 4} + """ return dict((k,f(v)) for k,v in d.items()) +def dictkmap(f, d): + """ + A map function for dictionaries that passes key and value. + + >>> dictkmap(lambda x, y: x + y, {1: 4, 3: 4}) + {1: 5, 3: 7} + """ + return dict((k,f(k,v)) for k,v in d.items()) + def get_exception_name(output): - """Reads the stderr output of another Python command and grabs the - fully qualified exception name""" + """ + Reads the traceback from a Python program and grabs the + fully qualified exception name. + """ lines = output.split("\n") - for line in lines[1:]: # skip the "traceback" line + cue = False + result = "(unknown)" + for line in lines[1:]: line = line.rstrip() - if line[0] == ' ': continue - if line[-1] == ":": - return line[:-1] - else: - return line + if not line: continue + if line[0] == ' ': + cue = True + continue + if cue: + cue = False + return line.partition(':')[0] + return result def get_dir_uid(dir): """Finds the uid of the person who owns this directory.""" return os.stat(dir).st_uid def get_dir_owner(dir = "."): - """Finds the name of the locker this directory is in.""" - return pwd.getpwuid(get_dir_uid(dir)).pw_name + """ + Finds the name of the locker this directory is in. + + .. note:: + + When querying AFS servers, this function only works if + you're on a Scripts server (which has the correct passwd + database) or if you're on a Debathena machine. + """ + uid = get_dir_uid(dir) + try: + pwentry = pwd.getpwuid(uid) + return pwentry.pw_name + except KeyError: + # do an pts query to get the name + return subprocess.Popen(['pts', 'examine', str(uid)], stdout=subprocess.PIPE).communicate()[0].partition(",")[0].partition(": ")[2] def get_revision(): """Returns the commit ID of the current Wizard install.""" + # If you decide to convert this to use wizard.shell, be warned + # that there is a circular dependency, so this function would + # probably have to live somewhere else, probably wizard.git wizard_git = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), ".git") return subprocess.Popen(["git", "--git-dir=" + wizard_git, "rev-parse", "HEAD"], stdout=subprocess.PIPE).communicate()[0].rstrip() def get_operator_info(): - """Returns tuple of (realname, email) of person who is operating - this script, as told to use by the Kerberos principal name. - Useful for commit messages.""" - username = get_operator_name() - hesinfo = subprocess.Popen(["hesinfo", username, "passwd"],stdout=subprocess.PIPE).communicate()[0] - fields = hesinfo.partition(",")[0] - realname = fields.rpartition(":")[2] - return realname, username + "@mit.edu" + """ + Returns tuple of ``(realname, email)`` about the person running + the script. If run from a scripts server, get info from Hesiod. + Otherwise, use the passwd database (email generated probably won't + actually accept mail). Useful when generating commit messages. + """ + username = get_operator_name_from_gssapi() + if username: + # scripts approach + hesinfo = subprocess.Popen(["hesinfo", username, "passwd"],stdout=subprocess.PIPE).communicate()[0] + fields = hesinfo.partition(",")[0] + realname = fields.rpartition(":")[2] + return realname, username + "@mit.edu" + else: + # more traditional approach, but the email probably doesn't work + uid = os.getuid() + if not uid: + # since root isn't actually a useful designation, but maybe + # SUDO_USER contains something helpful + sudo_user = os.getenv("SUDO_USER") + if not sudo_user: + raise NoOperatorInfo + pwdentry = pwd.getpwnam(sudo_user) + else: + pwdentry = pwd.getpwuid(uid) + # XXX: error checking might be nice + # We follow the Ubuntu convention of gecos being a comma split field + # with the person's realname being the first entry. + return pwdentry.pw_gecos.split(",")[0], pwdentry.pw_name + "@" + socket.gethostname() def get_operator_git(): - """Returns Real Name suitable for use in - Git Something-by: string.""" + """ + Returns ``Real Name `` suitable for use in + Git ``Something-by:`` string. + """ return "%s <%s>" % get_operator_info() -def get_operator_name(): - """Returns username of the person operating this script.""" +def get_operator_name_from_gssapi(): + """ + Returns username of the person operating this script based + off of the :envvar:`SSH_GSSAPI_NAME` environment variable. + + .. note:: + + :envvar:`SSH_GSSAPI_NAME` is not set by a vanilla OpenSSH + distributions. Scripts servers are patched to support this + environment variable. + """ principal = os.getenv("SSH_GSSAPI_NAME") - if not principal: raise NoOperatorInfo() + if not principal: + return None instance, _, _ = principal.partition("@") if instance.endswith("/root"): username, _, _ = principal.partition("/") @@ -85,6 +288,11 @@ def get_operator_name(): return username def set_operator_env(): + """ + Sets :envvar:`GIT_COMMITTER_NAME` and :envvar:`GIT_COMMITTER_EMAIL` + environment variables if applicable. Does nothing if + :func:`get_operator_info` throws :exc:`NoOperatorInfo`. + """ try: op_realname, op_email = get_operator_info() os.putenv("GIT_COMMITTER_NAME", op_realname) @@ -93,25 +301,124 @@ def set_operator_env(): pass def set_author_env(): + """ + Sets :envvar:`GIT_AUTHOR_NAME` and :envvar:`GIT_AUTHOR_EMAIL` environment + variables if applicable. Does nothing if :func:`get_dir_owner` fails. + """ try: + # XXX: should check if the directory is in AFS, and if not, use + # a more traditional metric lockername = get_dir_owner() os.putenv("GIT_AUTHOR_NAME", "%s locker" % lockername) os.putenv("GIT_AUTHOR_EMAIL", "%s@scripts.mit.edu" % lockername) - except KeyError: + except KeyError: # XXX: This doesn't actually make sense pass def set_git_env(): + """Sets all appropriate environment variables for Git commits.""" set_operator_env() set_author_env() def get_git_footer(): + """Returns strings for placing in Git log info about Wizard.""" return "\n".join(["Wizard-revision: %s" % get_revision() ,"Wizard-args: %s" % " ".join(sys.argv) ]) -class Error(wizard.Error): - pass +def safe_unlink(file): + """Moves a file/dir to a backup location.""" + if not os.path.lexists(file): + return None + prefix = "%s.bak" % file + name = None + for i in itertools.count(): + name = "%s.%d" % (prefix, i) + if not os.path.lexists(name): + break + os.rename(file, name) + return name + +def soft_unlink(file): + """Unlink a file, but don't complain if it doesn't exist.""" + try: + os.unlink(file) + except OSError: + pass + +def makedirs(path): + """ + Create a directory path (a la ``mkdir -p`` or ``os.makedirs``), + but don't complain if it already exists. + """ + try: + os.makedirs(path) + except OSError as exc: + if exc.errno == errno.EEXIST: + pass + else: + raise + +def fetch(host, path, subpath, post=None): + try: + # XXX: Special case if it's https; not sure why this data isn't + # passed + h = httplib.HTTPConnection(host) + fullpath = path.rstrip("/") + "/" + subpath.lstrip("/") # to be lenient about input we accept + if post: + headers = {"Content-type": "application/x-www-form-urlencoded"} + h.request("POST", fullpath, urllib.urlencode(post), headers) + else: + h.request("GET", fullpath) + r = h.getresponse() + data = r.read() + h.close() + return data + except socket.gaierror as e: + if e.errno == socket.EAI_NONAME: + raise DNSError(host) + else: + raise -class NoOperatorInfo(Error): +def mixed_newlines(filename): + """Returns ``True`` if ``filename`` has mixed newlines.""" + f = open(filename, "U") # requires universal newline support + f.read() + ret = isinstance(f.newlines, tuple) + f.close() # just to be safe + return ret + +def random_key(length=30): + """Generates a random alphanumeric key of ``length`` size.""" + return ''.join(random.choice(string.letters + string.digits) for i in xrange(length)) + +class NoOperatorInfo(wizard.Error): + """No information could be found about the operator from Kerberos.""" pass +class PermissionsError(IOError): + errno = errno.EACCES + +class NoSuchDirectoryError(IOError): + errno = errno.ENOENT + +class DirectoryLockedError(wizard.Error): + def __init__(self, dir): + self.dir = dir + def __str__(self): + return """ + +ERROR: Could not acquire lock on directory. Maybe there is +another migration process running? +""" + +class DNSError(socket.gaierror): + errno = socket.EAI_NONAME + #: Hostname that could not resolve name + host = None + def __init__(self, host): + self.host = host + def __str__(self): + return """ + +ERROR: Could not resolve hostname %s. +""" % self.host