diff --git a/conservancy_beancount/reports/core.py b/conservancy_beancount/reports/core.py index 3c3952e777d1470d832ba1d6ecaa531f4717541b..60a0510985057918f7d15e3d7a2d50cf77e17847 100644 --- a/conservancy_beancount/reports/core.py +++ b/conservancy_beancount/reports/core.py @@ -1069,3 +1069,46 @@ def normalize_amount_func(account_name: str) -> Callable[[T], T]: return operator.neg else: raise ValueError(f"unrecognized account name {account_name!r}") + +def sort_and_filter_accounts( + accounts: Iterable[data.Account], + order: Sequence[str], +) -> Iterator[Tuple[int, data.Account]]: + """Reorganize accounts based on an ordered set of names + + This function takes a iterable of Account objects, and a sequence of + account names. Usually the account names are higher parts of the account + hierarchy like Income, Equity, or Assets:Receivable. + + It returns an iterator of 2-tuples, ``(index, account)`` where ``index`` is + an index into the ordering sequence, and ``account`` is one of the input + Account objects that's under the account name ``order[index]``. Tuples are + sorted, so ``index`` increases monotonically, and Account objects using the + same index are yielded sorted by name. + + For example, if your order is + ``['Liabilities:Payable', 'Assets:Receivable']``, the return value will + first yield zero or more results with index 0 and an account under + Liabilities:Payable, then zero or more results with index 1 and an account + under Accounts:Receivable. + + Input Accounts that are not under any of the account names in ``order`` do + not appear in the output iterator. That's the filtering part. + + Note that if none of the input Accounts are under one of the ordering + sequence accounts, its index will never appear in the results. This is why + the 2-tuples include an index rather than the original account name string, + to make it easier for callers to know when this happens and do something + with unused ordering accounts. + """ + index_map = {s: ii for ii, s in enumerate(order)} + retval: Mapping[int, List[data.Account]] = collections.defaultdict(list) + for account in accounts: + acct_key = account.is_under(*order) + if acct_key is not None: + retval[index_map[acct_key]].append(account) + return ( + (key, account) + for key in sorted(retval) + for account in sorted(retval[key]) + ) diff --git a/conservancy_beancount/reports/ledger.py b/conservancy_beancount/reports/ledger.py index 6676f570f213c76bc6ab013e6c27c0aff4c496b8..333ceef563ac4f04f634b1d79c89eb3ad934ec37 100644 --- a/conservancy_beancount/reports/ledger.py +++ b/conservancy_beancount/reports/ledger.py @@ -240,23 +240,6 @@ class LedgerODS(core.BaseODS[data.Posting, data.Account]): for sheet_name in cls._split_sheet(split_tally[key], sheet_size, key) ] - @staticmethod - def _sort_and_filter_accounts( - accounts: Iterable[data.Account], - order: Sequence[str], - ) -> Iterator[Tuple[int, data.Account]]: - index_map = {s: ii for ii, s in enumerate(order)} - retval: Mapping[int, List[data.Account]] = collections.defaultdict(list) - for account in accounts: - acct_key = account.is_under(*order) - if acct_key is not None: - retval[index_map[acct_key]].append(account) - for key in sorted(retval): - acct_list = retval[key] - acct_list.sort() - for account in acct_list: - yield key, account - def section_key(self, row: data.Posting) -> data.Account: return row.account @@ -383,7 +366,7 @@ class LedgerODS(core.BaseODS[data.Posting, data.Account]): )) self.add_row() self._combined_balance_row(balance_accounts, 'start') - for _, account in self._sort_and_filter_accounts( + for _, account in core.sort_and_filter_accounts( self.account_groups, balance_accounts, ): balance = self.account_groups[account].period_bal @@ -413,7 +396,7 @@ class LedgerODS(core.BaseODS[data.Posting, data.Account]): tally_by_account, self.required_sheet_names, self.sheet_size, ) using_sheet_index = -1 - for sheet_index, account in self._sort_and_filter_accounts( + for sheet_index, account in core.sort_and_filter_accounts( tally_by_account, sheet_names, ): while using_sheet_index < sheet_index: diff --git a/tests/test_reports_core.py b/tests/test_reports_core.py index a0b8a03a3604aab4ec266fc9e2cecd4388603bda..2be6813302f3b734b4464c408784e71b0f78e217 100644 --- a/tests/test_reports_core.py +++ b/tests/test_reports_core.py @@ -22,6 +22,8 @@ from . import testutil from conservancy_beancount.reports import core +from conservancy_beancount.data import Account + AMOUNTS = [ 2, Decimal('4.40'), @@ -64,3 +66,62 @@ def test_normalize_amount_func_neg(acct_name): def test_normalize_amount_func_bad_acct_name(acct_name): with pytest.raises(ValueError): core.normalize_amount_func(acct_name) + +def test_sort_and_filter_accounts(): + accounts = (Account(s) for s in [ + 'Expenses:Services', + 'Assets:Receivable', + 'Income:Other', + 'Liabilities:Payable', + 'Equity:Funds:Unrestricted', + 'Income:Donations', + 'Expenses:Other', + ]) + actual = core.sort_and_filter_accounts(accounts, ['Equity', 'Income', 'Expenses']) + assert list(actual) == [ + (0, 'Equity:Funds:Unrestricted'), + (1, 'Income:Donations'), + (1, 'Income:Other'), + (2, 'Expenses:Other'), + (2, 'Expenses:Services'), + ] + +def test_sort_and_filter_accounts_unused_name(): + accounts = (Account(s) for s in [ + 'Liabilities:CreditCard', + 'Assets:Cash', + 'Assets:Receivable:Accounts', + ]) + actual = core.sort_and_filter_accounts( + accounts, ['Assets:Receivable', 'Liabilities:Payable', 'Assets', 'Liabilities'], + ) + assert list(actual) == [ + (0, 'Assets:Receivable:Accounts'), + (2, 'Assets:Cash'), + (3, 'Liabilities:CreditCard'), + ] + +def test_sort_and_filter_accounts_with_subaccounts(): + accounts = (Account(s) for s in [ + 'Assets:Checking', + 'Assets:Receivable:Fraud', + 'Assets:Cash', + 'Assets:Receivable:Accounts', + ]) + actual = core.sort_and_filter_accounts(accounts, ['Assets:Receivable', 'Assets']) + assert list(actual) == [ + (0, 'Assets:Receivable:Accounts'), + (0, 'Assets:Receivable:Fraud'), + (1, 'Assets:Cash'), + (1, 'Assets:Checking'), + ] + +@pytest.mark.parametrize('empty_arg', ['accounts', 'order']) +def test_sort_and_filter_accounts_empty_accounts(empty_arg): + accounts = [Account(s) for s in ['Expenses:Other', 'Income:Other']] + if empty_arg == 'accounts': + args = ([], accounts) + else: + args = (accounts, []) + actual = core.sort_and_filter_accounts(*args) + assert next(actual, None) is None