diff --git a/NEWS b/NEWS index a5b7e45..4bab68c 100644 --- a/NEWS +++ b/NEWS @@ -1,3 +1,8 @@ +1.3.1: April 18, 2008 + * Now if you try to load a page where access is required, and you're not + logged in, you'll be redirected to a login page. After you login, you'll + be redirected to the page you were originally trying to access. + 1.3.0: April 16, 2008 * Created a new hierarchical note tree area for browsing notes. * Added a list of recent notes. diff --git a/controller/Expose.py b/controller/Expose.py index 6ef0cd8..b257443 100644 --- a/controller/Expose.py +++ b/controller/Expose.py @@ -58,6 +58,8 @@ def expose( view = None, rss = None ): if hasattr( error, "to_dict" ): if not view: raise error result = error.to_dict() + elif isinstance( error, cherrypy.HTTPRedirect ): + raise else: import traceback traceback.print_exc() diff --git a/controller/Files.py b/controller/Files.py index ed7cd98..dddb0d9 100644 --- a/controller/Files.py +++ b/controller/Files.py @@ -11,7 +11,7 @@ from threading import Lock, Event from Expose import expose from Validate import validate, Valid_int, Valid_bool, Validation_error from Database import Valid_id, end_transaction -from Users import grab_user_id +from Users import grab_user_id, Access_error from Expire import strongly_expire from model.File import File from model.User import User @@ -22,20 +22,6 @@ from view.Progress_bar import stream_progress, stream_quota_error, quota_error_s from view.File_preview_page import File_preview_page -class Access_error( Exception ): - def __init__( self, message = None ): - if message is None: - message = u"Sorry, you don't have access to do that. Please make sure you're logged in as the correct user." - - Exception.__init__( self, message ) - self.__message = message - - def to_dict( self ): - return dict( - error = self.__message - ) - - class Upload_error( Exception ): def __init__( self, message = None ): if message is None: diff --git a/controller/Notebooks.py b/controller/Notebooks.py index 87aaa8a..1bd816f 100644 --- a/controller/Notebooks.py +++ b/controller/Notebooks.py @@ -5,7 +5,7 @@ from datetime import datetime from Expose import expose from Validate import validate, Valid_string, Validation_error, Valid_bool from Database import Valid_id, Valid_revision, end_transaction -from Users import grab_user_id +from Users import grab_user_id, Access_error from Expire import strongly_expire from Html_nuker import Html_nuker from model.Notebook import Notebook @@ -19,20 +19,6 @@ from view.Html_file import Html_file from view.Note_tree_area import Note_tree_area -class Access_error( Exception ): - def __init__( self, message = None ): - if message is None: - message = u"Sorry, you don't have access to do that. Please make sure you're logged in as the correct user." - - Exception.__init__( self, message ) - self.__message = message - - def to_dict( self ): - return dict( - error = self.__message - ) - - class Notebooks( object ): WHITESPACE_PATTERN = re.compile( u"\s+" ) LINK_PATTERN = re.compile( u']+\s)?href="([^"]+)"(?:\s+target="([^"]*)")?[^>]*)>([^<]+)', re.IGNORECASE ) diff --git a/controller/Users.py b/controller/Users.py index 1f50b82..2dfb1cb 100644 --- a/controller/Users.py +++ b/controller/Users.py @@ -131,7 +131,17 @@ def grab_user_id( function ): else: kwargs[ "user_id" ] = cherrypy.session.get( "user_id" ) - return function( *args, **kwargs ) + try: + return function( *args, **kwargs ) + except Access_error: + # if there was an Access_error, and the user isn't logged in, and this is an HTTP GET request, + # redirect to the login page + if cherrypy.session.get( "user_id" ) is None and cherrypy.request.method == "GET": + original_path = cherrypy.request.path + \ + ( cherrypy.request.query_string and u"?%s" % cherrypy.request.query_string or "" ) + raise cherrypy.HTTPRedirect( u"/login?after_login=%s" % urllib.quote( original_path ) ) + else: + raise return get_id diff --git a/controller/test/Test_files.py b/controller/test/Test_files.py index eb7ab11..5834dcc 100644 --- a/controller/test/Test_files.py +++ b/controller/test/Test_files.py @@ -286,11 +286,12 @@ class Test_files( Test_controller ): session_id = self.session_id, ) - result = self.http_get( - "/files/download?file_id=%s" % self.file_id, - ) + path = "/files/download?file_id=%s" % self.file_id + result = self.http_get( path ) - assert u"access" in result[ u"body" ][ 0 ] + headers = result.get( "headers" ) + assert headers + assert headers.get( "Location" ) == u"http:///login?after_login=%s" % urllib.quote( path ) def test_download_without_access( self ): self.login() @@ -413,11 +414,12 @@ class Test_files( Test_controller ): session_id = self.session_id, ) - result = self.http_get( - "/files/preview?file_id=%s" % self.file_id, - ) + path = "/files/preview?file_id=%s" % self.file_id + result = self.http_get( path ) - assert u"access" in result[ u"error" ] + headers = result.get( "headers" ) + assert headers + assert headers.get( "Location" ) == u"http:///login?after_login=%s" % urllib.quote( path ) def test_preview_without_access( self ): self.login() @@ -611,11 +613,12 @@ class Test_files( Test_controller ): session_id = self.session_id, ) - result = self.http_get( - "/files/thumbnail?file_id=%s" % self.file_id, - ) + path = "/files/thumbnail?file_id=%s" % self.file_id + result = self.http_get( path ) - assert u"access" in result[ u"body" ][ 0 ] + headers = result.get( "headers" ) + assert headers + assert headers.get( "Location" ) == u"http:///login?after_login=%s" % urllib.quote( path ) def test_thumbnail_without_access( self ): self.login() @@ -720,11 +723,12 @@ class Test_files( Test_controller ): session_id = self.session_id, ) - result = self.http_get( - "/files/image?file_id=%s" % self.file_id, - ) + path = "/files/image?file_id=%s" % self.file_id + result = self.http_get( path ) - assert u"access" in result[ u"body" ][ 0 ] + headers = result.get( "headers" ) + assert headers + assert headers.get( "Location" ) == u"http:///login?after_login=%s" % urllib.quote( path ) def test_image_without_access( self ): self.login() @@ -773,11 +777,12 @@ class Test_files( Test_controller ): assert result.get( u"file_id" ) def test_upload_page_without_login( self ): - result = self.http_get( - "/files/upload_page?notebook_id=%s¬e_id=%s" % ( self.notebook.object_id, self.note.object_id ), - ) + path = "/files/upload_page?notebook_id=%s¬e_id=%s" % ( self.notebook.object_id, self.note.object_id ) + result = self.http_get( path ) - assert u"access" in result.get( u"error" ) + headers = result.get( "headers" ) + assert headers + assert headers.get( "Location" ) == u"http:///login?after_login=%s" % urllib.quote( path ) def test_upload( self, filename = None ): self.login() @@ -1063,11 +1068,12 @@ class Test_files( Test_controller ): self.upload_thread.start() # report on that file's upload progress - result = self.http_get( - "/files/progress?file_id=%s&filename=%s" % ( self.file_id, self.filename ), - ) + path = "/files/progress?file_id=%s&filename=%s" % ( self.file_id, self.filename ) + result = self.http_get( path ) - assert u"access" in result[ u"body" ][ 0 ] + headers = result.get( "headers" ) + assert headers + assert headers.get( "Location" ) == u"http:///login?after_login=%s" % urllib.quote( path ) def test_progress_for_completed_upload( self ): self.login() @@ -1192,11 +1198,12 @@ class Test_files( Test_controller ): session_id = self.session_id, ) - result = self.http_get( - "/files/stats?file_id=%s" % self.file_id, - ) + path = "/files/stats?file_id=%s" % self.file_id + result = self.http_get( path ) - assert u"access" in result[ u"error" ] + headers = result.get( "headers" ) + assert headers + assert headers.get( "Location" ) == u"http:///login?after_login=%s" % urllib.quote( path ) def test_stats_without_access( self ): self.login() diff --git a/controller/test/Test_notebooks.py b/controller/test/Test_notebooks.py index 5054eed..d4b5956 100644 --- a/controller/test/Test_notebooks.py +++ b/controller/test/Test_notebooks.py @@ -1,5 +1,6 @@ import cherrypy import cgi +import urllib from nose.tools import raises from urllib import quote from Test_controller import Test_controller @@ -99,11 +100,13 @@ class Test_notebooks( Test_controller ): self.database.save( self.invite, commit = False ) def test_default_without_login( self ): - result = self.http_get( - "/notebooks/%s" % self.notebook.object_id, - ) + path = "/notebooks/%s" % self.notebook.object_id + result = self.http_get( path ) - assert u"access" in result[ u"error" ] + headers = result.get( "headers" ) + assert headers + assert headers.get( "Location" ) == u"http:///login?after_login=%s" % urllib.quote( path ) + user = self.database.load( User, self.user.object_id ) assert user.storage_bytes == 0 @@ -2503,12 +2506,15 @@ class Test_notebooks( Test_controller ): note3 = Note.create( "55", u"

blah

foo", notebook_id = self.notebook.object_id ) self.database.save( note3 ) + path = "/notebooks/download_html/%s" % self.notebook.object_id result = self.http_get( - "/notebooks/download_html/%s" % self.notebook.object_id, + path, session_id = self.session_id, ) - assert result.get( "error" ) + headers = result.get( "headers" ) + assert headers + assert headers.get( "Location" ) == u"http:///login?after_login=%s" % urllib.quote( path ) def test_download_html_with_unknown_notebook( self ): self.login()