Changeset - f508df06c162
[Not reviewed]
0 3 0
Brett Smith - 3 years ago 2021-02-19 23:54:36
brettcsmith@brettcsmith.org
rewrite: Richer exceptions for error reporting.
3 files changed with 131 insertions and 50 deletions:
0 comments (0 inline, 0 general)
conservancy_beancount/errors.py
Show inline comments
...
 
@@ -9,6 +9,7 @@ import beancount.core.data as bc_data
 

	
 
from typing import (
 
    Any,
 
    ClassVar,
 
    Iterable,
 
    Optional,
 
)
...
 
@@ -127,3 +128,41 @@ class InvalidMetadataError(Error):
 
        super().__init__(msg, txn, source)
 
        self.key = key
 
        self.value = value
 

	
 

	
 
class RewriteRuleError(ValueError):
 
    item_type: ClassVar[Optional[str]] = None
 

	
 
    def __init__(self,
 
                 message: str,
 
                 filename: Optional[str]=None,
 
                 rulenum: Optional[int]=None,
 
                 source: Optional[Any]=None,
 
    ) -> None:
 
        super().__init__(message, filename, rulenum, source)
 
        self.message = message
 
        self.filename = filename
 
        self.rulenum = rulenum
 
        self.source = source
 

	
 
    def __str__(self) -> str:
 
        msg_parts = [self.message]
 
        if self.filename is not None:
 
            msg_parts.append(f"in {self.filename}")
 
        if self.rulenum is not None:
 
            msg_parts.append(f"rule #{self.rulenum}")
 
            if self.item_type is not None:
 
                msg_parts.append(self.item_type)
 
        return ' '.join(msg_parts)
 

	
 

	
 
class RewriteRuleActionError(RewriteRuleError):
 
    item_type = 'action'
 

	
 

	
 
class RewriteRuleConditionError(RewriteRuleError):
 
    item_type = 'condition'
 

	
 

	
 
class RewriteRuleLoadError(RewriteRuleError): pass
 
class RewriteRuleValidationError(RewriteRuleError): pass
conservancy_beancount/reports/rewrite.py
Show inline comments
...
 
@@ -149,6 +149,7 @@ import operator as opmod
 
import re
 

	
 
from typing import (
 
    Any,
 
    Callable,
 
    Dict,
 
    Generic,
...
 
@@ -177,6 +178,7 @@ from pathlib import Path
 
import yaml
 

	
 
from .. import data
 
from .. import errors
 

	
 
Decimal = decimal.Decimal
 
T = TypeVar('T')
...
 
@@ -199,14 +201,14 @@ logger = logging.getLogger('conservancy_beancount.reports.rewrite')
 

	
 
class _Registry(Generic[T]):
 
    def __init__(self,
 
                 description: str,
 
                 error_cls: Type[errors.RewriteRuleError],
 
                 parser: Union[str, Pattern],
 
                 default: Type[T],
 
                 *others: Tuple[str, Type[T]],
 
    ) -> None:
 
        if isinstance(parser, str):
 
            parser = re.compile(parser)
 
        self.description = description
 
        self.error_cls = error_cls
 
        self.parser = parser
 
        self.default = default
 
        self.registry: Mapping[str, Type[T]] = dict(others)
...
 
@@ -214,7 +216,7 @@ class _Registry(Generic[T]):
 
    def parse(self, s: str) -> T:
 
        match = self.parser.match(s)
 
        if match is None:
 
            raise ValueError(f"could not parse {self.description} {s!r}")
 
            raise self.error_cls(f"parse error: {s!r}")
 
        subject = match.group(1)
 
        operator = match.group(2)
 
        operand = s[match.end():].strip()
...
 
@@ -228,19 +230,20 @@ class _Registry(Generic[T]):
 
        try:
 
            retclass = self.registry[subject]
 
        except KeyError:
 
            raise ValueError(f"unknown subject in {self.description} {subject!r}") from None
 
            raise self.error_cls(f"unknown subject: {subject!r}") from None
 
        else:
 
            return retclass(operator, operand)  # type:ignore[call-arg]
 

	
 

	
 
class Tester(Generic[T], metaclass=abc.ABCMeta):
 
    ERROR = errors.RewriteRuleConditionError
 
    OPS: Mapping[str, TestCallable] = CMP_OPS
 

	
 
    def __init__(self, operator: str, operand: str) -> None:
 
        try:
 
            self.op_func = self.OPS[operator]
 
        except KeyError:
 
            raise ValueError(f"unsupported operator {operator!r}") from None
 
            raise self.ERROR(f"unsupported operator: {operator!r}") from None
 
        self.operand = self.parse_operand(operand)
 

	
 
    @staticmethod
...
 
@@ -268,7 +271,7 @@ class AccountTest(Tester[str]):
 
        if data.Account.is_account(f'{operand}:RootsOK'):
 
            return operand
 
        else:
 
            raise ValueError(f"invalid account name {operand!r}")
 
            raise errors.RewriteRuleConditionError(f"invalid account name: {operand!r}")
 

	
 
    def post_get(self, post: data.Posting) -> str:
 
        return post.account
...
 
@@ -283,7 +286,10 @@ class AccountTest(Tester[str]):
 
class DateTest(Tester[datetime.date]):
 
    @staticmethod
 
    def parse_operand(operand: str) -> datetime.date:
 
        return datetime.datetime.strptime(operand, '%Y-%m-%d').date()
 
        try:
 
            return datetime.datetime.strptime(operand, '%Y-%m-%d').date()
 
        except ValueError as error:
 
            raise errors.RewriteRuleConditionError(error.args[0])
 

	
 
    def post_get(self, post: data.Posting) -> datetime.date:
 
        return post.meta.date
...
 
@@ -329,14 +335,14 @@ class NumberTest(Tester[Decimal]):
 
        try:
 
            return Decimal(operand)
 
        except decimal.DecimalException:
 
            raise ValueError(f"could not parse decimal {operand!r}")
 
            raise errors.RewriteRuleConditionError(f"decimal parse error: {operand!r}")
 

	
 
    def post_get(self, post: data.Posting) -> Decimal:
 
        return post.units.number
 

	
 

	
 
TestRegistry: _Registry[Tester] = _Registry(
 
    'condition',
 
    Tester.ERROR,
 
    '^{}{}'.format(
 
        SUBJECT_PAT,
 
        r'({}|in)'.format('|'.join(re.escape(s) for s in Tester.OPS)),
...
 
@@ -353,6 +359,7 @@ class Setter(Generic[T], metaclass=abc.ABCMeta):
 
        r'',
 
    ))
 
    _regtype = 'setter'
 
    ERROR = errors.RewriteRuleActionError
 

	
 
    @abc.abstractmethod
 
    def __call__(self, post: data.Posting) -> Tuple[str, T]: ...
...
 
@@ -361,8 +368,13 @@ class Setter(Generic[T], metaclass=abc.ABCMeta):
 
class AccountSet(Setter[data.Account]):
 
    def __init__(self, operator: str, value: str) -> None:
 
        if operator != '=':
 
            raise ValueError(f"unsupported operator for account {operator!r}")
 
        self.value = data.Account(AccountTest.parse_operand(value))
 
            raise self.ERROR(f"unsupported operator for account: {operator!r}")
 
        try:
 
            self.value = data.Account(AccountTest.parse_operand(value))
 
        except errors.RewriteRuleConditionError as error:
 
            new_error = errors.RewriteRuleActionError(*error.args)
 
            new_error.__cause__ = error.__cause__
 
            raise new_error
 

	
 
    def __call__(self, post: data.Posting) -> Tuple[str, data.Account]:
 
        return ('account', self.value)
...
 
@@ -371,7 +383,7 @@ class AccountSet(Setter[data.Account]):
 
class MetadataSet(Setter[str]):
 
    def __init__(self, key: str, operator: str, value: str) -> None:
 
        if operator != '=':
 
            raise ValueError(f"unsupported operator for metadata {operator!r}")
 
            raise self.ERROR(f"unsupported operator for metadata: {operator!r}")
 
        self.key = key
 
        self.value = value
 

	
...
 
@@ -382,8 +394,13 @@ class MetadataSet(Setter[str]):
 
class NumberSet(Setter[data.Amount]):
 
    def __init__(self, operator: str, value: str) -> None:
 
        if operator != '*=':
 
            raise ValueError(f"unsupported operator for number {operator!r}")
 
        self.value = NumberTest.parse_operand(value)
 
            raise self.ERROR(f"unsupported operator for number: {operator!r}")
 
        try:
 
            self.value = NumberTest.parse_operand(value)
 
        except errors.RewriteRuleConditionError as error:
 
            new_error = errors.RewriteRuleActionError(*error.args)
 
            new_error.__cause__ = error.__cause__
 
            raise new_error
 

	
 
    def __call__(self, post: data.Posting) -> Tuple[str, data.Amount]:
 
        number = post.units.number * self.value
...
 
@@ -391,7 +408,7 @@ class NumberSet(Setter[data.Amount]):
 

	
 

	
 
SetRegistry: _Registry[Setter] = _Registry(
 
    'action',
 
    Setter.ERROR,
 
    rf'^{SUBJECT_PAT}([-+/*]?=)',
 
    MetadataSet,
 
    ('.account', AccountSet),
...
 
@@ -413,6 +430,8 @@ class _RootAccount(enum.Enum):
 

	
 

	
 
class RewriteRule:
 
    ERROR = errors.RewriteRuleValidationError
 

	
 
    def __init__(self, source: Mapping[str, List[str]]) -> None:
 
        self.new_meta: List[Sequence[MetadataSet]] = []
 
        self.rewrites: List[Sequence[Setter]] = []
...
 
@@ -427,7 +446,7 @@ class RewriteRule:
 
                    if isinstance(setter, MetadataSet):
 
                        new_meta.append(setter)
 
                    elif any(isinstance(t, type(setter)) for t in rewrites):
 
                        raise ValueError(f"rule conflicts with earlier action: {rule_s!r}")
 
                        raise self.ERROR(f"rule conflicts with earlier action: {rule_s!r}")
 
                    else:
 
                        rewrites.append(setter)
 
                self.new_meta.append(new_meta)
...
 
@@ -438,7 +457,7 @@ class RewriteRule:
 
        except AttributeError:
 
            if_ok = False
 
        if not if_ok:
 
            raise ValueError("no `if` condition in rule") from None
 
            raise self.ERROR("no `if` condition in rule") from None
 

	
 
        account_conditions: Set[_RootAccount] = set()
 
        for test in self.tests:
...
 
@@ -460,7 +479,7 @@ class RewriteRule:
 
                if isinstance(rule, AccountSet):
 
                    new_root = _RootAccount.from_account(rule.value)
 
                    if new_root is not account_condition:
 
                        raise ValueError(
 
                        raise self.ERROR(
 
                            f"cannot assign {new_root} account "
 
                            f"when `if` checks for {account_condition}",
 
                        )
...
 
@@ -469,9 +488,9 @@ class RewriteRule:
 
            number_reallocation += rewrite_number
 

	
 
        if not number_reallocation:
 
            raise ValueError("no rewrite actions in rule")
 
            raise self.ERROR("no rewrite actions")
 
        elif number_reallocation != 1:
 
            raise ValueError(f"rule multiplies number by {number_reallocation}")
 
            raise self.ERROR(f"rule multiplies number by {number_reallocation}")
 

	
 
    def match(self, post: data.Posting) -> bool:
 
        return all(test(post) for test in self.tests)
...
 
@@ -500,23 +519,51 @@ class RewriteRuleset:
 
                yield post
 

	
 
    @classmethod
 
    def from_yaml(cls, source: Union[str, IO, Path]) -> 'RewriteRuleset':
 
    def _iter_yaml(cls, source: Iterable[Any], name: str) -> Iterator[RewriteRule]:
 
        for number, item in enumerate(source, 1):
 
            try:
 
                if not isinstance(item, Mapping):
 
                    raise errors.RewriteRuleLoadError(f"item is not a rule hash")
 
                for key, value in item.items():
 
                    if not isinstance(value, list):
 
                        raise errors.RewriteRuleLoadError(
 
                            f"YAML item {number} {key!r} value is not a list",
 
                        )
 
                    elif not all(isinstance(s, str) for s in value):
 
                        raise errors.RewriteRuleLoadError(
 
                            f"YAML item {number} {key!r} value is not all strings",
 
                        )
 
                yield RewriteRule(item)
 
            except errors.RewriteRuleError as error:
 
                error.filename = name
 
                error.rulenum = number
 
                error.source = item
 
                raise
 

	
 
    @classmethod
 
    def from_yaml(
 
            cls,
 
            source: Union[str, IO, Path],
 
            name: Optional[str]=None,
 
    ) -> 'RewriteRuleset':
 
        if isinstance(source, Path):
 
            with source.open() as source_file:
 
                return cls.from_yaml(source_file)
 
        doc = yaml.safe_load(source)
 
        if not isinstance(doc, list):
 
            raise ValueError("YAML root element is not a list")
 
        for number, item in enumerate(doc, 1):
 
            if not isinstance(item, Mapping):
 
                raise ValueError(f"YAML item {number} is not a rule hash")
 
            for key, value in item.items():
 
                if not isinstance(value, list):
 
                    raise ValueError(f"YAML item {number} {key!r} value is not a list")
 
                elif not all(isinstance(s, str) for s in value):
 
                    raise ValueError(f"YAML item {number} {key!r} value is not all strings")
 
                return cls.from_yaml(source_file, str(source))
 
        if name is None:
 
            if isinstance(source, str):
 
                name = '<string>'
 
            else:
 
                name = getattr(source, 'name', '<file>')
 
        try:
 
            logger.debug("loaded %s rewrite rules from YAML", number)
 
        except NameError:
 
            logger.warning("YAML source is empty; no rewrite rules loaded")
 
        return cls(RewriteRule(src) for src in doc)
 
            doc = yaml.safe_load(source)
 
        except yaml.error.YAMLError as error:
 
            raise errors.RewriteRuleLoadError(str(error), name)
 
        if not isinstance(doc, list):
 
            raise errors.RewriteRuleLoadError("YAML root element is not a list", name)
 
        retval = cls(cls._iter_yaml(doc, name))
 
        logger.log(
 
            logging.DEBUG if retval.rules else logging.WARNING,
 
            "loaded %s rewrite rules from %s",
 
            len(retval.rules), name,
 
        )
 
        return retval
tests/test_reports_rewrite.py
Show inline comments
...
 
@@ -15,7 +15,7 @@ import yaml
 

	
 
from . import testutil
 

	
 
from conservancy_beancount import data
 
from conservancy_beancount import data, errors
 
from conservancy_beancount.reports import rewrite
 

	
 
CMP_OPS = frozenset('< <= == != >= >'.split())
...
 
@@ -131,7 +131,7 @@ def test_parse_good_condition(subject, operator, operand):
 
    '.units == 5',  # Bad subject (unknown)
 
])
 
def test_parse_bad_condition(cond_s):
 
    with pytest.raises(ValueError):
 
    with pytest.raises(errors.RewriteRuleConditionError):
 
        rewrite.TestRegistry.parse(cond_s)
 

	
 
@pytest.mark.parametrize('value', ['Equity:Other', 'Income:Other'])
...
 
@@ -195,7 +195,7 @@ def test_parse_good_set(subject, operator, operand):
 
    'testkey *= 3',  # Bad operator
 
])
 
def test_parse_bad_set(set_s):
 
    with pytest.raises(ValueError):
 
    with pytest.raises(errors.RewriteRuleActionError):
 
        rewrite.SetRegistry.parse(set_s)
 

	
 
def test_good_rewrite_rule():
...
 
@@ -271,17 +271,12 @@ def test_valid_rewrite_rule(source):
 
    {},
 
    {'if': ['.account in Equity']},
 
    {'a': ['.account = Income:Other'], 'b': ['.account = Expenses:Other']},
 
    # Condition/assignment mixup
 
    {'if': ['.account = Equity:Other'], 'then': ['equity-type = other']},
 
    {'if': ['.account == Equity:Other'], 'then': ['equity-type != other']},
 
    # Cross-category account assignment
 
    {'if': ['.date >= 2020-01-01'], 'then': ['.account = Assets:Cash']},
 
    {'if': ['.account in Equity'], 'then': ['.account = Assets:Cash']},
 
    # Number reallocation != 1
 
    {'if': ['.date >= 2020-01-01'], 'then': ['.number *= .5']},
 
    {'if': ['.date >= 2020-01-01'], 'a': ['k1=v1'], 'b': ['k2=v2']},
 
    # Date assignment
 
    {'if': ['.date == 2020-01-01'], 'then': ['.date = 2020-02-02']},
 
    # Redundant assignments
 
    {'if': ['.account in Income'],
 
     'then': ['.account = Income:Other', '.account = Income:Other']},
...
 
@@ -290,7 +285,7 @@ def test_valid_rewrite_rule(source):
 
     'b': ['.number *= .5']},
 
])
 
def test_invalid_rewrite_rule(source):
 
    with pytest.raises(ValueError):
 
    with pytest.raises(errors.RewriteRuleValidationError):
 
        rewrite.RewriteRule(source)
 

	
 
def test_rewrite_ruleset():
...
 
@@ -333,7 +328,7 @@ def test_ruleset_from_yaml_str():
 

	
 
def test_bad_ruleset_yaml_path():
 
    yaml_path = testutil.test_path('repository/Projects/project-data.yml')
 
    with pytest.raises(ValueError):
 
    with pytest.raises(errors.RewriteRuleLoadError):
 
        rewrite.RewriteRuleset.from_yaml(yaml_path)
 

	
 
@pytest.mark.parametrize('source', [
...
 
@@ -344,13 +339,13 @@ def test_bad_ruleset_yaml_path():
 
    None,
 
    {},
 
    'string',
 
    [{}, 'a'],
 
    [{}, ['b']],
 
    ['a'],
 
    [['b']],
 
    # Rules have wrong type
 
    [{'if': '.account in Equity', 'add': ['testkey = value']}],
 
    [{'if': ['.account in Equity'], 'add': 'testkey = value'}],
 
])
 
def test_bad_ruleset_yaml_str(source):
 
    yaml_doc = yaml.safe_dump(source)
 
    with pytest.raises(ValueError):
 
    with pytest.raises(errors.RewriteRuleLoadError):
 
        rewrite.RewriteRuleset.from_yaml(yaml_doc)
0 comments (0 inline, 0 general)