witten
/
luminotes
Archived
1
0
Fork 0
This repository has been archived on 2023-12-16. You can view files and clone it, but cannot push or open issues or pull requests.
luminotes/controller/Database.py

520 lines
15 KiB
Python

import re
import os
import os.path
import sys
import sha
import cherrypy
import random
import threading
from model.Persistent import Persistent
from model.Notebook import Notebook
class Connection_wrapper( object ):
def __init__( self, connection ):
self.connection = connection
self.pending_saves = []
def __getattr__( self, name ):
return getattr( self.connection, name )
def synchronized( method ):
def lock( self, *args, **kwargs ):
if self.lock:
self.lock.acquire()
try:
return method( self, *args, **kwargs )
finally:
if self.lock:
self.lock.release()
return lock
class Database( object ):
ID_BITS = 128 # number of bits within an id
ID_DIGITS = "0123456789abcdefghijklmnopqrstuvwxyz"
# caching Notebooks causes problems because different users have different read_write/owner values
CLASSES_NOT_TO_CACHE = ( Notebook, )
def __init__( self, connection = None, cache = None, host = None, ssl_mode = None, data_dir = None ):
"""
Create a new database and return it.
@type connection: existing connection object with cursor()/close()/commit() methods, or NoneType
@param connection: database connection to use (optional, defaults to making a connection pool)
@type cache: cmemcache.Client or something with a similar API, or NoneType
@param cache: existing memory cache to use (optional, defaults to making a cache)
@type host: unicode or NoneType
@param host: hostname of PostgreSQL database, or None to use a local SQLite database
@type ssl_mode: unicode or NoneType
@param ssl_mode: SSL mode for the database connection, one of "disallow", "allow", "prefer", or
"require". ignored if host is None
@type data_dir: unicode or NoneType
@param data_dir: directory in which to store data (defaults to a reasonable directory). ignored
if host is not None
@rtype: Database
@return: newly constructed Database
"""
# This tells PostgreSQL to give us timestamps in UTC. I'd use "set timezone" instead, but that
# makes SQLite angry.
os.putenv( "PGTZ", "UTC" )
if host is None:
from pysqlite2 import dbapi2 as sqlite
from datetime import datetime
from pytz import utc
TIMESTAMP_PATTERN = re.compile( "^(\d\d\d\d)-(\d\d)-(\d\d) (\d\d):(\d\d):(\d\d).(\d+)(?:\+\d\d:\d\d$)?" )
MICROSECONDS_PER_SECOND = 1000000
def convert_timestamp( value ):
( year, month, day, hours, minutes, seconds, fractional_seconds ) = \
TIMESTAMP_PATTERN.search( value ).groups( 0 )
# convert fractional seconds (with an arbitrary number of decimal places) to microseconds
microseconds = int( fractional_seconds )
while microseconds > MICROSECONDS_PER_SECOND:
fractional_seconds = fractional_seconds[ : -1 ]
microseconds = int( fractional_seconds or 0 )
# ignore time zone in timestamp and assume UTC
return datetime(
int( year ), int( month ), int( day ),
int( hours ), int( minutes ), int( seconds ), int( microseconds ),
utc,
)
sqlite.register_converter( "boolean", lambda value: value in ( "t", "True", "true" ) and True or False )
sqlite.register_converter( "timestamp", convert_timestamp )
if connection:
self.__connection = connection
else:
if data_dir is None:
if sys.platform.startswith( "win" ):
data_dir = os.path.join( os.environ.get( "APPDATA" ), "Luminotes" )
else:
data_dir = os.path.join( os.environ.get( "HOME", "" ), ".luminotes" )
data_filename = os.path.join( data_dir, "luminotes.db" )
# if the user doesn't yet have their own luminotes.db file, make them an initial copy
if os.path.exists( "luminotes.db" ):
if not os.path.exists( data_dir ):
import stat
os.makedirs( data_dir, stat.S_IXUSR | stat.S_IRUSR | stat.S_IWUSR )
if not os.path.exists( data_filename ):
import shutil
shutil.copyfile( "luminotes.db", data_filename )
self.__connection = \
Connection_wrapper( sqlite.connect( data_filename, detect_types = sqlite.PARSE_DECLTYPES, check_same_thread = False ) )
self.__pool = None
self.__backend = Persistent.SQLITE_BACKEND
self.lock = threading.Lock() # multiple simultaneous client threads make SQLite angry
else:
import psycopg2 as psycopg
from psycopg2.pool import PersistentConnectionPool
# forcibly replace psycopg's connect() function with another function that returns the psycopg
# connection wrapped in a class with a pending_saves member, used in save() and commit() below
original_connect = psycopg.connect
def connect( *args, **kwargs ):
return Connection_wrapper( original_connect( *args, **kwargs ) )
psycopg.connect = connect
if connection:
self.__connection = connection
self.__pool = None
else:
self.__connection = None
self.__pool = PersistentConnectionPool(
1, # minimum connections
50, # maximum connections
"host=%s sslmode=%s dbname=luminotes user=luminotes password=%s" % (
host or "localhost",
ssl_mode or "allow",
os.getenv( "PGPASSWORD", "dev" )
),
)
self.__backend = Persistent.POSTGRESQL_BACKEND
self.lock = None # PostgreSQL does its own synchronization
self.__cache = cache
try:
if self.__cache is None:
import cmemcache
print "using memcached"
except ImportError:
return None
def get_connection( self ):
if self.__connection:
return self.__connection
else:
return self.__pool.getconn()
def __get_cache_connection( self ):
if self.__cache is not None:
return self.__cache
try:
import cmemcache
return cmemcache.Client( [ "127.0.0.1:11211" ], debug = 0 )
except ImportError:
return None
def unescape( self, sql_command ):
"""
For backends that don't treat backslashes specially, un-double all backslashes in the given
sql_command.
"""
if self.__backend == Persistent.SQLITE_BACKEND:
return sql_command.replace( "\\\\", "\\" )
return sql_command
@synchronized
def save( self, obj, commit = True ):
"""
Save the given object to the database.
@type obj: Persistent
@param obj: object to save
@type commit: bool
@param commit: True to automatically commit after the save
"""
connection = self.get_connection()
cursor = connection.cursor()
cursor.execute( self.unescape( obj.sql_exists() ) )
if cursor.fetchone():
cursor.execute( self.unescape( obj.sql_update() ) )
else:
cursor.execute( self.unescape( obj.sql_create() ) )
if isinstance( obj, self.CLASSES_NOT_TO_CACHE ):
cache = None
else:
cache = self.__get_cache_connection()
if commit:
connection.commit()
if cache:
cache.set( obj.cache_key, obj )
elif cache:
# no commit yet, so don't touch the cache
connection.pending_saves.append( obj )
@synchronized
def commit( self ):
connection = self.get_connection()
connection.commit()
# save any pending saves to the cache
cache = self.__get_cache_connection()
if cache:
for obj in connection.pending_saves:
cache.set( obj.cache_key, obj )
connection.pending_saves = []
@synchronized
def rollback( self ):
connection = self.get_connection()
connection.rollback()
def load( self, Object_type, object_id, revision = None ):
"""
Load the object corresponding to the given object id from the database and return it, or None if
the object_id is unknown. If a revision is provided, a specific revision of the object will be
loaded.
@type Object_type: type
@param Object_type: class of the object to load
@type object_id: unicode
@param object_id: id of the object to load
@type revision: int or NoneType
@param revision: revision of the object to load (optional)
@rtype: Object_type or NoneType
@return: loaded object, or None if no match
"""
if revision or Object_type in self.CLASSES_NOT_TO_CACHE:
cache = None
else:
cache = self.__get_cache_connection()
if cache: # don't bother caching old revisions
obj = cache.get( Persistent.make_cache_key( Object_type, object_id ) )
if obj:
return obj
obj = self.select_one( Object_type, Object_type.sql_load( object_id, revision ) )
if obj and cache:
cache.set( obj.cache_key, obj )
return obj
@synchronized
def select_one( self, Object_type, sql_command, use_cache = False ):
"""
Execute the given sql_command and return its results in the form of an object of Object_type,
or None if there was no match.
@type Object_type: type
@param Object_type: class of the object to load
@type sql_command: unicode
@param sql_command: SQL command to execute
@type use_cache: bool
@param use_cache: whether to look for and store objects in the cache
@rtype: Object_type or NoneType
@return: loaded object, or None if no match
"""
if not use_cache or Object_type in self.CLASSES_NOT_TO_CACHE:
cache = None
else:
cache = self.__get_cache_connection()
if cache:
cache_key = sha.new( sql_command ).hexdigest()
obj = cache.get( cache_key )
if obj:
return obj
connection = self.get_connection()
cursor = connection.cursor()
cursor.execute( self.unescape( sql_command ) )
row = self.__row_to_unicode( cursor.fetchone() )
if not row:
return None
if Object_type in ( tuple, list ):
obj = Object_type( row )
else:
obj = Object_type( *row )
if obj and cache:
cache.set( cache_key, obj )
return obj
@synchronized
def select_many( self, Object_type, sql_command ):
"""
Execute the given sql_command and return its results in the form of a list of objects of
Object_type.
@type Object_type: type
@param Object_type: class of the object to load
@type sql_command: unicode
@param sql_command: SQL command to execute
@rtype: list of Object_type
@return: loaded objects
"""
connection = self.get_connection()
cursor = connection.cursor()
cursor.execute( self.unescape( sql_command ) )
objects = []
row = self.__row_to_unicode( cursor.fetchone() )
while row:
if Object_type in ( tuple, list ):
obj = Object_type( row )
else:
obj = Object_type( *row )
objects.append( obj )
row = self.__row_to_unicode( cursor.fetchone() )
return objects
def __row_to_unicode( self, row ):
if row is None:
return None
return [ isinstance( item, str ) and unicode( item, encoding = "utf8" ) or item for item in row ]
@synchronized
def execute( self, sql_command, commit = True ):
"""
Execute the given sql_command.
@type sql_command: unicode
@param sql_command: SQL command to execute
@type commit: bool
@param commit: True to automatically commit after the command
"""
connection = self.get_connection()
cursor = connection.cursor()
cursor.execute( self.unescape( sql_command ) )
if commit:
connection.commit()
@synchronized
def execute_script( self, sql_commands, commit = True ):
"""
Execute the given sql_commands.
@type sql_command: unicode
@param sql_command: multiple SQL commands to execute
@type commit: bool
@param commit: True to automatically commit after the command
"""
connection = self.get_connection()
cursor = connection.cursor()
if self.__backend == Persistent.SQLITE_BACKEND:
cursor.executescript( sql_commands )
else:
cursor.execute( self.unescape( sql_commands ) )
if commit:
connection.commit()
def uncache_command( self, sql_command ):
cache = self.__get_cache_connection()
if not cache: return
cache_key = sha.new( sql_command ).hexdigest()
cache.delete( cache_key )
def uncache( self, obj ):
cache = self.__get_cache_connection()
if not cache: return
cache.delete( obj.cache_key )
def uncache_many( self, Object_type, obj_ids ):
cache = self.__get_cache_connection()
if not cache: return
for obj_id in obj_ids:
cache.delete( Persistent.make_cache_key( Object_type, obj_id ) )
@staticmethod
def generate_id():
int_id = random.getrandbits( Database.ID_BITS )
base = len( Database.ID_DIGITS )
digits = []
while True:
index = int_id % base
digits.insert( 0, Database.ID_DIGITS[ index ] )
int_id = int_id / base
if int_id == 0:
break
return "".join( digits )
@synchronized
def next_id( self, Object_type, commit = True ):
"""
Generate the next available object id and return it.
@type Object_type: type
@param Object_type: class of the object that the id is for
@type commit: bool
@param commit: True to automatically commit after storing the next id
"""
connection = self.get_connection()
cursor = connection.cursor()
# generate a random id, but on the off-chance that it collides with something else already in
# the database, try again
next_id = Database.generate_id()
cursor.execute( self.unescape( Object_type.sql_id_exists( next_id ) ) )
while cursor.fetchone() is not None:
next_id = Database.generate_id()
cursor.execute( self.unescape( Object_type.sql_id_exists( next_id ) ) )
# save a new object with the next_id to the database
obj = Object_type( next_id )
cursor.execute( self.unescape( obj.sql_create() ) )
if commit:
connection.commit()
return next_id
@synchronized
def close( self ):
"""
Shutdown the database.
"""
if self.__connection:
self.__connection.close()
if self.__pool:
self.__pool.closeall()
backend = property( lambda self: self.__backend )
def end_transaction( function ):
"""
Decorator that prevents transaction leaks by rolling back any transactions left open when the
wrapped function returns or raises.
"""
def rollback( *args, **kwargs ):
try:
return function( *args, **kwargs )
finally:
cherrypy.root.database.rollback()
return rollback
class Valid_id( object ):
"""
Validator for an object id.
"""
ID_PATTERN = re.compile( "^[%s]+$" % Database.ID_DIGITS )
def __init__( self, none_okay = False ):
self.__none_okay = none_okay
def __call__( self, value ):
if value in ( None, "None", "" ):
if self.__none_okay:
return None
else:
raise ValueError()
if self.ID_PATTERN.search( value ): return str( value )
raise ValueError()
class Valid_revision( object ):
"""
Validator for an object revision timestamp.
"""
REVISION_PATTERN = re.compile( "^\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d+[+-]\d\d(:)?\d\d$" )
def __init__( self, none_okay = False ):
self.__none_okay = none_okay
def __call__( self, value ):
if self.__none_okay and value in ( None, "None", "" ): return None
if self.REVISION_PATTERN.search( value ): return str( value )
raise ValueError()