diff --git a/controller/Database.py b/controller/Database.py index 9b22f59..262fb7e 100644 --- a/controller/Database.py +++ b/controller/Database.py @@ -7,6 +7,15 @@ import random from model.Persistent import Persistent +class Connection_wrapper( object ): + def __init__( self, connection ): + self.connection = connection + self.pending_saves = [] + + def __getattr__( self, name ): + return getattr( self.connection, name ) + + class Database( object ): ID_BITS = 128 # number of bits within an id ID_DIGITS = "0123456789abcdefghijklmnopqrstuvwxyz" @@ -26,6 +35,15 @@ class Database( object ): # makes SQLite angry. os.putenv( "PGTZ", "UTC" ) + # forcibly replace psycopg's connect() function with another function that returns the psycopg + # connection wrapped in a class with a pending_saves member, used in save() and commit() below + original_connect = psycopg.connect + + def connect( *args, **kwargs ): + return Connection_wrapper( original_connect( *args, **kwargs ) ) + + psycopg.connect = connect + if connection: self.__connection = connection self.__pool = None @@ -72,15 +90,21 @@ class Database( object ): if commit: connection.commit() - - # FIXME: we shouldn't touch the cache unless there's actually a commit. - # the problem is that in self.commit() below, we don't know which objects - # to actually save into the cache - if self.__cache: - self.__cache.set( obj.cache_key, obj ) + if self.__cache: + self.__cache.set( obj.cache_key, obj ) + else: + # no commit yet, so don't touch the cache + connection.pending_saves.append( obj ) def commit( self ): - self.__get_connection().commit() + connection = self.__get_connection() + connection.commit() + + # save any pending saves to the cache + for obj in connection.pending_saves: + self.__cache.set( obj.cache_key, obj ) + + connection.pending_saves = [] def load( self, Object_type, object_id, revision = None ): """