]> scripts.mit.edu Git - wizard.git/blob - wizard/util.py
419e9d92485cd21004b1426dbf180cd52d868462
[wizard.git] / wizard / util.py
1 """
2 Miscellaneous utility functions and classes.
3
4 .. testsetup:: *
5
6     from wizard.util import *
7 """
8
9 import os.path
10 import os
11 import subprocess
12 import pwd
13 import sys
14 import socket
15 import errno
16 import itertools
17 import signal
18 import httplib
19 import urllib
20 import time
21 import logging
22 import random
23 import string
24
25 import wizard
26
27 class ChangeDirectory(object):
28     """
29     Context for temporarily changing the working directory.
30
31         >>> with ChangeDirectory("/tmp"):
32         ...    print os.getcwd()
33         /tmp
34     """
35     def __init__(self, dir):
36         self.dir = dir
37         self.olddir = None
38     def __enter__(self):
39         self.olddir = os.getcwd()
40         chdir(self.dir)
41     def __exit__(self, *args):
42         chdir(self.olddir)
43
44 class Counter(object):
45     """
46     Object for counting different values when you don't know what
47     they are a priori.  Supports index access and iteration.
48
49         >>> counter = Counter()
50         >>> counter.count("foo")
51         >>> print counter["foo"]
52         1
53     """
54     def __init__(self):
55         self.dict = {}
56     def count(self, value):
57         """Increments count for ``value``."""
58         self.dict.setdefault(value, 0)
59         self.dict[value] += 1
60     def __getitem__(self, key):
61         return self.dict[key]
62     def __iter__(self):
63         return self.dict.__iter__()
64     def max(self):
65         """Returns the max counter value seen."""
66         return max(self.dict.values())
67     def sum(self):
68         """Returns the sum of all counter values."""
69         return sum(self.dict.values())
70     def keys(self):
71         """Returns the keys of counters."""
72         return self.dict.keys()
73
74 class PipeToLess(object):
75     """
76     Context for printing output to a pager.  Use this if output
77     is expected to be long.
78     """
79     def __enter__(self):
80         self.proc = subprocess.Popen("less", stdin=subprocess.PIPE)
81         self.old_stdout = sys.stdout
82         sys.stdout = self.proc.stdin
83     def __exit__(self, *args):
84         if self.proc:
85             self.proc.stdin.close()
86             self.proc.wait()
87             sys.stdout = self.old_stdout
88
89 class IgnoreKeyboardInterrupts(object):
90     """
91     Context for temporarily ignoring keyboard interrupts.  Use this
92     if aborting would cause more harm than finishing the job.
93     """
94     def __enter__(self):
95         signal.signal(signal.SIGINT,signal.SIG_IGN)
96     def __exit__(self, *args):
97         signal.signal(signal.SIGINT, signal.default_int_handler)
98
99 class LockDirectory(object):
100     """
101     Context for locking a directory.
102     """
103     def __init__(self, lockfile, expiry = 3600):
104         self.lockfile = lockfile
105         self.expiry = expiry # by default an hour
106     def __enter__(self):
107         # It's A WAVY
108         for i in range(0, 3):
109             try:
110                 os.open(self.lockfile, os.O_CREAT | os.O_EXCL)
111                 open(self.lockfile, "w").write("%d" % os.getpid())
112             except OSError as e:
113                 if e.errno == errno.EEXIST:
114                     # There is a possibility of infinite recursion, but we
115                     # expect it to be unlikely, and not harmful if it does happen
116                     with LockDirectory(self.lockfile + "_"):
117                         # See if we can break the lock
118                         try:
119                             pid = open(self.lockfile, "r").read().strip()
120                             if not os.path.exists("/proc/%s" % pid):
121                                 # break the lock, try again
122                                 logging.warning("Breaking orphaned lock at %s", self.lockfile)
123                                 os.unlink(self.lockfile)
124                                 continue
125                             try:
126                                 # check if the file is expiry old, if so, break the lock, try again
127                                 if time.time() - os.stat(self.lockfile).st_mtime > self.expiry:
128                                     logging.warning("Breaking stale lock at %s", self.lockfile)
129                                     os.unlink(self.lockfile)
130                                     continue
131                             except OSError as e:
132                                 if e.errno == errno.ENOENT:
133                                     continue
134                                 raise
135                         except IOError:
136                             # oh hey, it went away; try again
137                             continue
138                     raise DirectoryLockedError(os.getcwd())
139                 elif e.errno == errno.EACCES:
140                     raise PermissionsError(os.getcwd())
141                 raise
142             return
143         raise DirectoryLockedError(os.getcwd())
144     def __exit__(self, *args):
145         try:
146             os.unlink(self.lockfile)
147         except OSError:
148             pass
149
150 def chdir(dir):
151     """
152     Changes a directory, but has special exceptions for certain
153     classes of errors.
154     """
155     try:
156         os.chdir(dir)
157     except OSError as e:
158         if e.errno == errno.EACCES:
159             raise PermissionsError()
160         elif e.errno == errno.ENOENT:
161             raise NoSuchDirectoryError()
162         else: raise e
163
164 def dictmap(f, d):
165     """
166     A map function for dictionaries.  Only changes values.
167
168         >>> dictmap(lambda x: x + 2, {'a': 1, 'b': 2})
169         {'a': 3, 'b': 4}
170     """
171     return dict((k,f(v)) for k,v in d.items())
172
173 def dictkmap(f, d):
174     """
175     A map function for dictionaries that passes key and value.
176
177         >>> dictkmap(lambda x, y: x + y, {1: 4, 3: 4})
178         {1: 5, 3: 7}
179     """
180     return dict((k,f(k,v)) for k,v in d.items())
181
182 def get_exception_name(output):
183     """
184     Reads the traceback from a Python program and grabs the
185     fully qualified exception name.
186     """
187     lines = output.split("\n")
188     cue = False
189     result = "(unknown)"
190     for line in lines[1:]:
191         line = line.rstrip()
192         if not line: continue
193         if line[0] == ' ':
194             cue = True
195             continue
196         if cue:
197             cue = False
198             return line.partition(':')[0]
199     return result
200
201 def get_dir_uid(dir):
202     """Finds the uid of the person who owns this directory."""
203     return os.stat(dir).st_uid
204
205 def get_dir_owner(dir = "."):
206     """
207     Finds the name of the locker this directory is in.
208
209     .. note::
210
211         When querying AFS servers, this function only works if
212         you're on a Scripts server (which has the correct passwd
213         database) or if you're on a Debathena machine.
214     """
215     uid = get_dir_uid(dir)
216     try:
217         pwentry = pwd.getpwuid(uid)
218         return pwentry.pw_name
219     except KeyError:
220         # do an pts query to get the name
221         return subprocess.Popen(['pts', 'examine', str(uid)], stdout=subprocess.PIPE).communicate()[0].partition(",")[0].partition(": ")[2]
222
223 def get_revision():
224     """Returns the commit ID of the current Wizard install."""
225     # If you decide to convert this to use wizard.shell, be warned
226     # that there is a circular dependency, so this function would
227     # probably have to live somewhere else, probably wizard.git
228     wizard_git = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), ".git")
229     return subprocess.Popen(["git", "--git-dir=" + wizard_git, "rev-parse", "HEAD"], stdout=subprocess.PIPE).communicate()[0].rstrip()
230
231 def get_operator_info():
232     """
233     Returns tuple of ``(realname, email)`` about the person running
234     the script.  If run from a scripts server, get info from Hesiod.
235     Otherwise, use the passwd database (email generated probably won't
236     actually accept mail).  Useful when generating commit messages.
237     """
238     username = get_operator_name_from_gssapi()
239     if username:
240         # scripts approach
241         hesinfo = subprocess.Popen(["hesinfo", username, "passwd"],stdout=subprocess.PIPE).communicate()[0]
242         fields = hesinfo.partition(",")[0]
243         realname = fields.rpartition(":")[2]
244         return realname, username + "@mit.edu"
245     else:
246         # more traditional approach, but the email probably doesn't work
247         uid = os.getuid()
248         if not uid:
249             # since root isn't actually a useful designation, but maybe
250             # SUDO_USER contains something helpful
251             sudo_user = os.getenv("SUDO_USER")
252             if not sudo_user:
253                 raise NoOperatorInfo
254             pwdentry = pwd.getpwnam(sudo_user)
255         else:
256             pwdentry = pwd.getpwuid(uid)
257         # XXX: error checking might be nice
258         # We follow the Ubuntu convention of gecos being a comma split field
259         # with the person's realname being the first entry.
260         return pwdentry.pw_gecos.split(",")[0], pwdentry.pw_name + "@" + socket.gethostname()
261
262 def get_operator_git():
263     """
264     Returns ``Real Name <username@mit.edu>`` suitable for use in
265     Git ``Something-by:`` string.
266     """
267     return "%s <%s>" % get_operator_info()
268
269 def get_operator_name_from_gssapi():
270     """
271     Returns username of the person operating this script based
272     off of the :envvar:`SSH_GSSAPI_NAME` environment variable.
273
274     .. note::
275
276         :envvar:`SSH_GSSAPI_NAME` is not set by a vanilla OpenSSH
277         distributions.  Scripts servers are patched to support this
278         environment variable.
279     """
280     principal = os.getenv("SSH_GSSAPI_NAME")
281     if not principal:
282         return None
283     instance, _, _ = principal.partition("@")
284     if instance.endswith("/root"):
285         username, _, _ = principal.partition("/")
286     else:
287         username = instance
288     return username
289
290 def set_operator_env():
291     """
292     Sets :envvar:`GIT_COMMITTER_NAME` and :envvar:`GIT_COMMITTER_EMAIL`
293     environment variables if applicable.  Does nothing if
294     :func:`get_operator_info` throws :exc:`NoOperatorInfo`.
295     """
296     try:
297         op_realname, op_email = get_operator_info()
298         os.putenv("GIT_COMMITTER_NAME", op_realname)
299         os.putenv("GIT_COMMITTER_EMAIL", op_email)
300     except NoOperatorInfo:
301         pass
302
303 def set_author_env():
304     """
305     Sets :envvar:`GIT_AUTHOR_NAME` and :envvar:`GIT_AUTHOR_EMAIL` environment
306     variables if applicable. Does nothing if :func:`get_dir_owner` fails.
307     """
308     try:
309         # XXX: should check if the directory is in AFS, and if not, use
310         # a more traditional metric
311         lockername = get_dir_owner()
312         os.putenv("GIT_AUTHOR_NAME", "%s locker" % lockername)
313         os.putenv("GIT_AUTHOR_EMAIL", "%s@scripts.mit.edu" % lockername)
314     except KeyError: # XXX: This doesn't actually make sense
315         pass
316
317 def set_git_env():
318     """Sets all appropriate environment variables for Git commits."""
319     set_operator_env()
320     set_author_env()
321
322 def get_git_footer():
323     """Returns strings for placing in Git log info about Wizard."""
324     return "\n".join(["Wizard-revision: %s" % get_revision()
325         ,"Wizard-args: %s" % " ".join(sys.argv)
326         ])
327
328 def safe_unlink(file):
329     """Moves a file/dir to a backup location."""
330     if not os.path.lexists(file):
331         return None
332     prefix = "%s.bak" % file
333     name = None
334     for i in itertools.count():
335         name = "%s.%d" % (prefix, i)
336         if not os.path.lexists(name):
337             break
338     os.rename(file, name)
339     return name
340
341 def soft_unlink(file):
342     """Unlink a file, but don't complain if it doesn't exist."""
343     try:
344         os.unlink(file)
345     except OSError:
346         pass
347
348 def makedirs(path):
349     """
350     Create a directory path (a la ``mkdir -p`` or ``os.makedirs``),
351     but don't complain if it already exists.
352     """
353     try:
354         os.makedirs(path)
355     except OSError as exc:
356         if exc.errno == errno.EEXIST:
357             pass
358         else:
359             raise
360
361 def fetch(host, path, subpath, post=None):
362     try:
363         # XXX: Special case if it's https; not sure why this data isn't
364         # passed
365         h = httplib.HTTPConnection(host)
366         fullpath = path.rstrip("/") + "/" + subpath.lstrip("/") # to be lenient about input we accept
367         if post:
368             headers = {"Content-type": "application/x-www-form-urlencoded"}
369             logging.info("POST request to http://%s%s", host, fullpath)
370             logging.debug("POST contents:\n" + urllib.urlencode(post))
371             h.request("POST", fullpath, urllib.urlencode(post), headers)
372         else:
373             logging.info("GET request to http://%s%s", host, fullpath)
374             h.request("GET", fullpath)
375         r = h.getresponse()
376         data = r.read()
377         h.close()
378         return data
379     except socket.gaierror as e:
380         if e.errno == socket.EAI_NONAME:
381             raise DNSError(host)
382         else:
383             raise
384
385 def mixed_newlines(filename):
386     """Returns ``True`` if ``filename`` has mixed newlines."""
387     f = open(filename, "U") # requires universal newline support
388     f.read()
389     ret = isinstance(f.newlines, tuple)
390     f.close() # just to be safe
391     return ret
392
393 def disk_usage(dir=None, excluded_dir=".git"):
394     """
395     Recursively determines the disk usage of a directory, excluding
396     .git directories.  Value is in bytes.  If ``dir`` is omitted, the
397     current working directory is assumed.
398     """
399     if dir is None: dir = os.getcwd()
400     sum_sizes = 0
401     for root, _, files in os.walk(dir):
402         for name in files:
403             if not os.path.join(root, name).startswith(os.path.join(dir, excluded_dir)):
404                 file = os.path.join(root, name)
405                 try:
406                     if os.path.islink(file): continue
407                     sum_sizes += os.path.getsize(file)
408                 except OSError as e:
409                     if e.errno == errno.ENOENT:
410                         logging.warning("%s disappeared before we could stat", file)
411                     else:
412                         raise
413     return sum_sizes
414
415 def random_key(length=30):
416     """Generates a random alphanumeric key of ``length`` size."""
417     return ''.join(random.choice(string.letters + string.digits) for i in xrange(length))
418
419 class NoOperatorInfo(wizard.Error):
420     """No information could be found about the operator from Kerberos."""
421     pass
422
423 class PermissionsError(IOError):
424     errno = errno.EACCES
425
426 class NoSuchDirectoryError(IOError):
427     errno = errno.ENOENT
428
429 class DirectoryLockedError(wizard.Error):
430     def __init__(self, dir):
431         self.dir = dir
432     def __str__(self):
433         return """
434
435 ERROR: Could not acquire lock on directory.  Maybe there is
436 another migration process running?
437 """
438
439 class DNSError(socket.gaierror):
440     errno = socket.EAI_NONAME
441     #: Hostname that could not resolve name
442     host = None
443     def __init__(self, host):
444         self.host = host
445     def __str__(self):
446         return """
447
448 ERROR: Could not resolve hostname %s.
449 """ % self.host