Personal wiki notebook (not under development)

Database.py 15KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520
  1. import re
  2. import os
  3. import os.path
  4. import sys
  5. import sha
  6. import cherrypy
  7. import random
  8. import threading
  9. from model.Persistent import Persistent
  10. from model.Notebook import Notebook
  11. class Connection_wrapper( object ):
  12. def __init__( self, connection ):
  13. self.connection = connection
  14. self.pending_saves = []
  15. def __getattr__( self, name ):
  16. return getattr( self.connection, name )
  17. def synchronized( method ):
  18. def lock( self, *args, **kwargs ):
  19. if self.lock:
  20. self.lock.acquire()
  21. try:
  22. return method( self, *args, **kwargs )
  23. finally:
  24. if self.lock:
  25. self.lock.release()
  26. return lock
  27. class Database( object ):
  28. ID_BITS = 128 # number of bits within an id
  29. ID_DIGITS = "0123456789abcdefghijklmnopqrstuvwxyz"
  30. # caching Notebooks causes problems because different users have different read_write/owner values
  31. CLASSES_NOT_TO_CACHE = ( Notebook, )
  32. def __init__( self, connection = None, cache = None, host = None, ssl_mode = None, data_dir = None ):
  33. """
  34. Create a new database and return it.
  35. @type connection: existing connection object with cursor()/close()/commit() methods, or NoneType
  36. @param connection: database connection to use (optional, defaults to making a connection pool)
  37. @type cache: cmemcache.Client or something with a similar API, or NoneType
  38. @param cache: existing memory cache to use (optional, defaults to making a cache)
  39. @type host: unicode or NoneType
  40. @param host: hostname of PostgreSQL database, or None to use a local SQLite database
  41. @type ssl_mode: unicode or NoneType
  42. @param ssl_mode: SSL mode for the database connection, one of "disallow", "allow", "prefer", or
  43. "require". ignored if host is None
  44. @type data_dir: unicode or NoneType
  45. @param data_dir: directory in which to store data (defaults to a reasonable directory). ignored
  46. if host is not None
  47. @rtype: Database
  48. @return: newly constructed Database
  49. """
  50. # This tells PostgreSQL to give us timestamps in UTC. I'd use "set timezone" instead, but that
  51. # makes SQLite angry.
  52. os.putenv( "PGTZ", "UTC" )
  53. if host is None:
  54. from pysqlite2 import dbapi2 as sqlite
  55. from datetime import datetime
  56. from pytz import utc
  57. TIMESTAMP_PATTERN = re.compile( "^(\d\d\d\d)-(\d\d)-(\d\d) (\d\d):(\d\d):(\d\d).(\d+)(?:\+\d\d:\d\d$)?" )
  58. MICROSECONDS_PER_SECOND = 1000000
  59. def convert_timestamp( value ):
  60. ( year, month, day, hours, minutes, seconds, fractional_seconds ) = \
  61. TIMESTAMP_PATTERN.search( value ).groups( 0 )
  62. # convert fractional seconds (with an arbitrary number of decimal places) to microseconds
  63. microseconds = int( fractional_seconds )
  64. while microseconds > MICROSECONDS_PER_SECOND:
  65. fractional_seconds = fractional_seconds[ : -1 ]
  66. microseconds = int( fractional_seconds or 0 )
  67. # ignore time zone in timestamp and assume UTC
  68. return datetime(
  69. int( year ), int( month ), int( day ),
  70. int( hours ), int( minutes ), int( seconds ), int( microseconds ),
  71. utc,
  72. )
  73. sqlite.register_converter( "boolean", lambda value: value in ( "t", "True", "true" ) and True or False )
  74. sqlite.register_converter( "timestamp", convert_timestamp )
  75. if connection:
  76. self.__connection = connection
  77. else:
  78. if data_dir is None:
  79. if sys.platform.startswith( "win" ):
  80. data_dir = os.path.join( os.environ.get( "APPDATA" ), "Luminotes" )
  81. else:
  82. data_dir = os.path.join( os.environ.get( "HOME", "" ), ".luminotes" )
  83. data_filename = os.path.join( data_dir, "luminotes.db" )
  84. # if the user doesn't yet have their own luminotes.db file, make them an initial copy
  85. if os.path.exists( "luminotes.db" ):
  86. if not os.path.exists( data_dir ):
  87. import stat
  88. os.makedirs( data_dir, stat.S_IXUSR | stat.S_IRUSR | stat.S_IWUSR )
  89. if not os.path.exists( data_filename ):
  90. import shutil
  91. shutil.copyfile( "luminotes.db", data_filename )
  92. self.__connection = \
  93. Connection_wrapper( sqlite.connect( data_filename, detect_types = sqlite.PARSE_DECLTYPES, check_same_thread = False ) )
  94. self.__pool = None
  95. self.__backend = Persistent.SQLITE_BACKEND
  96. self.lock = threading.Lock() # multiple simultaneous client threads make SQLite angry
  97. else:
  98. import psycopg2 as psycopg
  99. from psycopg2.pool import PersistentConnectionPool
  100. # forcibly replace psycopg's connect() function with another function that returns the psycopg
  101. # connection wrapped in a class with a pending_saves member, used in save() and commit() below
  102. original_connect = psycopg.connect
  103. def connect( *args, **kwargs ):
  104. return Connection_wrapper( original_connect( *args, **kwargs ) )
  105. psycopg.connect = connect
  106. if connection:
  107. self.__connection = connection
  108. self.__pool = None
  109. else:
  110. self.__connection = None
  111. self.__pool = PersistentConnectionPool(
  112. 1, # minimum connections
  113. 50, # maximum connections
  114. "host=%s sslmode=%s dbname=luminotes user=luminotes password=%s" % (
  115. host or "localhost",
  116. ssl_mode or "allow",
  117. os.getenv( "PGPASSWORD", "dev" )
  118. ),
  119. )
  120. self.__backend = Persistent.POSTGRESQL_BACKEND
  121. self.lock = None # PostgreSQL does its own synchronization
  122. self.__cache = cache
  123. try:
  124. if self.__cache is None:
  125. import cmemcache
  126. print "using memcached"
  127. except ImportError:
  128. return None
  129. def get_connection( self ):
  130. if self.__connection:
  131. return self.__connection
  132. else:
  133. return self.__pool.getconn()
  134. def __get_cache_connection( self ):
  135. if self.__cache is not None:
  136. return self.__cache
  137. try:
  138. import cmemcache
  139. return cmemcache.Client( [ "127.0.0.1:11211" ], debug = 0 )
  140. except ImportError:
  141. return None
  142. def unescape( self, sql_command ):
  143. """
  144. For backends that don't treat backslashes specially, un-double all backslashes in the given
  145. sql_command.
  146. """
  147. if self.__backend == Persistent.SQLITE_BACKEND:
  148. return sql_command.replace( "\\\\", "\\" )
  149. return sql_command
  150. @synchronized
  151. def save( self, obj, commit = True ):
  152. """
  153. Save the given object to the database.
  154. @type obj: Persistent
  155. @param obj: object to save
  156. @type commit: bool
  157. @param commit: True to automatically commit after the save
  158. """
  159. connection = self.get_connection()
  160. cursor = connection.cursor()
  161. cursor.execute( self.unescape( obj.sql_exists() ) )
  162. if cursor.fetchone():
  163. cursor.execute( self.unescape( obj.sql_update() ) )
  164. else:
  165. cursor.execute( self.unescape( obj.sql_create() ) )
  166. if isinstance( obj, self.CLASSES_NOT_TO_CACHE ):
  167. cache = None
  168. else:
  169. cache = self.__get_cache_connection()
  170. if commit:
  171. connection.commit()
  172. if cache:
  173. cache.set( obj.cache_key, obj )
  174. elif cache:
  175. # no commit yet, so don't touch the cache
  176. connection.pending_saves.append( obj )
  177. @synchronized
  178. def commit( self ):
  179. connection = self.get_connection()
  180. connection.commit()
  181. # save any pending saves to the cache
  182. cache = self.__get_cache_connection()
  183. if cache:
  184. for obj in connection.pending_saves:
  185. cache.set( obj.cache_key, obj )
  186. connection.pending_saves = []
  187. @synchronized
  188. def rollback( self ):
  189. connection = self.get_connection()
  190. connection.rollback()
  191. def load( self, Object_type, object_id, revision = None ):
  192. """
  193. Load the object corresponding to the given object id from the database and return it, or None if
  194. the object_id is unknown. If a revision is provided, a specific revision of the object will be
  195. loaded.
  196. @type Object_type: type
  197. @param Object_type: class of the object to load
  198. @type object_id: unicode
  199. @param object_id: id of the object to load
  200. @type revision: int or NoneType
  201. @param revision: revision of the object to load (optional)
  202. @rtype: Object_type or NoneType
  203. @return: loaded object, or None if no match
  204. """
  205. if revision or Object_type in self.CLASSES_NOT_TO_CACHE:
  206. cache = None
  207. else:
  208. cache = self.__get_cache_connection()
  209. if cache: # don't bother caching old revisions
  210. obj = cache.get( Persistent.make_cache_key( Object_type, object_id ) )
  211. if obj:
  212. return obj
  213. obj = self.select_one( Object_type, Object_type.sql_load( object_id, revision ) )
  214. if obj and cache:
  215. cache.set( obj.cache_key, obj )
  216. return obj
  217. @synchronized
  218. def select_one( self, Object_type, sql_command, use_cache = False ):
  219. """
  220. Execute the given sql_command and return its results in the form of an object of Object_type,
  221. or None if there was no match.
  222. @type Object_type: type
  223. @param Object_type: class of the object to load
  224. @type sql_command: unicode
  225. @param sql_command: SQL command to execute
  226. @type use_cache: bool
  227. @param use_cache: whether to look for and store objects in the cache
  228. @rtype: Object_type or NoneType
  229. @return: loaded object, or None if no match
  230. """
  231. if not use_cache or Object_type in self.CLASSES_NOT_TO_CACHE:
  232. cache = None
  233. else:
  234. cache = self.__get_cache_connection()
  235. if cache:
  236. cache_key = sha.new( sql_command ).hexdigest()
  237. obj = cache.get( cache_key )
  238. if obj:
  239. return obj
  240. connection = self.get_connection()
  241. cursor = connection.cursor()
  242. cursor.execute( self.unescape( sql_command ) )
  243. row = self.__row_to_unicode( cursor.fetchone() )
  244. if not row:
  245. return None
  246. if Object_type in ( tuple, list ):
  247. obj = Object_type( row )
  248. else:
  249. obj = Object_type( *row )
  250. if obj and cache:
  251. cache.set( cache_key, obj )
  252. return obj
  253. @synchronized
  254. def select_many( self, Object_type, sql_command ):
  255. """
  256. Execute the given sql_command and return its results in the form of a list of objects of
  257. Object_type.
  258. @type Object_type: type
  259. @param Object_type: class of the object to load
  260. @type sql_command: unicode
  261. @param sql_command: SQL command to execute
  262. @rtype: list of Object_type
  263. @return: loaded objects
  264. """
  265. connection = self.get_connection()
  266. cursor = connection.cursor()
  267. cursor.execute( self.unescape( sql_command ) )
  268. objects = []
  269. row = self.__row_to_unicode( cursor.fetchone() )
  270. while row:
  271. if Object_type in ( tuple, list ):
  272. obj = Object_type( row )
  273. else:
  274. obj = Object_type( *row )
  275. objects.append( obj )
  276. row = self.__row_to_unicode( cursor.fetchone() )
  277. return objects
  278. def __row_to_unicode( self, row ):
  279. if row is None:
  280. return None
  281. return [ isinstance( item, str ) and unicode( item, encoding = "utf8" ) or item for item in row ]
  282. @synchronized
  283. def execute( self, sql_command, commit = True ):
  284. """
  285. Execute the given sql_command.
  286. @type sql_command: unicode
  287. @param sql_command: SQL command to execute
  288. @type commit: bool
  289. @param commit: True to automatically commit after the command
  290. """
  291. connection = self.get_connection()
  292. cursor = connection.cursor()
  293. cursor.execute( self.unescape( sql_command ) )
  294. if commit:
  295. connection.commit()
  296. @synchronized
  297. def execute_script( self, sql_commands, commit = True ):
  298. """
  299. Execute the given sql_commands.
  300. @type sql_command: unicode
  301. @param sql_command: multiple SQL commands to execute
  302. @type commit: bool
  303. @param commit: True to automatically commit after the command
  304. """
  305. connection = self.get_connection()
  306. cursor = connection.cursor()
  307. if self.__backend == Persistent.SQLITE_BACKEND:
  308. cursor.executescript( sql_commands )
  309. else:
  310. cursor.execute( self.unescape( sql_commands ) )
  311. if commit:
  312. connection.commit()
  313. def uncache_command( self, sql_command ):
  314. cache = self.__get_cache_connection()
  315. if not cache: return
  316. cache_key = sha.new( sql_command ).hexdigest()
  317. cache.delete( cache_key )
  318. def uncache( self, obj ):
  319. cache = self.__get_cache_connection()
  320. if not cache: return
  321. cache.delete( obj.cache_key )
  322. def uncache_many( self, Object_type, obj_ids ):
  323. cache = self.__get_cache_connection()
  324. if not cache: return
  325. for obj_id in obj_ids:
  326. cache.delete( Persistent.make_cache_key( Object_type, obj_id ) )
  327. @staticmethod
  328. def generate_id():
  329. int_id = random.getrandbits( Database.ID_BITS )
  330. base = len( Database.ID_DIGITS )
  331. digits = []
  332. while True:
  333. index = int_id % base
  334. digits.insert( 0, Database.ID_DIGITS[ index ] )
  335. int_id = int_id / base
  336. if int_id == 0:
  337. break
  338. return "".join( digits )
  339. @synchronized
  340. def next_id( self, Object_type, commit = True ):
  341. """
  342. Generate the next available object id and return it.
  343. @type Object_type: type
  344. @param Object_type: class of the object that the id is for
  345. @type commit: bool
  346. @param commit: True to automatically commit after storing the next id
  347. """
  348. connection = self.get_connection()
  349. cursor = connection.cursor()
  350. # generate a random id, but on the off-chance that it collides with something else already in
  351. # the database, try again
  352. next_id = Database.generate_id()
  353. cursor.execute( self.unescape( Object_type.sql_id_exists( next_id ) ) )
  354. while cursor.fetchone() is not None:
  355. next_id = Database.generate_id()
  356. cursor.execute( self.unescape( Object_type.sql_id_exists( next_id ) ) )
  357. # save a new object with the next_id to the database
  358. obj = Object_type( next_id )
  359. cursor.execute( self.unescape( obj.sql_create() ) )
  360. if commit:
  361. connection.commit()
  362. return next_id
  363. @synchronized
  364. def close( self ):
  365. """
  366. Shutdown the database.
  367. """
  368. if self.__connection:
  369. self.__connection.close()
  370. if self.__pool:
  371. self.__pool.closeall()
  372. backend = property( lambda self: self.__backend )
  373. def end_transaction( function ):
  374. """
  375. Decorator that prevents transaction leaks by rolling back any transactions left open when the
  376. wrapped function returns or raises.
  377. """
  378. def rollback( *args, **kwargs ):
  379. try:
  380. return function( *args, **kwargs )
  381. finally:
  382. cherrypy.root.database.rollback()
  383. return rollback
  384. class Valid_id( object ):
  385. """
  386. Validator for an object id.
  387. """
  388. ID_PATTERN = re.compile( "^[%s]+$" % Database.ID_DIGITS )
  389. def __init__( self, none_okay = False ):
  390. self.__none_okay = none_okay
  391. def __call__( self, value ):
  392. if value in ( None, "None", "" ):
  393. if self.__none_okay:
  394. return None
  395. else:
  396. raise ValueError()
  397. if self.ID_PATTERN.search( value ): return str( value )
  398. raise ValueError()
  399. class Valid_revision( object ):
  400. """
  401. Validator for an object revision timestamp.
  402. """
  403. REVISION_PATTERN = re.compile( "^\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d+[+-]\d\d(:)?\d\d$" )
  404. def __init__( self, none_okay = False ):
  405. self.__none_okay = none_okay
  406. def __call__( self, value ):
  407. if self.__none_okay and value in ( None, "None", "" ): return None
  408. if self.REVISION_PATTERN.search( value ): return str( value )
  409. raise ValueError()