diff --git a/tests/test_plugin_run.py b/tests/test_plugin_run.py index 42838298fcb2ad28e7b71f86b10ebefa59e62cf3..f5f910d5d5c9a9bb6b7591019857eb29fde35fe6 100644 --- a/tests/test_plugin_run.py +++ b/tests/test_plugin_run.py @@ -18,14 +18,15 @@ import pytest from . import testutil -from conservancy_beancount import plugin +from conservancy_beancount import plugin, _typing CONFIG_MAP = {} HOOK_REGISTRY = plugin.HookRegistry() @HOOK_REGISTRY.add_hook class TransactionCounter: - HOOK_GROUPS = frozenset(['Transaction', 'counter']) + DIRECTIVE = _typing.Transaction + HOOK_GROUPS = frozenset() def run(self, txn): return ['txn:{}'.format(id(txn))] @@ -33,10 +34,11 @@ class TransactionCounter: @HOOK_REGISTRY.add_hook class PostingCounter(TransactionCounter): - HOOK_GROUPS = frozenset(['Posting', 'counter']) + DIRECTIVE = _typing.Transaction + HOOK_GROUPS = frozenset(['posting']) - def run(self, txn, post, post_index): - return ['post:{}'.format(id(post))] + def run(self, txn): + return ['post:{}'.format(id(post)) for post in txn.postings] def map_errors(errors): @@ -74,7 +76,7 @@ def test_with_posting_hooks_only(): ('Liabilites:CreditCard', -10), ]), ] - out_entries, errors = plugin.run(in_entries, CONFIG_MAP, 'Posting', HOOK_REGISTRY) + out_entries, errors = plugin.run(in_entries, CONFIG_MAP, 'posting', HOOK_REGISTRY) assert len(out_entries) == 2 errmap = map_errors(errors) assert len(errmap.get('txn', '')) == 0