Files @ a41feb94b3e0
Branch filter:

Location: NPO-Accounting/conservancy_beancount/conservancy_beancount/plugin/core.py

Brett Smith
plugin: Transform posting hooks into transaction hooks.

I feel like posting hooks a case of premature optimization in early
development. This approach reduces the number of special cases in
the code and allows us to more strongly reason about hooks in the
type system.
"""Base classes for plugin checks"""
# Copyright © 2020  Brett Smith
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.

import abc
import datetime
import re

from . import errors as errormod

from typing import (
    FrozenSet,
    Generic,
    Iterable,
    Iterator,
    Mapping,
    Optional,
    TypeVar,
)
from .._typing import (
    Account,
    Directive,
    Error,
    ErrorIter,
    LessComparable,
    MetaKey,
    MetaValue,
    MetaValueEnum,
    Posting,
    Transaction,
    Type,
)

### CONSTANTS

# I expect these will become configurable in the future, which is why I'm
# keeping them outside of a class, but for now constants will do.
DEFAULT_START_DATE: datetime.date = datetime.date(2020, 3, 1)
# The default stop date leaves a little room after so it's easy to test
# dates past the far end of the range.
DEFAULT_STOP_DATE: datetime.date = datetime.date(datetime.MAXYEAR, 1, 1)

### TYPE DEFINITIONS

HookName = str

Entry = TypeVar('Entry', bound=Directive)
class Hook(Generic[Entry], metaclass=abc.ABCMeta):
    DIRECTIVE: Type[Directive]
    HOOK_GROUPS: FrozenSet[HookName] = frozenset()

    @abc.abstractmethod
    def run(self, entry: Entry) -> ErrorIter: ...

    def __init_subclass__(cls):
        cls.DIRECTIVE = cls.__orig_bases__[0].__args__[0]


TransactionHook = Hook[Transaction]

### HELPER CLASSES

CT = TypeVar('CT', bound=LessComparable)
class _GenericRange(Generic[CT]):
    """Convenience class to check whether a value is within a range.

    `foo in generic_range` is equivalent to `start <= foo < stop`.
    Since we have multiple user-configurable ranges, having the check
    encapsulated in an object helps implement the check consistently, and
    makes it easier for subclasses to override.
    """

    def __init__(self, start: CT, stop: CT) -> None:
        self.start = start
        self.stop = stop

    def __repr__(self) -> str:
        return "{clsname}({self.start!r}, {self.stop!r})".format(
            clsname=type(self).__name__,
            self=self,
        )

    def __contains__(self, item: CT) -> bool:
        return self.start <= item < self.stop


class MetadataEnum:
    """Map acceptable metadata values to their normalized forms.

    When a piece of metadata uses a set of allowed values, use this class to
    define them. You can also specify aliases that hooks will normalize to
    the primary values.
    """

    def __init__(self,
                 key: MetaKey,
                 standard_values: Iterable[MetaValueEnum],
                 aliases_map: Mapping[MetaValueEnum, MetaValueEnum],
    ) -> None:
        """Specify allowed values and aliases for this metadata.

        Arguments:

        * key: The name of the metadata key that uses this enum.
        * standard_values: A sequence of strings that enumerate the standard
          values for this metadata.
        * aliases_map: A mapping of strings to strings. The keys are
          additional allowed metadata values. The values are standard values
          that each key will evaluate to. The code asserts that all values are
          in standard_values.
        """
        self.key = key
        self._stdvalues = frozenset(standard_values)
        self._aliases = dict(aliases_map)
        self._aliases.update((v, v) for v in standard_values)
        assert self._stdvalues == set(self._aliases.values())

    def __repr__(self) -> str:
        return "{}<{}>".format(type(self).__name__, self.key)

    def __contains__(self, key: MetaValueEnum) -> bool:
        """Returns true if `key` is a standard value or alias."""
        return key in self._aliases

    def __getitem__(self, key: MetaValueEnum) -> MetaValueEnum:
        """Return the standard value for `key`.

        Raises KeyError if `key` is not a known value or alias.
        """
        return self._aliases[key]

    def __iter__(self) -> Iterator[MetaValueEnum]:
        """Iterate over standard values."""
        return iter(self._stdvalues)

    def get(self,
            key: MetaValueEnum,
            default_key: Optional[MetaValueEnum]=None,
    ) -> Optional[MetaValueEnum]:
        """Return self[key], or a default fallback if that doesn't exist.

        default_key is another key to look up, *not* a default value to return.
        This helps ensure you always get a standard value.
        """
        try:
            return self[key]
        except KeyError:
            if default_key is None:
                return None
            else:
                return self[default_key]


### HOOK SUBCLASSES

class _PostingHook(TransactionHook, metaclass=abc.ABCMeta):
    TXN_DATE_RANGE: _GenericRange = _GenericRange(DEFAULT_START_DATE, DEFAULT_STOP_DATE)

    def __init_subclass__(cls) -> None:
        cls.HOOK_GROUPS = cls.HOOK_GROUPS.union(['posting'])

    def _meta_get(self,
                  txn: Transaction,
                  post: Posting,
                  key: MetaKey,
                  default: MetaValue=None,
    ) -> MetaValue:
        if post.meta and key in post.meta:
            return post.meta[key]
        else:
            return txn.meta.get(key, default)

    def _meta_set(self,
                  txn: Transaction,
                  post: Posting,
                  post_index: int,
                  key: MetaKey,
                  value: MetaValue,
    ) -> None:
        if post.meta is None:
            txn.postings[post_index] = Posting(*post[:5], {key: value})
        else:
            post.meta[key] = value

    def _run_on_txn(self, txn: Transaction) -> bool:
        return txn.date in self.TXN_DATE_RANGE

    def _run_on_post(self, txn: Transaction, post: Posting) -> bool:
        return True

    def run(self, txn: Transaction) -> ErrorIter:
        if self._run_on_txn(txn):
            for index, post in enumerate(txn.postings):
                if self._run_on_post(txn, post):
                    yield from self.post_run(txn, post, index)

    @abc.abstractmethod
    def post_run(self, txn: Transaction, post: Posting, post_index: int) -> ErrorIter: ...


class _NormalizePostingMetadataHook(_PostingHook):
    """Base class to normalize posting metadata from an enum."""
    # This class provides basic functionality to filter postings, normalize
    # metadata values, and set default values.
    METADATA_KEY: MetaKey
    VALUES_ENUM: MetadataEnum

    def __init_subclass__(cls) -> None:
        super().__init_subclass__()
        cls.METADATA_KEY = cls.VALUES_ENUM.key
        cls.HOOK_GROUPS = cls.HOOK_GROUPS.union(['metadata', cls.METADATA_KEY])

    # If the posting does not specify METADATA_KEY, the hook calls
    # _default_value to get a default. This method should either return
    # a value string from METADATA_ENUM, or else raise InvalidMetadataError.
    # This base implementation does the latter.
    def _default_value(self, txn: Transaction, post: Posting) -> MetaValueEnum:
        raise errormod.InvalidMetadataError(txn, post, self.METADATA_KEY)

    def post_run(self, txn: Transaction, post: Posting, post_index: int) -> ErrorIter:
        source_value = self._meta_get(txn, post, self.METADATA_KEY)
        set_value = source_value
        error: Optional[Error] = None
        if source_value is None:
            try:
                set_value = self._default_value(txn, post)
            except errormod._BaseError as error_:
                error = error_
        else:
            try:
                set_value = self.VALUES_ENUM[source_value]
            except KeyError:
                error = errormod.InvalidMetadataError(
                    txn, post, self.METADATA_KEY, source_value,
                )
        if error is None:
            self._meta_set(txn, post, post_index, self.METADATA_KEY, set_value)
        else:
            yield error