Changeset - a41feb94b3e0
[Not reviewed]
0 9 0
Brett Smith - 4 years ago 2020-03-15 19:50:14
brettcsmith@brettcsmith.org
plugin: Transform posting hooks into transaction hooks.

I feel like posting hooks a case of premature optimization in early
development. This approach reduces the number of special cases in
the code and allows us to more strongly reason about hooks in the
type system.
9 files changed with 181 insertions and 123 deletions:
0 comments (0 inline, 0 general)
conservancy_beancount/_typing.py
Show inline comments
...
 
@@ -20,2 +20,3 @@ import datetime
 
import beancount.core.data as bc_data
 
from .plugin import errors
 

	
...
 
@@ -23,2 +24,4 @@ from typing import (
 
    Any,
 
    FrozenSet,
 
    Iterable,
 
    List,
...
 
@@ -32,3 +35,4 @@ from typing import (
 
Account = bc_data.Account
 
HookName = str
 
Error = errors._BaseError
 
ErrorIter = Iterable[Error]
 
MetaKey = str
...
 
@@ -58 +62,6 @@ class Transaction(Directive):
 
    postings: List[Posting]
 

	
 

	
 
ALL_DIRECTIVES: FrozenSet[Type[Directive]] = frozenset([
 
    Transaction,
 
])
conservancy_beancount/plugin/__init__.py
Show inline comments
...
 
@@ -20,2 +20,22 @@ import beancount.core.data as bc_data
 

	
 
from typing import (
 
    AbstractSet,
 
    Any,
 
    Dict,
 
    List,
 
    Mapping,
 
    Set,
 
    Tuple,
 
    Type,
 
)
 
from .._typing import (
 
    ALL_DIRECTIVES,
 
    Directive,
 
    Error,
 
)
 
from .core import (
 
    Hook,
 
    HookName,
 
)
 

	
 
__plugins__ = ['run']
...
 
@@ -23,24 +43,13 @@ __plugins__ = ['run']
 
class HookRegistry:
 
    DIRECTIVES = frozenset([
 
        *(cls.__name__ for cls in bc_data.ALL_DIRECTIVES),
 
        'Posting',
 
    ])
 

	
 
    def __init__(self):
 
        self.group_hooks_map = {key: set() for key in self.DIRECTIVES}
 
    def __init__(self) -> None:
 
        self.group_name_map: Dict[HookName, Set[Type[Hook]]] = {
 
            t.__name__: set() for t in ALL_DIRECTIVES
 
        }
 
        self.group_name_map['all'] = set()
 

	
 
    def add_hook(self, hook_cls):
 
        hook_groups = list(hook_cls.HOOK_GROUPS)
 
        assert self.DIRECTIVES.intersection(hook_groups)
 
        hook_groups.append('all')
 
        for name_attr in ['HOOK_NAME', 'METADATA_KEY', '__name__']:
 
            try:
 
                hook_name = getattr(hook_cls, name_attr)
 
            except AttributeError:
 
                pass
 
            else:
 
                hook_groups.append(hook_name)
 
                break
 
        for key in hook_groups:
 
            self.group_hooks_map.setdefault(key, set()).add(hook_cls)
 
    def add_hook(self, hook_cls: Type[Hook]) -> Type[Hook]:
 
        self.group_name_map['all'].add(hook_cls)
 
        self.group_name_map[hook_cls.DIRECTIVE.__name__].add(hook_cls)
 
        for key in hook_cls.HOOK_GROUPS:
 
            self.group_name_map.setdefault(key, set()).add(hook_cls)
 
        return hook_cls  # to allow use as a decorator
...
 
@@ -52,3 +61,3 @@ class HookRegistry:
 

	
 
    def group_by_directive(self, config_str=''):
 
    def group_by_directive(self, config_str: str='') -> Mapping[HookName, List[Hook]]:
 
        config_str = config_str.strip()
...
 
@@ -58,3 +67,3 @@ class HookRegistry:
 
            config_str = 'all ' + config_str
 
        available_hooks = set()
 
        available_hooks: Set[Type[Hook]] = set()
 
        for token in config_str.split():
...
 
@@ -67,3 +76,3 @@ class HookRegistry:
 
            try:
 
                update_set = self.group_hooks_map[key]
 
                update_set = self.group_name_map[key]
 
            except KeyError:
...
 
@@ -72,4 +81,6 @@ class HookRegistry:
 
                update_available(update_set)
 
        return {key: [hook() for hook in self.group_hooks_map[key] & available_hooks]
 
                for key in self.DIRECTIVES}
 
        return {
 
            t.__name__: [hook() for hook in self.group_name_map[t.__name__] & available_hooks]
 
            for t in ALL_DIRECTIVES
 
        }
 

	
...
 
@@ -80,4 +91,9 @@ HOOK_REGISTRY.import_hooks('.meta_tax_implication', 'MetaTaxImplication')
 

	
 
def run(entries, options_map, config='', hook_registry=HOOK_REGISTRY):
 
    errors = []
 
def run(
 
        entries: List[Directive],
 
        options_map: Dict[str, Any],
 
        config: str='',
 
        hook_registry: HookRegistry=HOOK_REGISTRY,
 
) -> Tuple[List[Directive], List[Error]]:
 
    errors: List[Error] = []
 
    hooks = hook_registry.group_by_directive(config)
...
 
@@ -87,6 +103,2 @@ def run(entries, options_map, config='', hook_registry=HOOK_REGISTRY):
 
            errors.extend(hook.run(entry))
 
        if entry_type == 'Transaction':
 
            for index, post in enumerate(entry.postings):
 
                for hook in hooks['Posting']:
 
                    errors.extend(hook.run(entry, post, index))
 
    return entries, errors
conservancy_beancount/plugin/core.py
Show inline comments
...
 
@@ -16,2 +16,3 @@
 

	
 
import abc
 
import datetime
...
 
@@ -22,5 +23,3 @@ from . import errors as errormod
 
from typing import (
 
    AbstractSet,
 
    Any,
 
    ClassVar,
 
    FrozenSet,
 
    Generic,
...
 
@@ -28,8 +27,5 @@ from typing import (
 
    Iterator,
 
    List,
 
    Mapping,
 
    Optional,
 
    Tuple,
 
    TypeVar,
 
    Union,
 
)
...
 
@@ -37,3 +33,5 @@ from .._typing import (
 
    Account,
 
    HookName,
 
    Directive,
 
    Error,
 
    ErrorIter,
 
    LessComparable,
...
 
@@ -44,4 +42,7 @@ from .._typing import (
 
    Transaction,
 
    Type,
 
)
 

	
 
### CONSTANTS
 

	
 
# I expect these will become configurable in the future, which is why I'm
...
 
@@ -53,4 +54,23 @@ DEFAULT_STOP_DATE: datetime.date = datetime.date(datetime.MAXYEAR, 1, 1)
 

	
 
CT = TypeVar('CT', bound=LessComparable)
 
### TYPE DEFINITIONS
 

	
 
HookName = str
 

	
 
Entry = TypeVar('Entry', bound=Directive)
 
class Hook(Generic[Entry], metaclass=abc.ABCMeta):
 
    DIRECTIVE: Type[Directive]
 
    HOOK_GROUPS: FrozenSet[HookName] = frozenset()
 

	
 
    @abc.abstractmethod
 
    def run(self, entry: Entry) -> ErrorIter: ...
 

	
 
    def __init_subclass__(cls):
 
        cls.DIRECTIVE = cls.__orig_bases__[0].__args__[0]
 

	
 

	
 
TransactionHook = Hook[Transaction]
 

	
 
### HELPER CLASSES
 

	
 
CT = TypeVar('CT', bound=LessComparable)
 
class _GenericRange(Generic[CT]):
...
 
@@ -145,20 +165,10 @@ class MetadataEnum:
 

	
 
class PostingChecker:
 
    """Base class to normalize posting metadata from an enum."""
 
    # This class provides basic functionality to filter postings, normalize
 
    # metadata values, and set default values.
 
    # Subclasses should set:
 
    # * METADATA_KEY: A string with the name of the metadata key to normalize.
 
    # * ACCOUNTS: Only check postings that match these account names.
 
    #   Can be a tuple of account prefix strings, or a regexp.
 
    # * VALUES_ENUM: A MetadataEnum with allowed values and aliases.
 
    # Subclasses may wish to override _default_value and _should_check.
 
    # See below.
 

	
 
    METADATA_KEY: ClassVar[MetaKey]
 
    VALUES_ENUM: MetadataEnum
 
    HOOK_GROUPS: AbstractSet[HookName] = frozenset(['Posting', 'metadata'])
 
    ACCOUNTS: Union[str, Tuple[Account, ...]] = ('',)
 
### HOOK SUBCLASSES
 

	
 
class _PostingHook(TransactionHook, metaclass=abc.ABCMeta):
 
    TXN_DATE_RANGE: _GenericRange = _GenericRange(DEFAULT_START_DATE, DEFAULT_STOP_DATE)
 

	
 
    def __init_subclass__(cls) -> None:
 
        cls.HOOK_GROUPS = cls.HOOK_GROUPS.union(['posting'])
 

	
 
    def _meta_get(self,
...
 
@@ -186,2 +196,30 @@ class PostingChecker:
 

	
 
    def _run_on_txn(self, txn: Transaction) -> bool:
 
        return txn.date in self.TXN_DATE_RANGE
 

	
 
    def _run_on_post(self, txn: Transaction, post: Posting) -> bool:
 
        return True
 

	
 
    def run(self, txn: Transaction) -> ErrorIter:
 
        if self._run_on_txn(txn):
 
            for index, post in enumerate(txn.postings):
 
                if self._run_on_post(txn, post):
 
                    yield from self.post_run(txn, post, index)
 

	
 
    @abc.abstractmethod
 
    def post_run(self, txn: Transaction, post: Posting, post_index: int) -> ErrorIter: ...
 

	
 

	
 
class _NormalizePostingMetadataHook(_PostingHook):
 
    """Base class to normalize posting metadata from an enum."""
 
    # This class provides basic functionality to filter postings, normalize
 
    # metadata values, and set default values.
 
    METADATA_KEY: MetaKey
 
    VALUES_ENUM: MetadataEnum
 

	
 
    def __init_subclass__(cls) -> None:
 
        super().__init_subclass__()
 
        cls.METADATA_KEY = cls.VALUES_ENUM.key
 
        cls.HOOK_GROUPS = cls.HOOK_GROUPS.union(['metadata', cls.METADATA_KEY])
 

	
 
    # If the posting does not specify METADATA_KEY, the hook calls
...
 
@@ -193,19 +231,6 @@ class PostingChecker:
 

	
 
    # The hook calls _should_check on every posting and only checks postings
 
    # when the method returns true. This base method checks the transaction
 
    # date is in TXN_DATE_RANGE, and the posting account name matches ACCOUNTS.
 
    def _should_check(self, txn: Transaction, post: Posting) -> bool:
 
        ok = txn.date in self.TXN_DATE_RANGE
 
        if isinstance(self.ACCOUNTS, tuple):
 
            ok = ok and post.account.startswith(self.ACCOUNTS)
 
        else:
 
            ok = ok and bool(re.search(self.ACCOUNTS, post.account))
 
        return ok
 

	
 
    def run(self, txn: Transaction, post: Posting, post_index: int) -> Iterable[errormod._BaseError]:
 
        errors: List[errormod._BaseError] = []
 
        if not self._should_check(txn, post):
 
            return errors
 
    def post_run(self, txn: Transaction, post: Posting, post_index: int) -> ErrorIter:
 
        source_value = self._meta_get(txn, post, self.METADATA_KEY)
 
        set_value = source_value
 
        error: Optional[Error] = None
 
        if source_value is None:
...
 
@@ -213,4 +238,4 @@ class PostingChecker:
 
                set_value = self._default_value(txn, post)
 
            except errormod._BaseError as error:
 
                errors.append(error)
 
            except errormod._BaseError as error_:
 
                error = error_
 
        else:
...
 
@@ -219,7 +244,8 @@ class PostingChecker:
 
            except KeyError:
 
                errors.append(errormod.InvalidMetadataError(
 
                error = errormod.InvalidMetadataError(
 
                    txn, post, self.METADATA_KEY, source_value,
 
                ))
 
        if not errors:
 
                )
 
        if error is None:
 
            self._meta_set(txn, post, post_index, self.METADATA_KEY, set_value)
 
        return errors
 
        else:
 
            yield error
conservancy_beancount/plugin/meta_expense_allocation.py
Show inline comments
...
 
@@ -17,7 +17,10 @@
 
from . import core
 
from .._typing import (
 
    MetaValueEnum,
 
    Posting,
 
    Transaction,
 
)
 

	
 
class MetaExpenseAllocation(core.PostingChecker):
 
    ACCOUNTS = ('Expenses:',)
 
    METADATA_KEY = 'expense-allocation'
 
    VALUES_ENUM = core.MetadataEnum(METADATA_KEY, {
 
class MetaExpenseAllocation(core._NormalizePostingMetadataHook):
 
    VALUES_ENUM = core.MetadataEnum('expense-allocation', {
 
        'administration',
...
 
@@ -34,3 +37,6 @@ class MetaExpenseAllocation(core.PostingChecker):
 

	
 
    def _default_value(self, txn, post):
 
    def _run_on_post(self, txn: Transaction, post: Posting) -> bool:
 
        return post.account.startswith('Expenses:')
 

	
 
    def _default_value(self, txn: Transaction, post: Posting) -> MetaValueEnum:
 
        return self.DEFAULT_VALUES.get(post.account, 'program')
conservancy_beancount/plugin/meta_tax_implication.py
Show inline comments
...
 
@@ -19,2 +19,6 @@ import decimal
 
from . import core
 
from .._typing import (
 
    Posting,
 
    Transaction,
 
)
 

	
...
 
@@ -22,6 +26,4 @@ DEFAULT_STOP_AMOUNT = decimal.Decimal(0)
 

	
 
class MetaTaxImplication(core.PostingChecker):
 
    ACCOUNTS = ('Assets:',)
 
    METADATA_KEY = 'tax-implication'
 
    VALUES_ENUM = core.MetadataEnum(METADATA_KEY, [
 
class MetaTaxImplication(core._NormalizePostingMetadataHook):
 
    VALUES_ENUM = core.MetadataEnum('tax-implication', [
 
        '1099',
...
 
@@ -45,5 +47,6 @@ class MetaTaxImplication(core.PostingChecker):
 

	
 
    def _should_check(self, txn, post):
 
        return (
 
            super()._should_check(txn, post)
 
    def _run_on_post(self, txn: Transaction, post: Posting) -> bool:
 
        return bool(
 
            post.account.startswith('Assets:')
 
            and post.units.number
 
            and post.units.number < DEFAULT_STOP_AMOUNT
tests/test_meta_expense_allocation.py
Show inline comments
...
 
@@ -46,3 +46,3 @@ def test_valid_values_on_postings(src_value, set_value):
 
    checker = meta_expense_allocation.MetaExpenseAllocation()
 
    errors = checker.run(txn, txn.postings[-1], -1)
 
    errors = list(checker.run(txn))
 
    assert not errors
...
 
@@ -57,3 +57,3 @@ def test_invalid_values_on_postings(src_value):
 
    checker = meta_expense_allocation.MetaExpenseAllocation()
 
    errors = checker.run(txn, txn.postings[-1], -1)
 
    errors = list(checker.run(txn))
 
    assert errors
...
 
@@ -67,3 +67,3 @@ def test_valid_values_on_transactions(src_value, set_value):
 
    checker = meta_expense_allocation.MetaExpenseAllocation()
 
    errors = checker.run(txn, txn.postings[-1], -1)
 
    errors = list(checker.run(txn))
 
    assert not errors
...
 
@@ -78,3 +78,3 @@ def test_invalid_values_on_transactions(src_value):
 
    checker = meta_expense_allocation.MetaExpenseAllocation()
 
    errors = checker.run(txn, txn.postings[-1], -1)
 
    errors = list(checker.run(txn))
 
    assert errors
...
 
@@ -94,3 +94,3 @@ def test_non_expense_accounts_skipped(account):
 
    checker = meta_expense_allocation.MetaExpenseAllocation()
 
    errors = checker.run(txn, txn.postings[0], 0)
 
    errors = list(checker.run(txn))
 
    assert not errors
...
 
@@ -110,3 +110,3 @@ def test_default_values(account, set_value):
 
    checker = meta_expense_allocation.MetaExpenseAllocation()
 
    errors = checker.run(txn, txn.postings[-1], -1)
 
    errors = list(checker.run(txn))
 
    assert not errors
...
 
@@ -127,3 +127,3 @@ def test_default_value_set_in_date_range(date, set_value):
 
    checker = meta_expense_allocation.MetaExpenseAllocation()
 
    errors = checker.run(txn, txn.postings[-1], -1)
 
    errors = list(checker.run(txn))
 
    assert not errors
tests/test_meta_tax_implication.py
Show inline comments
...
 
@@ -58,3 +58,3 @@ def test_valid_values_on_postings(src_value, set_value):
 
    checker = meta_tax_implication.MetaTaxImplication()
 
    errors = checker.run(txn, txn.postings[-1], -1)
 
    errors = list(checker.run(txn))
 
    assert not errors
...
 
@@ -69,3 +69,3 @@ def test_invalid_values_on_postings(src_value):
 
    checker = meta_tax_implication.MetaTaxImplication()
 
    errors = checker.run(txn, txn.postings[-1], -1)
 
    errors = list(checker.run(txn))
 
    assert errors
...
 
@@ -79,3 +79,3 @@ def test_valid_values_on_transactions(src_value, set_value):
 
    checker = meta_tax_implication.MetaTaxImplication()
 
    errors = checker.run(txn, txn.postings[-1], -1)
 
    errors = list(checker.run(txn))
 
    assert not errors
...
 
@@ -90,3 +90,3 @@ def test_invalid_values_on_transactions(src_value):
 
    checker = meta_tax_implication.MetaTaxImplication()
 
    errors = checker.run(txn, txn.postings[-1], -1)
 
    errors = list(checker.run(txn))
 
    assert errors
...
 
@@ -104,3 +104,3 @@ def test_non_asset_accounts_skipped(account):
 
    checker = meta_tax_implication.MetaTaxImplication()
 
    errors = checker.run(txn, txn.postings[0], 0)
 
    errors = list(checker.run(txn))
 
    assert not errors
...
 
@@ -113,3 +113,3 @@ def test_asset_credits_skipped():
 
    checker = meta_tax_implication.MetaTaxImplication()
 
    errors = checker.run(txn, txn.postings[-1], -1)
 
    errors = list(checker.run(txn))
 
    assert not errors
...
 
@@ -130,3 +130,3 @@ def test_default_value_set_in_date_range(date, need_value):
 
    checker = meta_tax_implication.MetaTaxImplication()
 
    errors = checker.run(txn, txn.postings[-1], -1)
 
    errors = list(checker.run(txn))
 
    assert bool(errors) == bool(need_value)
tests/test_plugin_HookRegistry.py
Show inline comments
...
 
@@ -27,6 +27,6 @@ def test_default_registrations():
 
    hooks = plugin.HOOK_REGISTRY.group_by_directive()
 
    post_hook_names = hook_names(hooks, 'Posting')
 
    assert len(post_hook_names) >= 2
 
    assert 'MetaExpenseAllocation' in post_hook_names
 
    assert 'MetaTaxImplication' in post_hook_names
 
    txn_hook_names = hook_names(hooks, 'Transaction')
 
    assert len(txn_hook_names) >= 2
 
    assert 'MetaExpenseAllocation' in txn_hook_names
 
    assert 'MetaTaxImplication' in txn_hook_names
 

	
...
 
@@ -34,5 +34,5 @@ def test_exclude_single():
 
    hooks = plugin.HOOK_REGISTRY.group_by_directive('-expense-allocation')
 
    post_hook_names = hook_names(hooks, 'Posting')
 
    assert post_hook_names
 
    assert 'MetaExpenseAllocation' not in post_hook_names
 
    txn_hook_names = hook_names(hooks, 'Transaction')
 
    assert txn_hook_names
 
    assert 'MetaExpenseAllocation' not in txn_hook_names
 

	
...
 
@@ -40,5 +40,5 @@ def test_exclude_group_then_include_single():
 
    hooks = plugin.HOOK_REGISTRY.group_by_directive('-metadata expense-allocation')
 
    post_hook_names = hook_names(hooks, 'Posting')
 
    assert 'MetaExpenseAllocation' in post_hook_names
 
    assert 'MetaTaxImplication' not in post_hook_names
 
    txn_hook_names = hook_names(hooks, 'Transaction')
 
    assert 'MetaExpenseAllocation' in txn_hook_names
 
    assert 'MetaTaxImplication' not in txn_hook_names
 

	
...
 
@@ -46,5 +46,5 @@ def test_include_group_then_exclude_single():
 
    hooks = plugin.HOOK_REGISTRY.group_by_directive('metadata -tax-implication')
 
    post_hook_names = hook_names(hooks, 'Posting')
 
    assert 'MetaExpenseAllocation' in post_hook_names
 
    assert 'MetaTaxImplication' not in post_hook_names
 
    txn_hook_names = hook_names(hooks, 'Transaction')
 
    assert 'MetaExpenseAllocation' in txn_hook_names
 
    assert 'MetaTaxImplication' not in txn_hook_names
 

	
tests/test_plugin_run.py
Show inline comments
...
 
@@ -20,3 +20,3 @@ from . import testutil
 

	
 
from conservancy_beancount import plugin
 
from conservancy_beancount import plugin, _typing
 

	
...
 
@@ -27,3 +27,4 @@ HOOK_REGISTRY = plugin.HookRegistry()
 
class TransactionCounter:
 
    HOOK_GROUPS = frozenset(['Transaction', 'counter'])
 
    DIRECTIVE = _typing.Transaction
 
    HOOK_GROUPS = frozenset()
 

	
...
 
@@ -35,6 +36,7 @@ class TransactionCounter:
 
class PostingCounter(TransactionCounter):
 
    HOOK_GROUPS = frozenset(['Posting', 'counter'])
 
    DIRECTIVE = _typing.Transaction
 
    HOOK_GROUPS = frozenset(['posting'])
 

	
 
    def run(self, txn, post, post_index):
 
        return ['post:{}'.format(id(post))]
 
    def run(self, txn):
 
        return ['post:{}'.format(id(post)) for post in txn.postings]
 

	
...
 
@@ -76,3 +78,3 @@ def test_with_posting_hooks_only():
 
    ]
 
    out_entries, errors = plugin.run(in_entries, CONFIG_MAP, 'Posting', HOOK_REGISTRY)
 
    out_entries, errors = plugin.run(in_entries, CONFIG_MAP, 'posting', HOOK_REGISTRY)
 
    assert len(out_entries) == 2
0 comments (0 inline, 0 general)