Files @ 54a1bc460028
Branch filter:

Location: NPO-Accounting/conservancy_beancount/conservancy_beancount/reports/core.py

Brett Smith
filters: Add filter_for_rt_id function.
"""core.py - Common data classes for reporting functionality"""
# Copyright © 2020  Brett Smith
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.

import collections

from decimal import Decimal

from .. import data

from typing import (
    overload,
    Dict,
    Iterable,
    Iterator,
    List,
    Mapping,
    Optional,
    Sequence,
    Tuple,
    Union,
)

class Balance(Mapping[str, data.Amount]):
    """A collection of amounts mapped by currency

    Each key is a Beancount currency string, and each value represents the
    balance in that currency.
    """
    __slots__ = ('_currency_map',)

    def __init__(self,
                 source: Union[Iterable[Tuple[str, data.Amount]],
                               Mapping[str, data.Amount]]=(),
    ) -> None:
        if isinstance(source, Mapping):
            source = source.items()
        self._currency_map = {
            currency: amount.number for currency, amount in source
        }

    def __repr__(self) -> str:
        return f"{type(self).__name__}({self._currency_map!r})"

    def __getitem__(self, key: str) -> data.Amount:
        return data.Amount(self._currency_map[key], key)

    def __iter__(self) -> Iterator[str]:
        return iter(self._currency_map)

    def __len__(self) -> int:
        return len(self._currency_map)

    def is_zero(self) -> bool:
        return all(number == 0 for number in self._currency_map.values())


class MutableBalance(Balance):
    __slots__ = ()

    def add_amount(self, amount: data.Amount) -> None:
        try:
            self._currency_map[amount.currency] += amount.number
        except KeyError:
            self._currency_map[amount.currency] = amount.number


class RelatedPostings(Sequence[data.Posting]):
    """Collect and query related postings

    This class provides common functionality for collecting related postings
    and running queries on them: iterating over them, tallying their balance,
    etc.

    This class doesn't know anything about how the postings are related. That's
    entirely up to the caller.

    A common pattern is to use this class with collections.defaultdict
    to organize postings based on some key::

        report = collections.defaultdict(RelatedPostings)
        for txn in transactions:
            for post in Posting.from_txn(txn):
                if should_report(post):
                    key = post_key(post)
                    report[key].add(post)
    """

    def __init__(self) -> None:
        self._postings: List[data.Posting] = []

    @overload
    def __getitem__(self, index: int) -> data.Posting: ...

    @overload
    def __getitem__(self, s: slice) -> Sequence[data.Posting]: ...

    def __getitem__(self,
                    index: Union[int, slice],
    ) -> Union[data.Posting, Sequence[data.Posting]]:
        if isinstance(index, slice):
            raise NotImplementedError("RelatedPostings[slice]")
        else:
            return self._postings[index]

    def __len__(self) -> int:
        return len(self._postings)

    def add(self, post: data.Posting) -> None:
        self._postings.append(post)

    def iter_with_balance(self) -> Iterable[Tuple[data.Posting, Balance]]:
        balance = MutableBalance()
        for post in self:
            balance.add_amount(post.units)
            yield post, balance

    def balance(self) -> Balance:
        for _, balance in self.iter_with_balance():
            pass
        try:
            return balance
        except NameError:
            return Balance()