Files @ 999ca2c5e1fa
Branch filter:

Location: NPO-Accounting/conservancy_beancount/conservancy_beancount/reports/core.py

Brett Smith
rtutil: Add RT.txn_with_urls() method.
"""core.py - Common data classes for reporting functionality"""
# Copyright © 2020  Brett Smith
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.

import collections

from decimal import Decimal

from .. import data

from typing import (
    overload,
    DefaultDict,
    Dict,
    Iterable,
    Iterator,
    List,
    Mapping,
    Optional,
    Sequence,
    Set,
    Tuple,
    Union,
)
from ..beancount_types import (
    MetaKey,
    MetaValue,
)

class Balance(Mapping[str, data.Amount]):
    """A collection of amounts mapped by currency

    Each key is a Beancount currency string, and each value represents the
    balance in that currency.
    """
    __slots__ = ('_currency_map',)

    def __init__(self,
                 source: Union[Iterable[Tuple[str, data.Amount]],
                               Mapping[str, data.Amount]]=(),
    ) -> None:
        if isinstance(source, Mapping):
            source = source.items()
        self._currency_map = {
            currency: amount.number for currency, amount in source
        }

    def __repr__(self) -> str:
        return f"{type(self).__name__}({self._currency_map!r})"

    def __getitem__(self, key: str) -> data.Amount:
        return data.Amount(self._currency_map[key], key)

    def __iter__(self) -> Iterator[str]:
        return iter(self._currency_map)

    def __len__(self) -> int:
        return len(self._currency_map)

    def is_zero(self) -> bool:
        return all(number == 0 for number in self._currency_map.values())


class MutableBalance(Balance):
    __slots__ = ()

    def add_amount(self, amount: data.Amount) -> None:
        try:
            self._currency_map[amount.currency] += amount.number
        except KeyError:
            self._currency_map[amount.currency] = amount.number


class RelatedPostings(Sequence[data.Posting]):
    """Collect and query related postings

    This class provides common functionality for collecting related postings
    and running queries on them: iterating over them, tallying their balance,
    etc.

    This class doesn't know anything about how the postings are related. That's
    entirely up to the caller.

    A common pattern is to use this class with collections.defaultdict
    to organize postings based on some key. See the group_by_meta classmethod
    for an example.
    """

    def __init__(self, source: Iterable[data.Posting]=()) -> None:
        self._postings: List[data.Posting] = list(source)

    @classmethod
    def group_by_meta(cls,
                      postings: Iterable[data.Posting],
                      key: MetaKey,
                      default: Optional[MetaValue]=None,
    ) -> Mapping[Optional[MetaValue], 'RelatedPostings']:
        """Relate postings by metadata value

        This method takes an iterable of postings and returns a mapping.
        The keys of the mapping are the values of post.meta.get(key, default).
        The values are RelatedPostings instances that contain all the postings
        that had that same metadata value.
        """
        retval: DefaultDict[Optional[MetaValue], 'RelatedPostings'] = collections.defaultdict(cls)
        for post in postings:
            retval[post.meta.get(key, default)].add(post)
        retval.default_factory = None
        return retval

    @overload
    def __getitem__(self, index: int) -> data.Posting: ...

    @overload
    def __getitem__(self, s: slice) -> Sequence[data.Posting]: ...

    def __getitem__(self,
                    index: Union[int, slice],
    ) -> Union[data.Posting, Sequence[data.Posting]]:
        if isinstance(index, slice):
            raise NotImplementedError("RelatedPostings[slice]")
        else:
            return self._postings[index]

    def __len__(self) -> int:
        return len(self._postings)

    def add(self, post: data.Posting) -> None:
        self._postings.append(post)

    def clear(self) -> None:
        self._postings.clear()

    def iter_with_balance(self) -> Iterable[Tuple[data.Posting, Balance]]:
        balance = MutableBalance()
        for post in self:
            balance.add_amount(post.units)
            yield post, balance

    def balance(self) -> Balance:
        for _, balance in self.iter_with_balance():
            pass
        try:
            return balance
        except NameError:
            return Balance()

    def meta_values(self,
                    key: MetaKey,
                    default: Optional[MetaValue]=None,
    ) -> Set[Optional[MetaValue]]:
        return {post.meta.get(key, default) for post in self}