]> scripts.mit.edu Git - wizard.git/blobdiff - wizard/app/__init__.py
Remove string exception from remaster.
[wizard.git] / wizard / app / __init__.py
index de127369060dfa413deea9bccf1e68c211edbb22..6a5ee6a05ac00588e33df8af2961cf9e5020dfa6 100644 (file)
@@ -6,6 +6,18 @@ You'll need to know how to overload the :class:`Application` class
 and use some of the functions in this module in order to specify
 new applications.
 
 and use some of the functions in this module in order to specify
 new applications.
 
+To specify custom applications as plugins,  add the following ``entry_points``
+configuration::
+
+    [wizard.app]
+    yourappname = your.module:Application
+    otherappname = your.other.module:Application
+
+.. note::
+
+    Wizard will complain loudly if ``yourappname`` conflicts with an
+    application name defined by someone else.
+
 There are some submodules for programming languages that define common
 functions and data that may be used by applications in that language.  See:
 
 There are some submodules for programming languages that define common
 functions and data that may be used by applications in that language.  See:
 
@@ -21,6 +33,7 @@ functions and data that may be used by applications in that language.  See:
 """
 
 import os.path
 """
 
 import os.path
+import subprocess
 import re
 import distutils.version
 import decorator
 import re
 import distutils.version
 import decorator
@@ -28,29 +41,49 @@ import shlex
 import logging
 import shutil
 import sqlalchemy
 import logging
 import shutil
 import sqlalchemy
-import random
+import sqlalchemy.exc
 import string
 import urlparse
 import tempfile
 import string
 import urlparse
 import tempfile
+import pkg_resources
+import traceback
 
 import wizard
 
 import wizard
-from wizard import resolve, scripts, shell, util
-
-_application_list = [
-    "mediawiki", "wordpress", "joomla", "e107", "gallery2",
-    "phpBB", "advancedbook", "phpical", "trac", "turbogears", "django",
-    # these are technically deprecated
-    "advancedpoll", "gallery",
-]
-_applications = None
+from wizard import plugin, resolve, shell, sql, util
 
 
+_applications = None
 def applications():
     """Hash table for looking up string application name to instance"""
     global _applications
     if not _applications:
 def applications():
     """Hash table for looking up string application name to instance"""
     global _applications
     if not _applications:
-        _applications = dict([(n,Application.make(n)) for n in _application_list ])
+        _applications = dict()
+        for dist in pkg_resources.working_set:
+            for appname, entry in dist.get_entry_map("wizard.app").items():
+                if appname in _applications:
+                    newname = dist.key + ":" + appname
+                    if newname in _applications:
+                        raise Exception("Unrecoverable application name conflict for %s from %s", appname, dist.key)
+                    logging.warning("Could not overwrite %s, used %s instead", appname, newname)
+                    appname = newname
+                appclass = entry.load()
+                _applications[appname] = appclass(appname)
+        # setup dummy apps
+        for entry in pkg_resources.iter_entry_points("wizard.dummy_apps"):
+            appfun = entry.load()
+            dummy_apps = appfun()
+            for appname in dummy_apps:
+                # a dummy app that already exists is not a fatal error
+                if appname in _applications:
+                    continue
+                _applications[appname] = Application(appname)
     return _applications
 
     return _applications
 
+def getApplication(appname):
+    """Retrieves application instance given a name"""
+    try:
+        return applications()[appname]
+    except KeyError:
+        raise NoSuchApplication(appname)
 
 class Application(object):
     """
 
 class Application(object):
     """
@@ -74,6 +107,9 @@ class Application(object):
     deprecated_keys = set()
     #: Keys that we can simply generate random strings for if they're missing
     random_keys = set()
     deprecated_keys = set()
     #: Keys that we can simply generate random strings for if they're missing
     random_keys = set()
+    #: Values that are not sufficiently random for a random key.  This can
+    #: include default values for a random configuration option,
+    random_blacklist = set()
     #: Dictionary of variable names to extractor functions.  These functions
     #: take a :class:`wizard.deploy.Deployment` as an argument and return the value of
     #: the variable, or ``None`` if it could not be found.
     #: Dictionary of variable names to extractor functions.  These functions
     #: take a :class:`wizard.deploy.Deployment` as an argument and return the value of
     #: the variable, or ``None`` if it could not be found.
@@ -97,6 +133,8 @@ class Application(object):
     #: :class:`wizard.deploy.Deployment`; the value here is merely the preferred
     #: value.
     database = None
     #: :class:`wizard.deploy.Deployment`; the value here is merely the preferred
     #: value.
     database = None
+    #: Indicates whether or not a web stub is necessary.
+    needs_web_stub = False
     def __init__(self, name):
         self.name = name
         self.versions = {}
     def __init__(self, name):
         self.name = name
         self.versions = {}
@@ -132,13 +170,13 @@ class Application(object):
             result[k] = extractor(deployment)
         # XXX: ugh... we have to do quoting
         for k in self.random_keys:
             result[k] = extractor(deployment)
         # XXX: ugh... we have to do quoting
         for k in self.random_keys:
-            if result[k] is None:
-                result[k] = "'%s'" % ''.join(random.choice(string.letters + string.digits) for i in xrange(30))
+            if result[k] is None or result[k] in self.random_blacklist:
+                result[k] = "'%s'" % util.random_key()
         return result
     def dsn(self, deployment):
         """
         Returns the deployment specific database URL.  Uses the override file
         return result
     def dsn(self, deployment):
         """
         Returns the deployment specific database URL.  Uses the override file
-        in :file:`.scripts` if it exists, and otherwise attempt to extract the
+        in :file:`.wizard` if it exists, and otherwise attempt to extract the
         variables from the source files.
 
         Under some cases, the database URL will contain only the database
         variables from the source files.
 
         Under some cases, the database URL will contain only the database
@@ -184,7 +222,7 @@ class Application(object):
     def url(self, deployment):
         """
         Returns the deployment specific web URL.  Uses the override file
     def url(self, deployment):
         """
         Returns the deployment specific web URL.  Uses the override file
-        in :file:`.scripts` if it exists, and otherwise attempt to extract
+        in :file:`.wizard` if it exists, and otherwise attempt to extract
         the variables from the source files.
 
         This function might return ``None``, which indicates we couldn't figure
         the variables from the source files.
 
         This function might return ``None``, which indicates we couldn't figure
@@ -214,15 +252,27 @@ class Application(object):
         Takes a generic source checkout and parametrizes it according to the
         values of ``deployment``.  This function operates on the current
         working directory.  ``deployment`` should **not** be the same as the
         Takes a generic source checkout and parametrizes it according to the
         values of ``deployment``.  This function operates on the current
         working directory.  ``deployment`` should **not** be the same as the
-        current working directory.  Default implementation uses
-        :attr:`parametrized_files` and a simple search and replace on those
-        files.
+        current working directory.  See :meth:`parametrizeWithVars` for details
+        on the parametrization.
         """
         """
+        # deployment is not used in this implementation, but note that
+        # we do have the invariant the current directory matches
+        # deployment's directory
         variables = ref_deployment.extract()
         variables = ref_deployment.extract()
+        self.parametrizeWithVars(variables)
+    def parametrizeWithVars(self, variables):
+        """
+        Takes a generic source checkout and parametrizes it according to
+        the values of ``variables``.  Default implementation uses
+        :attr:`parametrized_files` and a simple search and replace on
+        those files.
+        """
         for file in self.parametrized_files:
         for file in self.parametrized_files:
+            logging.debug("Parametrizing file '%s'\n" % (file, ))
             try:
                 contents = open(file, "r").read()
             except IOError:
             try:
                 contents = open(file, "r").read()
             except IOError:
+                logging.debug("Failed to open file '%s'\n" % (file, ))
                 continue
             for key, value in variables.items():
                 if value is None: continue
                 continue
             for key, value in variables.items():
                 if value is None: continue
@@ -237,61 +287,24 @@ class Application(object):
         default implementation uses :attr:`resolutions`.
         """
         resolved = True
         default implementation uses :attr:`resolutions`.
         """
         resolved = True
-        sh = shell.Shell()
         files = set()
         files = set()
-        for status in sh.eval("git", "ls-files", "--unmerged").splitlines():
-            files.add(status.split()[-1])
-        for file in files:
-            # check for newline mismatch
-            # HACK: using git diff to tell if files are binary or not
-            if not len(sh.eval("git", "diff", file).splitlines()) == 1 and util.mixed_newlines(file):
-                # this code only works on Unix
-                def get_newline(filename):
-                    f = open(filename, "U")
-                    # for some reason I need two
-                    s = f.readline()
-                    if s != "" and f.newlines is None:
-                        f.readline()
-                    if not isinstance(f.newlines, str):
-                        raise Exception("Assert: expected newlines to be string, instead was %s in %s" % (repr(f.newlines), file))
-                    return f.newlines
-                def create_reference(id):
-                    f = tempfile.NamedTemporaryFile(prefix="wizardResolve", delete=False)
-                    sh.call("git", "cat-file", "blob", ":%d:%s" % (id, file), stdout=f)
-                    f.close()
-                    return get_newline(f.name), f.name
-                def convert(filename, dest_nl):
-                    contents = open(filename, "U").read().replace("\n", dest_nl)
-                    open(filename, "wb").write(contents)
-                logging.info("Mixed newlines detected in %s", file)
-                common_nl, common_file = create_reference(1)
-                our_nl,    our_file    = create_reference(2)
-                their_nl,  their_file  = create_reference(3)
-                remerge = False
-                if common_nl != their_nl:
-                    # upstream can't keep their newlines straight
-                    logging.info("Converting common file (1) from %s to %s newlines", repr(common_nl), repr(their_nl))
-                    convert(common_file, their_nl)
-                    remerge = True
-                if our_nl != their_nl:
-                    # common case
-                    logging.info("Converting our file (2) from %s to %s newlines", repr(our_nl), repr(their_nl))
-                    convert(our_file, their_nl)
-                    remerge = True
-                if remerge:
-                    logging.info("Remerging %s", file)
-                    with open(file, "wb") as f:
-                        try:
-                            sh.call("git", "merge-file", "--stdout", our_file, common_file, their_file, stdout=f)
-                            logging.info("New merge was clean")
-                            sh.call("git", "add", file)
-                            continue
-                        except shell.CallError:
-                            pass
-                    logging.info("Merge was still unclean")
-                else:
-                    logging.warning("Mixed newlines detected in %s, but no remerge possible", file)
+        files = {}
+        for status in shell.eval("git", "ls-files", "--unmerged").splitlines():
+            mode, hash, role, name = status.split()
+            files.setdefault(name, set()).add(int(role))
+        for file, roles in files.items():
+            # some automatic resolutions
+            if 1 not in roles and 2 not in roles and 3 in roles:
+                # upstream added a file, but it conflicted for whatever reason
+                shell.call("git", "add", file)
+                continue
+            elif 1 in roles and 2 not in roles and 3 in roles:
+                # user deleted the file, but upstream changed it
+                shell.call("git", "rm", file)
+                continue
             # manual resolutions
             # manual resolutions
+            # XXX: this functionality is mostly subsumed by the rerere
+            # tricks we do
             if file in self.resolutions:
                 contents = open(file, "r").read()
                 for spec, result in self.resolutions[file]:
             if file in self.resolutions:
                 contents = open(file, "r").read()
                 for spec, result in self.resolutions[file]:
@@ -301,7 +314,7 @@ class Application(object):
                         logging.info("Did resolution with spec:\n" + spec)
                 open(file, "w").write(contents)
                 if not resolve.is_conflict(contents):
                         logging.info("Did resolution with spec:\n" + spec)
                 open(file, "w").write(contents)
                 if not resolve.is_conflict(contents):
-                    sh.call("git", "add", file)
+                    shell.call("git", "add", file)
                 else:
                     resolved = False
             else:
                 else:
                     resolved = False
             else:
@@ -326,7 +339,7 @@ class Application(object):
         """
         for key, subst in self.substitutions.items():
             subs = subst(deployment)
         """
         for key, subst in self.substitutions.items():
             subs = subst(deployment)
-            if not subs and key not in self.deprecated_keys:
+            if not subs and key not in self.deprecated_keys and key not in self.random_keys:
                 logging.warning("No substitutions for %s" % key)
     def install(self, version, options):
         """
                 logging.warning("No substitutions for %s" % key)
     def install(self, version, options):
         """
@@ -390,6 +403,20 @@ class Application(object):
         match = regex.search(contents)
         if not match: return None
         return distutils.version.LooseVersion(shlex.split(match.group(2))[0])
         match = regex.search(contents)
         if not match: return None
         return distutils.version.LooseVersion(shlex.split(match.group(2))[0])
+    # XXX: This signature doesn't really make too much sense...
+    def detectVersionFromGit(self, tagPattern, preStrip = ''):
+        """
+        Helper method that detects a version by using the most recent tag
+        in git that matches the specified pattern.
+        This assumes that the current working directory is the deployment.
+        """
+        sh = wizard.shell.Shell()
+        cmd = ['git', 'describe', '--tags', '--match', tagPattern, ]
+        tag = sh.call(*cmd, strip=True)
+        if tag and len(tag) > len(preStrip) and tag[:len(preStrip)] == preStrip:
+            tag = tag[len(preStrip):]
+        if not tag: return None
+        return distutils.version.LooseVersion(tag)
     def download(self, version):
         """
         Returns a URL that can be used to download a tarball of ``version`` of
     def download(self, version):
         """
         Returns a URL that can be used to download a tarball of ``version`` of
@@ -409,17 +436,33 @@ class Application(object):
             not to depend on pages that are not the main page.
         """
         raise NotImplementedError
             not to depend on pages that are not the main page.
         """
         raise NotImplementedError
-    def checkWebPage(self, deployment, page, output):
+    def checkDatabase(self, deployment):
+        """
+        Checks if the database is accessible.
+        """
+        try:
+            sql.connect(deployment.dsn)
+            return True
+        except sqlalchemy.exc.DBAPIError:
+            return False
+    def checkWebPage(self, deployment, page, outputs=[], exclude=[]):
         """
         Checks if a given page of an autoinstall contains a particular string.
         """
         page = deployment.fetch(page)
         """
         Checks if a given page of an autoinstall contains a particular string.
         """
         page = deployment.fetch(page)
-        result = page.find(output) != -1
-        if result:
+        for x in exclude:
+            if page.find(x) != -1:
+                logging.info("checkWebPage (failed due to %s):\n\n%s", x, page)
+                return False
+        votes = 0
+        for output in outputs:
+            votes += page.find(output) != -1
+        if votes > len(outputs) / 2:
             logging.debug("checkWebPage (passed):\n\n" + page)
             logging.debug("checkWebPage (passed):\n\n" + page)
+            return True
         else:
             logging.info("checkWebPage (failed):\n\n" + page)
         else:
             logging.info("checkWebPage (failed):\n\n" + page)
-        return result
+            return False
     def checkConfig(self, deployment):
         """
         Checks whether or not an autoinstall has been configured/installed
     def checkConfig(self, deployment):
         """
         Checks whether or not an autoinstall has been configured/installed
@@ -427,8 +470,8 @@ class Application(object):
         Subclasses should provide an implementation.
         """
         # XXX: Unfortunately, this doesn't quite work because we package
         Subclasses should provide an implementation.
         """
         # XXX: Unfortunately, this doesn't quite work because we package
-        # bogus config files in the -scripts versions of installs.  Maybe
-        # we should check a hash or something?
+        # bogus config files.  Maybe we should check a hash or
+        # something?
         raise NotImplementedError
     def researchFilter(self, filename, added, deleted):
         """
         raise NotImplementedError
     def researchFilter(self, filename, added, deleted):
         """
@@ -447,19 +490,6 @@ class Application(object):
         be displayed in verbose mode.
         """
         return filename in self.parametrized_files
         be displayed in verbose mode.
         """
         return filename in self.parametrized_files
-    @staticmethod
-    def make(name):
-        """Makes an application, but uses the correct subtype if available."""
-        try:
-            __import__("wizard.app." + name)
-            return getattr(wizard.app, name).Application(name)
-        except ImportError as error:
-            # XXX ugly hack to check if the import error is from the top level
-            # module we care about or a submodule. should be an archetectural change.
-            if error.args[0].split()[-1]==name:
-                return Application(name)
-            else:
-                raise
 
 class ApplicationVersion(object):
     """Represents an abstract notion of a version for an application, where
 
 class ApplicationVersion(object):
     """Represents an abstract notion of a version for an application, where
@@ -480,10 +510,11 @@ class ApplicationVersion(object):
         """
         return "%s-%s" % (self.application, self.version)
     @property
         """
         return "%s-%s" % (self.application, self.version)
     @property
-    def scripts_tag(self):
+    def wizard_tag(self):
         """
         Returns the name of the Git tag for this version.
         """
         """
         Returns the name of the Git tag for this version.
         """
+        # XXX: Scripts specific
         end = str(self.version).partition('-scripts')[2].partition('-')[0]
         return "%s-scripts%s" % (self.pristine_tag, end)
     @property
         end = str(self.version).partition('-scripts')[2].partition('-')[0]
         return "%s-scripts%s" % (self.pristine_tag, end)
     @property
@@ -533,12 +564,9 @@ class ApplicationVersion(object):
         Makes/retrieves a singleton :class:`ApplicationVersion` from
         a``app`` and ``version`` string.
         """
         Makes/retrieves a singleton :class:`ApplicationVersion` from
         a``app`` and ``version`` string.
         """
-        try:
-            # defer to the application for version creation to enforce
-            # singletons
-            return applications()[app].makeVersion(version)
-        except KeyError:
-            raise NoSuchApplication(app)
+        # defer to the application for version creation to enforce
+        # singletons
+        return getApplication(app).makeVersion(version)
 
 def expand_re(val):
     """
 
 def expand_re(val):
     """
@@ -668,81 +696,24 @@ def filename_regex_substitution(key, files, regex):
         return subs
     return h
 
         return subs
     return h
 
-def backup_database(outdir, deployment):
-    """
-    Generic database backup function for MySQL.
+@decorator.decorator
+def throws_database_errors(f, self, *args, **kwargs):
     """
     """
-    # XXX: Change this once deployments support multiple dbs
-    if deployment.application.database == "mysql":
-        return backup_mysql_database(outdir, deployment)
-    else:
-        raise NotImplementedError
-
-def backup_mysql_database(outdir, deployment):
+    Decorator that takes database errors from :mod:`wizard.sql` and
+    converts them into application script failures from
+    :mod:`wizard.app`.  We can't throw application errors directly from
+    :mod:`wizard.sql` because that would result in a cyclic import;
+    also, it's cleaner to distinguish between a database error and an
+    application script failure.
     """
     """
-    Database backups for MySQL using the :command:`mysqldump` utility.
-    """
-    sh = shell.Shell()
-    outfile = os.path.join(outdir, "db.sql")
     try:
     try:
-        sh.call("mysqldump", "--compress", "-r", outfile, *get_mysql_args(deployment.dsn))
-        sh.call("gzip", "--best", outfile)
-    except shell.CallError as e:
-        shutil.rmtree(outdir)
-        raise BackupFailure(e.stderr)
-
-def restore_database(backup_dir, deployment):
-    """
-    Generic database restoration function for MySQL.
-    """
-    # XXX: see backup_database
-    if deployment.application.database == "mysql":
-        return restore_mysql_database(backup_dir, deployment)
-    else:
-        raise NotImplementedError
-
-def restore_mysql_database(backup_dir, deployment):
-    """
-    Database restoration for MySQL by piping SQL commands into :command:`mysql`.
-    """
-    sh = shell.Shell()
-    if not os.path.exists(backup_dir):
-        raise RestoreFailure("Backup %s doesn't exist", backup_dir.rpartition("/")[2])
-    sql = open(os.path.join(backup_dir, "db.sql"), 'w+')
-    sh.call("gunzip", "-c", os.path.join(backup_dir, "db.sql.gz"), stdout=sql)
-    sql.seek(0)
-    sh.call("mysql", *get_mysql_args(deployment.dsn), stdin=sql)
-    sql.close()
-
-def remove_database(deployment):
-    """
-    Generic database removal function.  Actually, not so generic because we
-    go and check if we're on scripts and if we are run a different command.
-    """
-    sh = shell.Shell()
-    if deployment.dsn.host == "sql.mit.edu":
-        try:
-            sh.call("/mit/scripts/sql/bin/drop-database", deployment.dsn.database)
-            return
-        except shell.CallError:
-            pass
-    engine = sqlalchemy.create_engine(deployment.dsn)
-    engine.execute("DROP DATABASE `%s`" % deployment.dsn.database)
-
-def get_mysql_args(dsn):
-    """
-    Extracts arguments that would be passed to the command line mysql utility
-    from a deployment.
-    """
-    args = []
-    if dsn.host:
-        args += ["-h", dsn.host]
-    if dsn.username:
-        args += ["-u", dsn.username]
-    if dsn.password:
-        args += ["-p" + dsn.password]
-    args += [dsn.database]
-    return args
+        return f(self, *args, **kwargs)
+    except sql.BackupDatabaseError:
+        raise BackupFailure(traceback.format_exc())
+    except sql.RestoreDatabaseError:
+        raise RestoreFailure(traceback.format_exc())
+    except sql.RemoveDatabaseError:
+        raise RemoveFailure(traceback.format_exc())
 
 class Error(wizard.Error):
     """Generic error class for this module."""
 
 class Error(wizard.Error):
     """Generic error class for this module."""
@@ -771,6 +742,8 @@ class DeploymentParseError(Error):
     location = None
     def __init__(self, value):
         self.value = value
     location = None
     def __init__(self, value):
         self.value = value
+    def __str__(self):
+        return "Could not parse '%s' from versions store in '%s'" % (self.value, self.location)
 
 class NoSuchApplication(Error):
     """
 
 class NoSuchApplication(Error):
     """
@@ -784,6 +757,8 @@ class NoSuchApplication(Error):
     location = None
     def __init__(self, app):
         self.app = app
     location = None
     def __init__(self, app):
         self.app = app
+    def __str__(self):
+        return "Wizard doesn't know about an application named '%s'." % self.app
 
 class Failure(Error):
     """
 
 class Failure(Error):
     """