![Dan Helfman](/assets/img/avatar_default.png)
regardless of commit flag. Without this, the file upload progress bar breaks because it gets a stale File object out of the cache. Eventually, it would be better if cache sets were only done condittionally based on commit flag, and also whenever Database.commit() is called.
297 lines
8.4 KiB
Python
297 lines
8.4 KiB
Python
import re
|
|
import os
|
|
import sha
|
|
import psycopg2 as psycopg
|
|
from psycopg2.pool import PersistentConnectionPool
|
|
import random
|
|
from model.Persistent import Persistent
|
|
|
|
|
|
class Database( object ):
|
|
ID_BITS = 128 # number of bits within an id
|
|
ID_DIGITS = "0123456789abcdefghijklmnopqrstuvwxyz"
|
|
|
|
def __init__( self, connection = None, cache = None ):
|
|
"""
|
|
Create a new database and return it.
|
|
|
|
@type connection: existing connection object with cursor()/close()/commit() methods, or NoneType
|
|
@param connection: database connection to use (optional, defaults to making a connection pool)
|
|
@type cache: cmemcache.Client or something with a similar API, or NoneType
|
|
@param cache: existing memory cache to use (optional, defaults to making a cache)
|
|
@rtype: Database
|
|
@return: newly constructed Database
|
|
"""
|
|
# This tells PostgreSQL to give us timestamps in UTC. I'd use "set timezone" instead, but that
|
|
# makes SQLite angry.
|
|
os.putenv( "PGTZ", "UTC" )
|
|
|
|
if connection:
|
|
self.__connection = connection
|
|
self.__pool = None
|
|
else:
|
|
self.__connection = None
|
|
self.__pool = PersistentConnectionPool(
|
|
1, # minimum connections
|
|
50, # maximum connections
|
|
"dbname=luminotes user=luminotes password=%s" % os.getenv( "PGPASSWORD", "dev" ),
|
|
)
|
|
|
|
self.__cache = cache
|
|
if not cache:
|
|
try:
|
|
import cmemcache
|
|
self.__cache = cmemcache.Client( [ "127.0.0.1:11211" ], debug = 0 )
|
|
print "using memcached"
|
|
except ImportError:
|
|
pass
|
|
|
|
def __get_connection( self ):
|
|
if self.__connection:
|
|
return self.__connection
|
|
else:
|
|
return self.__pool.getconn()
|
|
|
|
def save( self, obj, commit = True ):
|
|
"""
|
|
Save the given object to the database.
|
|
|
|
@type obj: Persistent
|
|
@param obj: object to save
|
|
@type commit: bool
|
|
@param commit: True to automatically commit after the save
|
|
"""
|
|
connection = self.__get_connection()
|
|
cursor = connection.cursor()
|
|
|
|
cursor.execute( obj.sql_exists() )
|
|
if cursor.fetchone():
|
|
cursor.execute( obj.sql_update() )
|
|
else:
|
|
cursor.execute( obj.sql_create() )
|
|
|
|
if commit:
|
|
connection.commit()
|
|
|
|
# FIXME: we shouldn't touch the cache unless there's actually a commit.
|
|
# the problem is that in self.commit() below, we don't know which objects
|
|
# to actually save into the cache
|
|
if self.__cache:
|
|
self.__cache.set( obj.cache_key, obj )
|
|
|
|
def commit( self ):
|
|
self.__get_connection().commit()
|
|
|
|
def load( self, Object_type, object_id, revision = None ):
|
|
"""
|
|
Load the object corresponding to the given object id from the database and return it, or None if
|
|
the object_id is unknown. If a revision is provided, a specific revision of the object will be
|
|
loaded.
|
|
|
|
@type Object_type: type
|
|
@param Object_type: class of the object to load
|
|
@type object_id: unicode
|
|
@param object_id: id of the object to load
|
|
@type revision: int or NoneType
|
|
@param revision: revision of the object to load (optional)
|
|
@rtype: Object_type or NoneType
|
|
@return: loaded object, or None if no match
|
|
"""
|
|
if revision is None and self.__cache: # don't bother caching old revisions
|
|
obj = self.__cache.get( Persistent.make_cache_key( Object_type, object_id ) )
|
|
if obj:
|
|
return obj
|
|
|
|
obj = self.select_one( Object_type, Object_type.sql_load( object_id, revision ) )
|
|
if obj and revision is None and self.__cache:
|
|
self.__cache.set( obj.cache_key, obj )
|
|
|
|
return obj
|
|
|
|
def select_one( self, Object_type, sql_command, use_cache = False ):
|
|
"""
|
|
Execute the given sql_command and return its results in the form of an object of Object_type,
|
|
or None if there was no match.
|
|
|
|
@type Object_type: type
|
|
@param Object_type: class of the object to load
|
|
@type sql_command: unicode
|
|
@param sql_command: SQL command to execute
|
|
@type use_cache: bool
|
|
@param use_cache: whether to look for and store objects in the cache
|
|
@rtype: Object_type or NoneType
|
|
@return: loaded object, or None if no match
|
|
"""
|
|
if use_cache and self.__cache:
|
|
cache_key = sha.new( sql_command ).hexdigest()
|
|
obj = self.__cache.get( cache_key )
|
|
if obj:
|
|
return obj
|
|
|
|
connection = self.__get_connection()
|
|
cursor = connection.cursor()
|
|
|
|
cursor.execute( sql_command )
|
|
|
|
row = self.__row_to_unicode( cursor.fetchone() )
|
|
if not row:
|
|
return None
|
|
|
|
if Object_type in ( tuple, list ):
|
|
obj = Object_type( row )
|
|
else:
|
|
obj = Object_type( *row )
|
|
|
|
if obj and use_cache and self.__cache:
|
|
self.__cache.set( cache_key, obj )
|
|
|
|
return obj
|
|
|
|
def select_many( self, Object_type, sql_command ):
|
|
"""
|
|
Execute the given sql_command and return its results in the form of a list of objects of
|
|
Object_type.
|
|
|
|
@type Object_type: type
|
|
@param Object_type: class of the object to load
|
|
@type sql_command: unicode
|
|
@param sql_command: SQL command to execute
|
|
@rtype: list of Object_type
|
|
@return: loaded objects
|
|
"""
|
|
connection = self.__get_connection()
|
|
cursor = connection.cursor()
|
|
|
|
cursor.execute( sql_command )
|
|
|
|
objects = []
|
|
row = self.__row_to_unicode( cursor.fetchone() )
|
|
|
|
while row:
|
|
if Object_type in ( tuple, list ):
|
|
obj = Object_type( row )
|
|
else:
|
|
obj = Object_type( *row )
|
|
|
|
objects.append( obj )
|
|
row = self.__row_to_unicode( cursor.fetchone() )
|
|
|
|
return objects
|
|
|
|
def __row_to_unicode( self, row ):
|
|
if row is None:
|
|
return None
|
|
|
|
return [ isinstance( item, str ) and unicode( item, encoding = "utf8" ) or item for item in row ]
|
|
|
|
def execute( self, sql_command, commit = True ):
|
|
"""
|
|
Execute the given sql_command.
|
|
|
|
@type sql_command: unicode
|
|
@param sql_command: SQL command to execute
|
|
@type commit: bool
|
|
@param commit: True to automatically commit after the command
|
|
"""
|
|
connection = self.__get_connection()
|
|
cursor = connection.cursor()
|
|
|
|
cursor.execute( sql_command )
|
|
|
|
if commit:
|
|
connection.commit()
|
|
|
|
def uncache_command( self, sql_command ):
|
|
if not self.__cache: return
|
|
|
|
cache_key = sha.new( sql_command ).hexdigest()
|
|
self.__cache.delete( cache_key )
|
|
|
|
@staticmethod
|
|
def generate_id():
|
|
int_id = random.getrandbits( Database.ID_BITS )
|
|
|
|
base = len( Database.ID_DIGITS )
|
|
digits = []
|
|
|
|
while True:
|
|
index = int_id % base
|
|
digits.insert( 0, Database.ID_DIGITS[ index ] )
|
|
int_id = int_id / base
|
|
if int_id == 0:
|
|
break
|
|
|
|
return "".join( digits )
|
|
|
|
def next_id( self, Object_type, commit = True ):
|
|
"""
|
|
Generate the next available object id and return it.
|
|
|
|
@type Object_type: type
|
|
@param Object_type: class of the object that the id is for
|
|
@type commit: bool
|
|
@param commit: True to automatically commit after storing the next id
|
|
"""
|
|
connection = self.__get_connection()
|
|
cursor = connection.cursor()
|
|
|
|
# generate a random id, but on the off-chance that it collides with something else already in
|
|
# the database, try again
|
|
next_id = Database.generate_id()
|
|
cursor.execute( Object_type.sql_id_exists( next_id ) )
|
|
|
|
while cursor.fetchone() is not None:
|
|
next_id = Database.generate_id()
|
|
cursor.execute( Object_type.sql_id_exists( next_id ) )
|
|
|
|
# save a new object with the next_id to the database
|
|
obj = Object_type( next_id )
|
|
cursor.execute( obj.sql_create() )
|
|
|
|
if commit:
|
|
connection.commit()
|
|
|
|
return next_id
|
|
|
|
def close( self ):
|
|
"""
|
|
Shutdown the database.
|
|
"""
|
|
if self.__connection:
|
|
self.__connection.close()
|
|
|
|
if self.__pool:
|
|
self.__pool.closeall()
|
|
|
|
|
|
class Valid_id( object ):
|
|
"""
|
|
Validator for an object id.
|
|
"""
|
|
ID_PATTERN = re.compile( "^[%s]+$" % Database.ID_DIGITS )
|
|
|
|
def __init__( self, none_okay = False ):
|
|
self.__none_okay = none_okay
|
|
|
|
def __call__( self, value ):
|
|
if self.__none_okay and value in ( None, "None", "" ): return None
|
|
if self.ID_PATTERN.search( value ): return str( value )
|
|
|
|
raise ValueError()
|
|
|
|
|
|
class Valid_revision( object ):
|
|
"""
|
|
Validator for an object revision timestamp.
|
|
"""
|
|
REVISION_PATTERN = re.compile( "^\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d+[+-]\d\d(:)?\d\d$" )
|
|
|
|
def __init__( self, none_okay = False ):
|
|
self.__none_okay = none_okay
|
|
|
|
def __call__( self, value ):
|
|
if self.__none_okay and value in ( None, "None", "" ): return None
|
|
if self.REVISION_PATTERN.search( value ): return str( value )
|
|
|
|
raise ValueError()
|