diff --git a/conservancy_beancount/reports/core.py b/conservancy_beancount/reports/core.py index cf5a39a7ca440f6073cfa68165533c37ceccc5ab..6b2845c70f7b3fcf29aef1af8c603fee4ab62ba0 100644 --- a/conservancy_beancount/reports/core.py +++ b/conservancy_beancount/reports/core.py @@ -132,10 +132,12 @@ class Balance(Mapping[str, data.Amount]): return type(self)(retval_map.values()) def __eq__(self, other: Any) -> bool: - if (self.is_zero() - and isinstance(other, Balance) - and other.is_zero()): - return True + if isinstance(other, Balance): + clean_self = self.clean_copy() + clean_other = other.clean_copy() + return len(clean_self) == len(clean_other) and all( + clean_self[key] == clean_other.get(key) for key in clean_self + ) else: return super().__eq__(other) @@ -160,6 +162,17 @@ class Balance(Mapping[str, data.Amount]): ) -> bool: return all(op_func(amt.number, operand) for amt in self.values()) + def copy(self: BalanceType) -> BalanceType: + return type(self)(self.values()) + + def clean_copy(self: BalanceType, tolerance: Optional[Decimal]=None) -> BalanceType: + if tolerance is None: + tolerance = self.tolerance + return type(self)( + amount for amount in self.values() + if abs(amount.number) >= tolerance + ) + @staticmethod def within_tolerance(dec: DecimalCompat, tolerance: DecimalCompat) -> bool: dec = cast(Decimal, dec) diff --git a/tests/test_reports_balance.py b/tests/test_reports_balance.py index 22271857aadd74662da6ad4c8d0b375e987356a1..706531e63689f3e5196803cb7f9aaa949e98f3c1 100644 --- a/tests/test_reports_balance.py +++ b/tests/test_reports_balance.py @@ -34,6 +34,8 @@ DEFAULT_STRINGS = [ ({'JPY': '-5500.00', 'BRL': '-8500.00'}, "-8,500.00 BRL, -5,500 JPY"), ] +TOLERANCES = [Decimal(n) for n in ['.1', '.01', '.001', 0]] + def amounts_from_map(currency_map): for code, number in currency_map.items(): yield testutil.Amount(number, code) @@ -219,6 +221,15 @@ def test_eq(map1, map2, expected): actual = bal1 == bal2 assert actual == expected +@pytest.mark.parametrize('tolerance', TOLERANCES) +def test_eq_considers_tolerance(tolerance): + tolerance = Decimal(tolerance) + mapping = {'EUR': 100, 'USD': '.002'} + bal1 = core.Balance(amounts_from_map(mapping)) + mapping['USD'] = '.004' + bal2 = core.Balance(amounts_from_map(mapping), tolerance) + assert (bal1 == bal2) == (tolerance > Decimal('.002')) + @pytest.mark.parametrize('number,currency', { (50, 'USD'), (-50, 'USD'), @@ -294,6 +305,34 @@ def test_iadd_balance(mapping): expected = core.Balance(amounts_from_map(expect_numbers)) assert balance == expected +def test_copy(): + amounts = frozenset(amounts_from_map({'USD': 10, 'EUR': '.001'})) + # Use a ridiculous tolerance to test it doesn't matter. + actual = core.Balance(amounts, 100).copy() + assert frozenset(actual.values()) == amounts + +@pytest.mark.parametrize('tolerance', TOLERANCES) +def test_clean_copy(tolerance): + usd = testutil.Amount(10) + eur = testutil.Amount('.002', 'EUR') + actual = core.Balance([usd, eur], tolerance).clean_copy() + if tolerance < eur.number: + expected = {usd, eur} + else: + expected = {usd} + assert frozenset(actual.values()) == expected + +@pytest.mark.parametrize('tolerance', TOLERANCES) +def test_clean_copy_arg(tolerance): + usd = testutil.Amount(10) + eur = testutil.Amount('.002', 'EUR') + actual = core.Balance([usd, eur], 0).clean_copy(tolerance) + if tolerance < eur.number: + expected = {usd, eur} + else: + expected = {usd} + assert frozenset(actual.values()) == expected + @pytest.mark.parametrize('mapping,expected', DEFAULT_STRINGS) def test_str(mapping, expected): balance = core.Balance(amounts_from_map(mapping))