Changeset - 42d2002fda44
[Not reviewed]
0 2 0
Brett Smith - 4 years ago 2020-06-21 15:39:31
brettcsmith@brettcsmith.org
reports: Balance.format(None) uses accounting formats.
2 files changed with 7 insertions and 5 deletions:
0 comments (0 inline, 0 general)
conservancy_beancount/reports/core.py
Show inline comments
...
 
@@ -106,258 +106,259 @@ class Balance(Mapping[str, data.Amount]):
 
    ) -> None:
 
        code = amount.currency
 
        try:
 
            current_number = currency_map[code].number
 
        except KeyError:
 
            current_number = Decimal(0)
 
        currency_map[code] = data.Amount(current_number + amount.number, code)
 

	
 
    def _add_other(self,
 
                   currency_map: MutableMapping[str, data.Amount],
 
                   other: Union[data.Amount, 'Balance'],
 
    ) -> None:
 
        if isinstance(other, Balance):
 
            for amount in other.values():
 
                self._add_amount(currency_map, amount)
 
        else:
 
            self._add_amount(currency_map, other)
 

	
 
    def __repr__(self) -> str:
 
        values = [repr(amt) for amt in self.values()]
 
        return f"{type(self).__name__}({values!r})"
 

	
 
    def __str__(self) -> str:
 
        return self.format()
 

	
 
    def __abs__(self: BalanceType) -> BalanceType:
 
        return type(self)(bc_amount.abs(amt) for amt in self.values())
 

	
 
    def __add__(self: BalanceType, other: Union[data.Amount, 'Balance']) -> BalanceType:
 
        retval_map = self._currency_map.copy()
 
        self._add_other(retval_map, other)
 
        return type(self)(retval_map.values())
 

	
 
    def __sub__(self: BalanceType, other: Union[data.Amount, 'Balance']) -> BalanceType:
 
        return self.__add__(-other)
 

	
 
    def __eq__(self, other: Any) -> bool:
 
        if isinstance(other, Balance):
 
            clean_self = self.clean_copy()
 
            clean_other = other.clean_copy()
 
            return len(clean_self) == len(clean_other) and all(
 
                clean_self[key] == clean_other.get(key) for key in clean_self
 
            )
 
        else:
 
            return super().__eq__(other)
 

	
 
    def __neg__(self: BalanceType) -> BalanceType:
 
        return type(self)(-amt for amt in self.values())
 

	
 
    def __pos__(self: BalanceType) -> BalanceType:
 
        return self
 

	
 
    def __getitem__(self, key: str) -> data.Amount:
 
        return self._currency_map[key]
 

	
 
    def __iter__(self) -> Iterator[str]:
 
        return iter(self._currency_map)
 

	
 
    def __len__(self) -> int:
 
        return len(self._currency_map)
 

	
 
    def _all_amounts(self,
 
                     op_func: Callable[[DecimalCompat, DecimalCompat], bool],
 
                     operand: DecimalCompat,
 
    ) -> bool:
 
        return all(op_func(amt.number, operand) for amt in self.values())
 

	
 
    def copy(self: BalanceType, tolerance: Optional[Decimal]=None) -> BalanceType:
 
        if tolerance is None:
 
            tolerance = self.tolerance
 
        return type(self)(self.values(), tolerance)
 

	
 
    def clean_copy(self: BalanceType, tolerance: Optional[Decimal]=None) -> BalanceType:
 
        if tolerance is None:
 
            tolerance = self.tolerance
 
        return type(self)(
 
            (amount for amount in self.values() if abs(amount.number) >= tolerance),
 
            tolerance,
 
        )
 

	
 
    @staticmethod
 
    def within_tolerance(dec: DecimalCompat, tolerance: DecimalCompat) -> bool:
 
        dec = cast(Decimal, dec)
 
        return abs(dec) < tolerance
 

	
 
    def eq_zero(self) -> bool:
 
        """Returns true if all amounts in the balance == 0, within tolerance."""
 
        return self._all_amounts(self.within_tolerance, self.tolerance)
 

	
 
    is_zero = eq_zero
 

	
 
    def ge_zero(self) -> bool:
 
        """Returns true if all amounts in the balance >= 0, within tolerance."""
 
        op_func = operator.gt if self.tolerance else operator.ge
 
        return self._all_amounts(op_func, -self.tolerance)
 

	
 
    def le_zero(self) -> bool:
 
        """Returns true if all amounts in the balance <= 0, within tolerance."""
 
        op_func = operator.lt if self.tolerance else operator.le
 
        return self._all_amounts(op_func, self.tolerance)
 

	
 
    def format(self,
 
               fmt: Optional[str]='#,##0.00 ¤¤',
 
               sep: str=', ',
 
               empty: str="Zero balance",
 
               zero: Optional[str]=None,
 
               tolerance: Optional[Decimal]=None,
 
    ) -> str:
 
        """Formats the balance as a string with the given parameters
 

	
 
        If the balance is completely empty, return ``empty``.
 
        If the balance is zero (within tolerance) and ``zero`` is specified,
 
        return ``zero``.
 
        Otherwise, return a string with each amount in the balance formatted
 
        as ``fmt``, separated by ``sep``.
 

	
 
        If you set ``fmt`` to None, amounts will be formatted according to the
 
        user's locale. The default format is Beancount's input format.
 
        """
 
        balance = self.clean_copy(tolerance) or self.copy(tolerance)
 
        if not balance:
 
            return empty
 
        elif zero is not None and balance.is_zero():
 
            return zero
 
        else:
 
            amounts = list(balance.values())
 
            amounts.sort(key=lambda amt: (-abs(amt.number), amt.currency))
 
            return sep.join(
 
                babel.numbers.format_currency(amt.number, amt.currency, fmt)
 
                for amt in amounts
 
                babel.numbers.format_currency(
 
                    amt.number, amt.currency, fmt, format_type='accounting',
 
                ) for amt in amounts
 
            )
 

	
 

	
 
class MutableBalance(Balance):
 
    __slots__ = ()
 

	
 
    def __iadd__(self: BalanceType, other: Union[data.Amount, Balance]) -> BalanceType:
 
        self._add_other(self._currency_map, other)
 
        return self
 

	
 
    def __isub__(self: BalanceType, other: Union[data.Amount, Balance]) -> BalanceType:
 
        self._add_other(self._currency_map, -other)
 
        return self
 

	
 

	
 
class RelatedPostings(Sequence[data.Posting]):
 
    """Collect and query related postings
 

	
 
    This class provides common functionality for collecting related postings
 
    and running queries on them: iterating over them, tallying their balance,
 
    etc.
 

	
 
    This class doesn't know anything about how the postings are related. That's
 
    entirely up to the caller.
 

	
 
    A common pattern is to use this class with collections.defaultdict
 
    to organize postings based on some key. See the group_by_meta classmethod
 
    for an example.
 
    """
 
    __slots__ = ('_postings',)
 

	
 
    def __init__(self,
 
                 source: Iterable[data.Posting]=(),
 
                 *,
 
                 _can_own: bool=False,
 
    ) -> None:
 
        self._postings: List[data.Posting]
 
        if _can_own and isinstance(source, list):
 
            self._postings = source
 
        else:
 
            self._postings = list(source)
 

	
 
    @classmethod
 
    def _group_by(cls: Type[RelatedType],
 
                  postings: Iterable[data.Posting],
 
                  key: Callable[[data.Posting], T],
 
    ) -> Iterator[Tuple[T, RelatedType]]:
 
        mapping: Dict[T, List[data.Posting]] = collections.defaultdict(list)
 
        for post in postings:
 
            mapping[key(post)].append(post)
 
        for value, posts in mapping.items():
 
            yield value, cls(posts, _can_own=True)
 

	
 
    @classmethod
 
    def group_by_account(cls: Type[RelatedType],
 
                         postings: Iterable[data.Posting],
 
    ) -> Iterator[Tuple[data.Account, RelatedType]]:
 
        return cls._group_by(postings, operator.attrgetter('account'))
 

	
 
    @classmethod
 
    def group_by_meta(cls: Type[RelatedType],
 
                      postings: Iterable[data.Posting],
 
                      key: MetaKey,
 
                      default: Optional[MetaValue]=None,
 
    ) -> Iterator[Tuple[Optional[MetaValue], RelatedType]]:
 
        """Relate postings by metadata value
 

	
 
        This method takes an iterable of postings and returns a mapping.
 
        The keys of the mapping are the values of post.meta.get(key, default).
 
        The values are RelatedPostings instances that contain all the postings
 
        that had that same metadata value.
 
        """
 
        def key_func(post: data.Posting) -> Optional[MetaValue]:
 
            return post.meta.get(key, default)
 
        return cls._group_by(postings, key_func)
 

	
 
    @classmethod
 
    def group_by_first_meta_link(
 
            cls: Type[RelatedType],
 
            postings: Iterable[data.Posting],
 
            key: MetaKey,
 
    ) -> Iterator[Tuple[Optional[str], RelatedType]]:
 
        """Relate postings by the first link in metadata
 

	
 
        This method takes an iterable of postings and returns a mapping.
 
        The keys of the mapping are the values of
 
        post.meta.first_link(key, None).
 
        The values are RelatedPostings instances that contain all the postings
 
        that had that same first metadata link.
 
        """
 
        def key_func(post: data.Posting) -> Optional[MetaValue]:
 
            return post.meta.first_link(key, None)
 
        return cls._group_by(postings, key_func)
 

	
 
    def __repr__(self) -> str:
 
        return f'<{type(self).__name__} {self._postings!r}>'
 

	
 
    @overload
 
    def __getitem__(self: RelatedType, index: int) -> data.Posting: ...
 

	
 
    @overload
 
    def __getitem__(self: RelatedType, s: slice) -> RelatedType: ...
 

	
 
    def __getitem__(self: RelatedType,
 
                    index: Union[int, slice],
 
    ) -> Union[data.Posting, RelatedType]:
 
        if isinstance(index, slice):
 
            return type(self)(self._postings[index], _can_own=True)
 
        else:
 
            return self._postings[index]
 

	
 
    def __len__(self) -> int:
 
        return len(self._postings)
 

	
 
    def all_meta_links(self, key: MetaKey) -> Iterator[str]:
 
        return filters.iter_unique(
 
            link for post in self for link in post.meta.report_links(key)
 
        )
 

	
 
    @overload
 
    def first_meta_links(self, key: MetaKey, default: str='') -> Iterator[str]: ...
 

	
 
    @overload
 
    def first_meta_links(self, key: MetaKey, default: None) -> Iterator[Optional[str]]: ...
 

	
 
    def first_meta_links(self,
 
                         key: MetaKey,
 
                         default: Optional[str]='',
tests/test_reports_balance.py
Show inline comments
...
 
@@ -294,172 +294,173 @@ def test_sub_amount(number, currency):
 
    (50, 'USD'),
 
    (-50, 'USD'),
 
    (50000, 'BRL'),
 
    (-4000, 'BRL'),
 
})
 
def test_isub_amount(number, currency):
 
    balance = core.MutableBalance([testutil.Amount(500, 'USD')])
 
    sub_amount = testutil.Amount(number, currency)
 
    balance -= sub_amount
 
    if currency == 'USD':
 
        assert len(balance) == 1
 
        assert balance['USD'] == testutil.Amount(500 - number)
 
    else:
 
        assert len(balance) == 2
 
        assert balance['USD'] == testutil.Amount(500)
 
        assert balance[currency] == -sub_amount
 

	
 
@pytest.mark.parametrize('mapping', [
 
    {},
 
    {'USD': 0},
 
    {'EUR': 10},
 
    {'JPY': 20, 'BRL': 30},
 
    {'EUR': -15},
 
    {'JPY': -25, 'BRL': -35},
 
    {'JPY': 40, 'USD': 0, 'EUR': -50},
 
])
 
def test_add_balance(mapping):
 
    expect_numbers = {'USD': 500, 'BRL': 40000}
 
    start_bal = core.Balance(amounts_from_map(expect_numbers))
 
    for code, number in mapping.items():
 
        expect_numbers[code] = expect_numbers.get(code, 0) + number
 
    add_bal = core.Balance(amounts_from_map(mapping))
 
    actual = start_bal + add_bal
 
    expected = core.Balance(amounts_from_map(expect_numbers))
 
    assert actual == expected
 

	
 
@pytest.mark.parametrize('mapping', [
 
    {},
 
    {'USD': 0},
 
    {'EUR': 10},
 
    {'JPY': 20, 'BRL': 30},
 
    {'EUR': -15},
 
    {'JPY': -25, 'BRL': -35},
 
    {'JPY': 40, 'USD': 0, 'EUR': -50},
 
])
 
def test_iadd_balance(mapping):
 
    expect_numbers = {'USD': 500, 'BRL': 40000}
 
    balance = core.MutableBalance(amounts_from_map(expect_numbers))
 
    for code, number in mapping.items():
 
        expect_numbers[code] = expect_numbers.get(code, 0) + number
 
    balance += core.Balance(amounts_from_map(mapping))
 
    expected = core.Balance(amounts_from_map(expect_numbers))
 
    assert balance == expected
 

	
 
@pytest.mark.parametrize('tolerance', TOLERANCES)
 
def test_copy(tolerance):
 
    eur = testutil.Amount('.003', 'EUR')
 
    source = core.Balance([eur], tolerance)
 
    new = source.copy()
 
    assert source is not new
 
    assert dict(source) == dict(new)
 
    assert new.tolerance == tolerance
 

	
 
@pytest.mark.parametrize('tolerance', TOLERANCES)
 
def test_copy_tolerance_arg(tolerance):
 
    eur = testutil.Amount('.003', 'EUR')
 
    source = core.Balance([eur])
 
    new = source.copy(tolerance)
 
    assert source is not new
 
    assert dict(source) == dict(new)
 
    assert new.tolerance == tolerance
 

	
 
@pytest.mark.parametrize('tolerance', TOLERANCES)
 
def test_clean_copy(tolerance):
 
    usd = testutil.Amount(10)
 
    eur = testutil.Amount('.002', 'EUR')
 
    actual = core.Balance([usd, eur], tolerance).clean_copy()
 
    if tolerance < eur.number:
 
        expected = {usd, eur}
 
    else:
 
        expected = {usd}
 
    assert frozenset(actual.values()) == expected
 
    assert actual.tolerance == tolerance
 

	
 
@pytest.mark.parametrize('tolerance', TOLERANCES)
 
def test_clean_copy_arg(tolerance):
 
    usd = testutil.Amount(10)
 
    eur = testutil.Amount('.002', 'EUR')
 
    actual = core.Balance([usd, eur], 0).clean_copy(tolerance)
 
    if tolerance < eur.number:
 
        expected = {usd, eur}
 
    else:
 
        expected = {usd}
 
    assert frozenset(actual.values()) == expected
 
    assert actual.tolerance == tolerance
 

	
 
@pytest.mark.parametrize('mapping,expected', DEFAULT_STRINGS)
 
def test_str(mapping, expected):
 
    balance = core.Balance(amounts_from_map(mapping))
 
    assert str(balance) == expected
 

	
 
@pytest.mark.parametrize('mapping,expected', DEFAULT_STRINGS)
 
def test_format_defaults(mapping, expected):
 
    balance = core.Balance(amounts_from_map(mapping))
 
    assert balance.format() == expected
 

	
 
@pytest.mark.parametrize('fmt,expected', [
 
    ('¤##0.0', '¥5000, -€1500.00'),
 
    ('#,##0.0¤¤', '5,000JPY, -1,500.00EUR'),
 
    ('¤+##0.0;¤-##0.0', '¥+5000, €-1500.00'),
 
    ('#,##0.0 ¤¤;(#,##0.0 ¤¤)', '5,000 JPY, (1,500.00 EUR)'),
 
])
 
def test_format_fmt(fmt, expected):
 
    amounts = [testutil.Amount(5000, 'JPY'), testutil.Amount(-1500, 'EUR')]
 
    balance = core.Balance(amounts)
 
    assert balance.format(fmt) == expected
 

	
 
@pytest.mark.parametrize('sep', [
 
    '; ',
 
    '—',
 
    '\0',
 
])
 
def test_format_sep(sep):
 
    mapping, expected = DEFAULT_STRINGS[-1]
 
    expected = expected.replace(', ', sep)
 
    balance = core.Balance(amounts_from_map(mapping))
 
    assert balance.format(sep=sep) == expected
 

	
 
def test_format_none():
 
    args = (65000, 'BRL')
 
@pytest.mark.parametrize('number', [65000, -77000])
 
def test_format_none(number):
 
    args = (number, 'BRL')
 
    balance = core.Balance([testutil.Amount(*args)])
 
    expected = babel.numbers.format_currency(*args)
 
    expected = babel.numbers.format_currency(*args, format_type='accounting')
 
    assert balance.format(None) == expected
 

	
 
@pytest.mark.parametrize('empty', [
 
    "N/A",
 
    "Zero",
 
    "ø",
 
])
 
def test_format_empty(empty):
 
    balance = core.Balance()
 
    assert balance.format(empty=empty) == empty
 

	
 
@pytest.mark.parametrize('currency,fmt', itertools.product(
 
    ['USD', 'JPY', 'BRL'],
 
    [None, '¤#,##0.00', '###0.00 ¤¤'],
 
))
 
def test_format_zero_balance_fmt(currency, fmt):
 
    zero_amt = testutil.Amount(0, currency)
 
    nonzero_amt = testutil.Amount(9, currency)
 
    zero_bal = core.Balance([zero_amt])
 
    nonzero_bal = core.Balance([nonzero_amt])
 
    expected = nonzero_bal.format(fmt).replace('9', '0')
 
    assert zero_bal.format(fmt) == expected
 

	
 
@pytest.mark.parametrize('currency,fmt', testutil.combine_values(
 
    ['USD', 'JPY', 'BRL'],
 
    ["N/A", "Zero", "ø"],
 
))
 
def test_format_zero_balance_zero_str(currency, fmt):
 
    zero_amt = testutil.Amount(0, currency)
 
    zero_bal = core.Balance([zero_amt])
 
    assert zero_bal.format(zero=fmt) == fmt
 

	
 
@pytest.mark.parametrize('tolerance', TOLERANCES)
 
def test_format_zero_balance_with_tolerance(tolerance):
 
    chf = testutil.Amount('.005', 'CHF')
 
    actual = core.Balance([chf]).format(zero="ø", tolerance=tolerance)
 
    if tolerance > chf.number:
 
        assert actual == "ø"
 
    else:
 
        assert actual == "0.00 CHF"
0 comments (0 inline, 0 general)