diff --git a/conservancy_beancount/reports/core.py b/conservancy_beancount/reports/core.py index 72ade4db951079379bae5b8ccb5977cbc558a4d7..2c5aaf5096ad2f205e97e29f81276bd155c3aa47 100644 --- a/conservancy_beancount/reports/core.py +++ b/conservancy_beancount/reports/core.py @@ -80,10 +80,17 @@ class Balance(Mapping[str, data.Amount]): Each key is a Beancount currency string, and each value represents the balance in that currency. """ - __slots__ = ('_currency_map',) + __slots__ = ('_currency_map', 'tolerance') + TOLERANCE = Decimal('0.01') - def __init__(self, source: Iterable[data.Amount]=()) -> None: + def __init__(self, + source: Iterable[data.Amount]=(), + tolerance: Optional[Decimal]=None, + ) -> None: + if tolerance is None: + tolerance = self.TOLERANCE self._currency_map = {amount.currency: amount for amount in source} + self.tolerance = tolerance def _add_amount(self, currency_map: MutableMapping[str, data.Amount], @@ -147,19 +154,24 @@ class Balance(Mapping[str, data.Amount]): ) -> bool: return all(op_func(amt.number, operand) for amt in self.values()) + @staticmethod + def within_tolerance(dec: DecimalCompat, tolerance: DecimalCompat) -> bool: + dec = cast(Decimal, dec) + return abs(dec) < tolerance + def eq_zero(self) -> bool: - """Returns true if all amounts in the balance == 0.""" - return self._all_amounts(operator.eq, 0) + """Returns true if all amounts in the balance == 0, within tolerance.""" + return self._all_amounts(self.within_tolerance, self.tolerance) is_zero = eq_zero def ge_zero(self) -> bool: - """Returns true if all amounts in the balance >= 0.""" - return self._all_amounts(operator.ge, 0) + """Returns true if all amounts in the balance >= 0, within tolerance.""" + return self._all_amounts(operator.ge, -self.tolerance) def le_zero(self) -> bool: - """Returns true if all amounts in the balance <= 0.""" - return self._all_amounts(operator.le, 0) + """Returns true if all amounts in the balance <= 0, within tolerance.""" + return self._all_amounts(operator.le, self.tolerance) def format(self, fmt: Optional[str]='#,#00.00 ¤¤', diff --git a/tests/test_reports_balance.py b/tests/test_reports_balance.py index d827018f52d459839657bda2acf7a5b7bd93b473..8ab0a587b9cd9eb7089af14fcfd544200009a73d 100644 --- a/tests/test_reports_balance.py +++ b/tests/test_reports_balance.py @@ -92,6 +92,8 @@ def test_mixed_balance(): ({'JPY': 10}, False), ({'JPY': 10, 'BRL': 0}, False), ({'JPY': 10, 'BRL': 20}, False), + ({'USD': '0.00015'}, True), + ({'EUR': '-0.00052'}, True), ]) def test_eq_zero(mapping, expected): balance = core.Balance(amounts_from_map(mapping)) @@ -108,6 +110,8 @@ def test_eq_zero(mapping, expected): ({'JPY': 10}, True), ({'JPY': 10, 'BRL': 0}, True), ({'JPY': 10, 'BRL': 20}, True), + ({'USD': '0.00015'}, True), + ({'EUR': '-0.00052'}, True), ]) def test_ge_zero(mapping, expected): balance = core.Balance(amounts_from_map(mapping)) @@ -123,6 +127,8 @@ def test_ge_zero(mapping, expected): ({'JPY': 10}, False), ({'JPY': 10, 'BRL': 0}, False), ({'JPY': 10, 'BRL': 20}, False), + ({'USD': '0.00015'}, True), + ({'EUR': '-0.00052'}, True), ]) def test_le_zero(mapping, expected): balance = core.Balance(amounts_from_map(mapping))