diff --git a/controller/Schema_upgrader.py b/controller/Schema_upgrader.py index 7ab2192..fce39de 100644 --- a/controller/Schema_upgrader.py +++ b/controller/Schema_upgrader.py @@ -45,17 +45,8 @@ class Schema_upgrader: @param to_version: the desired version to upgrade to, as a string """ to_version = self.version_string_to_tuple( to_version ) - - try: - from_version = self.__database.select_one( tuple, "select * from schema_version;" ); - # if there's no schema version table, assume the from_version is 1.5.4, which was the last - # version not to include a schema_version table - except: - self.__database.rollback() - from_version = ( 1, 5, 4 ) - self.__database.execute( "create table schema_version ( major numeric, minor numeric, release numeric );", commit = False ); - self.__database.execute( "insert into schema_version values ( %s, %s, %s );" % from_version, commit = False ); - self.__database.commit() + from_version = self.schema_version( self.__database ) + self.__database.commit() # if the database schema version is already equal to to_version, there's nothing to do if to_version == from_version: @@ -96,6 +87,22 @@ class Schema_upgrader: self.__database.commit() print "successfully upgraded database schema" + @staticmethod + def schema_version( database, default_version = None ): + try: + schema_version = database.select_one( tuple, "select * from schema_version;" ); + # if there's no schema version table, then use the default version given. if there's no default + # version, then assume the from_version is 1.5.4, which was the last version not to include a + # schema_version table + except: + database.rollback() + schema_version = default_version or ( 1, 5, 4 ) + + database.execute( "create table schema_version ( major numeric, minor numeric, release numeric );", commit = False ); + database.execute( "insert into schema_version values ( %s, %s, %s );" % schema_version, commit = False ); + + return schema_version + def apply_schema_delta( self, version, filename ): """ Upgrade the database from its current version to a given version, applying only the named diff --git a/tools/initdb.py b/tools/initdb.py index 3b32c95..7123887 100644 --- a/tools/initdb.py +++ b/tools/initdb.py @@ -4,9 +4,11 @@ import os import os.path import sys from controller.Database import Database +from controller.Schema_upgrader import Schema_upgrader from model.Notebook import Notebook from model.Note import Note from model.User import User +from config.Version import VERSION class Initializer( object ): @@ -45,6 +47,9 @@ class Initializer( object ): if desktop is True: self.create_desktop_user() + version = Schema_upgrader.version_string_to_tuple( VERSION ) + Schema_upgrader.schema_version( database, default_version = version ) + self.database.commit() def create_main_notebook( self ): @@ -96,6 +101,17 @@ class Initializer( object ): users.create_user( u"desktopuser" ) + def set_schema_version( self ): + try: + from_version = self.__database.select_one( tuple, "select * from schema_version;" ); + # if there's no schema version table, set the schema to the current version + except: + self.__database.rollback() + from_version = ( 1, 5, 4 ) + self.__database.execute( "create table schema_version ( major numeric, minor numeric, release numeric );", commit = False ); + self.__database.execute( "insert into schema_version values ( %s, %s, %s );" % from_version, commit = False ); + self.__database.commit() + def main( args = None ): nuke = False