diff --git a/conservancy_beancount/errors.py b/conservancy_beancount/errors.py index 397c2de932658c7162552d1cf51efbb4a3de6c25..6256097aaa17c02a66f4b294d2fb373d2ffc33ad 100644 --- a/conservancy_beancount/errors.py +++ b/conservancy_beancount/errors.py @@ -9,6 +9,7 @@ import beancount.core.data as bc_data from typing import ( Any, + ClassVar, Iterable, Optional, ) @@ -127,3 +128,41 @@ class InvalidMetadataError(Error): super().__init__(msg, txn, source) self.key = key self.value = value + + +class RewriteRuleError(ValueError): + item_type: ClassVar[Optional[str]] = None + + def __init__(self, + message: str, + filename: Optional[str]=None, + rulenum: Optional[int]=None, + source: Optional[Any]=None, + ) -> None: + super().__init__(message, filename, rulenum, source) + self.message = message + self.filename = filename + self.rulenum = rulenum + self.source = source + + def __str__(self) -> str: + msg_parts = [self.message] + if self.filename is not None: + msg_parts.append(f"in {self.filename}") + if self.rulenum is not None: + msg_parts.append(f"rule #{self.rulenum}") + if self.item_type is not None: + msg_parts.append(self.item_type) + return ' '.join(msg_parts) + + +class RewriteRuleActionError(RewriteRuleError): + item_type = 'action' + + +class RewriteRuleConditionError(RewriteRuleError): + item_type = 'condition' + + +class RewriteRuleLoadError(RewriteRuleError): pass +class RewriteRuleValidationError(RewriteRuleError): pass diff --git a/conservancy_beancount/reports/rewrite.py b/conservancy_beancount/reports/rewrite.py index f6bc6c4696f7a9b842a2e6480ec4294d0d0db71f..6266ffb2b87327e421d44ceb476fe918f6b26307 100644 --- a/conservancy_beancount/reports/rewrite.py +++ b/conservancy_beancount/reports/rewrite.py @@ -149,6 +149,7 @@ import operator as opmod import re from typing import ( + Any, Callable, Dict, Generic, @@ -177,6 +178,7 @@ from pathlib import Path import yaml from .. import data +from .. import errors Decimal = decimal.Decimal T = TypeVar('T') @@ -199,14 +201,14 @@ logger = logging.getLogger('conservancy_beancount.reports.rewrite') class _Registry(Generic[T]): def __init__(self, - description: str, + error_cls: Type[errors.RewriteRuleError], parser: Union[str, Pattern], default: Type[T], *others: Tuple[str, Type[T]], ) -> None: if isinstance(parser, str): parser = re.compile(parser) - self.description = description + self.error_cls = error_cls self.parser = parser self.default = default self.registry: Mapping[str, Type[T]] = dict(others) @@ -214,7 +216,7 @@ class _Registry(Generic[T]): def parse(self, s: str) -> T: match = self.parser.match(s) if match is None: - raise ValueError(f"could not parse {self.description} {s!r}") + raise self.error_cls(f"parse error: {s!r}") subject = match.group(1) operator = match.group(2) operand = s[match.end():].strip() @@ -228,19 +230,20 @@ class _Registry(Generic[T]): try: retclass = self.registry[subject] except KeyError: - raise ValueError(f"unknown subject in {self.description} {subject!r}") from None + raise self.error_cls(f"unknown subject: {subject!r}") from None else: return retclass(operator, operand) # type:ignore[call-arg] class Tester(Generic[T], metaclass=abc.ABCMeta): + ERROR = errors.RewriteRuleConditionError OPS: Mapping[str, TestCallable] = CMP_OPS def __init__(self, operator: str, operand: str) -> None: try: self.op_func = self.OPS[operator] except KeyError: - raise ValueError(f"unsupported operator {operator!r}") from None + raise self.ERROR(f"unsupported operator: {operator!r}") from None self.operand = self.parse_operand(operand) @staticmethod @@ -268,7 +271,7 @@ class AccountTest(Tester[str]): if data.Account.is_account(f'{operand}:RootsOK'): return operand else: - raise ValueError(f"invalid account name {operand!r}") + raise errors.RewriteRuleConditionError(f"invalid account name: {operand!r}") def post_get(self, post: data.Posting) -> str: return post.account @@ -283,7 +286,10 @@ class AccountTest(Tester[str]): class DateTest(Tester[datetime.date]): @staticmethod def parse_operand(operand: str) -> datetime.date: - return datetime.datetime.strptime(operand, '%Y-%m-%d').date() + try: + return datetime.datetime.strptime(operand, '%Y-%m-%d').date() + except ValueError as error: + raise errors.RewriteRuleConditionError(error.args[0]) def post_get(self, post: data.Posting) -> datetime.date: return post.meta.date @@ -329,14 +335,14 @@ class NumberTest(Tester[Decimal]): try: return Decimal(operand) except decimal.DecimalException: - raise ValueError(f"could not parse decimal {operand!r}") + raise errors.RewriteRuleConditionError(f"decimal parse error: {operand!r}") def post_get(self, post: data.Posting) -> Decimal: return post.units.number TestRegistry: _Registry[Tester] = _Registry( - 'condition', + Tester.ERROR, '^{}{}'.format( SUBJECT_PAT, r'({}|in)'.format('|'.join(re.escape(s) for s in Tester.OPS)), @@ -353,6 +359,7 @@ class Setter(Generic[T], metaclass=abc.ABCMeta): r'', )) _regtype = 'setter' + ERROR = errors.RewriteRuleActionError @abc.abstractmethod def __call__(self, post: data.Posting) -> Tuple[str, T]: ... @@ -361,8 +368,13 @@ class Setter(Generic[T], metaclass=abc.ABCMeta): class AccountSet(Setter[data.Account]): def __init__(self, operator: str, value: str) -> None: if operator != '=': - raise ValueError(f"unsupported operator for account {operator!r}") - self.value = data.Account(AccountTest.parse_operand(value)) + raise self.ERROR(f"unsupported operator for account: {operator!r}") + try: + self.value = data.Account(AccountTest.parse_operand(value)) + except errors.RewriteRuleConditionError as error: + new_error = errors.RewriteRuleActionError(*error.args) + new_error.__cause__ = error.__cause__ + raise new_error def __call__(self, post: data.Posting) -> Tuple[str, data.Account]: return ('account', self.value) @@ -371,7 +383,7 @@ class AccountSet(Setter[data.Account]): class MetadataSet(Setter[str]): def __init__(self, key: str, operator: str, value: str) -> None: if operator != '=': - raise ValueError(f"unsupported operator for metadata {operator!r}") + raise self.ERROR(f"unsupported operator for metadata: {operator!r}") self.key = key self.value = value @@ -382,8 +394,13 @@ class MetadataSet(Setter[str]): class NumberSet(Setter[data.Amount]): def __init__(self, operator: str, value: str) -> None: if operator != '*=': - raise ValueError(f"unsupported operator for number {operator!r}") - self.value = NumberTest.parse_operand(value) + raise self.ERROR(f"unsupported operator for number: {operator!r}") + try: + self.value = NumberTest.parse_operand(value) + except errors.RewriteRuleConditionError as error: + new_error = errors.RewriteRuleActionError(*error.args) + new_error.__cause__ = error.__cause__ + raise new_error def __call__(self, post: data.Posting) -> Tuple[str, data.Amount]: number = post.units.number * self.value @@ -391,7 +408,7 @@ class NumberSet(Setter[data.Amount]): SetRegistry: _Registry[Setter] = _Registry( - 'action', + Setter.ERROR, rf'^{SUBJECT_PAT}([-+/*]?=)', MetadataSet, ('.account', AccountSet), @@ -413,6 +430,8 @@ class _RootAccount(enum.Enum): class RewriteRule: + ERROR = errors.RewriteRuleValidationError + def __init__(self, source: Mapping[str, List[str]]) -> None: self.new_meta: List[Sequence[MetadataSet]] = [] self.rewrites: List[Sequence[Setter]] = [] @@ -427,7 +446,7 @@ class RewriteRule: if isinstance(setter, MetadataSet): new_meta.append(setter) elif any(isinstance(t, type(setter)) for t in rewrites): - raise ValueError(f"rule conflicts with earlier action: {rule_s!r}") + raise self.ERROR(f"rule conflicts with earlier action: {rule_s!r}") else: rewrites.append(setter) self.new_meta.append(new_meta) @@ -438,7 +457,7 @@ class RewriteRule: except AttributeError: if_ok = False if not if_ok: - raise ValueError("no `if` condition in rule") from None + raise self.ERROR("no `if` condition in rule") from None account_conditions: Set[_RootAccount] = set() for test in self.tests: @@ -460,7 +479,7 @@ class RewriteRule: if isinstance(rule, AccountSet): new_root = _RootAccount.from_account(rule.value) if new_root is not account_condition: - raise ValueError( + raise self.ERROR( f"cannot assign {new_root} account " f"when `if` checks for {account_condition}", ) @@ -469,9 +488,9 @@ class RewriteRule: number_reallocation += rewrite_number if not number_reallocation: - raise ValueError("no rewrite actions in rule") + raise self.ERROR("no rewrite actions") elif number_reallocation != 1: - raise ValueError(f"rule multiplies number by {number_reallocation}") + raise self.ERROR(f"rule multiplies number by {number_reallocation}") def match(self, post: data.Posting) -> bool: return all(test(post) for test in self.tests) @@ -500,23 +519,51 @@ class RewriteRuleset: yield post @classmethod - def from_yaml(cls, source: Union[str, IO, Path]) -> 'RewriteRuleset': + def _iter_yaml(cls, source: Iterable[Any], name: str) -> Iterator[RewriteRule]: + for number, item in enumerate(source, 1): + try: + if not isinstance(item, Mapping): + raise errors.RewriteRuleLoadError(f"item is not a rule hash") + for key, value in item.items(): + if not isinstance(value, list): + raise errors.RewriteRuleLoadError( + f"YAML item {number} {key!r} value is not a list", + ) + elif not all(isinstance(s, str) for s in value): + raise errors.RewriteRuleLoadError( + f"YAML item {number} {key!r} value is not all strings", + ) + yield RewriteRule(item) + except errors.RewriteRuleError as error: + error.filename = name + error.rulenum = number + error.source = item + raise + + @classmethod + def from_yaml( + cls, + source: Union[str, IO, Path], + name: Optional[str]=None, + ) -> 'RewriteRuleset': if isinstance(source, Path): with source.open() as source_file: - return cls.from_yaml(source_file) - doc = yaml.safe_load(source) - if not isinstance(doc, list): - raise ValueError("YAML root element is not a list") - for number, item in enumerate(doc, 1): - if not isinstance(item, Mapping): - raise ValueError(f"YAML item {number} is not a rule hash") - for key, value in item.items(): - if not isinstance(value, list): - raise ValueError(f"YAML item {number} {key!r} value is not a list") - elif not all(isinstance(s, str) for s in value): - raise ValueError(f"YAML item {number} {key!r} value is not all strings") + return cls.from_yaml(source_file, str(source)) + if name is None: + if isinstance(source, str): + name = '' + else: + name = getattr(source, 'name', '') try: - logger.debug("loaded %s rewrite rules from YAML", number) - except NameError: - logger.warning("YAML source is empty; no rewrite rules loaded") - return cls(RewriteRule(src) for src in doc) + doc = yaml.safe_load(source) + except yaml.error.YAMLError as error: + raise errors.RewriteRuleLoadError(str(error), name) + if not isinstance(doc, list): + raise errors.RewriteRuleLoadError("YAML root element is not a list", name) + retval = cls(cls._iter_yaml(doc, name)) + logger.log( + logging.DEBUG if retval.rules else logging.WARNING, + "loaded %s rewrite rules from %s", + len(retval.rules), name, + ) + return retval diff --git a/tests/test_reports_rewrite.py b/tests/test_reports_rewrite.py index 50235c48925ab5eedd13ad48e074faec0147368d..a915d77d0075d881e5cad7a325b8b5aa3e202780 100644 --- a/tests/test_reports_rewrite.py +++ b/tests/test_reports_rewrite.py @@ -15,7 +15,7 @@ import yaml from . import testutil -from conservancy_beancount import data +from conservancy_beancount import data, errors from conservancy_beancount.reports import rewrite CMP_OPS = frozenset('< <= == != >= >'.split()) @@ -131,7 +131,7 @@ def test_parse_good_condition(subject, operator, operand): '.units == 5', # Bad subject (unknown) ]) def test_parse_bad_condition(cond_s): - with pytest.raises(ValueError): + with pytest.raises(errors.RewriteRuleConditionError): rewrite.TestRegistry.parse(cond_s) @pytest.mark.parametrize('value', ['Equity:Other', 'Income:Other']) @@ -195,7 +195,7 @@ def test_parse_good_set(subject, operator, operand): 'testkey *= 3', # Bad operator ]) def test_parse_bad_set(set_s): - with pytest.raises(ValueError): + with pytest.raises(errors.RewriteRuleActionError): rewrite.SetRegistry.parse(set_s) def test_good_rewrite_rule(): @@ -271,17 +271,12 @@ def test_valid_rewrite_rule(source): {}, {'if': ['.account in Equity']}, {'a': ['.account = Income:Other'], 'b': ['.account = Expenses:Other']}, - # Condition/assignment mixup - {'if': ['.account = Equity:Other'], 'then': ['equity-type = other']}, - {'if': ['.account == Equity:Other'], 'then': ['equity-type != other']}, # Cross-category account assignment {'if': ['.date >= 2020-01-01'], 'then': ['.account = Assets:Cash']}, {'if': ['.account in Equity'], 'then': ['.account = Assets:Cash']}, # Number reallocation != 1 {'if': ['.date >= 2020-01-01'], 'then': ['.number *= .5']}, {'if': ['.date >= 2020-01-01'], 'a': ['k1=v1'], 'b': ['k2=v2']}, - # Date assignment - {'if': ['.date == 2020-01-01'], 'then': ['.date = 2020-02-02']}, # Redundant assignments {'if': ['.account in Income'], 'then': ['.account = Income:Other', '.account = Income:Other']}, @@ -290,7 +285,7 @@ def test_valid_rewrite_rule(source): 'b': ['.number *= .5']}, ]) def test_invalid_rewrite_rule(source): - with pytest.raises(ValueError): + with pytest.raises(errors.RewriteRuleValidationError): rewrite.RewriteRule(source) def test_rewrite_ruleset(): @@ -333,7 +328,7 @@ def test_ruleset_from_yaml_str(): def test_bad_ruleset_yaml_path(): yaml_path = testutil.test_path('repository/Projects/project-data.yml') - with pytest.raises(ValueError): + with pytest.raises(errors.RewriteRuleLoadError): rewrite.RewriteRuleset.from_yaml(yaml_path) @pytest.mark.parametrize('source', [ @@ -344,13 +339,13 @@ def test_bad_ruleset_yaml_path(): None, {}, 'string', - [{}, 'a'], - [{}, ['b']], + ['a'], + [['b']], # Rules have wrong type [{'if': '.account in Equity', 'add': ['testkey = value']}], [{'if': ['.account in Equity'], 'add': 'testkey = value'}], ]) def test_bad_ruleset_yaml_str(source): yaml_doc = yaml.safe_dump(source) - with pytest.raises(ValueError): + with pytest.raises(errors.RewriteRuleLoadError): rewrite.RewriteRuleset.from_yaml(yaml_doc)