Files @ 58b02b6f33c2
Branch filter:

Location: NPO-Accounting/conservancy_beancount/conservancy_beancount/reports/core.py

Brett Smith
accrual: Move more functionality into AccrualPostings.
"""core.py - Common data classes for reporting functionality"""
# Copyright © 2020  Brett Smith
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.

import collections
import operator

from decimal import Decimal

import babel.numbers  # type:ignore[import]

from .. import data

from typing import (
    overload,
    Any,
    Callable,
    DefaultDict,
    Dict,
    Iterable,
    Iterator,
    List,
    Mapping,
    Optional,
    Sequence,
    Set,
    Tuple,
    Type,
    TypeVar,
    Union,
)
from ..beancount_types import (
    MetaKey,
    MetaValue,
)

DecimalCompat = data.DecimalCompat
RelatedType = TypeVar('RelatedType', bound='RelatedPostings')

class Balance(Mapping[str, data.Amount]):
    """A collection of amounts mapped by currency

    Each key is a Beancount currency string, and each value represents the
    balance in that currency.
    """
    __slots__ = ('_currency_map',)

    def __init__(self,
                 source: Union[Iterable[Tuple[str, data.Amount]],
                               Mapping[str, data.Amount]]=(),
    ) -> None:
        if isinstance(source, Mapping):
            source = source.items()
        self._currency_map = {
            currency: amount.number for currency, amount in source
        }

    def __repr__(self) -> str:
        return f"{type(self).__name__}({self._currency_map!r})"

    def __str__(self) -> str:
        return self.format()

    def __eq__(self, other: Any) -> bool:
        if (self.is_zero()
            and isinstance(other, Balance)
            and other.is_zero()):
            return True
        else:
            return super().__eq__(other)

    def __neg__(self) -> 'Balance':
        return type(self)(
            (key, -amt) for key, amt in self.items()
        )

    def __getitem__(self, key: str) -> data.Amount:
        return data.Amount(self._currency_map[key], key)

    def __iter__(self) -> Iterator[str]:
        return iter(self._currency_map)

    def __len__(self) -> int:
        return len(self._currency_map)

    def _all_amounts(self,
                     op_func: Callable[[DecimalCompat, DecimalCompat], bool],
                     operand: DecimalCompat,
    ) -> bool:
        return all(op_func(number, operand) for number in self._currency_map.values())

    def eq_zero(self) -> bool:
        """Returns true if all amounts in the balance == 0."""
        return self._all_amounts(operator.eq, 0)

    is_zero = eq_zero

    def ge_zero(self) -> bool:
        """Returns true if all amounts in the balance >= 0."""
        return self._all_amounts(operator.ge, 0)

    def le_zero(self) -> bool:
        """Returns true if all amounts in the balance <= 0."""
        return self._all_amounts(operator.le, 0)

    def format(self,
               fmt: Optional[str]='#,#00.00 ¤¤',
               sep: str=', ',
               empty: str="Zero balance",
    ) -> str:
        """Formats the balance as a string with the given parameters

        If the balance is zero, returns ``empty``. Otherwise, returns a string
        with each amount in the balance formatted as ``fmt``, separated by
        ``sep``.

        If you set ``fmt`` to None, amounts will be formatted according to the
        user's locale. The default format is Beancount's input format.
        """
        amounts = [amount for amount in self.values() if amount.number]
        if not amounts:
            return empty
        amounts.sort(key=lambda amt: abs(amt.number), reverse=True)
        return sep.join(
            babel.numbers.format_currency(amt.number, amt.currency, fmt)
            for amt in amounts
        )


class MutableBalance(Balance):
    __slots__ = ()

    def add_amount(self, amount: data.Amount) -> None:
        try:
            self._currency_map[amount.currency] += amount.number
        except KeyError:
            self._currency_map[amount.currency] = amount.number


class RelatedPostings(Sequence[data.Posting]):
    """Collect and query related postings

    This class provides common functionality for collecting related postings
    and running queries on them: iterating over them, tallying their balance,
    etc.

    This class doesn't know anything about how the postings are related. That's
    entirely up to the caller.

    A common pattern is to use this class with collections.defaultdict
    to organize postings based on some key. See the group_by_meta classmethod
    for an example.
    """
    __slots__ = ('_postings',)

    def __init__(self,
                 source: Iterable[data.Posting]=(),
                 *,
                 _can_own: bool=False,
    ) -> None:
        self._postings: List[data.Posting]
        if _can_own and isinstance(source, list):
            self._postings = source
        else:
            self._postings = list(source)

    @classmethod
    def group_by_meta(cls: Type[RelatedType],
                      postings: Iterable[data.Posting],
                      key: MetaKey,
                      default: Optional[MetaValue]=None,
    ) -> Iterator[Tuple[Optional[MetaValue], RelatedType]]:
        """Relate postings by metadata value

        This method takes an iterable of postings and returns a mapping.
        The keys of the mapping are the values of post.meta.get(key, default).
        The values are RelatedPostings instances that contain all the postings
        that had that same metadata value.
        """
        mapping: DefaultDict[Optional[MetaValue], List[data.Posting]] = collections.defaultdict(list)
        for post in postings:
            mapping[post.meta.get(key, default)].append(post)
        for value, posts in mapping.items():
            yield value, cls(posts, _can_own=True)

    def __repr__(self) -> str:
        return f'<{type(self).__name__} {self._postings!r}>'

    @overload
    def __getitem__(self: RelatedType, index: int) -> data.Posting: ...

    @overload
    def __getitem__(self: RelatedType, s: slice) -> RelatedType: ...

    def __getitem__(self: RelatedType,
                    index: Union[int, slice],
    ) -> Union[data.Posting, RelatedType]:
        if isinstance(index, slice):
            return type(self)(self._postings[index], _can_own=True)
        else:
            return self._postings[index]

    def __len__(self) -> int:
        return len(self._postings)

    def all_meta_links(self, key: MetaKey) -> Set[str]:
        retval: Set[str] = set()
        for post in self:
            try:
                retval.update(post.meta.get_links(key))
            except TypeError:
                pass
        return retval

    def iter_with_balance(self) -> Iterator[Tuple[data.Posting, Balance]]:
        balance = MutableBalance()
        for post in self:
            balance.add_amount(post.units)
            yield post, balance

    def balance(self) -> Balance:
        for _, balance in self.iter_with_balance():
            pass
        try:
            return balance
        except NameError:
            return Balance()

    def balance_at_cost(self) -> Balance:
        balance = MutableBalance()
        for post in self:
            if post.cost is None:
                balance.add_amount(post.units)
            else:
                number = post.units.number * post.cost.number
                balance.add_amount(data.Amount(number, post.cost.currency))
        return balance

    def meta_values(self,
                    key: MetaKey,
                    default: Optional[MetaValue]=None,
    ) -> Set[Optional[MetaValue]]:
        return {post.meta.get(key, default) for post in self}