From 8b2683d96290014d85cb33dc6430614f47787dc3 2020-05-28 19:52:10 From: Brett Smith Date: 2020-05-28 19:52:10 Subject: [PATCH] accrual: Refactor reports into classes. Preparation for introducing the aging report. This helps us distinguish each report's setup requirements (different __init__ arguments). --- diff --git a/conservancy_beancount/reports/accrual.py b/conservancy_beancount/reports/accrual.py index 155cc12b2c4b5380842e8b43d7eb28b8cde8bc73..fe4ac6d4c514849bb0066cf3e7326e367fc1f3d1 100644 --- a/conservancy_beancount/reports/accrual.py +++ b/conservancy_beancount/reports/accrual.py @@ -96,10 +96,6 @@ from .. import filters from .. import rtutil PostGroups = Mapping[Optional[MetaValue], core.RelatedPostings] -ReportFunc = Callable[ - [PostGroups, TextIO, TextIO, Optional[rt.Rt], Optional[rtutil.RT]], - None -] RTObject = Mapping[str, str] class Account(NamedTuple): @@ -132,37 +128,161 @@ class AccrualAccount(enum.Enum): } -class ReportType: - NAMES: Set[str] = set() - BY_NAME: Dict[str, ReportFunc] = {} +class BaseReport: + def __init__(self, out_file: TextIO, err_file: TextIO) -> None: + self.out_file = out_file + self.err_file = err_file + + def _since_last_nonzero(self, posts: core.RelatedPostings) -> core.RelatedPostings: + retval = core.RelatedPostings() + for post in posts: + if retval.balance().is_zero(): + retval.clear() + retval.add(post) + return retval + + def _report(self, + invoice: str, + posts: core.RelatedPostings, + index: int, + ) -> Iterable[str]: + raise NotImplementedError("BaseReport._report") + + def run(self, groups: PostGroups) -> None: + for index, invoice in enumerate(groups): + for line in self._report(str(invoice), groups[invoice], index): + print(line, file=self.out_file) + + +class BalanceReport(BaseReport): + def _report(self, + invoice: str, + posts: core.RelatedPostings, + index: int, + ) -> Iterable[str]: + posts = self._since_last_nonzero(posts) + balance = posts.balance() + date_s = posts[0].meta.date.strftime('%Y-%m-%d') + if index: + yield "" + yield f"{invoice}:" + yield f" {balance} outstanding since {date_s}" + + +class OutgoingReport(BaseReport): + def __init__(self, rt_client: rt.Rt, out_file: TextIO, err_file: TextIO) -> None: + self.rt_client = rt_client + self.rt_wrapper = rtutil.RT(rt_client) + self.out_file = out_file + self.err_file = err_file + + def _primary_rt_id(self, posts: core.RelatedPostings) -> rtutil.TicketAttachmentIds: + rt_ids = posts.all_meta_links('rt-id') + rt_ids_count = len(rt_ids) + if rt_ids_count != 1: + raise ValueError(f"{rt_ids_count} rt-id links found") + parsed = rtutil.RT.parse(rt_ids.pop()) + if parsed is None: + raise ValueError("rt-id is not a valid RT reference") + else: + return parsed + + def _report(self, + invoice: str, + posts: core.RelatedPostings, + index: int, + ) -> Iterable[str]: + posts = self._since_last_nonzero(posts) + try: + ticket_id, _ = self._primary_rt_id(posts) + ticket = self.rt_client.get_ticket(ticket_id) + # Note we only use this when ticket is None. + errmsg = f"ticket {ticket_id} not found" + except (ValueError, rt.RtError) as error: + ticket = None + errmsg = error.args[0] + if ticket is None: + print("error: can't generate outgoings report for {}" + " because no RT ticket available: {}".format( + invoice, errmsg, + ), file=self.err_file) + return + + try: + rt_requestor = self.rt_client.get_user(ticket['Requestors'][0]) + except (IndexError, rt.RtError): + rt_requestor = None + if rt_requestor is None: + requestor = '' + requestor_name = '' + else: + requestor_name = ( + rt_requestor.get('RealName') + or ticket.get('CF.{payment-to}') + or '' + ) + requestor = f'{requestor_name} <{rt_requestor["EmailAddress"]}>'.strip() + + raw_balance = -posts.balance() + cost_balance = -posts.balance_at_cost() + cost_balance_s = cost_balance.format(None) + if raw_balance == cost_balance: + balance_s = cost_balance_s + else: + balance_s = f'{raw_balance} ({cost_balance_s})' + + contract_links = posts.all_meta_links('contract') + if contract_links: + contract_s = ' , '.join(self.rt_wrapper.iter_urls( + contract_links, missing_fmt='', + )) + else: + contract_s = "NO CONTRACT GOVERNS THIS TRANSACTION" + projects = [v for v in posts.meta_values('project') + if isinstance(v, str)] + + yield "PAYMENT FOR APPROVAL:" + yield f"REQUESTOR: {requestor}" + yield f"TOTAL TO PAY: {balance_s}" + yield f"AGREEMENT: {contract_s}" + yield f"PAYMENT TO: {ticket.get('CF.{payment-to}') or requestor_name}" + yield f"PAYMENT METHOD: {ticket.get('CF.{payment-method}', '')}" + yield f"PROJECT: {', '.join(projects)}" + yield "\nBEANCOUNT ENTRIES:\n" + + last_txn: Optional[Transaction] = None + for post in posts: + txn = post.meta.txn + if txn is not last_txn: + last_txn = txn + txn = self.rt_wrapper.txn_with_urls(txn, '{}') + yield bc_printer.format_entry(txn) - @classmethod - def register(cls, *names: str) -> Callable[[ReportFunc], ReportFunc]: - def register_wrapper(func: ReportFunc) -> ReportFunc: - for name in names: - cls.BY_NAME[name] = func - cls.NAMES.add(names[0]) - return func - return register_wrapper + +class ReportType(enum.Enum): + BALANCE = BalanceReport + OUTGOING = OutgoingReport + BAL = BALANCE + OUT = OUTGOING + OUTGOINGS = OUTGOING @classmethod - def by_name(cls, name: str) -> ReportFunc: + def by_name(cls, name: str) -> 'ReportType': try: - return cls.BY_NAME[name.lower()] + return cls[name.upper()] except KeyError: raise ValueError(f"unknown report type {name!r}") from None @classmethod - def default_for(cls, groups: PostGroups) -> ReportFunc: + def default_for(cls, groups: PostGroups) -> 'ReportType': if len(groups) == 1 and all( AccrualAccount.classify(group) is AccrualAccount.PAYABLE and not AccrualAccount.PAYABLE.value.balance_paid(group.balance()) for group in groups.values() ): - report_name = 'outgoing' + return cls.OUTGOING else: - report_name = 'balance' - return cls.BY_NAME[report_name] + return cls.BALANCE class ReturnFlag(enum.IntFlag): @@ -217,123 +337,6 @@ def consistency_check(groups: PostGroups) -> Iterable[Error]: post.meta.txn, ) -def _since_last_nonzero(posts: core.RelatedPostings) -> core.RelatedPostings: - retval = core.RelatedPostings() - for post in posts: - if retval.balance().is_zero(): - retval.clear() - retval.add(post) - return retval - -@ReportType.register('balance', 'bal') -def balance_report(groups: PostGroups, - out_file: TextIO, - err_file: TextIO=sys.stderr, - rt_client: Optional[rt.Rt]=None, - rt_wrapper: Optional[rtutil.RT]=None, -) -> None: - prefix = '' - for invoice, related in groups.items(): - related = _since_last_nonzero(related) - balance = related.balance() - date_s = related[0].meta.date.strftime('%Y-%m-%d') - print( - f"{prefix}{invoice}:", - f" {balance} outstanding since {date_s}", - sep='\n', file=out_file, - ) - prefix = '\n' - -def _primary_rt_id(related: core.RelatedPostings) -> rtutil.TicketAttachmentIds: - rt_ids = related.all_meta_links('rt-id') - rt_ids_count = len(rt_ids) - if rt_ids_count != 1: - raise ValueError(f"{rt_ids_count} rt-id links found") - parsed = rtutil.RT.parse(rt_ids.pop()) - if parsed is None: - raise ValueError("rt-id is not a valid RT reference") - else: - return parsed - -@ReportType.register('outgoing', 'outgoings', 'out') -def outgoing_report(groups: PostGroups, - out_file: TextIO, - err_file: TextIO=sys.stderr, - rt_client: Optional[rt.Rt]=None, - rt_wrapper: Optional[rtutil.RT]=None, -) -> None: - if rt_client is None or rt_wrapper is None: - raise ValueError("RT client is required but not configured") - for invoice, related in groups.items(): - related = _since_last_nonzero(related) - try: - ticket_id, _ = _primary_rt_id(related) - ticket = rt_client.get_ticket(ticket_id) - # Note we only use this when ticket is None. - errmsg = f"ticket {ticket_id} not found" - except (ValueError, rt.RtError) as error: - ticket = None - errmsg = error.args[0] - if ticket is None: - print("error: can't generate outgoings report for {}" - " because no RT ticket available: {}".format( - invoice, errmsg, - ), file=err_file) - continue - - try: - rt_requestor = rt_client.get_user(ticket['Requestors'][0]) - except (IndexError, rt.RtError): - rt_requestor = None - if rt_requestor is None: - requestor = '' - requestor_name = '' - else: - requestor_name = ( - rt_requestor.get('RealName') - or ticket.get('CF.{payment-to}') - or '' - ) - requestor = f'{requestor_name} <{rt_requestor["EmailAddress"]}>'.strip() - - raw_balance = -related.balance() - cost_balance = -related.balance_at_cost() - cost_balance_s = cost_balance.format(None) - if raw_balance == cost_balance: - balance_s = cost_balance_s - else: - balance_s = f'{raw_balance} ({cost_balance_s})' - - contract_links = related.all_meta_links('contract') - if contract_links: - contract_s = ' , '.join(rt_wrapper.iter_urls( - contract_links, missing_fmt='', - )) - else: - contract_s = "NO CONTRACT GOVERNS THIS TRANSACTION" - projects = [v for v in related.meta_values('project') - if isinstance(v, str)] - - print( - "PAYMENT FOR APPROVAL:", - f"REQUESTOR: {requestor}", - f"TOTAL TO PAY: {balance_s}", - f"AGREEMENT: {contract_s}", - f"PAYMENT TO: {ticket.get('CF.{payment-to}') or requestor_name}", - f"PAYMENT METHOD: {ticket.get('CF.{payment-method}', '')}", - f"PROJECT: {', '.join(projects)}", - "\nBEANCOUNT ENTRIES:\n", - sep='\n', file=out_file, - ) - - last_txn: Optional[Transaction] = None - for post in related: - txn = post.meta.txn - if txn is not last_txn: - last_txn = txn - txn = rt_wrapper.txn_with_urls(txn, '{}') - bc_printer.print_entry(txn, file=out_file) - def filter_search(postings: Iterable[data.Posting], search_terms: Iterable[SearchTerm], ) -> Iterable[data.Posting]: @@ -411,21 +414,22 @@ def main(arglist: Optional[Sequence[str]]=None, if not groups: print("warning: no matching entries found to report", file=stderr) returncode |= ReturnFlag.NOTHING_TO_REPORT - else: - try: - args.report_type( - groups, - stdout, - stderr, - config.rt_client(), - config.rt_wrapper(), + report: Optional[BaseReport] = None + if args.report_type is ReportType.OUTGOING: + rt_client = config.rt_client() + if rt_client is None: + print( + "error: unable to generate outgoing report: RT client is required", + file=stderr, ) - except ValueError as exc: - print("error: unable to generate {}: {}".format( - args.report_type.__name__.replace('_', ' '), - exc.args[0], - ), file=stderr) - returncode |= ReturnFlag.REPORT_ERRORS + else: + report = OutgoingReport(rt_client, stdout, stderr) + else: + report = args.report_type.value(stdout, stderr) + if report is None: + returncode |= ReturnFlag.REPORT_ERRORS + else: + report.run(groups) return 0 if returncode == 0 else 16 + returncode if __name__ == '__main__': diff --git a/tests/test_reports_accrual.py b/tests/test_reports_accrual.py index 0e661f77563df7daf2b091f781dd3a084a8b2a45..1af316d940ba6dca473628ab4a9ba979b9e091a6 100644 --- a/tests/test_reports_accrual.py +++ b/tests/test_reports_accrual.py @@ -178,16 +178,16 @@ def test_filter_search(accrual_postings, search_terms, expect_count, check_func) assert check_func(post) @pytest.mark.parametrize('arg,expected', [ - ('balance', accrual.balance_report), - ('outgoing', accrual.outgoing_report), - ('bal', accrual.balance_report), - ('out', accrual.outgoing_report), - ('outgoings', accrual.outgoing_report), + ('balance', accrual.BalanceReport), + ('outgoing', accrual.OutgoingReport), + ('bal', accrual.BalanceReport), + ('out', accrual.OutgoingReport), + ('outgoings', accrual.OutgoingReport), ]) def test_report_type_by_name(arg, expected): - assert accrual.ReportType.by_name(arg.lower()) is expected - assert accrual.ReportType.by_name(arg.title()) is expected - assert accrual.ReportType.by_name(arg.upper()) is expected + assert accrual.ReportType.by_name(arg.lower()).value is expected + assert accrual.ReportType.by_name(arg.title()).value is expected + assert accrual.ReportType.by_name(arg.upper()).value is expected @pytest.mark.parametrize('arg', [ 'unknown', @@ -260,8 +260,8 @@ def run_outgoing(invoice, postings, rt_client=None): postings = relate_accruals_by_meta(postings, invoice) output = io.StringIO() errors = io.StringIO() - rt_cache = rtutil.RT(rt_client) - accrual.outgoing_report({invoice: postings}, output, errors, rt_client, rt_cache) + report = accrual.OutgoingReport(rt_client, output, errors) + report.run({invoice: postings}) return output, errors @pytest.mark.parametrize('invoice,expected', [ @@ -273,7 +273,10 @@ def run_outgoing(invoice, postings, rt_client=None): def test_balance_report(accrual_postings, invoice, expected): related = relate_accruals_by_meta(accrual_postings, invoice) output = io.StringIO() - accrual.balance_report({invoice: related}, output) + errors = io.StringIO() + report = accrual.BalanceReport(output, errors) + report.run({invoice: related}) + assert not errors.getvalue() check_output(output, [invoice, expected]) def test_outgoing_report(accrual_postings):