Factor out filtering of database configuration to function with tests (#228).

This commit is contained in:
Dan Helfman 2019-11-12 10:39:36 -08:00
parent 2a771161e7
commit a3e939f34b
3 changed files with 62 additions and 12 deletions

View File

@ -285,18 +285,9 @@ def run_actions(
)
# Map the restore names or detected dumps to the corresponding database configurations.
# TODO: Need to filter restore_names by database type? Maybe take a database --type argument to disambiguate.
restore_databases = {
hook_name: list(
dump.get_database_configurations(
hooks.get(hook_name),
restore_names
or dump.get_database_names_from_dumps(dump_patterns['hook_name']),
)
)
for hook_name in dump.DATABASE_HOOK_NAMES
if hook_name in hooks
}
restore_databases = dump.get_per_hook_database_configurations(
hooks, restore_names, dump_patterns
)
# Finally, restore the databases and cleanup the dumps.
dispatch.call_hooks(

View File

@ -100,3 +100,32 @@ def get_database_configurations(databases, names):
name
)
)
def get_per_hook_database_configurations(hooks, names, dump_patterns):
'''
Given the hooks configuration dict as per the configuration schema, a sequence of database
names to restore, and a dict from database hook name to glob patterns for matching dumps,
filter down the configuration for just the named databases.
If there are no named databases given, then find the corresponding database dumps on disk and
use the database names from their filenames. Additionally, if a database configuration is named
"all", project out that configuration for each named database.
Return the results as a dict from database hook name to a sequence of database configuration
dicts for that database type.
Raise ValueError if one of the database names cannot be matched to a database in borgmatic's
database configuration.
'''
# TODO: Need to filter names by database type? Maybe take a database --type argument to disambiguate.
return {
hook_name: list(
get_database_configurations(
hooks.get(hook_name),
names or get_database_names_from_dumps(dump_patterns[hook_name]),
)
)
for hook_name in DATABASE_HOOK_NAMES
if hook_name in hooks
}

View File

@ -99,3 +99,33 @@ def test_get_database_configurations_with_unknown_database_name_raises():
with pytest.raises(ValueError):
list(module.get_database_configurations(databases, ('foo', 'bar')))
def test_get_per_hook_database_configurations_partitions_by_hook():
hooks = {'postgresql_databases': [flexmock()]}
names = ('foo', 'bar')
dump_patterns = flexmock()
expected_config = {'postgresql_databases': [flexmock()]}
flexmock(module).should_receive('get_database_configurations').with_args(
hooks['postgresql_databases'], names
).and_return(expected_config['postgresql_databases'])
config = module.get_per_hook_database_configurations(hooks, names, dump_patterns)
assert config == expected_config
def test_get_per_hook_database_configurations_defaults_to_detected_database_names():
hooks = {'postgresql_databases': [flexmock()]}
names = ()
detected_names = flexmock()
dump_patterns = {'postgresql_databases': [flexmock()]}
expected_config = {'postgresql_databases': [flexmock()]}
flexmock(module).should_receive('get_database_names_from_dumps').and_return(detected_names)
flexmock(module).should_receive('get_database_configurations').with_args(
hooks['postgresql_databases'], detected_names
).and_return(expected_config['postgresql_databases'])
config = module.get_per_hook_database_configurations(hooks, names, dump_patterns)
assert config == expected_config