diff --git a/conservancy_beancount/reports/rewrite.py b/conservancy_beancount/reports/rewrite.py new file mode 100644 index 0000000000000000000000000000000000000000..40cfda5731d495abb1633d5a9a3a6d292c58cfb4 --- /dev/null +++ b/conservancy_beancount/reports/rewrite.py @@ -0,0 +1,375 @@ +"""rewrite.py - Post rewriting for financial reports""" +# Copyright © 2020 Brett Smith +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +import abc +import datetime +import decimal +import enum +import logging +import operator as opmod +import re + +from typing import ( + Callable, + Dict, + Generic, + IO, + Iterable, + Iterator, + List, + Mapping, + Optional, + Pattern, + Sequence, + Set, + Tuple, + Type, + TypeVar, + Union, +) +from ..beancount_types import ( + Meta, + MetaKey, + MetaValue, +) + +from pathlib import Path + +import yaml + +from .. import data + +Decimal = decimal.Decimal +T = TypeVar('T') +TestCallable = Callable[[T, T], bool] + +CMP_OPS: Mapping[str, TestCallable] = { + '==': opmod.eq, + '>=': opmod.ge, + '>': opmod.gt, + '<=': opmod.le, + '<': opmod.lt, + '!=': opmod.ne, +} + +# First half of this regexp is pseudo-attribute access. +# Second half is metadata keys, per the Beancount syntax docs. +SUBJECT_PAT = r'((?:\.\w+)+|[a-z][-\w]*)\b\s*' + +logger = logging.getLogger('conservancy_beancount.reports.rewrite') + +class _Registry(Generic[T]): + def __init__(self, + description: str, + parser: Union[str, Pattern], + default: Type[T], + *others: Tuple[str, Type[T]], + ) -> None: + if isinstance(parser, str): + parser = re.compile(parser) + self.description = description + self.parser = parser + self.default = default + self.registry: Mapping[str, Type[T]] = dict(others) + + def parse(self, s: str) -> T: + match = self.parser.match(s) + if match is None: + raise ValueError(f"could not parse {self.description} {s!r}") + subject = match.group(1) + operator = match.group(2) + operand = s[match.end():].strip() + if not subject.startswith('.'): + # FIXME: To avoid this type ignore, I would have to define a common + # superclass for Tester and Setter that provides a useful signature + # for __init__, including the versions that deal with Metadata, + # and then use that as the bound for our type variable. + # Not a priority right now. + return self.default(subject, operator, operand) # type:ignore[call-arg] + try: + retclass = self.registry[subject] + except KeyError: + raise ValueError(f"unknown subject in {self.description} {subject!r}") from None + else: + return retclass(operator, operand) # type:ignore[call-arg] + + +class Tester(Generic[T], metaclass=abc.ABCMeta): + OPS: Mapping[str, TestCallable] = CMP_OPS + + def __init__(self, operator: str, operand: str) -> None: + try: + self.op_func = self.OPS[operator] + except KeyError: + raise ValueError(f"unsupported operator {operator!r}") from None + self.operand = self.parse_operand(operand) + + @staticmethod + @abc.abstractmethod + def parse_operand(operand: str) -> T: ... + + @abc.abstractmethod + def post_get(self, post: data.Posting) -> T: ... + + def __call__(self, post: data.Posting) -> bool: + return self.op_func(self.post_get(post), self.operand) + + +class AccountTest(Tester[str]): + def __init__(self, operator: str, operand: str) -> None: + if operator == 'in': + self.under_args = operand.split() + for name in self.under_args: + self.parse_operand(name) + else: + super().__init__(operator, operand) + + @staticmethod + def parse_operand(operand: str) -> str: + if data.Account.is_account(f'{operand}:RootsOK'): + return operand + else: + raise ValueError(f"invalid account name {operand!r}") + + def post_get(self, post: data.Posting) -> str: + return post.account + + def __call__(self, post: data.Posting) -> bool: + try: + return post.account.is_under(*self.under_args) is not None + except AttributeError: + return super().__call__(post) + + +class DateTest(Tester[datetime.date]): + @staticmethod + def parse_operand(operand: str) -> datetime.date: + return datetime.datetime.strptime(operand, '%Y-%m-%d').date() + + def post_get(self, post: data.Posting) -> datetime.date: + return post.meta.date + + +class MetadataTest(Tester[Optional[MetaValue]]): + def __init__(self, key: MetaKey, operator: str, operand: str) -> None: + super().__init__(operator, operand) + self.key = key + + @staticmethod + def parse_operand(operand: str) -> str: + return operand + + def post_get(self, post: data.Posting) -> Optional[MetaValue]: + return post.meta.get(self.key) + + +class NumberTest(Tester[Decimal]): + @staticmethod + def parse_operand(operand: str) -> Decimal: + try: + return Decimal(operand) + except decimal.DecimalException: + raise ValueError(f"could not parse decimal {operand!r}") + + def post_get(self, post: data.Posting) -> Decimal: + return post.units.number + + +TestRegistry: _Registry[Tester] = _Registry( + 'condition', + '^{}{}'.format( + SUBJECT_PAT, + r'({}|in)'.format('|'.join(re.escape(s) for s in Tester.OPS)), + ), + MetadataTest, + ('.account', AccountTest), + ('.date', DateTest), + ('.number', NumberTest), +) + +class Setter(Generic[T], metaclass=abc.ABCMeta): + _regparser = re.compile(r'^{}{}'.format( + SUBJECT_PAT, + r'', + )) + _regtype = 'setter' + + @abc.abstractmethod + def __call__(self, post: data.Posting) -> Tuple[str, T]: ... + + +class AccountSet(Setter[data.Account]): + def __init__(self, operator: str, value: str) -> None: + if operator != '=': + raise ValueError(f"unsupported operator for account {operator!r}") + self.value = data.Account(AccountTest.parse_operand(value)) + + def __call__(self, post: data.Posting) -> Tuple[str, data.Account]: + return ('account', self.value) + + +class MetadataSet(Setter[str]): + def __init__(self, key: str, operator: str, value: str) -> None: + if operator != '=': + raise ValueError(f"unsupported operator for metadata {operator!r}") + self.key = key + self.value = value + + def __call__(self, post: data.Posting) -> Tuple[str, str]: + return (self.key, self.value) + + +class NumberSet(Setter[data.Amount]): + def __init__(self, operator: str, value: str) -> None: + if operator != '*=': + raise ValueError(f"unsupported operator for number {operator!r}") + self.value = NumberTest.parse_operand(value) + + def __call__(self, post: data.Posting) -> Tuple[str, data.Amount]: + number = post.units.number * self.value + return ('units', post.units._replace(number=number)) + + +SetRegistry: _Registry[Setter] = _Registry( + 'action', + rf'^{SUBJECT_PAT}([-+/*]?=)', + MetadataSet, + ('.account', AccountSet), + ('.number', NumberSet), +) + +class _RootAccount(enum.Enum): + Assets = 'Assets' + Liabilities = 'Liabilities' + Equity = 'Equity' + + @classmethod + def from_account(cls, name: str) -> '_RootAccount': + root, _, _ = name.partition(':') + try: + return cls[root] + except KeyError: + return cls.Equity + + +class RewriteRule: + def __init__(self, source: Mapping[str, List[str]]) -> None: + self.new_meta: List[Sequence[MetadataSet]] = [] + self.rewrites: List[Sequence[Setter]] = [] + for key, rules in source.items(): + if key == 'if': + self.tests = [TestRegistry.parse(rule) for rule in rules] + else: + new_meta: List[MetadataSet] = [] + rewrites: List[Setter] = [] + for rule_s in rules: + setter = SetRegistry.parse(rule_s) + if isinstance(setter, MetadataSet): + new_meta.append(setter) + elif any(isinstance(t, type(setter)) for t in rewrites): + raise ValueError(f"rule conflicts with earlier action: {rule_s!r}") + else: + rewrites.append(setter) + self.new_meta.append(new_meta) + self.rewrites.append(rewrites) + + try: + if_ok = any(self.tests) + except AttributeError: + if_ok = False + if not if_ok: + raise ValueError("no `if` condition in rule") from None + + account_conditions: Set[_RootAccount] = set() + for test in self.tests: + if isinstance(test, AccountTest): + try: + operands = test.under_args + except AttributeError: + operands = [test.operand] + account_conditions.update(_RootAccount.from_account(s) for s in operands) + if len(account_conditions) == 1: + account_condition: Optional[_RootAccount] = account_conditions.pop() + else: + account_condition = None + + number_reallocation = Decimal() + for rewrite in self.rewrites: + rewrite_number = Decimal(1) + for rule in rewrite: + if isinstance(rule, AccountSet): + new_root = _RootAccount.from_account(rule.value) + if new_root is not account_condition: + raise ValueError( + f"cannot assign {new_root} account " + f"when `if` checks for {account_condition}", + ) + elif isinstance(rule, NumberSet): + rewrite_number = rule.value + number_reallocation += rewrite_number + + if not number_reallocation: + raise ValueError("no rewrite actions in rule") + elif number_reallocation != 1: + raise ValueError(f"rule multiplies number by {number_reallocation}") + + def match(self, post: data.Posting) -> bool: + return all(test(post) for test in self.tests) + + def rewrite(self, post: data.Posting) -> Iterator[data.Posting]: + for rewrite, new_meta in zip(self.rewrites, self.new_meta): + kwargs = dict(setter(post) for setter in rewrite) + if new_meta: + meta = post.meta.detached() + meta.update(meta_setter(post) for meta_setter in new_meta) + kwargs['meta'] = meta + yield post._replace(**kwargs) + + +class RewriteRuleset: + def __init__(self, rules: Iterable[RewriteRule]) -> None: + self.rules = list(rules) + + def rewrite(self, posts: Iterable[data.Posting]) -> Iterator[data.Posting]: + for post in posts: + for rule in self.rules: + if rule.match(post): + yield from rule.rewrite(post) + break + else: + yield post + + @classmethod + def from_yaml(cls, source: Union[str, IO, Path]) -> 'RewriteRuleset': + if isinstance(source, Path): + with source.open() as source_file: + return cls.from_yaml(source_file) + doc = yaml.safe_load(source) + if not isinstance(doc, list): + raise ValueError("YAML root element is not a list") + for number, item in enumerate(doc, 1): + if not isinstance(item, Mapping): + raise ValueError(f"YAML item {number} is not a rule hash") + for key, value in item.items(): + if not isinstance(value, list): + raise ValueError(f"YAML item {number} {key!r} value is not a list") + elif not all(isinstance(s, str) for s in value): + raise ValueError(f"YAML item {number} {key!r} value is not all strings") + try: + logger.debug("loaded %s rewrite rules from YAML", number) + except NameError: + logger.warning("YAML source is empty; no rewrite rules loaded") + return cls(RewriteRule(src) for src in doc)