Files @ 22d5b7e90a1e
Branch filter:

Location: NPO-Accounting/conservancy_beancount/conservancy_beancount/plugin/__init__.py

Brett Smith
setup: Disallow untyped calls.
"""Beancount plugin entry point for Conservancy"""
# Copyright © 2020  Brett Smith
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.

import importlib

import beancount.core.data as bc_data

from typing import (
    AbstractSet,
    Any,
    Dict,
    Iterable,
    List,
    Optional,
    Set,
    Tuple,
    Type,
)
from ..beancount_types import (
    ALL_DIRECTIVES,
    Directive,
)
from .. import config as configmod
from .core import (
    Hook,
    HookName,
)
from ..errors import (
    Error,
)

__plugins__ = ['run']

class HookRegistry:
    def __init__(self) -> None:
        self.group_name_map: Dict[HookName, Set[Type[Hook]]] = {
            t.__name__: set() for t in ALL_DIRECTIVES
        }
        self.group_name_map['all'] = set()

    def add_hook(self, hook_cls: Type[Hook]) -> Type[Hook]:
        self.group_name_map['all'].add(hook_cls)
        self.group_name_map[hook_cls.DIRECTIVE.__name__].add(hook_cls)
        for key in hook_cls.HOOK_GROUPS:
            self.group_name_map.setdefault(key, set()).add(hook_cls)
        return hook_cls  # to allow use as a decorator

    def import_hooks(self,
                     mod_name: str,
                     *hook_names: str,
                     package: Optional[str]=__module__,  # type:ignore[name-defined]
    ) -> None:
        module = importlib.import_module(mod_name, package)
        for hook_name in hook_names:
            self.add_hook(getattr(module, hook_name))

    def group_by_directive(self, config_str: str='') -> Iterable[Tuple[HookName, Type[Hook]]]:
        config_str = config_str.strip()
        if not config_str:
            config_str = 'all'
        elif config_str.startswith('-'):
            config_str = 'all ' + config_str
        available_hooks: Set[Type[Hook]] = set()
        for token in config_str.split():
            if token.startswith('-'):
                update_available = available_hooks.difference_update
                key = token[1:]
            else:
                update_available = available_hooks.update
                key = token
            try:
                update_set = self.group_name_map[key]
            except KeyError:
                raise ValueError("configuration refers to unknown hooks {!r}".format(key)) from None
            else:
                update_available(update_set)
        for directive in ALL_DIRECTIVES:
            key = directive.__name__
            for hook in self.group_name_map[key] & available_hooks:
                yield key, hook


HOOK_REGISTRY = HookRegistry()
HOOK_REGISTRY.import_hooks('.meta_expense_allocation', 'MetaExpenseAllocation')
HOOK_REGISTRY.import_hooks('.meta_tax_implication', 'MetaTaxImplication')

def run(
        entries: List[Directive],
        options_map: Dict[str, Any],
        config: str='',
        hook_registry: HookRegistry=HOOK_REGISTRY,
) -> Tuple[List[Directive], List[Error]]:
    errors: List[Error] = []
    hooks: Dict[HookName, List[Hook]] = {}
    user_config = configmod.Config()
    for key, hook_type in hook_registry.group_by_directive(config):
        try:
            hook = hook_type(user_config)
        except Error as error:
            errors.append(error)
        else:
            hooks.setdefault(key, []).append(hook)
    for entry in entries:
        entry_type = type(entry).__name__
        for hook in hooks[entry_type]:
            errors.extend(hook.run(entry))
    return entries, errors