diff --git a/conservancy_beancount/reports/core.py b/conservancy_beancount/reports/core.py index a00e38d78f7e8808d07260c3362fc2a4a2ee67e5..dbf602da6b177c590b209fc3d759b00e90168bb7 100644 --- a/conservancy_beancount/reports/core.py +++ b/conservancy_beancount/reports/core.py @@ -133,6 +133,9 @@ class Balance(Mapping[str, data.Amount]): self._add_other(retval_map, other) return type(self)(retval_map.values()) + def __sub__(self: BalanceType, other: Union[data.Amount, 'Balance']) -> BalanceType: + return self.__add__(-other) + def __eq__(self, other: Any) -> bool: if isinstance(other, Balance): clean_self = self.clean_copy() @@ -228,6 +231,10 @@ class MutableBalance(Balance): self._add_other(self._currency_map, other) return self + def __isub__(self: BalanceType, other: Union[data.Amount, Balance]) -> BalanceType: + self._add_other(self._currency_map, -other) + return self + class RelatedPostings(Sequence[data.Posting]): """Collect and query related postings diff --git a/tests/test_reports_balance.py b/tests/test_reports_balance.py index be5b265c218638653775caed09a1677a909b7db0..24c435e91cbc6fa91c0364717acf8a55a5dba8e8 100644 --- a/tests/test_reports_balance.py +++ b/tests/test_reports_balance.py @@ -270,6 +270,44 @@ def test_iadd_amount(number, currency): assert balance['USD'] == testutil.Amount(500) assert balance[currency] == add_amount +@pytest.mark.parametrize('number,currency', { + (50, 'USD'), + (-50, 'USD'), + (50000, 'BRL'), + (-4000, 'BRL'), +}) +def test_sub_amount(number, currency): + start_amount = testutil.Amount(500, 'USD') + start_bal = core.Balance([start_amount]) + sub_amount = testutil.Amount(number, currency) + actual = start_bal - sub_amount + if currency == 'USD': + assert len(actual) == 1 + assert actual['USD'] == testutil.Amount(500 - number) + else: + assert len(actual) == 2 + assert actual['USD'] == start_amount + assert actual[currency] == -sub_amount + assert start_bal == {'USD': start_amount} + +@pytest.mark.parametrize('number,currency', { + (50, 'USD'), + (-50, 'USD'), + (50000, 'BRL'), + (-4000, 'BRL'), +}) +def test_isub_amount(number, currency): + balance = core.MutableBalance([testutil.Amount(500, 'USD')]) + sub_amount = testutil.Amount(number, currency) + balance -= sub_amount + if currency == 'USD': + assert len(balance) == 1 + assert balance['USD'] == testutil.Amount(500 - number) + else: + assert len(balance) == 2 + assert balance['USD'] == testutil.Amount(500) + assert balance[currency] == -sub_amount + @pytest.mark.parametrize('mapping', [ {}, {'USD': 0},