Changeset - bd0d607032eb
[Not reviewed]
0 4 0
Brett Smith - 4 years ago 2020-04-28 20:35:15
brettcsmith@brettcsmith.org
typing: Annotate Iterators more specifically.
4 files changed with 6 insertions and 6 deletions:
0 comments (0 inline, 0 general)
conservancy_beancount/data.py
Show inline comments
...
 
@@ -264,71 +264,71 @@ class Posting(BasePosting):
 
    * The `units` field is our Amount object (which simply declares that the
 
      number is always a Decimal—see that docstring for details)
 
    * The `meta` field is a PostingMeta object
 
    """
 
    __slots__ = ()
 

	
 
    account: Account
 
    units: Amount
 
    # mypy correctly complains that our MutableMapping is not compatible
 
    # with Beancount's meta type declaration of Optional[Dict]. IMO
 
    # Beancount's type declaration is a smidge too specific: I think its type
 
    # declaration should also use MutableMapping, because it would be very
 
    # unusual for code to specifically require a Dict over that.
 
    # If it did, this declaration would pass without issue.
 
    meta: PostingMeta  # type:ignore[assignment]
 

	
 
    @classmethod
 
    def from_beancount(cls,
 
                       txn: Transaction,
 
                       index: int,
 
                       post: Optional[BasePosting]=None,
 
    ) -> 'Posting':
 
        if post is None:
 
            post = txn.postings[index]
 
        return cls(
 
            Account(post.account),
 
            *post[1:5],
 
            # see rationale above about Posting.meta
 
            PostingMeta(txn, index, post), # type:ignore[arg-type]
 
        )
 

	
 
    @classmethod
 
    def from_txn(cls, txn: Transaction) -> Iterable['Posting']:
 
    def from_txn(cls, txn: Transaction) -> Iterator['Posting']:
 
        """Yield an enhanced Posting object for every posting in the transaction"""
 
        for index, post in enumerate(txn.postings):
 
            yield cls.from_beancount(txn, index, post)
 

	
 
    @classmethod
 
    def from_entries(cls, entries: Iterable[Directive]) -> Iterable['Posting']:
 
    def from_entries(cls, entries: Iterable[Directive]) -> Iterator['Posting']:
 
        """Yield an enhanced Posting object for every posting in these entries"""
 
        for entry in entries:
 
            # Because Beancount's own Transaction class isn't type-checkable,
 
            # we can't statically check this. Might as well rely on duck
 
            # typing while we're at it: just try to yield postings from
 
            # everything, and ignore entries that lack a postings attribute.
 
            try:
 
                yield from cls.from_txn(entry)  # type:ignore[arg-type]
 
            except AttributeError:
 
                pass
 

	
 

	
 
_KT = TypeVar('_KT', bound=Hashable)
 
_VT = TypeVar('_VT')
 
class _SizedDict(collections.OrderedDict, MutableMapping[_KT, _VT]):
 
    def __init__(self, maxsize: int=128) -> None:
 
        self.maxsize = maxsize
 
        super().__init__()
 

	
 
    def __setitem__(self, key: _KT, value: _VT) -> None:
 
        super().__setitem__(key, value)
 
        for _ in range(self.maxsize, len(self)):
 
            self.popitem(last=False)
 

	
 

	
 
def balance_of(txn: Transaction,
 
               *preds: Callable[[Account], Optional[bool]],
 
) -> Amount:
 
    """Return the balance of specified postings in a transaction.
 

	
 
    Given a transaction and a series of account predicates, balance_of
 
    returns the balance of the amounts of all postings with accounts that
conservancy_beancount/plugin/__init__.py
Show inline comments
 
"""Beancount plugin entry point for Conservancy"""
 
# Copyright © 2020  Brett Smith
 
#
 
# This program is free software: you can redistribute it and/or modify
 
# it under the terms of the GNU Affero General Public License as published by
 
# the Free Software Foundation, either version 3 of the License, or
 
# (at your option) any later version.
 
#
 
# This program is distributed in the hope that it will be useful,
 
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
# GNU Affero General Public License for more details.
 
#
 
# You should have received a copy of the GNU Affero General Public License
 
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
 

	
 
import importlib
 

	
 
import beancount.core.data as bc_data
 

	
 
from typing import (
 
    AbstractSet,
 
    Any,
 
    Dict,
 
    Iterable,
 
    Iterator,
 
    List,
 
    Optional,
 
    Set,
 
    Tuple,
 
    Type,
 
)
 
from ..beancount_types import (
 
    ALL_DIRECTIVES,
 
    Directive,
 
    Entries,
 
    Errors,
 
    OptionsMap,
 
)
 
from .. import config as configmod
 
from .core import (
 
    Hook,
 
    HookName,
 
)
 
from ..errors import (
 
    Error,
 
)
 

	
 
__plugins__ = ['run']
 

	
 
class HookRegistry:
 
    INCLUDED_HOOKS: Dict[str, Optional[List[str]]] = {
 
        '.meta_approval': None,
 
        '.meta_entity': None,
 
        '.meta_expense_allocation': None,
 
        '.meta_income_type': None,
 
        '.meta_invoice': None,
 
        '.meta_payable_documentation': None,
...
 
@@ -64,65 +64,65 @@ class HookRegistry:
 
        '.meta_tax_implication': None,
 
    }
 

	
 
    def __init__(self) -> None:
 
        self.group_name_map: Dict[HookName, Set[Type[Hook]]] = {
 
            t.__name__: set() for t in ALL_DIRECTIVES
 
        }
 
        self.group_name_map['all'] = set()
 

	
 
    def add_hook(self, hook_cls: Type[Hook]) -> Type[Hook]:
 
        self.group_name_map['all'].add(hook_cls)
 
        self.group_name_map[hook_cls.DIRECTIVE.__name__].add(hook_cls)
 
        for key in hook_cls.HOOK_GROUPS:
 
            self.group_name_map.setdefault(key, set()).add(hook_cls)
 
        return hook_cls  # to allow use as a decorator
 

	
 
    def import_hooks(self,
 
                     mod_name: str,
 
                     *hook_names: str,
 
                     package: Optional[str]=None,
 
    ) -> None:
 
        if not hook_names:
 
            _, _, hook_name = mod_name.rpartition('.')
 
            hook_names = (hook_name.title().replace('_', ''),)
 
        module = importlib.import_module(mod_name, package)
 
        for hook_name in hook_names:
 
            self.add_hook(getattr(module, hook_name))
 

	
 
    def load_included_hooks(self) -> None:
 
        for mod_name, hook_names in self.INCLUDED_HOOKS.items():
 
            self.import_hooks(mod_name, *(hook_names or []), package=self.__module__)
 

	
 
    def group_by_directive(self, config_str: str='') -> Iterable[Tuple[HookName, Type[Hook]]]:
 
    def group_by_directive(self, config_str: str='') -> Iterator[Tuple[HookName, Type[Hook]]]:
 
        config_str = config_str.strip()
 
        if not config_str:
 
            config_str = 'all'
 
        elif config_str.startswith('-'):
 
            config_str = 'all ' + config_str
 
        available_hooks: Set[Type[Hook]] = set()
 
        for token in config_str.split():
 
            if token.startswith('-'):
 
                update_available = available_hooks.difference_update
 
                key = token[1:]
 
            else:
 
                update_available = available_hooks.update
 
                key = token
 
            try:
 
                update_set = self.group_name_map[key]
 
            except KeyError:
 
                raise ValueError("configuration refers to unknown hooks {!r}".format(key)) from None
 
            else:
 
                update_available(update_set)
 
        for directive in ALL_DIRECTIVES:
 
            key = directive.__name__
 
            for hook in self.group_name_map[key] & available_hooks:
 
                yield key, hook
 

	
 

	
 
def run(
 
        entries: Entries,
 
        options_map: OptionsMap,
 
        config: str='',
 
        hook_registry: Optional[HookRegistry]=None,
 
) -> Tuple[Entries, Errors]:
 
    if hook_registry is None:
conservancy_beancount/plugin/core.py
Show inline comments
...
 
@@ -230,46 +230,46 @@ class _NormalizePostingMetadataHook(_PostingHook):
 
            try:
 
                set_value = self._default_value(txn, post)
 
            except errormod.Error as error_:
 
                error = error_
 
        else:
 
            try:
 
                set_value = self.VALUES_ENUM[source_value]
 
            except KeyError:
 
                error = errormod.InvalidMetadataError(
 
                    txn, self.METADATA_KEY, source_value, post,
 
                )
 
        if error is None:
 
            post.meta[self.METADATA_KEY] = set_value
 
        else:
 
            yield error
 

	
 

	
 
class _RequireLinksPostingMetadataHook(_PostingHook):
 
    """Base class to require that posting metadata include links"""
 
    # This base class confirms that a posting's metadata has one or more links
 
    # under one of the metadata keys listed in CHECKED_METADATA.
 
    # Most subclasses only need to define CHECKED_METADATA and _run_on_post.
 
    CHECKED_METADATA: Sequence[MetaKey]
 

	
 
    def __init_subclass__(cls) -> None:
 
        super().__init_subclass__()
 
        cls.HOOK_GROUPS = cls.HOOK_GROUPS.union(cls.CHECKED_METADATA).union('metadata')
 

	
 
    def _check_metadata(self,
 
                        txn: Transaction,
 
                        post: data.Posting,
 
                        keys: Sequence[MetaKey],
 
    ) -> Iterable[errormod.InvalidMetadataError]:
 
    ) -> Iterator[errormod.InvalidMetadataError]:
 
        have_docs = False
 
        for key in keys:
 
            try:
 
                links = post.meta.get_links(key)
 
            except TypeError as error:
 
                yield errormod.InvalidMetadataError(txn, key, post.meta[key], post)
 
            else:
 
                have_docs = have_docs or any(links)
 
        if not have_docs:
 
            yield errormod.InvalidMetadataError(txn, '/'.join(keys), None, post)
 

	
 
    def post_run(self, txn: Transaction, post: data.Posting) -> errormod.Iter:
 
        return self._check_metadata(txn, post, self.CHECKED_METADATA)
conservancy_beancount/reports/core.py
Show inline comments
...
 
@@ -114,51 +114,51 @@ class RelatedPostings(Sequence[data.Posting]):
 
        The values are RelatedPostings instances that contain all the postings
 
        that had that same metadata value.
 
        """
 
        retval: DefaultDict[Optional[MetaValue], 'RelatedPostings'] = collections.defaultdict(cls)
 
        for post in postings:
 
            retval[post.meta.get(key, default)].add(post)
 
        retval.default_factory = None
 
        return retval
 

	
 
    @overload
 
    def __getitem__(self, index: int) -> data.Posting: ...
 

	
 
    @overload
 
    def __getitem__(self, s: slice) -> Sequence[data.Posting]: ...
 

	
 
    def __getitem__(self,
 
                    index: Union[int, slice],
 
    ) -> Union[data.Posting, Sequence[data.Posting]]:
 
        if isinstance(index, slice):
 
            raise NotImplementedError("RelatedPostings[slice]")
 
        else:
 
            return self._postings[index]
 

	
 
    def __len__(self) -> int:
 
        return len(self._postings)
 

	
 
    def add(self, post: data.Posting) -> None:
 
        self._postings.append(post)
 

	
 
    def clear(self) -> None:
 
        self._postings.clear()
 

	
 
    def iter_with_balance(self) -> Iterable[Tuple[data.Posting, Balance]]:
 
    def iter_with_balance(self) -> Iterator[Tuple[data.Posting, Balance]]:
 
        balance = MutableBalance()
 
        for post in self:
 
            balance.add_amount(post.units)
 
            yield post, balance
 

	
 
    def balance(self) -> Balance:
 
        for _, balance in self.iter_with_balance():
 
            pass
 
        try:
 
            return balance
 
        except NameError:
 
            return Balance()
 

	
 
    def meta_values(self,
 
                    key: MetaKey,
 
                    default: Optional[MetaValue]=None,
 
    ) -> Set[Optional[MetaValue]]:
 
        return {post.meta.get(key, default) for post in self}
0 comments (0 inline, 0 general)