]> scripts.mit.edu Git - wizard.git/blob - wizard/util.py
Fix non-zero shell exit code and --continue in wrong directory.
[wizard.git] / wizard / util.py
1 """
2 Miscellaneous utility functions and classes.
3
4 .. testsetup:: *
5
6     from wizard.util import *
7 """
8
9 import os.path
10 import os
11 import subprocess
12 import pwd
13 import sys
14 import socket
15 import errno
16 import itertools
17 import signal
18 import httplib
19 import urllib
20 import time
21 import logging
22
23 import wizard
24
25 class ChangeDirectory(object):
26     """
27     Context for temporarily changing the working directory.
28
29         >>> with ChangeDirectory("/tmp"):
30         ...    print os.getcwd()
31         /tmp
32     """
33     def __init__(self, dir):
34         self.dir = dir
35         self.olddir = None
36     def __enter__(self):
37         self.olddir = os.getcwd()
38         chdir(self.dir)
39     def __exit__(self, *args):
40         chdir(self.olddir)
41
42 class Counter(object):
43     """
44     Object for counting different values when you don't know what
45     they are a priori.  Supports index access and iteration.
46
47         >>> counter = Counter()
48         >>> counter.count("foo")
49         >>> print counter["foo"]
50         1
51     """
52     def __init__(self):
53         self.dict = {}
54     def count(self, value):
55         """Increments count for ``value``."""
56         self.dict.setdefault(value, 0)
57         self.dict[value] += 1
58     def __getitem__(self, key):
59         return self.dict[key]
60     def __iter__(self):
61         return self.dict.__iter__()
62     def max(self):
63         """Returns the max counter value seen."""
64         return max(self.dict.values())
65     def sum(self):
66         """Returns the sum of all counter values."""
67         return sum(self.dict.values())
68     def keys(self):
69         """Returns the keys of counters."""
70         return self.dict.keys()
71
72 class PipeToLess(object):
73     """
74     Context for printing output to a pager.  Use this if output
75     is expected to be long.
76     """
77     def __enter__(self):
78         self.proc = subprocess.Popen("less", stdin=subprocess.PIPE)
79         self.old_stdout = sys.stdout
80         sys.stdout = self.proc.stdin
81     def __exit__(self, *args):
82         if self.proc:
83             self.proc.stdin.close()
84             self.proc.wait()
85             sys.stdout = self.old_stdout
86
87 class IgnoreKeyboardInterrupts(object):
88     """
89     Context for temporarily ignoring keyboard interrupts.  Use this
90     if aborting would cause more harm than finishing the job.
91     """
92     def __enter__(self):
93         signal.signal(signal.SIGINT,signal.SIG_IGN)
94     def __exit__(self, *args):
95         signal.signal(signal.SIGINT, signal.default_int_handler)
96
97 class LockDirectory(object):
98     """
99     Context for locking a directory.
100     """
101     def __init__(self, lockfile, expiry = 3600):
102         self.lockfile = lockfile
103         self.expiry = expiry # by default an hour
104     def __enter__(self):
105         # It's A WAVY
106         for i in range(0, 3):
107             try:
108                 os.open(self.lockfile, os.O_CREAT | os.O_EXCL)
109                 open(self.lockfile, "w").write("%d" % os.getpid())
110             except OSError as e:
111                 if e.errno == errno.EEXIST:
112                     # There is a possibility of infinite recursion, but we
113                     # expect it to be unlikely, and not harmful if it does happen
114                     with LockDirectory(self.lockfile + "_"):
115                         # See if we can break the lock
116                         try:
117                             pid = open(self.lockfile, "r").read().strip()
118                             if not os.path.exists("/proc/%s" % pid):
119                                 # break the lock, try again
120                                 logging.warning("Breaking orphaned lock at %s", self.lockfile)
121                                 os.unlink(self.lockfile)
122                                 continue
123                             try:
124                                 # check if the file is expiry old, if so, break the lock, try again
125                                 if time.time() - os.stat(self.lockfile).st_mtime > self.expiry:
126                                     logging.warning("Breaking stale lock at %s", self.lockfile)
127                                     os.unlink(self.lockfile)
128                                     continue
129                             except OSError as e:
130                                 if e.errno == errno.ENOENT:
131                                     continue
132                                 raise
133                         except IOError:
134                             # oh hey, it went away; try again
135                             continue
136                     raise DirectoryLockedError(os.getcwd())
137                 elif e.errno == errno.EACCES:
138                     raise PermissionsError(os.getcwd())
139                 raise
140             return
141         raise DirectoryLockedError(os.getcwd())
142     def __exit__(self, *args):
143         try:
144             os.unlink(self.lockfile)
145         except OSError:
146             pass
147
148 def chdir(dir):
149     """
150     Changes a directory, but has special exceptions for certain
151     classes of errors.
152     """
153     try:
154         os.chdir(dir)
155     except OSError as e:
156         if e.errno == errno.EACCES:
157             raise PermissionsError()
158         elif e.errno == errno.ENOENT:
159             raise NoSuchDirectoryError()
160         else: raise e
161
162 def dictmap(f, d):
163     """
164     A map function for dictionaries.  Only changes values.
165
166         >>> dictmap(lambda x: x + 2, {'a': 1, 'b': 2})
167         {'a': 3, 'b': 4}
168     """
169     return dict((k,f(v)) for k,v in d.items())
170
171 def dictkmap(f, d):
172     """
173     A map function for dictionaries that passes key and value.
174
175         >>> dictkmap(lambda x, y: x + y, {1: 4, 3: 4})
176         {1: 5, 3: 7}
177     """
178     return dict((k,f(k,v)) for k,v in d.items())
179
180 def get_exception_name(output):
181     """
182     Reads the traceback from a Python program and grabs the
183     fully qualified exception name.
184     """
185     lines = output.split("\n")
186     cue = False
187     result = "(unknown)"
188     for line in lines[1:]:
189         line = line.rstrip()
190         if not line: continue
191         if line[0] == ' ':
192             cue = True
193             continue
194         if cue:
195             cue = False
196             return line.partition(':')[0]
197     return result
198
199 def get_dir_uid(dir):
200     """Finds the uid of the person who owns this directory."""
201     return os.stat(dir).st_uid
202
203 def get_dir_owner(dir = "."):
204     """
205     Finds the name of the locker this directory is in.
206
207     .. note::
208
209         This function uses the passwd database and thus
210         only works on scripts servers when querying directories
211         that live on AFS.
212     """
213     uid = get_dir_uid(dir)
214     try:
215         pwentry = pwd.getpwuid(uid)
216         return pwentry.pw_name
217     except KeyError:
218         # do an pts query to get the name
219         return subprocess.Popen(['pts', 'examine', str(uid)], stdout=subprocess.PIPE).communicate()[0].partition(",")[0].partition(": ")[2]
220
221 def get_revision():
222     """Returns the commit ID of the current Wizard install."""
223     # If you decide to convert this to use wizard.shell, be warned
224     # that there is a circular dependency, so this function would
225     # probably have to live somewhere else, probably wizard.git
226     wizard_git = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), ".git")
227     return subprocess.Popen(["git", "--git-dir=" + wizard_git, "rev-parse", "HEAD"], stdout=subprocess.PIPE).communicate()[0].rstrip()
228
229 def get_operator_info():
230     """
231     Returns tuple of ``(realname, email)`` about the person running
232     the script.  If run from a scripts server, get info from Hesiod.
233     Otherwise, use the passwd database (email generated probably won't
234     actually accept mail).  Useful when generating commit messages.
235     """
236     username = get_operator_name_from_gssapi()
237     if username:
238         # scripts approach
239         hesinfo = subprocess.Popen(["hesinfo", username, "passwd"],stdout=subprocess.PIPE).communicate()[0]
240         fields = hesinfo.partition(",")[0]
241         realname = fields.rpartition(":")[2]
242         return realname, username + "@mit.edu"
243     else:
244         # more traditional approach, but the email probably doesn't work
245         uid = os.getuid()
246         if not uid:
247             # since root isn't actually a useful designation, but maybe
248             # SUDO_USER contains something helpful
249             sudo_user = os.getenv("SUDO_USER")
250             if not sudo_user:
251                 raise NoOperatorInfo
252             pwdentry = pwd.getpwnam(sudo_user)
253         else:
254             pwdentry = pwd.getpwuid(uid)
255         # XXX: error checking might be nice
256         # We follow the Ubuntu convention of gecos being a comma split field
257         # with the person's realname being the first entry.
258         return pwdentry.pw_gecos.split(",")[0], pwdentry.pw_name + "@" + socket.gethostname()
259
260 def get_operator_git():
261     """
262     Returns ``Real Name <username@mit.edu>`` suitable for use in
263     Git ``Something-by:`` string.
264     """
265     return "%s <%s>" % get_operator_info()
266
267 def get_operator_name_from_gssapi():
268     """
269     Returns username of the person operating this script based
270     off of the :envvar:`SSH_GSSAPI_NAME` environment variable.
271
272     .. note::
273
274         :envvar:`SSH_GSSAPI_NAME` is not set by a vanilla OpenSSH
275         distributions.  Scripts servers are patched to support this
276         environment variable.
277     """
278     principal = os.getenv("SSH_GSSAPI_NAME")
279     if not principal:
280         return None
281     instance, _, _ = principal.partition("@")
282     if instance.endswith("/root"):
283         username, _, _ = principal.partition("/")
284     else:
285         username = instance
286     return username
287
288 def set_operator_env():
289     """
290     Sets :envvar:`GIT_COMMITTER_NAME` and :envvar:`GIT_COMMITTER_EMAIL`
291     environment variables if applicable.  Does nothing if
292     :func:`get_operator_info` throws :exc:`NoOperatorInfo`.
293     """
294     try:
295         op_realname, op_email = get_operator_info()
296         os.putenv("GIT_COMMITTER_NAME", op_realname)
297         os.putenv("GIT_COMMITTER_EMAIL", op_email)
298     except NoOperatorInfo:
299         pass
300
301 def set_author_env():
302     """
303     Sets :envvar:`GIT_AUTHOR_NAME` and :envvar:`GIT_AUTHOR_EMAIL` environment
304     variables if applicable. Does nothing if :func:`get_dir_owner` fails.
305     """
306     try:
307         # XXX: should check if the directory is in AFS, and if not, use
308         # a more traditional metric
309         lockername = get_dir_owner()
310         os.putenv("GIT_AUTHOR_NAME", "%s locker" % lockername)
311         os.putenv("GIT_AUTHOR_EMAIL", "%s@scripts.mit.edu" % lockername)
312     except KeyError: # XXX: This doesn't actually make sense
313         pass
314
315 def set_git_env():
316     """Sets all appropriate environment variables for Git commits."""
317     set_operator_env()
318     set_author_env()
319
320 def get_git_footer():
321     """Returns strings for placing in Git log info about Wizard."""
322     return "\n".join(["Wizard-revision: %s" % get_revision()
323         ,"Wizard-args: %s" % " ".join(sys.argv)
324         ])
325
326 def safe_unlink(file):
327     """Moves a file/dir to a backup location."""
328     if not os.path.exists(file):
329         return None
330     prefix = "%s.bak" % file
331     name = None
332     for i in itertools.count():
333         name = "%s.%d" % (prefix, i)
334         if not os.path.exists(name):
335             break
336     os.rename(file, name)
337     return name
338
339 def soft_unlink(file):
340     """Unlink a file, but don't complain if it doesn't exist."""
341     try:
342         os.unlink(file)
343     except OSError:
344         pass
345
346 def fetch(host, path, subpath, post=None):
347     h = httplib.HTTPConnection(host)
348     fullpath = path.rstrip("/") + "/" + subpath.lstrip("/") # to be lenient about input we accept
349     if post:
350         headers = {"Content-type": "application/x-www-form-urlencoded"}
351         h.request("POST", fullpath, urllib.urlencode(post), headers)
352     else:
353         h.request("GET", fullpath)
354     r = h.getresponse()
355     data = r.read()
356     h.close()
357     return data
358
359 def mixed_newlines(filename):
360     """Returns ``True`` if ``filename`` has mixed newlines."""
361     f = open(filename, "U") # requires universal newline support
362     f.read()
363     ret = isinstance(f.newlines, tuple)
364     f.close() # just to be safe
365     return ret
366
367 class NoOperatorInfo(wizard.Error):
368     """No information could be found about the operator from Kerberos."""
369     pass
370
371 class PermissionsError(IOError):
372     errno = errno.EACCES
373
374 class NoSuchDirectoryError(IOError):
375     errno = errno.ENOENT
376
377 class DirectoryLockedError(wizard.Error):
378     def __init__(self, dir):
379         self.dir = dir
380     def __str__(self):
381         return """
382
383 ERROR: Could not acquire lock on directory.  Maybe there is
384 another migration process running?
385 """
386