diff --git a/eos/db/migration.py b/eos/db/migration.py index b04df38ae..63d3a515a 100644 --- a/eos/db/migration.py +++ b/eos/db/migration.py @@ -3,17 +3,7 @@ import shutil import time import re import os - -def getAppVersion(): - # calculate app version based on upgrade files we have - appVersion = 0 - for fname in os.listdir(os.path.join(os.path.dirname(__file__), "migrations")): - m = re.match("^upgrade(?P\d+)\.py$", fname) - if not m: - continue - index = int(m.group("index")) - appVersion = max(appVersion, index) - return appVersion +import migrations def getVersion(db): cursor = db.execute('PRAGMA user_version') @@ -21,7 +11,9 @@ def getVersion(db): def update(saveddata_engine): dbVersion = getVersion(saveddata_engine) - appVersion = getAppVersion() + appVersion = migrations.appVersion + + print dbVersion, appVersion if dbVersion == appVersion: return @@ -37,10 +29,11 @@ def update(saveddata_engine): shutil.copyfile(config.saveDB, toFile) for version in xrange(dbVersion, appVersion): - module = __import__("eos.db.migrations.upgrade{}".format(version + 1), fromlist=True) - upgrade = getattr(module, "upgrade", False) - if upgrade: - upgrade(saveddata_engine) + + func = migrations.updates[version+1] + if func: + print "applying update",version+1 + func(saveddata_engine) # when all is said and done, set version to current saveddata_engine.execute("PRAGMA user_version = {}".format(appVersion)) diff --git a/eos/db/migrations/__init__.py b/eos/db/migrations/__init__.py index 16c939868..87988fb3d 100644 --- a/eos/db/migrations/__init__.py +++ b/eos/db/migrations/__init__.py @@ -7,3 +7,25 @@ define an upgrade() function with the logic. Please note that there must be as many upgrade files as there are database versions (version 5 would include upgrade files 1-5) """ + +import pkgutil +import re + + +updates = {} +appVersion = 0 + +prefix = __name__ + "." +for importer, modname, ispkg in pkgutil.iter_modules(__path__, prefix): + # loop through python files, extracting update number and function, and + # adding it to a list + modname_tail = modname.rsplit('.', 1)[-1] + module = __import__(modname, fromlist=True) + m = re.match("^upgrade(?P\d+)$", modname_tail) + if not m: + continue + index = int(m.group("index")) + appVersion = max(appVersion, index) + upgrade = getattr(module, "upgrade", False) + if upgrade: + updates[index] = upgrade