diff --git a/conservancy_beancount/reports/accrual.py b/conservancy_beancount/reports/accrual.py index eb6d372f8a1228dddfb94ab2c20ba16072ce6574..86deed5b4d1c09f27f6ff7cae4213d4e0a5a5044 100644 --- a/conservancy_beancount/reports/accrual.py +++ b/conservancy_beancount/reports/accrual.py @@ -115,12 +115,12 @@ class Sentinel: class Account(NamedTuple): name: str - balance_paid: Callable[[core.Balance], bool] + norm_func: Callable[[core.Balance], core.Balance] class AccrualAccount(enum.Enum): - PAYABLE = Account('Liabilities:Payable', core.Balance.ge_zero) - RECEIVABLE = Account('Assets:Receivable', core.Balance.le_zero) + PAYABLE = Account('Liabilities:Payable', operator.neg) + RECEIVABLE = Account('Assets:Receivable', lambda bal: bal) @classmethod def account_names(cls) -> Iterator[str]: @@ -134,14 +134,6 @@ class AccrualAccount(enum.Enum): return account raise ValueError("unrecognized account set in related postings") - @classmethod - def filter_paid_accruals(cls, groups: PostGroups) -> PostGroups: - return { - key: related - for key, related in groups.items() - if not cls.classify(related).value.balance_paid(related.balance()) - } - class AccrualPostings(core.RelatedPostings): def _meta_getter(key: MetaKey) -> Callable[[data.Posting], MetaValue]: # type:ignore[misc] @@ -160,6 +152,7 @@ class AccrualPostings(core.RelatedPostings): INCONSISTENT = Sentinel() __slots__ = ( 'accrual_type', + 'final_bal', 'account', 'accounts', 'contract', @@ -198,8 +191,10 @@ class AccrualPostings(core.RelatedPostings): self.entities = self.entitys if self.account is self.INCONSISTENT: self.accrual_type: Optional[AccrualAccount] = None + self.final_bal = self.balance() else: self.accrual_type = AccrualAccount.classify(self) + self.final_bal = self.accrual_type.value.norm_func(self.balance()) def make_consistent(self) -> Iterator[Tuple[MetaValue, 'AccrualPostings']]: account_ok = isinstance(self.account, str) @@ -232,21 +227,33 @@ class AccrualPostings(core.RelatedPostings): ) yield Error(post.meta, errmsg, post.meta.txn) + def is_paid(self, default: Optional[bool]=None) -> Optional[bool]: + if self.accrual_type is None: + return default + else: + return self.final_bal.le_zero() -class BaseReport: - def __init__(self, out_file: TextIO) -> None: - self.out_file = out_file - self.logger = logger.getChild(type(self).__name__) + def is_zero(self, default: Optional[bool]=None) -> Optional[bool]: + if self.accrual_type is None: + return default + else: + return self.final_bal.is_zero() - def _since_last_nonzero(self, posts: AccrualPostings) -> AccrualPostings: - for index, (post, balance) in enumerate(posts.iter_with_balance()): + def since_last_nonzero(self) -> 'AccrualPostings': + for index, (post, balance) in enumerate(self.iter_with_balance()): if balance.is_zero(): start_index = index try: empty = start_index == index except NameError: empty = True - return posts if empty else AccrualPostings(posts[start_index + 1:]) + return self if empty else self[start_index + 1:] + + +class BaseReport: + def __init__(self, out_file: TextIO) -> None: + self.out_file = out_file + self.logger = logger.getChild(type(self).__name__) def _report(self, invoice: str, @@ -267,7 +274,7 @@ class BalanceReport(BaseReport): posts: AccrualPostings, index: int, ) -> Iterable[str]: - posts = self._since_last_nonzero(posts) + posts = posts.since_last_nonzero() balance = posts.balance() date_s = posts[0].meta.date.strftime('%Y-%m-%d') if index: @@ -298,7 +305,7 @@ class OutgoingReport(BaseReport): posts: AccrualPostings, index: int, ) -> Iterable[str]: - posts = self._since_last_nonzero(posts) + posts = posts.since_last_nonzero() try: ticket_id, _ = self._primary_rt_id(posts) ticket = self.rt_client.get_ticket(ticket_id) @@ -329,13 +336,12 @@ class OutgoingReport(BaseReport): ) requestor = f'{requestor_name} <{rt_requestor["EmailAddress"]}>'.strip() - raw_balance = -posts.balance() cost_balance = -posts.balance_at_cost() cost_balance_s = cost_balance.format(None) - if raw_balance == cost_balance: + if posts.final_bal == cost_balance: balance_s = cost_balance_s else: - balance_s = f'{raw_balance} ({cost_balance_s})' + balance_s = f'{posts.final_bal} ({cost_balance_s})' contract_links = posts.all_meta_links('contract') if contract_links: @@ -382,8 +388,8 @@ class ReportType(enum.Enum): @classmethod def default_for(cls, groups: PostGroups) -> 'ReportType': if len(groups) == 1 and all( - AccrualAccount.classify(group) is AccrualAccount.PAYABLE - and not AccrualAccount.PAYABLE.value.balance_paid(group.balance()) + group.accrual_type is AccrualAccount.PAYABLE + and not group.is_paid() for group in groups.values() ): return cls.OUTGOING @@ -501,7 +507,7 @@ def main(arglist: Optional[Sequence[str]]=None, filters.remove_opening_balance_txn(entries) postings = filter_search(data.Posting.from_entries(entries), args.search_terms) groups: PostGroups = dict(AccrualPostings.group_by_meta(postings, 'invoice')) - groups = AccrualAccount.filter_paid_accruals(groups) or groups + groups = {key: group for key, group in groups.items() if not group.is_paid()} or groups returncode = 0 for error in load_errors: bc_printer.print_error(error, file=stderr)