diff --git a/conservancy_beancount/reports/core.py b/conservancy_beancount/reports/core.py index cf5a39a7ca440f6073cfa68165533c37ceccc5ab..6b2845c70f7b3fcf29aef1af8c603fee4ab62ba0 100644 --- a/conservancy_beancount/reports/core.py +++ b/conservancy_beancount/reports/core.py @@ -132,10 +132,12 @@ class Balance(Mapping[str, data.Amount]): return type(self)(retval_map.values()) def __eq__(self, other: Any) -> bool: - if (self.is_zero() - and isinstance(other, Balance) - and other.is_zero()): - return True + if isinstance(other, Balance): + clean_self = self.clean_copy() + clean_other = other.clean_copy() + return len(clean_self) == len(clean_other) and all( + clean_self[key] == clean_other.get(key) for key in clean_self + ) else: return super().__eq__(other) @@ -160,6 +162,17 @@ class Balance(Mapping[str, data.Amount]): ) -> bool: return all(op_func(amt.number, operand) for amt in self.values()) + def copy(self: BalanceType) -> BalanceType: + return type(self)(self.values()) + + def clean_copy(self: BalanceType, tolerance: Optional[Decimal]=None) -> BalanceType: + if tolerance is None: + tolerance = self.tolerance + return type(self)( + amount for amount in self.values() + if abs(amount.number) >= tolerance + ) + @staticmethod def within_tolerance(dec: DecimalCompat, tolerance: DecimalCompat) -> bool: dec = cast(Decimal, dec)