import re import os import os.path import sys import sha import cherrypy import random import threading from model.Persistent import Persistent from model.Notebook import Notebook class Connection_wrapper( object ): def __init__( self, connection ): self.connection = connection self.pending_saves = [] def __getattr__( self, name ): return getattr( self.connection, name ) def synchronized( method ): def lock( self, *args, **kwargs ): if self.lock: self.lock.acquire() try: return method( self, *args, **kwargs ) finally: if self.lock: self.lock.release() return lock class Database( object ): ID_BITS = 128 # number of bits within an id ID_DIGITS = "0123456789abcdefghijklmnopqrstuvwxyz" # caching Notebooks causes problems because different users have different read_write/owner values CLASSES_NOT_TO_CACHE = ( Notebook, ) def __init__( self, connection = None, cache = None, host = None, ssl_mode = None, data_dir = None ): """ Create a new database and return it. @type connection: existing connection object with cursor()/close()/commit() methods, or NoneType @param connection: database connection to use (optional, defaults to making a connection pool) @type cache: cmemcache.Client or something with a similar API, or NoneType @param cache: existing memory cache to use (optional, defaults to making a cache) @type host: unicode or NoneType @param host: hostname of PostgreSQL database, or None to use a local SQLite database @type ssl_mode: unicode or NoneType @param ssl_mode: SSL mode for the database connection, one of "disallow", "allow", "prefer", or "require". ignored if host is None @type data_dir: unicode or NoneType @param data_dir: directory in which to store data (defaults to a reasonable directory). ignored if host is not None @rtype: Database @return: newly constructed Database """ # This tells PostgreSQL to give us timestamps in UTC. I'd use "set timezone" instead, but that # makes SQLite angry. os.putenv( "PGTZ", "UTC" ) if host is None: from pysqlite2 import dbapi2 as sqlite from datetime import datetime from pytz import utc TIMESTAMP_PATTERN = re.compile( "^(\d\d\d\d)-(\d\d)-(\d\d) (\d\d):(\d\d):(\d\d).(\d+)(?:\+\d\d:\d\d$)?" ) MICROSECONDS_PER_SECOND = 1000000 def convert_timestamp( value ): ( year, month, day, hours, minutes, seconds, fractional_seconds ) = \ TIMESTAMP_PATTERN.search( value ).groups( 0 ) # convert fractional seconds (with an arbitrary number of decimal places) to microseconds microseconds = int( fractional_seconds ) while microseconds > MICROSECONDS_PER_SECOND: fractional_seconds = fractional_seconds[ : -1 ] microseconds = int( fractional_seconds or 0 ) # ignore time zone in timestamp and assume UTC return datetime( int( year ), int( month ), int( day ), int( hours ), int( minutes ), int( seconds ), int( microseconds ), utc, ) sqlite.register_converter( "boolean", lambda value: value in ( "t", "True", "true" ) and True or False ) sqlite.register_converter( "timestamp", convert_timestamp ) if connection: self.__connection = connection else: if data_dir is None: if sys.platform.startswith( "win" ): data_dir = os.path.join( os.environ.get( "APPDATA" ), "Luminotes" ) else: data_dir = os.path.join( os.environ.get( "HOME", "" ), ".luminotes" ) data_filename = os.path.join( data_dir, "luminotes.db" ) # if the user doesn't yet have their own luminotes.db file, make them an initial copy if os.path.exists( "luminotes.db" ): if not os.path.exists( data_dir ): import stat os.makedirs( data_dir, stat.S_IXUSR | stat.S_IRUSR | stat.S_IWUSR ) if not os.path.exists( data_filename ): import shutil shutil.copyfile( "luminotes.db", data_filename ) self.__connection = \ Connection_wrapper( sqlite.connect( data_filename, detect_types = sqlite.PARSE_DECLTYPES, check_same_thread = False ) ) self.__pool = None self.__backend = Persistent.SQLITE_BACKEND self.lock = threading.Lock() # multiple simultaneous client threads make SQLite angry else: import psycopg2 as psycopg from psycopg2.pool import PersistentConnectionPool # forcibly replace psycopg's connect() function with another function that returns the psycopg # connection wrapped in a class with a pending_saves member, used in save() and commit() below original_connect = psycopg.connect def connect( *args, **kwargs ): return Connection_wrapper( original_connect( *args, **kwargs ) ) psycopg.connect = connect if connection: self.__connection = connection self.__pool = None else: self.__connection = None self.__pool = PersistentConnectionPool( 1, # minimum connections 50, # maximum connections "host=%s sslmode=%s dbname=luminotes user=luminotes password=%s" % ( host or "localhost", ssl_mode or "allow", os.getenv( "PGPASSWORD", "dev" ) ), ) self.__backend = Persistent.POSTGRESQL_BACKEND self.lock = None # PostgreSQL does its own synchronization self.__cache = cache try: if self.__cache is None: import cmemcache print "using memcached" except ImportError: return None def get_connection( self ): if self.__connection: return self.__connection else: return self.__pool.getconn() def __get_cache_connection( self ): if self.__cache is not None: return self.__cache try: import cmemcache return cmemcache.Client( [ "127.0.0.1:11211" ], debug = 0 ) except ImportError: return None def unescape( self, sql_command ): """ For backends that don't treat backslashes specially, un-double all backslashes in the given sql_command. """ if self.__backend == Persistent.SQLITE_BACKEND: return sql_command.replace( "\\\\", "\\" ) return sql_command @synchronized def save( self, obj, commit = True ): """ Save the given object to the database. @type obj: Persistent @param obj: object to save @type commit: bool @param commit: True to automatically commit after the save """ connection = self.get_connection() cursor = connection.cursor() cursor.execute( self.unescape( obj.sql_exists() ) ) if cursor.fetchone(): cursor.execute( self.unescape( obj.sql_update() ) ) else: cursor.execute( self.unescape( obj.sql_create() ) ) if isinstance( obj, self.CLASSES_NOT_TO_CACHE ): cache = None else: cache = self.__get_cache_connection() if commit: connection.commit() if cache: cache.set( obj.cache_key, obj ) elif cache: # no commit yet, so don't touch the cache connection.pending_saves.append( obj ) @synchronized def commit( self ): connection = self.get_connection() connection.commit() # save any pending saves to the cache cache = self.__get_cache_connection() if cache: for obj in connection.pending_saves: cache.set( obj.cache_key, obj ) connection.pending_saves = [] @synchronized def rollback( self ): connection = self.get_connection() connection.rollback() def load( self, Object_type, object_id, revision = None ): """ Load the object corresponding to the given object id from the database and return it, or None if the object_id is unknown. If a revision is provided, a specific revision of the object will be loaded. @type Object_type: type @param Object_type: class of the object to load @type object_id: unicode @param object_id: id of the object to load @type revision: int or NoneType @param revision: revision of the object to load (optional) @rtype: Object_type or NoneType @return: loaded object, or None if no match """ if revision or Object_type in self.CLASSES_NOT_TO_CACHE: cache = None else: cache = self.__get_cache_connection() if cache: # don't bother caching old revisions obj = cache.get( Persistent.make_cache_key( Object_type, object_id ) ) if obj: return obj obj = self.select_one( Object_type, Object_type.sql_load( object_id, revision ) ) if obj and cache: cache.set( obj.cache_key, obj ) return obj @synchronized def select_one( self, Object_type, sql_command, use_cache = False ): """ Execute the given sql_command and return its results in the form of an object of Object_type, or None if there was no match. @type Object_type: type @param Object_type: class of the object to load @type sql_command: unicode @param sql_command: SQL command to execute @type use_cache: bool @param use_cache: whether to look for and store objects in the cache @rtype: Object_type or NoneType @return: loaded object, or None if no match """ if not use_cache or Object_type in self.CLASSES_NOT_TO_CACHE: cache = None else: cache = self.__get_cache_connection() if cache: cache_key = sha.new( sql_command ).hexdigest() obj = cache.get( cache_key ) if obj: return obj connection = self.get_connection() cursor = connection.cursor() cursor.execute( self.unescape( sql_command ) ) row = self.__row_to_unicode( cursor.fetchone() ) if not row: return None if Object_type in ( tuple, list ): obj = Object_type( row ) else: obj = Object_type( *row ) if obj and cache: cache.set( cache_key, obj ) return obj @synchronized def select_many( self, Object_type, sql_command ): """ Execute the given sql_command and return its results in the form of a list of objects of Object_type. @type Object_type: type @param Object_type: class of the object to load @type sql_command: unicode @param sql_command: SQL command to execute @rtype: list of Object_type @return: loaded objects """ connection = self.get_connection() cursor = connection.cursor() cursor.execute( self.unescape( sql_command ) ) objects = [] row = self.__row_to_unicode( cursor.fetchone() ) while row: if Object_type in ( tuple, list ): obj = Object_type( row ) else: obj = Object_type( *row ) objects.append( obj ) row = self.__row_to_unicode( cursor.fetchone() ) return objects def __row_to_unicode( self, row ): if row is None: return None return [ isinstance( item, str ) and unicode( item, encoding = "utf8" ) or item for item in row ] @synchronized def execute( self, sql_command, commit = True ): """ Execute the given sql_command. @type sql_command: unicode @param sql_command: SQL command to execute @type commit: bool @param commit: True to automatically commit after the command """ connection = self.get_connection() cursor = connection.cursor() cursor.execute( self.unescape( sql_command ) ) if commit: connection.commit() @synchronized def execute_script( self, sql_commands, commit = True ): """ Execute the given sql_commands. @type sql_command: unicode @param sql_command: multiple SQL commands to execute @type commit: bool @param commit: True to automatically commit after the command """ connection = self.get_connection() cursor = connection.cursor() if self.__backend == Persistent.SQLITE_BACKEND: cursor.executescript( sql_commands ) else: cursor.execute( self.unescape( sql_commands ) ) if commit: connection.commit() def uncache_command( self, sql_command ): cache = self.__get_cache_connection() if not cache: return cache_key = sha.new( sql_command ).hexdigest() cache.delete( cache_key ) def uncache( self, obj ): cache = self.__get_cache_connection() if not cache: return cache.delete( obj.cache_key ) def uncache_many( self, Object_type, obj_ids ): cache = self.__get_cache_connection() if not cache: return for obj_id in obj_ids: cache.delete( Persistent.make_cache_key( Object_type, obj_id ) ) @staticmethod def generate_id(): int_id = random.getrandbits( Database.ID_BITS ) base = len( Database.ID_DIGITS ) digits = [] while True: index = int_id % base digits.insert( 0, Database.ID_DIGITS[ index ] ) int_id = int_id / base if int_id == 0: break return "".join( digits ) @synchronized def next_id( self, Object_type, commit = True ): """ Generate the next available object id and return it. @type Object_type: type @param Object_type: class of the object that the id is for @type commit: bool @param commit: True to automatically commit after storing the next id """ connection = self.get_connection() cursor = connection.cursor() # generate a random id, but on the off-chance that it collides with something else already in # the database, try again next_id = Database.generate_id() cursor.execute( self.unescape( Object_type.sql_id_exists( next_id ) ) ) while cursor.fetchone() is not None: next_id = Database.generate_id() cursor.execute( self.unescape( Object_type.sql_id_exists( next_id ) ) ) # save a new object with the next_id to the database obj = Object_type( next_id ) cursor.execute( self.unescape( obj.sql_create() ) ) if commit: connection.commit() return next_id @synchronized def close( self ): """ Shutdown the database. """ if self.__connection: self.__connection.close() if self.__pool: self.__pool.closeall() backend = property( lambda self: self.__backend ) def end_transaction( function ): """ Decorator that prevents transaction leaks by rolling back any transactions left open when the wrapped function returns or raises. """ def rollback( *args, **kwargs ): try: return function( *args, **kwargs ) finally: cherrypy.root.database.rollback() return rollback class Valid_id( object ): """ Validator for an object id. """ ID_PATTERN = re.compile( "^[%s]+$" % Database.ID_DIGITS ) def __init__( self, none_okay = False ): self.__none_okay = none_okay def __call__( self, value ): if value in ( None, "None", "" ): if self.__none_okay: return None else: raise ValueError() if self.ID_PATTERN.search( value ): return str( value ) raise ValueError() class Valid_revision( object ): """ Validator for an object revision timestamp. """ REVISION_PATTERN = re.compile( "^\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d+[+-]\d\d(:)?\d\d$" ) def __init__( self, none_okay = False ): self.__none_okay = none_okay def __call__( self, value ): if self.__none_okay and value in ( None, "None", "" ): return None if self.REVISION_PATTERN.search( value ): return str( value ) raise ValueError()