diff --git a/conservancy_beancount/_typing.py b/conservancy_beancount/_typing.py index fe22bf8b906bc7d49fda033eaaceeea03e75ab02..6bae7cc336833f7d3516b471f4e987d987aa7a8e 100644 --- a/conservancy_beancount/_typing.py +++ b/conservancy_beancount/_typing.py @@ -18,9 +18,12 @@ import abc import datetime import beancount.core.data as bc_data +from .plugin import errors from typing import ( Any, + FrozenSet, + Iterable, List, NamedTuple, Optional, @@ -30,7 +33,8 @@ from typing import ( ) Account = bc_data.Account -HookName = str +Error = errors._BaseError +ErrorIter = Iterable[Error] MetaKey = str MetaValue = Any MetaValueEnum = str @@ -56,3 +60,8 @@ class Transaction(Directive): tags: Set links: Set postings: List[Posting] + + +ALL_DIRECTIVES: FrozenSet[Type[Directive]] = frozenset([ + Transaction, +]) diff --git a/conservancy_beancount/plugin/__init__.py b/conservancy_beancount/plugin/__init__.py index cbcb8809baf55950de55b46b422a37523c96628d..b4370f973cd976a731aa1e41061dd8dccdf3cd21 100644 --- a/conservancy_beancount/plugin/__init__.py +++ b/conservancy_beancount/plugin/__init__.py @@ -18,31 +18,40 @@ import importlib import beancount.core.data as bc_data +from typing import ( + AbstractSet, + Any, + Dict, + List, + Mapping, + Set, + Tuple, + Type, +) +from .._typing import ( + ALL_DIRECTIVES, + Directive, + Error, +) +from .core import ( + Hook, + HookName, +) + __plugins__ = ['run'] class HookRegistry: - DIRECTIVES = frozenset([ - *(cls.__name__ for cls in bc_data.ALL_DIRECTIVES), - 'Posting', - ]) - - def __init__(self): - self.group_hooks_map = {key: set() for key in self.DIRECTIVES} + def __init__(self) -> None: + self.group_name_map: Dict[HookName, Set[Type[Hook]]] = { + t.__name__: set() for t in ALL_DIRECTIVES + } + self.group_name_map['all'] = set() - def add_hook(self, hook_cls): - hook_groups = list(hook_cls.HOOK_GROUPS) - assert self.DIRECTIVES.intersection(hook_groups) - hook_groups.append('all') - for name_attr in ['HOOK_NAME', 'METADATA_KEY', '__name__']: - try: - hook_name = getattr(hook_cls, name_attr) - except AttributeError: - pass - else: - hook_groups.append(hook_name) - break - for key in hook_groups: - self.group_hooks_map.setdefault(key, set()).add(hook_cls) + def add_hook(self, hook_cls: Type[Hook]) -> Type[Hook]: + self.group_name_map['all'].add(hook_cls) + self.group_name_map[hook_cls.DIRECTIVE.__name__].add(hook_cls) + for key in hook_cls.HOOK_GROUPS: + self.group_name_map.setdefault(key, set()).add(hook_cls) return hook_cls # to allow use as a decorator def import_hooks(self, mod_name, *hook_names, package=__module__): @@ -50,13 +59,13 @@ class HookRegistry: for hook_name in hook_names: self.add_hook(getattr(module, hook_name)) - def group_by_directive(self, config_str=''): + def group_by_directive(self, config_str: str='') -> Mapping[HookName, List[Hook]]: config_str = config_str.strip() if not config_str: config_str = 'all' elif config_str.startswith('-'): config_str = 'all ' + config_str - available_hooks = set() + available_hooks: Set[Type[Hook]] = set() for token in config_str.split(): if token.startswith('-'): update_available = available_hooks.difference_update @@ -65,29 +74,32 @@ class HookRegistry: update_available = available_hooks.update key = token try: - update_set = self.group_hooks_map[key] + update_set = self.group_name_map[key] except KeyError: raise ValueError("configuration refers to unknown hooks {!r}".format(key)) from None else: update_available(update_set) - return {key: [hook() for hook in self.group_hooks_map[key] & available_hooks] - for key in self.DIRECTIVES} + return { + t.__name__: [hook() for hook in self.group_name_map[t.__name__] & available_hooks] + for t in ALL_DIRECTIVES + } HOOK_REGISTRY = HookRegistry() HOOK_REGISTRY.import_hooks('.meta_expense_allocation', 'MetaExpenseAllocation') HOOK_REGISTRY.import_hooks('.meta_tax_implication', 'MetaTaxImplication') -def run(entries, options_map, config='', hook_registry=HOOK_REGISTRY): - errors = [] +def run( + entries: List[Directive], + options_map: Dict[str, Any], + config: str='', + hook_registry: HookRegistry=HOOK_REGISTRY, +) -> Tuple[List[Directive], List[Error]]: + errors: List[Error] = [] hooks = hook_registry.group_by_directive(config) for entry in entries: entry_type = type(entry).__name__ for hook in hooks[entry_type]: errors.extend(hook.run(entry)) - if entry_type == 'Transaction': - for index, post in enumerate(entry.postings): - for hook in hooks['Posting']: - errors.extend(hook.run(entry, post, index)) return entries, errors diff --git a/conservancy_beancount/plugin/core.py b/conservancy_beancount/plugin/core.py index 98b70f451d59e77e749f33db79e6857d83f21472..41b8dff4e2faff30c1084c0da5c2620dc27f2185 100644 --- a/conservancy_beancount/plugin/core.py +++ b/conservancy_beancount/plugin/core.py @@ -14,36 +14,37 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +import abc import datetime import re from . import errors as errormod from typing import ( - AbstractSet, - Any, - ClassVar, + FrozenSet, Generic, Iterable, Iterator, - List, Mapping, Optional, - Tuple, TypeVar, - Union, ) from .._typing import ( Account, - HookName, + Directive, + Error, + ErrorIter, LessComparable, MetaKey, MetaValue, MetaValueEnum, Posting, Transaction, + Type, ) +### CONSTANTS + # I expect these will become configurable in the future, which is why I'm # keeping them outside of a class, but for now constants will do. DEFAULT_START_DATE: datetime.date = datetime.date(2020, 3, 1) @@ -51,8 +52,27 @@ DEFAULT_START_DATE: datetime.date = datetime.date(2020, 3, 1) # dates past the far end of the range. DEFAULT_STOP_DATE: datetime.date = datetime.date(datetime.MAXYEAR, 1, 1) -CT = TypeVar('CT', bound=LessComparable) +### TYPE DEFINITIONS + +HookName = str + +Entry = TypeVar('Entry', bound=Directive) +class Hook(Generic[Entry], metaclass=abc.ABCMeta): + DIRECTIVE: Type[Directive] + HOOK_GROUPS: FrozenSet[HookName] = frozenset() + + @abc.abstractmethod + def run(self, entry: Entry) -> ErrorIter: ... + + def __init_subclass__(cls): + cls.DIRECTIVE = cls.__orig_bases__[0].__args__[0] + + +TransactionHook = Hook[Transaction] +### HELPER CLASSES + +CT = TypeVar('CT', bound=LessComparable) class _GenericRange(Generic[CT]): """Convenience class to check whether a value is within a range. @@ -143,24 +163,14 @@ class MetadataEnum: return self[default_key] -class PostingChecker: - """Base class to normalize posting metadata from an enum.""" - # This class provides basic functionality to filter postings, normalize - # metadata values, and set default values. - # Subclasses should set: - # * METADATA_KEY: A string with the name of the metadata key to normalize. - # * ACCOUNTS: Only check postings that match these account names. - # Can be a tuple of account prefix strings, or a regexp. - # * VALUES_ENUM: A MetadataEnum with allowed values and aliases. - # Subclasses may wish to override _default_value and _should_check. - # See below. - - METADATA_KEY: ClassVar[MetaKey] - VALUES_ENUM: MetadataEnum - HOOK_GROUPS: AbstractSet[HookName] = frozenset(['Posting', 'metadata']) - ACCOUNTS: Union[str, Tuple[Account, ...]] = ('',) +### HOOK SUBCLASSES + +class _PostingHook(TransactionHook, metaclass=abc.ABCMeta): TXN_DATE_RANGE: _GenericRange = _GenericRange(DEFAULT_START_DATE, DEFAULT_STOP_DATE) + def __init_subclass__(cls) -> None: + cls.HOOK_GROUPS = cls.HOOK_GROUPS.union(['posting']) + def _meta_get(self, txn: Transaction, post: Posting, @@ -184,6 +194,34 @@ class PostingChecker: else: post.meta[key] = value + def _run_on_txn(self, txn: Transaction) -> bool: + return txn.date in self.TXN_DATE_RANGE + + def _run_on_post(self, txn: Transaction, post: Posting) -> bool: + return True + + def run(self, txn: Transaction) -> ErrorIter: + if self._run_on_txn(txn): + for index, post in enumerate(txn.postings): + if self._run_on_post(txn, post): + yield from self.post_run(txn, post, index) + + @abc.abstractmethod + def post_run(self, txn: Transaction, post: Posting, post_index: int) -> ErrorIter: ... + + +class _NormalizePostingMetadataHook(_PostingHook): + """Base class to normalize posting metadata from an enum.""" + # This class provides basic functionality to filter postings, normalize + # metadata values, and set default values. + METADATA_KEY: MetaKey + VALUES_ENUM: MetadataEnum + + def __init_subclass__(cls) -> None: + super().__init_subclass__() + cls.METADATA_KEY = cls.VALUES_ENUM.key + cls.HOOK_GROUPS = cls.HOOK_GROUPS.union(['metadata', cls.METADATA_KEY]) + # If the posting does not specify METADATA_KEY, the hook calls # _default_value to get a default. This method should either return # a value string from METADATA_ENUM, or else raise InvalidMetadataError. @@ -191,35 +229,23 @@ class PostingChecker: def _default_value(self, txn: Transaction, post: Posting) -> MetaValueEnum: raise errormod.InvalidMetadataError(txn, post, self.METADATA_KEY) - # The hook calls _should_check on every posting and only checks postings - # when the method returns true. This base method checks the transaction - # date is in TXN_DATE_RANGE, and the posting account name matches ACCOUNTS. - def _should_check(self, txn: Transaction, post: Posting) -> bool: - ok = txn.date in self.TXN_DATE_RANGE - if isinstance(self.ACCOUNTS, tuple): - ok = ok and post.account.startswith(self.ACCOUNTS) - else: - ok = ok and bool(re.search(self.ACCOUNTS, post.account)) - return ok - - def run(self, txn: Transaction, post: Posting, post_index: int) -> Iterable[errormod._BaseError]: - errors: List[errormod._BaseError] = [] - if not self._should_check(txn, post): - return errors + def post_run(self, txn: Transaction, post: Posting, post_index: int) -> ErrorIter: source_value = self._meta_get(txn, post, self.METADATA_KEY) set_value = source_value + error: Optional[Error] = None if source_value is None: try: set_value = self._default_value(txn, post) - except errormod._BaseError as error: - errors.append(error) + except errormod._BaseError as error_: + error = error_ else: try: set_value = self.VALUES_ENUM[source_value] except KeyError: - errors.append(errormod.InvalidMetadataError( + error = errormod.InvalidMetadataError( txn, post, self.METADATA_KEY, source_value, - )) - if not errors: + ) + if error is None: self._meta_set(txn, post, post_index, self.METADATA_KEY, set_value) - return errors + else: + yield error diff --git a/conservancy_beancount/plugin/meta_expense_allocation.py b/conservancy_beancount/plugin/meta_expense_allocation.py index 1f09cce60e631d7ca0a13b015b20149427eaee8e..74ede4de04f706003062cf393e6fecd416581377 100644 --- a/conservancy_beancount/plugin/meta_expense_allocation.py +++ b/conservancy_beancount/plugin/meta_expense_allocation.py @@ -15,11 +15,14 @@ # along with this program. If not, see . from . import core +from .._typing import ( + MetaValueEnum, + Posting, + Transaction, +) -class MetaExpenseAllocation(core.PostingChecker): - ACCOUNTS = ('Expenses:',) - METADATA_KEY = 'expense-allocation' - VALUES_ENUM = core.MetadataEnum(METADATA_KEY, { +class MetaExpenseAllocation(core._NormalizePostingMetadataHook): + VALUES_ENUM = core.MetadataEnum('expense-allocation', { 'administration', 'fundraising', 'program', @@ -32,5 +35,8 @@ class MetaExpenseAllocation(core.PostingChecker): 'Expenses:Services:Fundraising': VALUES_ENUM['fundraising'], } - def _default_value(self, txn, post): + def _run_on_post(self, txn: Transaction, post: Posting) -> bool: + return post.account.startswith('Expenses:') + + def _default_value(self, txn: Transaction, post: Posting) -> MetaValueEnum: return self.DEFAULT_VALUES.get(post.account, 'program') diff --git a/conservancy_beancount/plugin/meta_tax_implication.py b/conservancy_beancount/plugin/meta_tax_implication.py index c683e2db7903ee817e018bf465c83340da1fddf9..044eecae01668001d3d5517e34edf28d2b3e14ed 100644 --- a/conservancy_beancount/plugin/meta_tax_implication.py +++ b/conservancy_beancount/plugin/meta_tax_implication.py @@ -17,13 +17,15 @@ import decimal from . import core +from .._typing import ( + Posting, + Transaction, +) DEFAULT_STOP_AMOUNT = decimal.Decimal(0) -class MetaTaxImplication(core.PostingChecker): - ACCOUNTS = ('Assets:',) - METADATA_KEY = 'tax-implication' - VALUES_ENUM = core.MetadataEnum(METADATA_KEY, [ +class MetaTaxImplication(core._NormalizePostingMetadataHook): + VALUES_ENUM = core.MetadataEnum('tax-implication', [ '1099', 'Accountant-Advises-No-1099', 'Bank-Transfer', @@ -43,8 +45,9 @@ class MetaTaxImplication(core.PostingChecker): 'W2', ], {}) - def _should_check(self, txn, post): - return ( - super()._should_check(txn, post) + def _run_on_post(self, txn: Transaction, post: Posting) -> bool: + return bool( + post.account.startswith('Assets:') + and post.units.number and post.units.number < DEFAULT_STOP_AMOUNT ) diff --git a/tests/test_meta_expense_allocation.py b/tests/test_meta_expense_allocation.py index f1e1a9dd4dd094d067a28f5b920271205ccf02d1..5995cdbef8fe312ee68378b9302f363779d304d8 100644 --- a/tests/test_meta_expense_allocation.py +++ b/tests/test_meta_expense_allocation.py @@ -44,7 +44,7 @@ def test_valid_values_on_postings(src_value, set_value): ('Expenses:General', 25, {TEST_KEY: src_value}), ]) checker = meta_expense_allocation.MetaExpenseAllocation() - errors = checker.run(txn, txn.postings[-1], -1) + errors = list(checker.run(txn)) assert not errors assert txn.postings[-1].meta.get(TEST_KEY) == set_value @@ -55,7 +55,7 @@ def test_invalid_values_on_postings(src_value): ('Expenses:General', 25, {TEST_KEY: src_value}), ]) checker = meta_expense_allocation.MetaExpenseAllocation() - errors = checker.run(txn, txn.postings[-1], -1) + errors = list(checker.run(txn)) assert errors @pytest.mark.parametrize('src_value,set_value', VALID_VALUES.items()) @@ -65,7 +65,7 @@ def test_valid_values_on_transactions(src_value, set_value): ('Expenses:General', 25), ]) checker = meta_expense_allocation.MetaExpenseAllocation() - errors = checker.run(txn, txn.postings[-1], -1) + errors = list(checker.run(txn)) assert not errors assert txn.postings[-1].meta.get(TEST_KEY) == set_value @@ -76,7 +76,7 @@ def test_invalid_values_on_transactions(src_value): ('Expenses:General', 25), ]) checker = meta_expense_allocation.MetaExpenseAllocation() - errors = checker.run(txn, txn.postings[-1], -1) + errors = list(checker.run(txn)) assert errors @pytest.mark.parametrize('account', [ @@ -92,7 +92,7 @@ def test_non_expense_accounts_skipped(account): ('Expenses:General', 25, {TEST_KEY: 'program'}), ]) checker = meta_expense_allocation.MetaExpenseAllocation() - errors = checker.run(txn, txn.postings[0], 0) + errors = list(checker.run(txn)) assert not errors @pytest.mark.parametrize('account,set_value', [ @@ -108,7 +108,7 @@ def test_default_values(account, set_value): (account, 25), ]) checker = meta_expense_allocation.MetaExpenseAllocation() - errors = checker.run(txn, txn.postings[-1], -1) + errors = list(checker.run(txn)) assert not errors assert txn.postings[-1].meta[TEST_KEY] == set_value @@ -125,7 +125,7 @@ def test_default_value_set_in_date_range(date, set_value): ('Expenses:General', 25), ]) checker = meta_expense_allocation.MetaExpenseAllocation() - errors = checker.run(txn, txn.postings[-1], -1) + errors = list(checker.run(txn)) assert not errors got_value = (txn.postings[-1].meta or {}).get(TEST_KEY) assert bool(got_value) == bool(set_value) diff --git a/tests/test_meta_tax_implication.py b/tests/test_meta_tax_implication.py index 94aa70a91a8dfbb8bc4f24d198c7a2189e02d716..522b71ac56d32d2bbda3911dab5806d5ff668dc6 100644 --- a/tests/test_meta_tax_implication.py +++ b/tests/test_meta_tax_implication.py @@ -56,7 +56,7 @@ def test_valid_values_on_postings(src_value, set_value): ('Assets:Cash', -25, {TEST_KEY: src_value}), ]) checker = meta_tax_implication.MetaTaxImplication() - errors = checker.run(txn, txn.postings[-1], -1) + errors = list(checker.run(txn)) assert not errors assert txn.postings[-1].meta.get(TEST_KEY) == set_value @@ -67,7 +67,7 @@ def test_invalid_values_on_postings(src_value): ('Assets:Cash', -25, {TEST_KEY: src_value}), ]) checker = meta_tax_implication.MetaTaxImplication() - errors = checker.run(txn, txn.postings[-1], -1) + errors = list(checker.run(txn)) assert errors @pytest.mark.parametrize('src_value,set_value', VALID_VALUES.items()) @@ -77,7 +77,7 @@ def test_valid_values_on_transactions(src_value, set_value): ('Assets:Cash', -25), ]) checker = meta_tax_implication.MetaTaxImplication() - errors = checker.run(txn, txn.postings[-1], -1) + errors = list(checker.run(txn)) assert not errors assert txn.postings[-1].meta.get(TEST_KEY) == set_value @@ -88,7 +88,7 @@ def test_invalid_values_on_transactions(src_value): ('Assets:Cash', -25), ]) checker = meta_tax_implication.MetaTaxImplication() - errors = checker.run(txn, txn.postings[-1], -1) + errors = list(checker.run(txn)) assert errors @pytest.mark.parametrize('account', [ @@ -102,7 +102,7 @@ def test_non_asset_accounts_skipped(account): ('Assets:Cash', -25, {TEST_KEY: 'USA-Corporation'}), ]) checker = meta_tax_implication.MetaTaxImplication() - errors = checker.run(txn, txn.postings[0], 0) + errors = list(checker.run(txn)) assert not errors def test_asset_credits_skipped(): @@ -111,7 +111,7 @@ def test_asset_credits_skipped(): ('Assets:Cash', 25), ]) checker = meta_tax_implication.MetaTaxImplication() - errors = checker.run(txn, txn.postings[-1], -1) + errors = list(checker.run(txn)) assert not errors assert not txn.postings[-1].meta @@ -128,5 +128,5 @@ def test_default_value_set_in_date_range(date, need_value): ('Assets:Cash', -25), ]) checker = meta_tax_implication.MetaTaxImplication() - errors = checker.run(txn, txn.postings[-1], -1) + errors = list(checker.run(txn)) assert bool(errors) == bool(need_value) diff --git a/tests/test_plugin_HookRegistry.py b/tests/test_plugin_HookRegistry.py index 6d899873a19af00d09e4a776c3a353f18653ff40..6f357e864d105994a939e51c5e088440683681d0 100644 --- a/tests/test_plugin_HookRegistry.py +++ b/tests/test_plugin_HookRegistry.py @@ -25,28 +25,28 @@ def hook_names(hooks, key): def test_default_registrations(): hooks = plugin.HOOK_REGISTRY.group_by_directive() - post_hook_names = hook_names(hooks, 'Posting') - assert len(post_hook_names) >= 2 - assert 'MetaExpenseAllocation' in post_hook_names - assert 'MetaTaxImplication' in post_hook_names + txn_hook_names = hook_names(hooks, 'Transaction') + assert len(txn_hook_names) >= 2 + assert 'MetaExpenseAllocation' in txn_hook_names + assert 'MetaTaxImplication' in txn_hook_names def test_exclude_single(): hooks = plugin.HOOK_REGISTRY.group_by_directive('-expense-allocation') - post_hook_names = hook_names(hooks, 'Posting') - assert post_hook_names - assert 'MetaExpenseAllocation' not in post_hook_names + txn_hook_names = hook_names(hooks, 'Transaction') + assert txn_hook_names + assert 'MetaExpenseAllocation' not in txn_hook_names def test_exclude_group_then_include_single(): hooks = plugin.HOOK_REGISTRY.group_by_directive('-metadata expense-allocation') - post_hook_names = hook_names(hooks, 'Posting') - assert 'MetaExpenseAllocation' in post_hook_names - assert 'MetaTaxImplication' not in post_hook_names + txn_hook_names = hook_names(hooks, 'Transaction') + assert 'MetaExpenseAllocation' in txn_hook_names + assert 'MetaTaxImplication' not in txn_hook_names def test_include_group_then_exclude_single(): hooks = plugin.HOOK_REGISTRY.group_by_directive('metadata -tax-implication') - post_hook_names = hook_names(hooks, 'Posting') - assert 'MetaExpenseAllocation' in post_hook_names - assert 'MetaTaxImplication' not in post_hook_names + txn_hook_names = hook_names(hooks, 'Transaction') + assert 'MetaExpenseAllocation' in txn_hook_names + assert 'MetaTaxImplication' not in txn_hook_names def test_unknown_group_name(): with pytest.raises(ValueError): diff --git a/tests/test_plugin_run.py b/tests/test_plugin_run.py index 42838298fcb2ad28e7b71f86b10ebefa59e62cf3..f5f910d5d5c9a9bb6b7591019857eb29fde35fe6 100644 --- a/tests/test_plugin_run.py +++ b/tests/test_plugin_run.py @@ -18,14 +18,15 @@ import pytest from . import testutil -from conservancy_beancount import plugin +from conservancy_beancount import plugin, _typing CONFIG_MAP = {} HOOK_REGISTRY = plugin.HookRegistry() @HOOK_REGISTRY.add_hook class TransactionCounter: - HOOK_GROUPS = frozenset(['Transaction', 'counter']) + DIRECTIVE = _typing.Transaction + HOOK_GROUPS = frozenset() def run(self, txn): return ['txn:{}'.format(id(txn))] @@ -33,10 +34,11 @@ class TransactionCounter: @HOOK_REGISTRY.add_hook class PostingCounter(TransactionCounter): - HOOK_GROUPS = frozenset(['Posting', 'counter']) + DIRECTIVE = _typing.Transaction + HOOK_GROUPS = frozenset(['posting']) - def run(self, txn, post, post_index): - return ['post:{}'.format(id(post))] + def run(self, txn): + return ['post:{}'.format(id(post)) for post in txn.postings] def map_errors(errors): @@ -74,7 +76,7 @@ def test_with_posting_hooks_only(): ('Liabilites:CreditCard', -10), ]), ] - out_entries, errors = plugin.run(in_entries, CONFIG_MAP, 'Posting', HOOK_REGISTRY) + out_entries, errors = plugin.run(in_entries, CONFIG_MAP, 'posting', HOOK_REGISTRY) assert len(out_entries) == 2 errmap = map_errors(errors) assert len(errmap.get('txn', '')) == 0