From 069939b2d3a5fdccb632e487042d94f33b268c20 2020-06-03 22:53:17 From: Brett Smith Date: 2020-06-03 22:53:17 Subject: [PATCH] reports: Balance classes support addition. --- diff --git a/conservancy_beancount/reports/core.py b/conservancy_beancount/reports/core.py index eb32eb97d4c904f3e87872b4776428f285ef694e..85cc5ff7c3c7a92ea733101d5180cc082242d2c8 100644 --- a/conservancy_beancount/reports/core.py +++ b/conservancy_beancount/reports/core.py @@ -49,6 +49,7 @@ from ..beancount_types import ( ) DecimalCompat = data.DecimalCompat +BalanceType = TypeVar('BalanceType', bound='Balance') RelatedType = TypeVar('RelatedType', bound='RelatedPostings') class Balance(Mapping[str, data.Amount]): @@ -69,6 +70,25 @@ class Balance(Mapping[str, data.Amount]): currency: amount.number for currency, amount in source } + def _add_amount(self, + currency_map: MutableMapping[str, Decimal], + amount: data.Amount, + ) -> None: + try: + currency_map[amount.currency] += amount.number + except KeyError: + currency_map[amount.currency] = amount.number + + def _add_other(self, + currency_map: MutableMapping[str, Decimal], + other: Union[data.Amount, 'Balance'], + ) -> None: + if isinstance(other, Balance): + for amount in other.values(): + self._add_amount(currency_map, amount) + else: + self._add_amount(currency_map, other) + def __repr__(self) -> str: return f"{type(self).__name__}({self._currency_map!r})" @@ -80,6 +100,12 @@ class Balance(Mapping[str, data.Amount]): (key, bc_amount.abs(amt)) for key, amt in self.items() ) + def __add__(self: BalanceType, other: Union[data.Amount, 'Balance']) -> BalanceType: + retval_map = self._currency_map.copy() + self._add_other(retval_map, other) + return type(self)((code, data.Amount(number, code)) + for code, number in retval_map.items()) + def __eq__(self, other: Any) -> bool: if (self.is_zero() and isinstance(other, Balance) @@ -149,11 +175,9 @@ class Balance(Mapping[str, data.Amount]): class MutableBalance(Balance): __slots__ = () - def add_amount(self, amount: data.Amount) -> None: - try: - self._currency_map[amount.currency] += amount.number - except KeyError: - self._currency_map[amount.currency] = amount.number + def __iadd__(self: BalanceType, other: Union[data.Amount, Balance]) -> BalanceType: + self._add_other(self._currency_map, other) + return self class RelatedPostings(Sequence[data.Posting]): @@ -234,7 +258,7 @@ class RelatedPostings(Sequence[data.Posting]): def iter_with_balance(self) -> Iterator[Tuple[data.Posting, Balance]]: balance = MutableBalance() for post in self: - balance.add_amount(post.units) + balance += post.units yield post, balance def balance(self) -> Balance: @@ -249,10 +273,10 @@ class RelatedPostings(Sequence[data.Posting]): balance = MutableBalance() for post in self: if post.cost is None: - balance.add_amount(post.units) + balance += post.units else: number = post.units.number * post.cost.number - balance.add_amount(data.Amount(number, post.cost.currency)) + balance += data.Amount(number, post.cost.currency) return balance def meta_values(self, diff --git a/tests/test_reports_balance.py b/tests/test_reports_balance.py index 876c04bac72224fd53bc200ec0f8cae63790bcf7..d2866919669d796ddd00e084300370bc35752279 100644 --- a/tests/test_reports_balance.py +++ b/tests/test_reports_balance.py @@ -179,6 +179,84 @@ def test_eq(kwargs1, kwargs2, expected): actual = bal1 == bal2 assert actual == expected +@pytest.mark.parametrize('number,currency', { + (50, 'USD'), + (-50, 'USD'), + (50000, 'BRL'), + (-4000, 'BRL'), +}) +def test_add_amount(number, currency): + start_amounts = testutil.balance_map(USD=500) + start_bal = core.Balance(start_amounts) + add_amount = testutil.Amount(number, currency) + actual = start_bal + add_amount + if currency == 'USD': + assert len(actual) == 1 + assert actual['USD'] == testutil.Amount(500 + number) + else: + assert len(actual) == 2 + assert actual['USD'] == testutil.Amount(500) + assert actual[currency] == add_amount + assert start_bal == start_amounts + +@pytest.mark.parametrize('number,currency', { + (50, 'USD'), + (-50, 'USD'), + (50000, 'BRL'), + (-4000, 'BRL'), +}) +def test_iadd_amount(number, currency): + start_amounts = testutil.balance_map(USD=500) + balance = core.MutableBalance(start_amounts) + add_amount = testutil.Amount(number, currency) + balance += add_amount + if currency == 'USD': + assert len(balance) == 1 + assert balance['USD'] == testutil.Amount(500 + number) + else: + assert len(balance) == 2 + assert balance['USD'] == testutil.Amount(500) + assert balance[currency] == add_amount + +@pytest.mark.parametrize('balance_map_kwargs', [ + {}, + {'USD': 0}, + {'EUR': 10}, + {'JPY': 20, 'BRL': 30}, + {'EUR': -15}, + {'JPY': -25, 'BRL': -35}, + {'JPY': 40, 'USD': 0, 'EUR': -50}, +]) +def test_add_balance(balance_map_kwargs): + start_numbers = {'USD': 500, 'BRL': 40000} + start_bal = core.Balance(testutil.balance_map(**start_numbers)) + expect_numbers = start_numbers.copy() + for code, number in balance_map_kwargs.items(): + expect_numbers[code] = expect_numbers.get(code, 0) + number + add_bal = core.Balance(testutil.balance_map(**balance_map_kwargs)) + actual = start_bal + add_bal + expected = core.Balance(testutil.balance_map(**expect_numbers)) + assert actual == expected + +@pytest.mark.parametrize('balance_map_kwargs', [ + {}, + {'USD': 0}, + {'EUR': 10}, + {'JPY': 20, 'BRL': 30}, + {'EUR': -15}, + {'JPY': -25, 'BRL': -35}, + {'JPY': 40, 'USD': 0, 'EUR': -50}, +]) +def test_iadd_balance(balance_map_kwargs): + start_numbers = {'USD': 500, 'BRL': 40000} + balance = core.MutableBalance(testutil.balance_map(**start_numbers)) + expect_numbers = start_numbers.copy() + for code, number in balance_map_kwargs.items(): + expect_numbers[code] = expect_numbers.get(code, 0) + number + balance += core.Balance(testutil.balance_map(**balance_map_kwargs)) + expected = core.Balance(testutil.balance_map(**expect_numbers)) + assert balance == expected + @pytest.mark.parametrize('balance_map_kwargs,expected', DEFAULT_STRINGS) def test_str(balance_map_kwargs, expected): amounts = testutil.balance_map(**balance_map_kwargs)