diff --git a/borgmatic/commands/borgmatic.py b/borgmatic/commands/borgmatic.py index 40c7065cd..8266254c4 100644 --- a/borgmatic/commands/borgmatic.py +++ b/borgmatic/commands/borgmatic.py @@ -285,18 +285,9 @@ def run_actions( ) # Map the restore names or detected dumps to the corresponding database configurations. - # TODO: Need to filter restore_names by database type? Maybe take a database --type argument to disambiguate. - restore_databases = { - hook_name: list( - dump.get_database_configurations( - hooks.get(hook_name), - restore_names - or dump.get_database_names_from_dumps(dump_patterns['hook_name']), - ) - ) - for hook_name in dump.DATABASE_HOOK_NAMES - if hook_name in hooks - } + restore_databases = dump.get_per_hook_database_configurations( + hooks, restore_names, dump_patterns + ) # Finally, restore the databases and cleanup the dumps. dispatch.call_hooks( diff --git a/borgmatic/hooks/dump.py b/borgmatic/hooks/dump.py index cfaf0be6a..c72b45d4d 100644 --- a/borgmatic/hooks/dump.py +++ b/borgmatic/hooks/dump.py @@ -100,3 +100,32 @@ def get_database_configurations(databases, names): name ) ) + + +def get_per_hook_database_configurations(hooks, names, dump_patterns): + ''' + Given the hooks configuration dict as per the configuration schema, a sequence of database + names to restore, and a dict from database hook name to glob patterns for matching dumps, + filter down the configuration for just the named databases. + + If there are no named databases given, then find the corresponding database dumps on disk and + use the database names from their filenames. Additionally, if a database configuration is named + "all", project out that configuration for each named database. + + Return the results as a dict from database hook name to a sequence of database configuration + dicts for that database type. + + Raise ValueError if one of the database names cannot be matched to a database in borgmatic's + database configuration. + ''' + # TODO: Need to filter names by database type? Maybe take a database --type argument to disambiguate. + return { + hook_name: list( + get_database_configurations( + hooks.get(hook_name), + names or get_database_names_from_dumps(dump_patterns[hook_name]), + ) + ) + for hook_name in DATABASE_HOOK_NAMES + if hook_name in hooks + } diff --git a/tests/unit/hooks/test_dump.py b/tests/unit/hooks/test_dump.py index 11b9e94d1..0a54e2abf 100644 --- a/tests/unit/hooks/test_dump.py +++ b/tests/unit/hooks/test_dump.py @@ -99,3 +99,33 @@ def test_get_database_configurations_with_unknown_database_name_raises(): with pytest.raises(ValueError): list(module.get_database_configurations(databases, ('foo', 'bar'))) + + +def test_get_per_hook_database_configurations_partitions_by_hook(): + hooks = {'postgresql_databases': [flexmock()]} + names = ('foo', 'bar') + dump_patterns = flexmock() + expected_config = {'postgresql_databases': [flexmock()]} + flexmock(module).should_receive('get_database_configurations').with_args( + hooks['postgresql_databases'], names + ).and_return(expected_config['postgresql_databases']) + + config = module.get_per_hook_database_configurations(hooks, names, dump_patterns) + + assert config == expected_config + + +def test_get_per_hook_database_configurations_defaults_to_detected_database_names(): + hooks = {'postgresql_databases': [flexmock()]} + names = () + detected_names = flexmock() + dump_patterns = {'postgresql_databases': [flexmock()]} + expected_config = {'postgresql_databases': [flexmock()]} + flexmock(module).should_receive('get_database_names_from_dumps').and_return(detected_names) + flexmock(module).should_receive('get_database_configurations').with_args( + hooks['postgresql_databases'], detected_names + ).and_return(expected_config['postgresql_databases']) + + config = module.get_per_hook_database_configurations(hooks, names, dump_patterns) + + assert config == expected_config