diff --git a/conservancy_beancount/plugin/__init__.py b/conservancy_beancount/plugin/__init__.py
index 4cbf473630d537dba26e872c6c9ebdc75b97282f..3c8842332ca052bb9db5a8ed3be048a13ab55556 100644
--- a/conservancy_beancount/plugin/__init__.py
+++ b/conservancy_beancount/plugin/__init__.py
@@ -14,6 +14,8 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
+import importlib
+
import beancount.core.data as bc_data
__plugins__ = ['run']
@@ -24,18 +26,61 @@ class HookRegistry:
'Posting',
])
- @classmethod
- def group_by_directive(cls, hooks_seq):
- hooks_map = {key: [] for key in cls.DIRECTIVES}
- for hook in hooks_seq:
- for key in cls.DIRECTIVES & hook.HOOK_GROUPS:
- hooks_map[key].append(hook)
- return hooks_map
+ def __init__(self):
+ self.group_hooks_map = {key: set() for key in self.DIRECTIVES}
+
+ def add_hook(self, hook_cls):
+ hook_groups = list(hook_cls.HOOK_GROUPS)
+ assert self.DIRECTIVES.intersection(hook_groups)
+ hook_groups.append('all')
+ for name_attr in ['HOOK_NAME', 'METADATA_KEY', '__name__']:
+ try:
+ hook_name = getattr(hook_cls, name_attr)
+ except AttributeError:
+ pass
+ else:
+ hook_groups.append(hook_name)
+ break
+ for key in hook_groups:
+ self.group_hooks_map.setdefault(key, set()).add(hook_cls)
+ return hook_cls # to allow use as a decorator
+
+ def import_hooks(self, mod_name, *hook_names, package=__module__):
+ module = importlib.import_module(mod_name, package)
+ for hook_name in hook_names:
+ self.add_hook(getattr(module, hook_name))
+
+ def group_by_directive(self, config_str=''):
+ config_str = config_str.strip()
+ if not config_str:
+ config_str = 'all'
+ elif config_str.startswith('-'):
+ config_str = 'all ' + config_str
+ available_hooks = set()
+ for token in config_str.split():
+ if token.startswith('-'):
+ update_available = available_hooks.difference_update
+ key = token[1:]
+ else:
+ update_available = available_hooks.update
+ key = token
+ try:
+ update_set = self.group_hooks_map[key]
+ except KeyError:
+ raise ValueError("configuration refers to unknown hooks {!r}".format(key)) from None
+ else:
+ update_available(update_set)
+ return {key: [hook() for hook in self.group_hooks_map[key] & available_hooks]
+ for key in self.DIRECTIVES}
+
+HOOK_REGISTRY = HookRegistry()
+HOOK_REGISTRY.import_hooks('.meta_expense_allocation', 'MetaExpenseAllocation')
+HOOK_REGISTRY.import_hooks('.meta_tax_implication', 'MetaTaxImplication')
-def run(entries, options_map, config):
+def run(entries, options_map, config='', hook_registry=HOOK_REGISTRY):
errors = []
- hooks = HookRegistry.group_by_directive(config)
+ hooks = hook_registry.group_by_directive(config)
for entry in entries:
entry_type = type(entry).__name__
for hook in hooks[entry_type]:
diff --git a/conservancy_beancount/plugin/core.py b/conservancy_beancount/plugin/core.py
index 4895c050711ee024480b1f9185f2296683aafc10..f7ecce75e269bb1e66e18e9fd5e0691467e8daf4 100644
--- a/conservancy_beancount/plugin/core.py
+++ b/conservancy_beancount/plugin/core.py
@@ -67,6 +67,7 @@ class MetadataEnum:
class PostingChecker:
+ HOOK_GROUPS = frozenset(['Posting', 'metadata'])
ACCOUNTS = ('',)
TXN_DATE_RANGE = _GenericRange(DEFAULT_START_DATE, DEFAULT_STOP_DATE)
VALUES_ENUM = {}
diff --git a/tests/test_plugin_HookRegistry.py b/tests/test_plugin_HookRegistry.py
new file mode 100644
index 0000000000000000000000000000000000000000..b7fd0b9521b6c461d0726c840e5203f0d2378bd3
--- /dev/null
+++ b/tests/test_plugin_HookRegistry.py
@@ -0,0 +1,53 @@
+"""Test main plugin's HookRegistry"""
+# Copyright © 2020 Brett Smith
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+import pytest
+
+from . import testutil
+
+from conservancy_beancount import plugin
+
+def hook_names(hooks, key):
+ return {type(hook).__name__ for hook in hooks[key]}
+
+def test_default_registrations():
+ hooks = plugin.HOOK_REGISTRY.group_by_directive()
+ post_hook_names = hook_names(hooks, 'Posting')
+ assert len(post_hook_names) >= 2
+ assert 'MetaExpenseAllocation' in post_hook_names
+ assert 'MetaTaxImplication' in post_hook_names
+
+def test_exclude_single():
+ hooks = plugin.HOOK_REGISTRY.group_by_directive('-expenseAllocation')
+ post_hook_names = hook_names(hooks, 'Posting')
+ assert post_hook_names
+ assert 'MetaExpenseAllocation' not in post_hook_names
+
+def test_exclude_group_then_include_single():
+ hooks = plugin.HOOK_REGISTRY.group_by_directive('-metadata expenseAllocation')
+ post_hook_names = hook_names(hooks, 'Posting')
+ assert 'MetaExpenseAllocation' in post_hook_names
+ assert 'MetaTaxImplication' not in post_hook_names
+
+def test_include_group_then_exclude_single():
+ hooks = plugin.HOOK_REGISTRY.group_by_directive('metadata -taxImplication')
+ post_hook_names = hook_names(hooks, 'Posting')
+ assert 'MetaExpenseAllocation' in post_hook_names
+ assert 'MetaTaxImplication' not in post_hook_names
+
+def test_unknown_group_name():
+ with pytest.raises(ValueError):
+ plugin.HOOK_REGISTRY.group_by_directive('UnKnownTestGroup')
diff --git a/tests/test_plugin_run.py b/tests/test_plugin_run.py
index 957e280a653e5ddf745c72f76b8e5539c6588338..521193490bd94db6528af9c7f8bfdf1b8bee6d08 100644
--- a/tests/test_plugin_run.py
+++ b/tests/test_plugin_run.py
@@ -21,7 +21,9 @@ from . import testutil
from conservancy_beancount import plugin
CONFIG_MAP = {}
+HOOK_REGISTRY = plugin.HookRegistry()
+@HOOK_REGISTRY.add_hook
class TransactionCounter:
HOOK_GROUPS = frozenset(['Transaction', 'counter'])
@@ -29,6 +31,7 @@ class TransactionCounter:
return ['txn:{}'.format(id(txn))]
+@HOOK_REGISTRY.add_hook
class PostingCounter(TransactionCounter):
HOOK_GROUPS = frozenset(['Posting', 'counter'])
@@ -44,8 +47,6 @@ def map_errors(errors):
return retval
def test_with_multiple_hooks():
- txn_counter = TransactionCounter()
- post_counter = PostingCounter()
in_entries = [
testutil.Transaction(postings=[
('Income:Donations', -25),
@@ -56,14 +57,13 @@ def test_with_multiple_hooks():
('Liabilites:CreditCard', -10),
]),
]
- out_entries, errors = plugin.run(in_entries, CONFIG_MAP, [txn_counter, post_counter])
+ out_entries, errors = plugin.run(in_entries, CONFIG_MAP, '', HOOK_REGISTRY)
assert len(out_entries) == 2
errmap = map_errors(errors)
assert len(errmap.get('txn', '')) == 2
assert len(errmap.get('post', '')) == 4
def test_with_posting_hooks_only():
- post_counter = PostingCounter()
in_entries = [
testutil.Transaction(postings=[
('Income:Donations', -25),
@@ -74,7 +74,7 @@ def test_with_posting_hooks_only():
('Liabilites:CreditCard', -10),
]),
]
- out_entries, errors = plugin.run(in_entries, CONFIG_MAP, [post_counter])
+ out_entries, errors = plugin.run(in_entries, CONFIG_MAP, 'Posting', HOOK_REGISTRY)
assert len(out_entries) == 2
errmap = map_errors(errors)
assert len(errmap.get('txn', '')) == 0