]> scripts.mit.edu Git - wizard.git/blob - wizard/util.py
Refactor more boilerplate out.
[wizard.git] / wizard / util.py
1 """
2 Miscellaneous utility functions and classes.
3
4 .. testsetup:: *
5
6     from wizard.util import *
7 """
8
9 import os.path
10 import os
11 import subprocess
12 import pwd
13 import sys
14 import socket
15 import errno
16 import itertools
17 import signal
18 import httplib
19 import urllib
20
21 import wizard
22
23 class ChangeDirectory(object):
24     """
25     Context for temporarily changing the working directory.
26
27         >>> with ChangeDirectory("/tmp"):
28         ...    print os.getcwd()
29         /tmp
30     """
31     def __init__(self, dir):
32         self.dir = dir
33         self.olddir = None
34     def __enter__(self):
35         self.olddir = os.getcwd()
36         chdir(self.dir)
37     def __exit__(self, *args):
38         chdir(self.olddir)
39
40 class Counter(object):
41     """
42     Object for counting different values when you don't know what
43     they are a priori.  Supports index access and iteration.
44
45         >>> counter = Counter()
46         >>> counter.count("foo")
47         >>> print counter["foo"]
48         1
49     """
50     def __init__(self):
51         self.dict = {}
52     def count(self, value):
53         """Increments count for ``value``."""
54         self.dict.setdefault(value, 0)
55         self.dict[value] += 1
56     def __getitem__(self, key):
57         return self.dict[key]
58     def __iter__(self):
59         return self.dict.__iter__()
60     def max(self):
61         """Returns the max counter value seen."""
62         return max(self.dict.values())
63     def sum(self):
64         """Returns the sum of all counter values."""
65         return sum(self.dict.values())
66     def keys(self):
67         """Returns the keys of counters."""
68         return self.dict.keys()
69
70 class PipeToLess(object):
71     """
72     Context for printing output to a pager.  Use this if output
73     is expected to be long.
74     """
75     def __enter__(self):
76         self.proc = subprocess.Popen("less", stdin=subprocess.PIPE)
77         self.old_stdout = sys.stdout
78         sys.stdout = self.proc.stdin
79     def __exit__(self, *args):
80         if self.proc:
81             self.proc.stdin.close()
82             self.proc.wait()
83             sys.stdout = self.old_stdout
84
85 class IgnoreKeyboardInterrupts(object):
86     """
87     Context for temporarily ignoring keyboard interrupts.  Use this
88     if aborting would cause more harm than finishing the job.
89     """
90     def __enter__(self):
91         signal.signal(signal.SIGINT,signal.SIG_IGN)
92     def __exit__(self, *args):
93         signal.signal(signal.SIGINT, signal.default_int_handler)
94
95 class LockDirectory(object):
96     """
97     Context for locking a directory.
98     """
99     def __init__(self, lockfile):
100         self.lockfile = lockfile
101     def __enter__(self):
102         try:
103             os.open(self.lockfile, os.O_CREAT | os.O_EXCL)
104         except OSError as e:
105             if e.errno == errno.EEXIST:
106                 raise DirectoryLockedError(os.getcwd())
107             elif e.errno == errno.EACCES:
108                 raise PermissionsError(os.getcwd())
109             raise
110     def __exit__(self, *args):
111         try:
112             os.unlink(self.lockfile)
113         except OSError:
114             pass
115
116 def chdir(dir):
117     """
118     Changes a directory, but has special exceptions for certain
119     classes of errors.
120     """
121     try:
122         os.chdir(dir)
123     except OSError as e:
124         if e.errno == errno.EACCES:
125             raise PermissionsError()
126         elif e.errno == errno.ENOENT:
127             raise NoSuchDirectoryError()
128         else: raise e
129
130 def dictmap(f, d):
131     """
132     A map function for dictionaries.  Only changes values.
133
134         >>> dictmap(lambda x: x + 2, {'a': 1, 'b': 2})
135         {'a': 3, 'b': 4}
136     """
137     return dict((k,f(v)) for k,v in d.items())
138
139 def dictkmap(f, d):
140     """
141     A map function for dictionaries that passes key and value.
142
143         >>> dictkmap(lambda x, y: x + y, {1: 4, 3: 4})
144         {1: 5, 3: 7}
145     """
146     return dict((k,f(k,v)) for k,v in d.items())
147
148 def get_exception_name(output):
149     """
150     Reads the traceback from a Python program and grabs the
151     fully qualified exception name.
152     """
153     lines = output.split("\n")
154     cue = False
155     result = "(unknown)"
156     for line in lines[1:]:
157         line = line.rstrip()
158         if not line: continue
159         if line[0] == ' ':
160             cue = True
161             continue
162         if cue:
163             cue = False
164             return line.partition(':')[0]
165     return result
166
167 def get_dir_uid(dir):
168     """Finds the uid of the person who owns this directory."""
169     return os.stat(dir).st_uid
170
171 def get_dir_owner(dir = "."):
172     """
173     Finds the name of the locker this directory is in.
174
175     .. note::
176
177         This function uses the passwd database and thus
178         only works on scripts servers when querying directories
179         that live on AFS.
180     """
181     uid = get_dir_uid(dir)
182     try:
183         pwentry = pwd.getpwuid(uid)
184         return pwentry.pw_name
185     except KeyError:
186         # do an pts query to get the name
187         return subprocess.Popen(['pts', 'examine', str(uid)], stdout=subprocess.PIPE).communicate()[0].partition(",")[0].partition(": ")[2]
188
189 def get_revision():
190     """Returns the commit ID of the current Wizard install."""
191     # If you decide to convert this to use wizard.shell, be warned
192     # that there is a circular dependency, so this function would
193     # probably have to live somewhere else, probably wizard.git
194     wizard_git = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), ".git")
195     return subprocess.Popen(["git", "--git-dir=" + wizard_git, "rev-parse", "HEAD"], stdout=subprocess.PIPE).communicate()[0].rstrip()
196
197 def get_operator_info():
198     """
199     Returns tuple of ``(realname, email)`` about the person running
200     the script.  If run from a scripts server, get info from Hesiod.
201     Otherwise, use the passwd database (email generated probably won't
202     actually accept mail).  Useful when generating commit messages.
203     """
204     username = get_operator_name_from_gssapi()
205     if username:
206         # scripts approach
207         hesinfo = subprocess.Popen(["hesinfo", username, "passwd"],stdout=subprocess.PIPE).communicate()[0]
208         fields = hesinfo.partition(",")[0]
209         realname = fields.rpartition(":")[2]
210         return realname, username + "@mit.edu"
211     else:
212         # more traditional approach, but the email probably doesn't work
213         uid = os.getuid()
214         if not uid:
215             # since root isn't actually a useful designation, but maybe
216             # SUDO_USER contains something helpful
217             sudo_user = os.getenv("SUDO_USER")
218             if not sudo_user:
219                 raise NoOperatorInfo
220             pwdentry = pwd.getpwnam(sudo_user)
221         else:
222             pwdentry = pwd.getpwuid(uid)
223         # XXX: error checking might be nice
224         # We follow the Ubuntu convention of gecos being a comma split field
225         # with the person's realname being the first entry.
226         return pwdentry.pw_gecos.split(",")[0], pwdentry.pw_name + "@" + socket.gethostname()
227
228 def get_operator_git():
229     """
230     Returns ``Real Name <username@mit.edu>`` suitable for use in
231     Git ``Something-by:`` string.
232     """
233     return "%s <%s>" % get_operator_info()
234
235 def get_operator_name_from_gssapi():
236     """
237     Returns username of the person operating this script based
238     off of the :envvar:`SSH_GSSAPI_NAME` environment variable.
239
240     .. note::
241
242         :envvar:`SSH_GSSAPI_NAME` is not set by a vanilla OpenSSH
243         distributions.  Scripts servers are patched to support this
244         environment variable.
245     """
246     principal = os.getenv("SSH_GSSAPI_NAME")
247     if not principal:
248         return None
249     instance, _, _ = principal.partition("@")
250     if instance.endswith("/root"):
251         username, _, _ = principal.partition("/")
252     else:
253         username = instance
254     return username
255
256 def set_operator_env():
257     """
258     Sets :envvar:`GIT_COMMITTER_NAME` and :envvar:`GIT_COMMITTER_EMAIL`
259     environment variables if applicable.  Does nothing if
260     :func:`get_operator_info` throws :exc:`NoOperatorInfo`.
261     """
262     try:
263         op_realname, op_email = get_operator_info()
264         os.putenv("GIT_COMMITTER_NAME", op_realname)
265         os.putenv("GIT_COMMITTER_EMAIL", op_email)
266     except NoOperatorInfo:
267         pass
268
269 def set_author_env():
270     """
271     Sets :envvar:`GIT_AUTHOR_NAME` and :envvar:`GIT_AUTHOR_EMAIL` environment
272     variables if applicable. Does nothing if :func:`get_dir_owner` fails.
273     """
274     try:
275         # XXX: should check if the directory is in AFS, and if not, use
276         # a more traditional metric
277         lockername = get_dir_owner()
278         os.putenv("GIT_AUTHOR_NAME", "%s locker" % lockername)
279         os.putenv("GIT_AUTHOR_EMAIL", "%s@scripts.mit.edu" % lockername)
280     except KeyError: # XXX: This doesn't actually make sense
281         pass
282
283 def set_git_env():
284     """Sets all appropriate environment variables for Git commits."""
285     set_operator_env()
286     set_author_env()
287
288 def get_git_footer():
289     """Returns strings for placing in Git log info about Wizard."""
290     return "\n".join(["Wizard-revision: %s" % get_revision()
291         ,"Wizard-args: %s" % " ".join(sys.argv)
292         ])
293
294 def safe_unlink(file):
295     """Moves a file/dir to a backup location."""
296     if not os.path.exists(file):
297         return None
298     prefix = "%s.bak" % file
299     name = None
300     for i in itertools.count():
301         name = "%s.%d" % (prefix, i)
302         if not os.path.exists(name):
303             break
304     os.rename(file, name)
305     return name
306
307 def soft_unlink(file):
308     """Unlink a file, but don't complain if it doesn't exist."""
309     try:
310         os.unlink(file)
311     except OSError:
312         pass
313
314 def fetch(host, path, subpath, post=None):
315     h = httplib.HTTPConnection(host)
316     fullpath = path.rstrip("/") + "/" + subpath.lstrip("/") # to be lenient about input we accept
317     if post:
318         headers = {"Content-type": "application/x-www-form-urlencoded"}
319         h.request("POST", fullpath, urllib.urlencode(post), headers)
320     else:
321         h.request("GET", fullpath)
322     r = h.getresponse()
323     data = r.read()
324     h.close()
325     return data
326
327 class NoOperatorInfo(wizard.Error):
328     """No information could be found about the operator from Kerberos."""
329     pass
330
331 class PermissionsError(IOError):
332     errno = errno.EACCES
333
334 class NoSuchDirectoryError(IOError):
335     errno = errno.ENOENT
336
337 class DirectoryLockedError(wizard.Error):
338     def __init__(self, dir):
339         self.dir = dir
340     def __str__(self):
341         return """
342
343 ERROR: Could not acquire lock on directory.  Maybe there is
344 another migration process running?
345 """
346