diff --git a/conservancy_beancount/plugin/__init__.py b/conservancy_beancount/plugin/__init__.py index 4cbf473630d537dba26e872c6c9ebdc75b97282f..3c8842332ca052bb9db5a8ed3be048a13ab55556 100644 --- a/conservancy_beancount/plugin/__init__.py +++ b/conservancy_beancount/plugin/__init__.py @@ -14,6 +14,8 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +import importlib + import beancount.core.data as bc_data __plugins__ = ['run'] @@ -24,18 +26,61 @@ class HookRegistry: 'Posting', ]) - @classmethod - def group_by_directive(cls, hooks_seq): - hooks_map = {key: [] for key in cls.DIRECTIVES} - for hook in hooks_seq: - for key in cls.DIRECTIVES & hook.HOOK_GROUPS: - hooks_map[key].append(hook) - return hooks_map + def __init__(self): + self.group_hooks_map = {key: set() for key in self.DIRECTIVES} + + def add_hook(self, hook_cls): + hook_groups = list(hook_cls.HOOK_GROUPS) + assert self.DIRECTIVES.intersection(hook_groups) + hook_groups.append('all') + for name_attr in ['HOOK_NAME', 'METADATA_KEY', '__name__']: + try: + hook_name = getattr(hook_cls, name_attr) + except AttributeError: + pass + else: + hook_groups.append(hook_name) + break + for key in hook_groups: + self.group_hooks_map.setdefault(key, set()).add(hook_cls) + return hook_cls # to allow use as a decorator + + def import_hooks(self, mod_name, *hook_names, package=__module__): + module = importlib.import_module(mod_name, package) + for hook_name in hook_names: + self.add_hook(getattr(module, hook_name)) + + def group_by_directive(self, config_str=''): + config_str = config_str.strip() + if not config_str: + config_str = 'all' + elif config_str.startswith('-'): + config_str = 'all ' + config_str + available_hooks = set() + for token in config_str.split(): + if token.startswith('-'): + update_available = available_hooks.difference_update + key = token[1:] + else: + update_available = available_hooks.update + key = token + try: + update_set = self.group_hooks_map[key] + except KeyError: + raise ValueError("configuration refers to unknown hooks {!r}".format(key)) from None + else: + update_available(update_set) + return {key: [hook() for hook in self.group_hooks_map[key] & available_hooks] + for key in self.DIRECTIVES} + +HOOK_REGISTRY = HookRegistry() +HOOK_REGISTRY.import_hooks('.meta_expense_allocation', 'MetaExpenseAllocation') +HOOK_REGISTRY.import_hooks('.meta_tax_implication', 'MetaTaxImplication') -def run(entries, options_map, config): +def run(entries, options_map, config='', hook_registry=HOOK_REGISTRY): errors = [] - hooks = HookRegistry.group_by_directive(config) + hooks = hook_registry.group_by_directive(config) for entry in entries: entry_type = type(entry).__name__ for hook in hooks[entry_type]: diff --git a/conservancy_beancount/plugin/core.py b/conservancy_beancount/plugin/core.py index 4895c050711ee024480b1f9185f2296683aafc10..f7ecce75e269bb1e66e18e9fd5e0691467e8daf4 100644 --- a/conservancy_beancount/plugin/core.py +++ b/conservancy_beancount/plugin/core.py @@ -67,6 +67,7 @@ class MetadataEnum: class PostingChecker: + HOOK_GROUPS = frozenset(['Posting', 'metadata']) ACCOUNTS = ('',) TXN_DATE_RANGE = _GenericRange(DEFAULT_START_DATE, DEFAULT_STOP_DATE) VALUES_ENUM = {} diff --git a/tests/test_plugin_HookRegistry.py b/tests/test_plugin_HookRegistry.py new file mode 100644 index 0000000000000000000000000000000000000000..b7fd0b9521b6c461d0726c840e5203f0d2378bd3 --- /dev/null +++ b/tests/test_plugin_HookRegistry.py @@ -0,0 +1,53 @@ +"""Test main plugin's HookRegistry""" +# Copyright © 2020 Brett Smith +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +import pytest + +from . import testutil + +from conservancy_beancount import plugin + +def hook_names(hooks, key): + return {type(hook).__name__ for hook in hooks[key]} + +def test_default_registrations(): + hooks = plugin.HOOK_REGISTRY.group_by_directive() + post_hook_names = hook_names(hooks, 'Posting') + assert len(post_hook_names) >= 2 + assert 'MetaExpenseAllocation' in post_hook_names + assert 'MetaTaxImplication' in post_hook_names + +def test_exclude_single(): + hooks = plugin.HOOK_REGISTRY.group_by_directive('-expenseAllocation') + post_hook_names = hook_names(hooks, 'Posting') + assert post_hook_names + assert 'MetaExpenseAllocation' not in post_hook_names + +def test_exclude_group_then_include_single(): + hooks = plugin.HOOK_REGISTRY.group_by_directive('-metadata expenseAllocation') + post_hook_names = hook_names(hooks, 'Posting') + assert 'MetaExpenseAllocation' in post_hook_names + assert 'MetaTaxImplication' not in post_hook_names + +def test_include_group_then_exclude_single(): + hooks = plugin.HOOK_REGISTRY.group_by_directive('metadata -taxImplication') + post_hook_names = hook_names(hooks, 'Posting') + assert 'MetaExpenseAllocation' in post_hook_names + assert 'MetaTaxImplication' not in post_hook_names + +def test_unknown_group_name(): + with pytest.raises(ValueError): + plugin.HOOK_REGISTRY.group_by_directive('UnKnownTestGroup') diff --git a/tests/test_plugin_run.py b/tests/test_plugin_run.py index 957e280a653e5ddf745c72f76b8e5539c6588338..521193490bd94db6528af9c7f8bfdf1b8bee6d08 100644 --- a/tests/test_plugin_run.py +++ b/tests/test_plugin_run.py @@ -21,7 +21,9 @@ from . import testutil from conservancy_beancount import plugin CONFIG_MAP = {} +HOOK_REGISTRY = plugin.HookRegistry() +@HOOK_REGISTRY.add_hook class TransactionCounter: HOOK_GROUPS = frozenset(['Transaction', 'counter']) @@ -29,6 +31,7 @@ class TransactionCounter: return ['txn:{}'.format(id(txn))] +@HOOK_REGISTRY.add_hook class PostingCounter(TransactionCounter): HOOK_GROUPS = frozenset(['Posting', 'counter']) @@ -44,8 +47,6 @@ def map_errors(errors): return retval def test_with_multiple_hooks(): - txn_counter = TransactionCounter() - post_counter = PostingCounter() in_entries = [ testutil.Transaction(postings=[ ('Income:Donations', -25), @@ -56,14 +57,13 @@ def test_with_multiple_hooks(): ('Liabilites:CreditCard', -10), ]), ] - out_entries, errors = plugin.run(in_entries, CONFIG_MAP, [txn_counter, post_counter]) + out_entries, errors = plugin.run(in_entries, CONFIG_MAP, '', HOOK_REGISTRY) assert len(out_entries) == 2 errmap = map_errors(errors) assert len(errmap.get('txn', '')) == 2 assert len(errmap.get('post', '')) == 4 def test_with_posting_hooks_only(): - post_counter = PostingCounter() in_entries = [ testutil.Transaction(postings=[ ('Income:Donations', -25), @@ -74,7 +74,7 @@ def test_with_posting_hooks_only(): ('Liabilites:CreditCard', -10), ]), ] - out_entries, errors = plugin.run(in_entries, CONFIG_MAP, [post_counter]) + out_entries, errors = plugin.run(in_entries, CONFIG_MAP, 'Posting', HOOK_REGISTRY) assert len(out_entries) == 2 errmap = map_errors(errors) assert len(errmap.get('txn', '')) == 0