Files @ d145e2273483
Branch filter:

Location: NPO-Accounting/conservancy_beancount/conservancy_beancount/plugin/core.py

Brett Smith
test_plugin_run: Simplify testing strategy.

Avoid keeping state in the hook classes/instances.
"""Base classes for plugin checks"""
# Copyright © 2020  Brett Smith
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.

import datetime

from . import errors as errormod

DEFAULT_START_DATE = datetime.date(2020, 3, 1)
DEFAULT_STOP_DATE = datetime.date(datetime.MAXYEAR, 1, 1)

class _GenericRange:
    def __init__(self, start, stop):
        self.start = start
        self.stop = stop

    def __repr__(self):
        return "{clsname}({self.start!r}, {self.stop!r})".format(
            clsname=type(self).__name__,
            self=self,
        )

    def __contains__(self, item):
        return self.start <= item < self.stop


class MetadataEnum:
    def __init__(self, key, standard_values, aliases_map):
        self.key = key
        self._stdvalues = frozenset(standard_values)
        self._aliases = dict(aliases_map)
        self._aliases.update((v, v) for v in standard_values)
        assert self._stdvalues == set(self._aliases.values())

    def __repr__(self):
        return "{}<{}>".format(type(self).__name__, self.key)

    def __contains__(self, key):
        return key in self._aliases

    def __getitem__(self, key):
        return self._aliases[key]

    def __iter__(self):
        return iter(self._stdvalues)

    def get(self, key, default_key=None):
        try:
            return self[key]
        except KeyError:
            if default_key is None:
                return None
            else:
                return self[default_key]


class PostingChecker:
    ACCOUNTS = ('',)
    TXN_DATE_RANGE = _GenericRange(DEFAULT_START_DATE, DEFAULT_STOP_DATE)
    VALUES_ENUM = {}

    def _meta_get(self, txn, post, key, default=None):
        try:
            return post.meta[key]
        except (KeyError, TypeError):
            return txn.meta.get(key, default)

    def _meta_set(self, post, key, value):
        if post.meta is None:
            post.meta = {}
        post.meta[key] = value

    def _default_value(self, txn, post):
        raise errormod.InvalidMetadataError(txn, post, self.METADATA_KEY)

    def _should_check(self, txn, post):
        ok = txn.date in self.TXN_DATE_RANGE
        if isinstance(self.ACCOUNTS, tuple):
            ok = ok and post.account.startswith(self.ACCOUNTS)
        else:
            ok = ok and re.search(self.ACCOUNTS, post.account)
        return ok

    def run(self, txn, post):
        errors = []
        if not self._should_check(txn, post):
            return errors
        source_value = self._meta_get(txn, post, self.METADATA_KEY)
        set_value = source_value
        if source_value is None:
            try:
                set_value = self._default_value(txn, post)
            except errormod._BaseError as error:
                errors.append(error)
        else:
            try:
                set_value = self.VALUES_ENUM[source_value]
            except KeyError:
                errors.append(errormod.InvalidMetadataError(
                    txn, post, self.METADATA_KEY, source_value,
                ))
        if not errors:
            self._meta_set(post, self.METADATA_KEY, set_value)
        return errors