diff --git a/conservancy_beancount/reports/core.py b/conservancy_beancount/reports/core.py index 3a0fbbebb918f76973158cc202919a1949be92ed..25a872b3b7e0ea90e70d2cb276196e8c17b8152c 100644 --- a/conservancy_beancount/reports/core.py +++ b/conservancy_beancount/reports/core.py @@ -22,6 +22,7 @@ from .. import data from typing import ( overload, + DefaultDict, Dict, Iterable, Iterator, @@ -93,19 +94,32 @@ class RelatedPostings(Sequence[data.Posting]): entirely up to the caller. A common pattern is to use this class with collections.defaultdict - to organize postings based on some key:: - - report = collections.defaultdict(RelatedPostings) - for txn in transactions: - for post in Posting.from_txn(txn): - if should_report(post): - key = post_key(post) - report[key].add(post) + to organize postings based on some key. See the group_by_meta classmethod + for an example. """ def __init__(self) -> None: self._postings: List[data.Posting] = [] + @classmethod + def group_by_meta(cls, + postings: Iterable[data.Posting], + key: MetaKey, + default: Optional[MetaValue]=None, + ) -> Mapping[MetaKey, 'RelatedPostings']: + """Relate postings by metadata value + + This method takes an iterable of postings and returns a mapping. + The keys of the mapping are the values of post.meta.get(key, default). + The values are RelatedPostings instances that contain all the postings + that had that same metadata value. + """ + retval: DefaultDict[MetaKey, 'RelatedPostings'] = collections.defaultdict(cls) + for post in postings: + retval[post.meta.get(key, default)].add(post) + retval.default_factory = None + return retval + @overload def __getitem__(self, index: int) -> data.Posting: ... diff --git a/tests/test_reports_related_postings.py b/tests/test_reports_related_postings.py index 1848c0f29729b144502dec6dbb0d5b2fc691315a..ba0066ddfca9a4144d80fd27261e1208605711ca 100644 --- a/tests/test_reports_related_postings.py +++ b/tests/test_reports_related_postings.py @@ -30,9 +30,10 @@ from conservancy_beancount.reports import core def accruals_and_payments(acct, src_acct, dst_acct, start_date, *amounts): dates = testutil.date_seq(start_date) for amt, currency in amounts: + post_meta = {'metanumber': amt, 'metacurrency': currency} yield testutil.Transaction(date=next(dates), postings=[ - (acct, amt, currency), - (dst_acct if amt < 0 else src_acct, -amt, currency), + (acct, amt, currency, post_meta), + (dst_acct if amt < 0 else src_acct, -amt, currency, post_meta), ]) @pytest.fixture @@ -174,3 +175,33 @@ def test_meta_values_many_types(): for index, value in enumerate(expected): related.add(testutil.Posting('Income:Donations', -index, key=value)) assert related.meta_values('key') == expected + +def test_group_by_meta_zero(): + assert len(core.RelatedPostings.group_by_meta([], 'metacurrency')) == 0 + +def test_group_by_meta_key_error(): + # Make sure the return value doesn't act like a defaultdict. + with pytest.raises(KeyError): + core.RelatedPostings.group_by_meta([], 'metakey')['metavalue'] + +def test_group_by_meta_one(credit_card_cycle): + posting = next(post for post in data.Posting.from_entries(credit_card_cycle) + if post.account.is_credit_card()) + actual = core.RelatedPostings.group_by_meta([posting], 'metacurrency') + assert set(actual) == {'USD'} + +def test_group_by_meta_many(two_accruals_three_payments): + postings = [post for post in data.Posting.from_entries(two_accruals_three_payments) + if post.account == 'Assets:Receivable:Accounts'] + actual = core.RelatedPostings.group_by_meta(postings, 'metacurrency') + assert set(actual) == {'USD', 'EUR'} + for key, group in actual.items(): + assert 2 <= len(group) <= 3 + assert group.balance().is_zero() + +def test_group_by_meta_many_single_posts(two_accruals_three_payments): + postings = [post for post in data.Posting.from_entries(two_accruals_three_payments) + if post.account == 'Assets:Receivable:Accounts'] + actual = core.RelatedPostings.group_by_meta(postings, 'metanumber') + assert set(actual) == {post.units.number for post in postings} + assert len(actual) == len(postings)