]> scripts.mit.edu Git - wizard.git/blob - wizard/util.py
Rewrite parametrize to use new parametrizeWithVars
[wizard.git] / wizard / util.py
1 """
2 Miscellaneous utility functions and classes.
3
4 .. testsetup:: *
5
6     from wizard.util import *
7 """
8
9 import os.path
10 import os
11 import subprocess
12 import pwd
13 import sys
14 import socket
15 import errno
16 import itertools
17 import signal
18 import httplib
19 import urllib
20 import time
21 import logging
22 import random
23 import string
24
25 import wizard
26
27 class ChangeDirectory(object):
28     """
29     Context for temporarily changing the working directory.
30
31         >>> with ChangeDirectory("/tmp"):
32         ...    print os.getcwd()
33         /tmp
34     """
35     def __init__(self, dir):
36         self.dir = dir
37         self.olddir = None
38     def __enter__(self):
39         self.olddir = os.getcwd()
40         chdir(self.dir)
41     def __exit__(self, *args):
42         chdir(self.olddir)
43
44 class Counter(object):
45     """
46     Object for counting different values when you don't know what
47     they are a priori.  Supports index access and iteration.
48
49         >>> counter = Counter()
50         >>> counter.count("foo")
51         >>> print counter["foo"]
52         1
53     """
54     def __init__(self):
55         self.dict = {}
56     def count(self, value):
57         """Increments count for ``value``."""
58         self.dict.setdefault(value, 0)
59         self.dict[value] += 1
60     def __getitem__(self, key):
61         return self.dict[key]
62     def __iter__(self):
63         return self.dict.__iter__()
64     def max(self):
65         """Returns the max counter value seen."""
66         return max(self.dict.values())
67     def sum(self):
68         """Returns the sum of all counter values."""
69         return sum(self.dict.values())
70     def keys(self):
71         """Returns the keys of counters."""
72         return self.dict.keys()
73
74 class PipeToLess(object):
75     """
76     Context for printing output to a pager.  Use this if output
77     is expected to be long.
78     """
79     def __enter__(self):
80         self.proc = subprocess.Popen("less", stdin=subprocess.PIPE)
81         self.old_stdout = sys.stdout
82         sys.stdout = self.proc.stdin
83     def __exit__(self, *args):
84         if self.proc:
85             self.proc.stdin.close()
86             self.proc.wait()
87             sys.stdout = self.old_stdout
88
89 class IgnoreKeyboardInterrupts(object):
90     """
91     Context for temporarily ignoring keyboard interrupts.  Use this
92     if aborting would cause more harm than finishing the job.
93     """
94     def __enter__(self):
95         signal.signal(signal.SIGINT,signal.SIG_IGN)
96     def __exit__(self, *args):
97         signal.signal(signal.SIGINT, signal.default_int_handler)
98
99 class LockDirectory(object):
100     """
101     Context for locking a directory.
102     """
103     def __init__(self, lockfile, expiry = 3600):
104         self.lockfile = lockfile
105         self.expiry = expiry # by default an hour
106     def __enter__(self):
107         # It's A WAVY
108         for i in range(0, 3):
109             try:
110                 os.open(self.lockfile, os.O_CREAT | os.O_EXCL)
111                 open(self.lockfile, "w").write("%d" % os.getpid())
112             except OSError as e:
113                 if e.errno == errno.EEXIST:
114                     # There is a possibility of infinite recursion, but we
115                     # expect it to be unlikely, and not harmful if it does happen
116                     with LockDirectory(self.lockfile + "_"):
117                         # See if we can break the lock
118                         try:
119                             pid = open(self.lockfile, "r").read().strip()
120                             if not os.path.exists("/proc/%s" % pid):
121                                 # break the lock, try again
122                                 logging.warning("Breaking orphaned lock at %s", self.lockfile)
123                                 os.unlink(self.lockfile)
124                                 continue
125                             try:
126                                 # check if the file is expiry old, if so, break the lock, try again
127                                 if time.time() - os.stat(self.lockfile).st_mtime > self.expiry:
128                                     logging.warning("Breaking stale lock at %s", self.lockfile)
129                                     os.unlink(self.lockfile)
130                                     continue
131                             except OSError as e:
132                                 if e.errno == errno.ENOENT:
133                                     continue
134                                 raise
135                         except IOError:
136                             # oh hey, it went away; try again
137                             continue
138                     raise DirectoryLockedError(os.getcwd())
139                 elif e.errno == errno.EACCES:
140                     raise PermissionsError(os.getcwd())
141                 raise
142             return
143         raise DirectoryLockedError(os.getcwd())
144     def __exit__(self, *args):
145         try:
146             os.unlink(self.lockfile)
147         except OSError:
148             pass
149
150 def chdir(dir):
151     """
152     Changes a directory, but has special exceptions for certain
153     classes of errors.
154     """
155     try:
156         os.chdir(dir)
157     except OSError as e:
158         if e.errno == errno.EACCES:
159             raise PermissionsError()
160         elif e.errno == errno.ENOENT:
161             raise NoSuchDirectoryError()
162         else: raise e
163
164 def dictmap(f, d):
165     """
166     A map function for dictionaries.  Only changes values.
167
168         >>> dictmap(lambda x: x + 2, {'a': 1, 'b': 2})
169         {'a': 3, 'b': 4}
170     """
171     return dict((k,f(v)) for k,v in d.items())
172
173 def dictkmap(f, d):
174     """
175     A map function for dictionaries that passes key and value.
176
177         >>> dictkmap(lambda x, y: x + y, {1: 4, 3: 4})
178         {1: 5, 3: 7}
179     """
180     return dict((k,f(k,v)) for k,v in d.items())
181
182 def get_exception_name(output):
183     """
184     Reads the traceback from a Python program and grabs the
185     fully qualified exception name.
186     """
187     lines = output.split("\n")
188     cue = False
189     result = "(unknown)"
190     for line in lines[1:]:
191         line = line.rstrip()
192         if not line: continue
193         if line[0] == ' ':
194             cue = True
195             continue
196         if cue:
197             cue = False
198             return line.partition(':')[0]
199     return result
200
201 def get_dir_uid(dir):
202     """Finds the uid of the person who owns this directory."""
203     return os.stat(dir).st_uid
204
205 def get_dir_owner(dir = "."):
206     """
207     Finds the name of the locker this directory is in.
208
209     .. note::
210
211         This function uses the passwd database and thus
212         only works on scripts servers when querying directories
213         that live on AFS.
214     """
215     uid = get_dir_uid(dir)
216     try:
217         pwentry = pwd.getpwuid(uid)
218         return pwentry.pw_name
219     except KeyError:
220         # do an pts query to get the name
221         return subprocess.Popen(['pts', 'examine', str(uid)], stdout=subprocess.PIPE).communicate()[0].partition(",")[0].partition(": ")[2]
222
223 def get_revision():
224     """Returns the commit ID of the current Wizard install."""
225     # If you decide to convert this to use wizard.shell, be warned
226     # that there is a circular dependency, so this function would
227     # probably have to live somewhere else, probably wizard.git
228     wizard_git = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), ".git")
229     return subprocess.Popen(["git", "--git-dir=" + wizard_git, "rev-parse", "HEAD"], stdout=subprocess.PIPE).communicate()[0].rstrip()
230
231 def get_operator_info():
232     """
233     Returns tuple of ``(realname, email)`` about the person running
234     the script.  If run from a scripts server, get info from Hesiod.
235     Otherwise, use the passwd database (email generated probably won't
236     actually accept mail).  Useful when generating commit messages.
237     """
238     username = get_operator_name_from_gssapi()
239     if username:
240         # scripts approach
241         hesinfo = subprocess.Popen(["hesinfo", username, "passwd"],stdout=subprocess.PIPE).communicate()[0]
242         fields = hesinfo.partition(",")[0]
243         realname = fields.rpartition(":")[2]
244         return realname, username + "@mit.edu"
245     else:
246         # more traditional approach, but the email probably doesn't work
247         uid = os.getuid()
248         if not uid:
249             # since root isn't actually a useful designation, but maybe
250             # SUDO_USER contains something helpful
251             sudo_user = os.getenv("SUDO_USER")
252             if not sudo_user:
253                 raise NoOperatorInfo
254             pwdentry = pwd.getpwnam(sudo_user)
255         else:
256             pwdentry = pwd.getpwuid(uid)
257         # XXX: error checking might be nice
258         # We follow the Ubuntu convention of gecos being a comma split field
259         # with the person's realname being the first entry.
260         return pwdentry.pw_gecos.split(",")[0], pwdentry.pw_name + "@" + socket.gethostname()
261
262 def get_operator_git():
263     """
264     Returns ``Real Name <username@mit.edu>`` suitable for use in
265     Git ``Something-by:`` string.
266     """
267     return "%s <%s>" % get_operator_info()
268
269 def get_operator_name_from_gssapi():
270     """
271     Returns username of the person operating this script based
272     off of the :envvar:`SSH_GSSAPI_NAME` environment variable.
273
274     .. note::
275
276         :envvar:`SSH_GSSAPI_NAME` is not set by a vanilla OpenSSH
277         distributions.  Scripts servers are patched to support this
278         environment variable.
279     """
280     principal = os.getenv("SSH_GSSAPI_NAME")
281     if not principal:
282         return None
283     instance, _, _ = principal.partition("@")
284     if instance.endswith("/root"):
285         username, _, _ = principal.partition("/")
286     else:
287         username = instance
288     return username
289
290 def set_operator_env():
291     """
292     Sets :envvar:`GIT_COMMITTER_NAME` and :envvar:`GIT_COMMITTER_EMAIL`
293     environment variables if applicable.  Does nothing if
294     :func:`get_operator_info` throws :exc:`NoOperatorInfo`.
295     """
296     try:
297         op_realname, op_email = get_operator_info()
298         os.putenv("GIT_COMMITTER_NAME", op_realname)
299         os.putenv("GIT_COMMITTER_EMAIL", op_email)
300     except NoOperatorInfo:
301         pass
302
303 def set_author_env():
304     """
305     Sets :envvar:`GIT_AUTHOR_NAME` and :envvar:`GIT_AUTHOR_EMAIL` environment
306     variables if applicable. Does nothing if :func:`get_dir_owner` fails.
307     """
308     try:
309         # XXX: should check if the directory is in AFS, and if not, use
310         # a more traditional metric
311         lockername = get_dir_owner()
312         os.putenv("GIT_AUTHOR_NAME", "%s locker" % lockername)
313         os.putenv("GIT_AUTHOR_EMAIL", "%s@scripts.mit.edu" % lockername)
314     except KeyError: # XXX: This doesn't actually make sense
315         pass
316
317 def set_git_env():
318     """Sets all appropriate environment variables for Git commits."""
319     set_operator_env()
320     set_author_env()
321
322 def get_git_footer():
323     """Returns strings for placing in Git log info about Wizard."""
324     return "\n".join(["Wizard-revision: %s" % get_revision()
325         ,"Wizard-args: %s" % " ".join(sys.argv)
326         ])
327
328 def safe_unlink(file):
329     """Moves a file/dir to a backup location."""
330     if not os.path.exists(file):
331         return None
332     prefix = "%s.bak" % file
333     name = None
334     for i in itertools.count():
335         name = "%s.%d" % (prefix, i)
336         if not os.path.exists(name):
337             break
338     os.rename(file, name)
339     return name
340
341 def soft_unlink(file):
342     """Unlink a file, but don't complain if it doesn't exist."""
343     try:
344         os.unlink(file)
345     except OSError:
346         pass
347
348 def fetch(host, path, subpath, post=None):
349     try:
350         # XXX: Special case if it's https; not sure why this data isn't
351         # passed
352         h = httplib.HTTPConnection(host)
353         fullpath = path.rstrip("/") + "/" + subpath.lstrip("/") # to be lenient about input we accept
354         if post:
355             headers = {"Content-type": "application/x-www-form-urlencoded"}
356             h.request("POST", fullpath, urllib.urlencode(post), headers)
357         else:
358             h.request("GET", fullpath)
359         r = h.getresponse()
360         data = r.read()
361         h.close()
362         return data
363     except socket.gaierror as e:
364         if e.errno == socket.EAI_NONAME:
365             raise DNSError(host)
366         else:
367             raise
368
369 def mixed_newlines(filename):
370     """Returns ``True`` if ``filename`` has mixed newlines."""
371     f = open(filename, "U") # requires universal newline support
372     f.read()
373     ret = isinstance(f.newlines, tuple)
374     f.close() # just to be safe
375     return ret
376
377 def random_key(length=30):
378     """Generates a random alphanumeric key of ``length`` size."""
379     return ''.join(random.choice(string.letters + string.digits) for i in xrange(length))
380
381 class NoOperatorInfo(wizard.Error):
382     """No information could be found about the operator from Kerberos."""
383     pass
384
385 class PermissionsError(IOError):
386     errno = errno.EACCES
387
388 class NoSuchDirectoryError(IOError):
389     errno = errno.ENOENT
390
391 class DirectoryLockedError(wizard.Error):
392     def __init__(self, dir):
393         self.dir = dir
394     def __str__(self):
395         return """
396
397 ERROR: Could not acquire lock on directory.  Maybe there is
398 another migration process running?
399 """
400
401 class DNSError(socket.gaierror):
402     errno = socket.EAI_NONAME
403     #: Hostname that could not resolve name
404     host = None
405     def __init__(self, host):
406         self.host = host
407     def __str__(self):
408         return """
409
410 ERROR: Could not resolve hostname %s.
411 """ % self.host