]> scripts.mit.edu Git - wizard.git/blobdiff - wizard/app/__init__.py
Convert ad hoc shell calls to singleton instance; fix upgrade bug.
[wizard.git] / wizard / app / __init__.py
index ee9f58deab43bc4990e0fc234033d540bd4f8c45..1d01966363bb45863db242c6f4120fdf7c7c372e 100644 (file)
@@ -27,6 +27,11 @@ import decorator
 import shlex
 import logging
 import shutil
+import sqlalchemy
+import random
+import string
+import urlparse
+import tempfile
 
 import wizard
 from wizard import resolve, scripts, shell, util
@@ -66,7 +71,9 @@ class Application(object):
     parametrized_files = []
     #: Keys that are used in older versions of the application, but
     #: not for the most recent version.
-    deprecated_keys = []
+    deprecated_keys = set()
+    #: Keys that we can simply generate random strings for if they're missing
+    random_keys = set()
     #: Dictionary of variable names to extractor functions.  These functions
     #: take a :class:`wizard.deploy.Deployment` as an argument and return the value of
     #: the variable, or ``None`` if it could not be found.
@@ -84,6 +91,12 @@ class Application(object):
     #: Instance of :class:`wizard.install.ArgSchema` that defines the arguments
     #: this application requires.
     install_schema = None
+    #: Name of the database that this application uses, i.e. ``mysql`` or
+    #: ``postgres``.  If we end up supporting multiple databases for a single
+    #: application, there should also be a value for this in
+    #: :class:`wizard.deploy.Deployment`; the value here is merely the preferred
+    #: value.
+    database = None
     def __init__(self, name):
         self.name = name
         self.versions = {}
@@ -117,7 +130,85 @@ class Application(object):
         result = {}
         for k,extractor in self.extractors.items():
             result[k] = extractor(deployment)
+        # XXX: ugh... we have to do quoting
+        for k in self.random_keys:
+            if result[k] is None:
+                result[k] = "'%s'" % ''.join(random.choice(string.letters + string.digits) for i in xrange(30))
         return result
+    def dsn(self, deployment):
+        """
+        Returns the deployment specific database URL.  Uses the override file
+        in :file:`.scripts` if it exists, and otherwise attempt to extract the
+        variables from the source files.
+
+        Under some cases, the database URL will contain only the database
+        property, and no other values.  This indicates that the actual DSN
+        should be determined from the environment.
+
+        This function might return ``None``.
+
+        .. note::
+
+            We are allowed to batch these two together, because the full precedence
+            chain for determining the database of an application combines these
+            two together.  If this was not the case, we would have to call
+            :meth:`databaseUrlFromOverride` and :meth:`databaseUrlFromExtract` manually.
+        """
+        url = self.dsnFromOverride(deployment)
+        if url:
+            return url
+        return self.dsnFromExtract(deployment)
+    def dsnFromOverride(self, deployment):
+        """
+        Extracts database URL from an explicit dsn override file.
+        """
+        try:
+            return sqlalchemy.engine.url.make_url(open(deployment.dsn_file).read().strip())
+        except IOError:
+            return None
+    def dsnFromExtract(self, deployment):
+        """
+        Extracts database URL from a deployment, and returns them as
+        a :class:`sqlalchemy.engine.url.URL`.  Returns ``None`` if we
+        can't figure it out: i.e. the conventional variables are not defined
+        for this application.
+        """
+        if not self.database:
+            return None
+        vars = self.extract(deployment)
+        names = ("WIZARD_DBSERVER", "WIZARD_DBUSER", "WIZARD_DBPASSWORD", "WIZARD_DBNAME")
+        host, user, password, database = (shlex.split(vars[x])[0] if vars[x] is not None else None for x in names)
+        # XXX: You'd have to put support for an explicit different database
+        # type here
+        return sqlalchemy.engine.url.URL(self.database, username=user, password=password, host=host, database=database)
+    def url(self, deployment):
+        """
+        Returns the deployment specific web URL.  Uses the override file
+        in :file:`.scripts` if it exists, and otherwise attempt to extract
+        the variables from the source files.
+
+        This function might return ``None``, which indicates we couldn't figure
+        it out.
+        """
+        url = self.urlFromOverride(deployment)
+        if url:
+            return url
+        return self.urlFromExtract(deployment)
+    def urlFromOverride(self, deployment):
+        """
+        Extracts URL from explicit url override file.
+        """
+        try:
+            return urlparse.urlparse(open(deployment.url_file).read().strip())
+        except IOError:
+            return None
+    def urlFromExtract(self, deployment):
+        """
+        Extracts URL from a deployment, and returns ``None`` if we can't
+        figure it out.  Default implementation is to fail; we might
+        do something clever with extractable variables in the future.
+        """
+        return None
     def parametrize(self, deployment, ref_deployment):
         """
         Takes a generic source checkout and parametrizes it according to the
@@ -146,9 +237,60 @@ class Application(object):
         default implementation uses :attr:`resolutions`.
         """
         resolved = True
-        sh = shell.Shell()
-        for status in sh.eval("git", "ls-files", "--unmerged").splitlines():
-            file = status.split()[-1]
+        files = set()
+        for status in shell.eval("git", "ls-files", "--unmerged").splitlines():
+            files.add(status.split()[-1])
+        for file in files:
+            # check for newline mismatch
+            # HACK: using git diff to tell if files are binary or not
+            if not len(shell.eval("git", "diff", file).splitlines()) == 1 and util.mixed_newlines(file):
+                # this code only works on Unix
+                def get_newline(filename):
+                    f = open(filename, "U")
+                    # for some reason I need two
+                    s = f.readline()
+                    if s != "" and f.newlines is None:
+                        f.readline()
+                    if not isinstance(f.newlines, str):
+                        raise Exception("Assert: expected newlines to be string, instead was %s in %s" % (repr(f.newlines), file))
+                    return f.newlines
+                def create_reference(id):
+                    f = tempfile.NamedTemporaryFile(prefix="wizardResolve", delete=False)
+                    shell.call("git", "cat-file", "blob", ":%d:%s" % (id, file), stdout=f)
+                    f.close()
+                    return get_newline(f.name), f.name
+                def convert(filename, dest_nl):
+                    contents = open(filename, "U").read().replace("\n", dest_nl)
+                    open(filename, "wb").write(contents)
+                logging.info("Mixed newlines detected in %s", file)
+                common_nl, common_file = create_reference(1)
+                our_nl,    our_file    = create_reference(2)
+                their_nl,  their_file  = create_reference(3)
+                remerge = False
+                if common_nl != their_nl:
+                    # upstream can't keep their newlines straight
+                    logging.info("Converting common file (1) from %s to %s newlines", repr(common_nl), repr(their_nl))
+                    convert(common_file, their_nl)
+                    remerge = True
+                if our_nl != their_nl:
+                    # common case
+                    logging.info("Converting our file (2) from %s to %s newlines", repr(our_nl), repr(their_nl))
+                    convert(our_file, their_nl)
+                    remerge = True
+                if remerge:
+                    logging.info("Remerging %s", file)
+                    with open(file, "wb") as f:
+                        try:
+                            shell.call("git", "merge-file", "--stdout", our_file, common_file, their_file, stdout=f)
+                            logging.info("New merge was clean")
+                            shell.call("git", "add", file)
+                            continue
+                        except shell.CallError:
+                            pass
+                    logging.info("Merge was still unclean")
+                else:
+                    logging.warning("Mixed newlines detected in %s, but no remerge possible", file)
+            # manual resolutions
             if file in self.resolutions:
                 contents = open(file, "r").read()
                 for spec, result in self.resolutions[file]:
@@ -158,7 +300,7 @@ class Application(object):
                         logging.info("Did resolution with spec:\n" + spec)
                 open(file, "w").write(contents)
                 if not resolve.is_conflict(contents):
-                    sh.call("git", "add", file)
+                    shell.call("git", "add", file)
                 else:
                     resolved = False
             else:
@@ -223,6 +365,13 @@ class Application(object):
         should provide an implementation.
         """
         raise NotImplementedError
+    def remove(self, deployment, options):
+        """
+        Run for 'wizard remove' to delete all database and non-local
+        file data.  This assumes that the current working directory is
+        the deployment.  Subclasses should provide an implementation.
+        """
+        raise NotImplementedError
     def detectVersion(self, deployment):
         """
         Checks source files to determine the version manually.  This assumes
@@ -280,14 +429,36 @@ class Application(object):
         # bogus config files in the -scripts versions of installs.  Maybe
         # we should check a hash or something?
         raise NotImplementedError
+    def researchFilter(self, filename, added, deleted):
+        """
+        Allows an application to selectively ignore certain diffstat signatures
+        during research; for example, configuration files will have a very
+        specific set of changes, so ignore them; certain installation files
+        may be removed, etc.  Return ``True`` if a diffstat signature should be
+        ignored,
+        """
+        return False
+    def researchVerbose(self, filename):
+        """
+        Allows an application to exclude certain dirty files from the output
+        report; usually this will just be parametrized files, since those are
+        guaranteed to have changes.  Return ``True`` if a file should only
+        be displayed in verbose mode.
+        """
+        return filename in self.parametrized_files
     @staticmethod
     def make(name):
         """Makes an application, but uses the correct subtype if available."""
         try:
             __import__("wizard.app." + name)
             return getattr(wizard.app, name).Application(name)
-        except ImportError:
-            return Application(name)
+        except ImportError as error:
+            # XXX ugly hack to check if the import error is from the top level
+            # module we care about or a submodule. should be an archetectural change.
+            if error.args[0].split()[-1]==name:
+                return Application(name)
+            else:
+                raise
 
 class ApplicationVersion(object):
     """Represents an abstract notion of a version for an application, where
@@ -496,52 +667,76 @@ def filename_regex_substitution(key, files, regex):
         return subs
     return h
 
-# XXX: rename to show that it's mysql specific
 def backup_database(outdir, deployment):
     """
-    Generic database backup function for MySQL.  Assumes that ``WIZARD_DBNAME``
-    is extractable, and that :func:`wizard.scripts.get_sql_credentials`
-    works.
+    Generic database backup function for MySQL.
+    """
+    # XXX: Change this once deployments support multiple dbs
+    if deployment.application.database == "mysql":
+        return backup_mysql_database(outdir, deployment)
+    else:
+        raise NotImplementedError
+
+def backup_mysql_database(outdir, deployment):
+    """
+    Database backups for MySQL using the :command:`mysqldump` utility.
     """
-    sh = shell.Shell()
     outfile = os.path.join(outdir, "db.sql")
     try:
-        sh.call("mysqldump", "--compress", "-r", outfile, *get_mysql_args(deployment))
-        sh.call("gzip", "--best", outfile)
+        shell.call("mysqldump", "--compress", "-r", outfile, *get_mysql_args(deployment.dsn))
+        shell.call("gzip", "--best", outfile)
     except shell.CallError as e:
-        shutil.rmtree(outdir)
         raise BackupFailure(e.stderr)
 
 def restore_database(backup_dir, deployment):
     """
-    Generic database restoration function for MySQL.  See :func:`backup_database`
-    for the assumptions that we make.
+    Generic database restoration function for MySQL.
+    """
+    # XXX: see backup_database
+    if deployment.application.database == "mysql":
+        return restore_mysql_database(backup_dir, deployment)
+    else:
+        raise NotImplementedError
+
+def restore_mysql_database(backup_dir, deployment):
+    """
+    Database restoration for MySQL by piping SQL commands into :command:`mysql`.
     """
-    sh = shell.Shell()
     if not os.path.exists(backup_dir):
         raise RestoreFailure("Backup %s doesn't exist", backup_dir.rpartition("/")[2])
     sql = open(os.path.join(backup_dir, "db.sql"), 'w+')
-    sh.call("gunzip", "-c", os.path.join(backup_dir, "db.sql.gz"), stdout=sql)
+    shell.call("gunzip", "-c", os.path.join(backup_dir, "db.sql.gz"), stdout=sql)
     sql.seek(0)
-    sh.call("mysql", *get_mysql_args(deployment), stdin=sql)
+    shell.call("mysql", *get_mysql_args(deployment.dsn), stdin=sql)
     sql.close()
 
-def get_mysql_args(d):
+def remove_database(deployment):
+    """
+    Generic database removal function.  Actually, not so generic because we
+    go and check if we're on scripts and if we are run a different command.
+    """
+    if deployment.dsn.host == "sql.mit.edu":
+        try:
+            shell.call("/mit/scripts/sql/bin/drop-database", deployment.dsn.database)
+            return
+        except shell.CallError:
+            pass
+    engine = sqlalchemy.create_engine(deployment.dsn)
+    engine.execute("DROP DATABASE `%s`" % deployment.dsn.database)
+
+def get_mysql_args(dsn):
     """
     Extracts arguments that would be passed to the command line mysql utility
     from a deployment.
     """
-    # XXX: add support for getting these out of options
-    vars = d.extract()
-    if 'WIZARD_DBNAME' not in vars:
-        raise BackupFailure("Could not determine database name")
-    triplet = scripts.get_sql_credentials(vars)
     args = []
-    if triplet is not None:
-        server, user, password = triplet
-        args += ["-h", server, "-u", user, "-p" + password]
-    name = shlex.split(vars['WIZARD_DBNAME'])[0]
-    args.append(name)
+    if dsn.host:
+        args += ["-h", dsn.host]
+    if dsn.username:
+        args += ["-u", dsn.username]
+    if dsn.password:
+        args += ["-p" + dsn.password]
+    args += [dsn.database]
     return args
 
 class Error(wizard.Error):
@@ -596,7 +791,11 @@ class Failure(Error):
 class InstallFailure(Error):
     """Installation failed for unknown reason."""
     def __str__(self):
-        return """Installation failed for unknown reason."""
+        return """
+
+ERROR: Installation failed for unknown reason.  You can
+retry the installation by appending --retry to the installation
+command."""
 
 class RecoverableInstallFailure(InstallFailure):
     """
@@ -609,7 +808,12 @@ class RecoverableInstallFailure(InstallFailure):
     def __init__(self, errors):
         self.errors = errors
     def __str__(self):
-        return """Installation failed due to the following errors: %s""" % ", ".join(self.errors)
+        return """
+
+ERROR: Installation failed due to the following errors:  %s
+
+You can retry the installation by appending --retry to the
+installation command.""" % ", ".join(self.errors)
 
 class UpgradeFailure(Failure):
     """Upgrade script failed."""
@@ -657,3 +861,16 @@ class RestoreFailure(Failure):
 ERROR: Restore script failed, details:
 
 %s""" % self.details
+
+class RemoveFailure(Failure):
+    """Remove script failed."""
+    #: String details of failure
+    details = None
+    def __init__(self, details):
+        self.details = details
+    def __str__(self):
+        return """
+
+ERROR: Remove script failed, details:
+
+%s""" % self.details