diff --git a/controller/Database.py b/controller/Database.py index 8472e2b..0fb3a33 100644 --- a/controller/Database.py +++ b/controller/Database.py @@ -142,6 +142,57 @@ class Database( object ): return obj + @async + def reload( self, object_id, callback = None ): + """ + Load and immediately save the object corresponding to the given object id or database key. This + is useful when the object has a __setstate__() method that performs some sort of schema + evolution operation. + + @type object_id: unicode + @param object_id: id or key of the object to reload + @type callback: generator or NoneType + @param callback: generator to wakeup when the save is complete (optional) + """ + self.__reload( object_id ) + yield callback + + def __reload( self, object_id, revision = None ): + object_id = unicode( object_id ).encode( "utf8" ) + + # grab the object for the given id from the database + buffer = StringIO() + unpickler = cPickle.Unpickler( buffer ) + unpickler.persistent_load = self.__load + + pickled = self.__db.get( object_id ) + if pickled is None or pickled == "": + return + + buffer.write( pickled ) + buffer.flush() + buffer.seek( 0 ) + + # unpickle the object. this should trigger __setstate__() if the object has such a method + obj = unpickler.load() + if obj is None: + print "error unpickling %s: %s" % ( object_id, pickled ) + return + self.__cache[ object_id ] = obj + + # set the pickler up to save persistent ids for every object except for the obj passed in, which + # will be pickled normally + buffer = StringIO() + pickler = cPickle.Pickler( buffer, protocol = -1 ) + pickler.persistent_id = lambda o: self.__persistent_id( o, skip = obj ) + + # pickle the object and write it to the database under its id key + pickler.dump( obj ) + pickled = buffer.getvalue() + self.__db.put( object_id, pickled ) + + self.__db.sync() + @staticmethod def generate_id(): int_id = random.getrandbits( Database.ID_BITS ) diff --git a/controller/test/Test_database.py b/controller/test/Test_database.py index ae7b80d..c26421d 100644 --- a/controller/test/Test_database.py +++ b/controller/test/Test_database.py @@ -168,6 +168,73 @@ class Test_database( object ): self.scheduler.add( g ) self.scheduler.wait_for( g ) + def test_reload( self ): + def gen(): + basic_obj = Some_object( object_id = "5", value = 1 ) + + self.database.save( basic_obj, self.scheduler.thread ) + yield Scheduler.SLEEP + if self.clear_cache: self.database.clear_cache() + + def setstate( self, state ): + state[ "_Some_object__value" ] = 55 + self.__dict__.update( state ) + + Some_object.__setstate__ = setstate + + self.database.reload( basic_obj.object_id, self.scheduler.thread ) + yield Scheduler.SLEEP + delattr( Some_object, "__setstate__" ) + if self.clear_cache: self.database.clear_cache() + + self.database.load( basic_obj.object_id, self.scheduler.thread ) + obj = ( yield Scheduler.SLEEP ) + + assert obj.object_id == basic_obj.object_id + assert obj.value == 55 + + g = gen() + self.scheduler.add( g ) + self.scheduler.wait_for( g ) + + def test_reload_revision( self ): + def gen(): + basic_obj = Some_object( object_id = "5", value = 1 ) + original_revision = basic_obj.revision + original_revision_id = basic_obj.revision_id() + + self.database.save( basic_obj, self.scheduler.thread ) + yield Scheduler.SLEEP + if self.clear_cache: self.database.clear_cache() + + basic_obj.value = 2 + + self.database.save( basic_obj, self.scheduler.thread ) + yield Scheduler.SLEEP + if self.clear_cache: self.database.clear_cache() + + def setstate( self, state ): + state[ "_Some_object__value" ] = 55 + self.__dict__.update( state ) + + Some_object.__setstate__ = setstate + + self.database.reload( original_revision_id, self.scheduler.thread ) + yield Scheduler.SLEEP + delattr( Some_object, "__setstate__" ) + if self.clear_cache: self.database.clear_cache() + + self.database.load( basic_obj.object_id, self.scheduler.thread, revision = original_revision ) + obj = ( yield Scheduler.SLEEP ) + + assert obj.object_id == basic_obj.object_id + assert obj.revision == original_revision + assert obj.value == 55 + + g = gen() + self.scheduler.add( g ) + self.scheduler.wait_for( g ) + def test_next_id( self ): def gen(): self.database.next_id( self.scheduler.thread )