From b37d7a302407710d1ceff431d949aba2e816fab1 2020-06-03 22:51:38 From: Brett Smith Date: 2020-06-03 22:51:38 Subject: [PATCH] reports: Make RelatedPostings an immutable data structure. This was an early mistake, it makes data consistency mistakes too easy, and I only used it once so far in actual code. Going to fix this now so I can more safely build on top of this data structure. --- diff --git a/conservancy_beancount/reports/accrual.py b/conservancy_beancount/reports/accrual.py index 3e55727a6a6a6cf17e6ad01fa2c3ce38fbeeaceb..f80de7c48cc4631fa4b7b87d53822da7b36a26d7 100644 --- a/conservancy_beancount/reports/accrual.py +++ b/conservancy_beancount/reports/accrual.py @@ -63,6 +63,7 @@ import collections import datetime import enum import logging +import operator import re import sys @@ -72,6 +73,8 @@ from typing import ( Dict, Iterable, Iterator, + FrozenSet, + List, Mapping, NamedTuple, Optional, @@ -79,6 +82,7 @@ from typing import ( Set, TextIO, Tuple, + Union, ) from ..beancount_types import ( Error, @@ -100,11 +104,15 @@ from .. import rtutil PROGNAME = 'accrual-report' -PostGroups = Mapping[Optional[MetaValue], core.RelatedPostings] +PostGroups = Mapping[Optional[MetaValue], 'AccrualPostings'] RTObject = Mapping[str, str] logger = logging.getLogger('conservancy_beancount.reports.accrual') +class Sentinel: + pass + + class Account(NamedTuple): name: str balance_paid: Callable[[core.Balance], bool] @@ -135,22 +143,95 @@ class AccrualAccount(enum.Enum): } +class AccrualPostings(core.RelatedPostings): + def _meta_getter(key: MetaKey) -> Callable[[data.Posting], MetaValue]: # type:ignore[misc] + def meta_getter(post: data.Posting) -> MetaValue: + return post.meta.get(key) + return meta_getter + + _FIELDS: Dict[str, Callable[[data.Posting], MetaValue]] = { + 'account': operator.attrgetter('account'), + 'contract': _meta_getter('contract'), + 'cost': operator.attrgetter('cost'), + 'entity': _meta_getter('entity'), + 'invoice': _meta_getter('invoice'), + 'purchase_order': _meta_getter('purchase-order'), + } + _INVOICE_COUNTER: Dict[str, int] = collections.defaultdict(int) + INCONSISTENT = Sentinel() + __slots__ = ( + 'accrual_type', + 'account', + 'accounts', + 'contract', + 'contracts', + 'cost', + 'costs', + 'entity', + 'entitys', + 'entities', + 'invoice', + 'invoices', + 'purchase_order', + 'purchase_orders', + ) + + def __init__(self, + source: Iterable[data.Posting]=(), + *, + _can_own: bool=False, + ) -> None: + super().__init__(source, _can_own=_can_own) + # The following type declarations tell mypy about values set in the for + # loop that are important enough to be referenced directly elsewhere. + self.account: Union[data.Account, Sentinel] + self.entitys: FrozenSet[MetaValue] + self.invoice: Union[MetaValue, Sentinel] + for name, get_func in self._FIELDS.items(): + values = frozenset(get_func(post) for post in self) + setattr(self, f'{name}s', values) + if len(values) == 1: + one_value = next(iter(values)) + else: + one_value = self.INCONSISTENT + setattr(self, name, one_value) + # Correct spelling = bug prevention for future users of this class. + self.entities = self.entitys + if self.account is self.INCONSISTENT: + self.accrual_type: Optional[AccrualAccount] = None + else: + self.accrual_type = AccrualAccount.classify(self) + + def report_inconsistencies(self) -> Iterable[Error]: + for field_name, get_func in self._FIELDS.items(): + if getattr(self, field_name) is self.INCONSISTENT: + for post in self: + errmsg = 'inconsistent {} for invoice {}: {}'.format( + field_name.replace('_', '-'), + self.invoice or "", + get_func(post), + ) + yield Error(post.meta, errmsg, post.meta.txn) + + class BaseReport: def __init__(self, out_file: TextIO) -> None: self.out_file = out_file self.logger = logger.getChild(type(self).__name__) - def _since_last_nonzero(self, posts: core.RelatedPostings) -> core.RelatedPostings: - retval = core.RelatedPostings() - for post in posts: - if retval.balance().is_zero(): - retval.clear() - retval.add(post) - return retval + def _since_last_nonzero(self, posts: AccrualPostings) -> AccrualPostings: + for index, (post, balance) in enumerate(posts.iter_with_balance()): + if balance.is_zero(): + start_index = index + try: + empty = start_index == index + except NameError: + empty = True + return posts if empty else AccrualPostings(posts[start_index + 1:]) def _report(self, invoice: str, - posts: core.RelatedPostings, + posts: AccrualPostings, index: int, ) -> Iterable[str]: raise NotImplementedError("BaseReport._report") @@ -164,7 +245,7 @@ class BaseReport: class BalanceReport(BaseReport): def _report(self, invoice: str, - posts: core.RelatedPostings, + posts: AccrualPostings, index: int, ) -> Iterable[str]: posts = self._since_last_nonzero(posts) @@ -182,7 +263,7 @@ class OutgoingReport(BaseReport): self.rt_client = rt_client self.rt_wrapper = rtutil.RT(rt_client) - def _primary_rt_id(self, posts: core.RelatedPostings) -> rtutil.TicketAttachmentIds: + def _primary_rt_id(self, posts: AccrualPostings) -> rtutil.TicketAttachmentIds: rt_ids = posts.all_meta_links('rt-id') rt_ids_count = len(rt_ids) if rt_ids_count != 1: @@ -195,7 +276,7 @@ class OutgoingReport(BaseReport): def _report(self, invoice: str, - posts: core.RelatedPostings, + posts: AccrualPostings, index: int, ) -> Iterable[str]: posts = self._since_last_nonzero(posts) @@ -329,28 +410,6 @@ class SearchTerm(NamedTuple): ) return cls(key, pattern) -def _consistency_check_one_thing( - key: MetaValue, - related: core.RelatedPostings, - get_name: str, - get_func: Callable[[data.Posting], Any], -) -> Iterable[Error]: - values = {get_func(post) for post in related} - if len(values) != 1: - for post in related: - errmsg = f'inconsistent {get_name} for invoice {key}: {get_func(post)}' - yield Error(post.meta, errmsg, post.meta.txn) - -def consistency_check(groups: PostGroups) -> Iterable[Error]: - errfmt = 'inconsistent {} for invoice {}: {{}}' - for key, related in groups.items(): - yield from _consistency_check_one_thing( - key, related, 'cost', lambda post: post.cost, - ) - for checked_meta in ['contract', 'entity', 'purchase-order']: - yield from _consistency_check_one_thing( - key, related, checked_meta, lambda post: post.meta.get(checked_meta), - ) def filter_search(postings: Iterable[data.Posting], search_terms: Iterable[SearchTerm], @@ -421,16 +480,16 @@ def main(arglist: Optional[Sequence[str]]=None, } load_errors = [Error(source, "no books to load in configuration", None)] postings = filter_search(data.Posting.from_entries(entries), args.search_terms) - groups = core.RelatedPostings.group_by_meta(postings, 'invoice') + groups: PostGroups = dict(AccrualPostings.group_by_meta(postings, 'invoice')) groups = AccrualAccount.filter_paid_accruals(groups) or groups - meta_errors = consistency_check(groups) returncode = 0 for error in load_errors: bc_printer.print_error(error, file=stderr) returncode |= ReturnFlag.LOAD_ERRORS - for error in meta_errors: - bc_printer.print_error(error, file=stderr) - returncode |= ReturnFlag.CONSISTENCY_ERRORS + for related in groups.values(): + for error in related.report_inconsistencies(): + bc_printer.print_error(error, file=stderr) + returncode |= ReturnFlag.CONSISTENCY_ERRORS if args.report_type is None: args.report_type = ReportType.default_for(groups) if not groups: diff --git a/conservancy_beancount/reports/core.py b/conservancy_beancount/reports/core.py index 8c8969ed4c62af957b12e4979cab34baee4d3538..dba92cd7b9dbe984506b1f40ee94230a41ea1b80 100644 --- a/conservancy_beancount/reports/core.py +++ b/conservancy_beancount/reports/core.py @@ -37,6 +37,8 @@ from typing import ( Sequence, Set, Tuple, + Type, + TypeVar, Union, ) from ..beancount_types import ( @@ -45,6 +47,7 @@ from ..beancount_types import ( ) DecimalCompat = data.DecimalCompat +RelatedType = TypeVar('RelatedType', bound='RelatedPostings') class Balance(Mapping[str, data.Amount]): """A collection of amounts mapped by currency @@ -162,15 +165,23 @@ class RelatedPostings(Sequence[data.Posting]): """ __slots__ = ('_postings',) - def __init__(self, source: Iterable[data.Posting]=()) -> None: - self._postings: List[data.Posting] = list(source) + def __init__(self, + source: Iterable[data.Posting]=(), + *, + _can_own: bool=False, + ) -> None: + self._postings: List[data.Posting] + if _can_own and isinstance(source, list): + self._postings = source + else: + self._postings = list(source) @classmethod - def group_by_meta(cls, + def group_by_meta(cls: Type[RelatedType], postings: Iterable[data.Posting], key: MetaKey, default: Optional[MetaValue]=None, - ) -> Mapping[Optional[MetaValue], 'RelatedPostings']: + ) -> Iterator[Tuple[Optional[MetaValue], RelatedType]]: """Relate postings by metadata value This method takes an iterable of postings and returns a mapping. @@ -178,32 +189,29 @@ class RelatedPostings(Sequence[data.Posting]): The values are RelatedPostings instances that contain all the postings that had that same metadata value. """ - retval: DefaultDict[Optional[MetaValue], 'RelatedPostings'] = collections.defaultdict(cls) + mapping: DefaultDict[Optional[MetaValue], List[data.Posting]] = collections.defaultdict(list) for post in postings: - retval[post.meta.get(key, default)].add(post) - retval.default_factory = None - return retval + mapping[post.meta.get(key, default)].append(post) + for value, posts in mapping.items(): + yield value, cls(posts, _can_own=True) @overload - def __getitem__(self, index: int) -> data.Posting: ... + def __getitem__(self: RelatedType, index: int) -> data.Posting: ... @overload - def __getitem__(self, s: slice) -> Sequence[data.Posting]: ... + def __getitem__(self: RelatedType, s: slice) -> RelatedType: ... - def __getitem__(self, + def __getitem__(self: RelatedType, index: Union[int, slice], - ) -> Union[data.Posting, Sequence[data.Posting]]: + ) -> Union[data.Posting, RelatedType]: if isinstance(index, slice): - raise NotImplementedError("RelatedPostings[slice]") + return type(self)(self._postings[index], _can_own=True) else: return self._postings[index] def __len__(self) -> int: return len(self._postings) - def add(self, post: data.Posting) -> None: - self._postings.append(post) - def all_meta_links(self, key: MetaKey) -> Set[str]: retval: Set[str] = set() for post in self: @@ -213,9 +221,6 @@ class RelatedPostings(Sequence[data.Posting]): pass return retval - def clear(self) -> None: - self._postings.clear() - def iter_with_balance(self) -> Iterator[Tuple[data.Posting, Balance]]: balance = MutableBalance() for post in self: diff --git a/tests/test_reports_accrual.py b/tests/test_reports_accrual.py index e880b0d33ffeb4de3b9b5f312c1db7cedf1c0036..8dd22f45fa960c95278136f98740f1c62de71a24 100644 --- a/tests/test_reports_accrual.py +++ b/tests/test_reports_accrual.py @@ -94,8 +94,8 @@ def check_link_regexp(regexp, match_s, first_link_only=False): else: assert end_match -def relate_accruals_by_meta(postings, value, key='invoice'): - return core.RelatedPostings( +def accruals_by_meta(postings, value, key='invoice', wrap_type=iter): + return wrap_type( post for post in postings if post.meta.get(key) == value and post.account.is_under('Assets:Receivable', 'Liabilities:Payable') @@ -200,22 +200,107 @@ def test_report_type_by_unknown_name(arg): with pytest.raises(ValueError): accrual.ReportType.by_name(arg) +@pytest.mark.parametrize('acct_name', ACCOUNTS) +def test_accrual_postings_consistent_account(acct_name): + meta = {'invoice': '{acct_name} invoice.pdf'} + txn = testutil.Transaction(postings=[ + (acct_name, 50, meta), + (acct_name, 25, meta), + ]) + related = accrual.AccrualPostings(data.Posting.from_txn(txn)) + assert related.account == acct_name + assert related.accounts == {acct_name} + +@pytest.mark.parametrize('cost', [ + testutil.Cost('1.2', 'USD'), + None, +]) +def test_accrual_postings_consistent_cost(cost): + meta = {'invoice': 'FXinvoice.pdf'} + txn = testutil.Transaction(postings=[ + (ACCOUNTS[0], 60, 'EUR', cost, meta), + (ACCOUNTS[0], 30, 'EUR', cost, meta), + ]) + related = accrual.AccrualPostings(data.Posting.from_txn(txn)) + assert related.cost == cost + assert related.costs == {cost} + +@pytest.mark.parametrize('meta_key,acct_name', testutil.combine_values( + CONSISTENT_METADATA, + ACCOUNTS, +)) +def test_accrual_postings_consistent_metadata(meta_key, acct_name): + meta_value = f'{meta_key}.pdf' + meta = { + meta_key: meta_value, + 'invoice': f'invoice with {meta_key}.pdf', + } + txn = testutil.Transaction(postings=[ + (acct_name, 70, meta), + (acct_name, 35, meta), + ]) + related = accrual.AccrualPostings(data.Posting.from_txn(txn)) + attr_name = meta_key.replace('-', '_') + assert getattr(related, attr_name) == meta_value + assert getattr(related, f'{attr_name}s') == {meta_value} + +def test_accrual_postings_inconsistent_account(): + meta = {'invoice': 'invoice.pdf'} + txn = testutil.Transaction(postings=[ + (acct_name, index, meta) + for index, acct_name in enumerate(ACCOUNTS) + ]) + related = accrual.AccrualPostings(data.Posting.from_txn(txn)) + assert related.account is related.INCONSISTENT + assert related.accounts == set(ACCOUNTS) + +def test_accrual_postings_inconsistent_cost(): + meta = {'invoice': 'FXinvoice.pdf'} + costs = { + testutil.Cost('1.1', 'USD'), + testutil.Cost('1.2', 'USD'), + } + txn = testutil.Transaction(postings=[ + (ACCOUNTS[0], 10, 'EUR', cost, meta) + for cost in costs + ]) + related = accrual.AccrualPostings(data.Posting.from_txn(txn)) + assert related.cost is related.INCONSISTENT + assert related.costs == costs + +@pytest.mark.parametrize('meta_key,acct_name', testutil.combine_values( + CONSISTENT_METADATA, + ACCOUNTS, +)) +def test_accrual_postings_inconsistent_metadata(meta_key, acct_name): + invoice = 'invoice with {meta_key}.pdf' + meta_value = f'{meta_key}.pdf' + txn = testutil.Transaction(postings=[ + (acct_name, 20, {'invoice': invoice, meta_key: meta_value}), + (acct_name, 35, {'invoice': invoice}), + ]) + related = accrual.AccrualPostings(data.Posting.from_txn(txn)) + attr_name = meta_key.replace('-', '_') + assert getattr(related, attr_name) is related.INCONSISTENT + assert getattr(related, f'{attr_name}s') == {meta_value, None} + @pytest.mark.parametrize('meta_key,account', testutil.combine_values( CONSISTENT_METADATA, ACCOUNTS, )) def test_consistency_check_when_consistent(meta_key, account): invoice = f'test-{meta_key}-invoice' + meta_value = f'test-{meta_key}-value' meta = { 'invoice': invoice, - meta_key: f'test-{meta_key}-value', + meta_key: meta_value, } txn = testutil.Transaction(postings=[ (account, 100, meta), (account, -100, meta), ]) - related = core.RelatedPostings(data.Posting.from_txn(txn)) - assert not list(accrual.consistency_check({invoice: related})) + related = accrual.AccrualPostings(data.Posting.from_txn(txn)) + assert not list(related.report_inconsistencies()) @pytest.mark.parametrize('meta_key,account', testutil.combine_values( ['approval', 'fx-rate', 'statement'], @@ -227,8 +312,8 @@ def test_consistency_check_ignored_metadata(meta_key, account): (account, 100, {'invoice': invoice, meta_key: 'credit'}), (account, -100, {'invoice': invoice, meta_key: 'debit'}), ]) - related = core.RelatedPostings(data.Posting.from_txn(txn)) - assert not list(accrual.consistency_check({invoice: related})) + related = accrual.AccrualPostings(data.Posting.from_txn(txn)) + assert not list(related.report_inconsistencies()) @pytest.mark.parametrize('meta_key,account', testutil.combine_values( CONSISTENT_METADATA, @@ -240,8 +325,8 @@ def test_consistency_check_when_inconsistent(meta_key, account): (account, 100, {'invoice': invoice, meta_key: 'credit', 'lineno': 1}), (account, -100, {'invoice': invoice, meta_key: 'debit', 'lineno': 2}), ]) - related = core.RelatedPostings(data.Posting.from_txn(txn)) - errors = list(accrual.consistency_check({invoice: related})) + related = accrual.AccrualPostings(data.Posting.from_txn(txn)) + errors = list(related.report_inconsistencies()) for exp_lineno, (actual, exp_msg) in enumerate(itertools.zip_longest(errors, [ f'inconsistent {meta_key} for invoice {invoice}: credit', f'inconsistent {meta_key} for invoice {invoice}: debit', @@ -257,8 +342,8 @@ def test_consistency_check_cost(): (account, 100, 'EUR', ('1.1251', 'USD'), {'invoice': invoice, 'lineno': 1}), (account, -100, 'EUR', ('1.125', 'USD'), {'invoice': invoice, 'lineno': 2}), ]) - related = core.RelatedPostings(data.Posting.from_txn(txn)) - errors = list(accrual.consistency_check({invoice: related})) + related = accrual.AccrualPostings(data.Posting.from_txn(txn)) + errors = list(related.report_inconsistencies()) for post, err in itertools.zip_longest(txn.postings, errors): assert err.message == f'inconsistent cost for invoice {invoice}: {post.cost}' assert err.entry is txn @@ -272,7 +357,7 @@ def run_outgoing(invoice, postings, rt_client=None): if rt_client is None: rt_client = RTClient() if not isinstance(postings, core.RelatedPostings): - postings = relate_accruals_by_meta(postings, invoice) + postings = accruals_by_meta(postings, invoice, wrap_type=accrual.AccrualPostings) output = io.StringIO() report = accrual.OutgoingReport(rt_client, output) report.run({invoice: postings}) @@ -285,7 +370,7 @@ def run_outgoing(invoice, postings, rt_client=None): ('rt://ticket/515/attachments/5150', "1,500.00 USD outstanding since 2020-05-15",), ]) def test_balance_report(accrual_postings, invoice, expected, caplog): - related = relate_accruals_by_meta(accrual_postings, invoice) + related = accruals_by_meta(accrual_postings, invoice, wrap_type=accrual.AccrualPostings) output = io.StringIO() report = accrual.BalanceReport(output) report.run({invoice: related}) diff --git a/tests/test_reports_related_postings.py b/tests/test_reports_related_postings.py index 13d8ee90fa53d15d560c4380961f879d237e022f..6707a2414ed8626cfc0375695511341368750b31 100644 --- a/tests/test_reports_related_postings.py +++ b/tests/test_reports_related_postings.py @@ -80,42 +80,27 @@ def test_balance_empty(): assert not balance assert balance.is_zero() -def test_balance_credit_card(credit_card_cycle): - related = core.RelatedPostings() - assert related.balance() == testutil.balance_map() - expected = Decimal() - for txn in credit_card_cycle: - post = txn.postings[0] - expected += post.units.number - related.add(post) - assert related.balance() == testutil.balance_map(USD=expected) - assert expected == 0 - -def test_clear_after_add(): - related = core.RelatedPostings() - related.add(testutil.Posting('Income:Donations', -10)) - assert related.balance() - related.clear() - assert not related.balance() - -def test_clear_after_initialization(): - related = core.RelatedPostings([ - testutil.Posting('Income:Donations', -12), - ]) - assert related.balance() - related.clear() - assert not related.balance() +@pytest.mark.parametrize('index,expected', enumerate([ + -110, + 0, + -120, + 0, +])) +def test_balance_credit_card(credit_card_cycle, index, expected): + related = core.RelatedPostings( + txn.postings[0] for txn in credit_card_cycle[:index + 1] + ) + assert related.balance() == testutil.balance_map(USD=expected) def check_iter_with_balance(entries): expect_posts = [txn.postings[0] for txn in entries] expect_balances = [] balance_tally = collections.defaultdict(Decimal) - related = core.RelatedPostings() for post in expect_posts: number, currency = post.units balance_tally[currency] += number expect_balances.append(testutil.balance_map(balance_tally.items())) - related.add(post) + related = core.RelatedPostings(expect_posts) for (post, balance), exp_post, exp_balance in zip( related.iter_with_balance(), expect_posts, @@ -195,48 +180,56 @@ def test_meta_values_empty(): assert related.meta_values('key') == set() def test_meta_values_no_match(): - related = core.RelatedPostings() - related.add(testutil.Posting('Income:Donations', -1, metakey='metavalue')) + related = core.RelatedPostings([ + testutil.Posting('Income:Donations', -1, metakey='metavalue'), + ]) assert related.meta_values('key') == {None} def test_meta_values_no_match_default_given(): - related = core.RelatedPostings() - related.add(testutil.Posting('Income:Donations', -1, metakey='metavalue')) + related = core.RelatedPostings([ + testutil.Posting('Income:Donations', -1, metakey='metavalue'), + ]) assert related.meta_values('key', '') == {''} def test_meta_values_one_match(): - related = core.RelatedPostings() - related.add(testutil.Posting('Income:Donations', -1, key='metavalue')) + related = core.RelatedPostings([ + testutil.Posting('Income:Donations', -1, key='metavalue'), + ]) assert related.meta_values('key') == {'metavalue'} def test_meta_values_some_match(): - related = core.RelatedPostings() - related.add(testutil.Posting('Income:Donations', -1, key='1')) - related.add(testutil.Posting('Income:Donations', -2, metakey='2')) + related = core.RelatedPostings([ + testutil.Posting('Income:Donations', -1, key='1'), + testutil.Posting('Income:Donations', -2, metakey='2'), + ]) assert related.meta_values('key') == {'1', None} def test_meta_values_some_match_default_given(): - related = core.RelatedPostings() - related.add(testutil.Posting('Income:Donations', -1, key='1')) - related.add(testutil.Posting('Income:Donations', -2, metakey='2')) + related = core.RelatedPostings([ + testutil.Posting('Income:Donations', -1, key='1'), + testutil.Posting('Income:Donations', -2, metakey='2'), + ]) assert related.meta_values('key', '') == {'1', ''} def test_meta_values_all_match(): - related = core.RelatedPostings() - related.add(testutil.Posting('Income:Donations', -1, key='1')) - related.add(testutil.Posting('Income:Donations', -2, key='2')) + related = core.RelatedPostings([ + testutil.Posting('Income:Donations', -1, key='1'), + testutil.Posting('Income:Donations', -2, key='2'), + ]) assert related.meta_values('key') == {'1', '2'} def test_meta_values_all_match_one_value(): - related = core.RelatedPostings() - related.add(testutil.Posting('Income:Donations', -1, key='1')) - related.add(testutil.Posting('Income:Donations', -2, key='1')) + related = core.RelatedPostings([ + testutil.Posting('Income:Donations', -1, key='1'), + testutil.Posting('Income:Donations', -2, key='1'), + ]) assert related.meta_values('key') == {'1'} def test_meta_values_all_match_default_given(): - related = core.RelatedPostings() - related.add(testutil.Posting('Income:Donations', -1, key='1')) - related.add(testutil.Posting('Income:Donations', -2, key='2')) + related = core.RelatedPostings([ + testutil.Posting('Income:Donations', -1, key='1'), + testutil.Posting('Income:Donations', -2, key='2'), + ]) assert related.meta_values('key', '') == {'1', '2'} def test_meta_values_many_types(): @@ -246,9 +239,10 @@ def test_meta_values_many_types(): testutil.Amount(5), 'rt:42', } - related = core.RelatedPostings() - for index, value in enumerate(expected): - related.add(testutil.Posting('Income:Donations', -index, key=value)) + related = core.RelatedPostings( + testutil.Posting('Income:Donations', -index, key=value) + for index, value in enumerate(expected) + ) assert related.meta_values('key') == expected @pytest.mark.parametrize('count', range(3)) @@ -289,23 +283,18 @@ def test_all_meta_links_multiples(): assert related.all_meta_links('approval') == testutil.LINK_METADATA_STRINGS def test_group_by_meta_zero(): - assert len(core.RelatedPostings.group_by_meta([], 'metacurrency')) == 0 - -def test_group_by_meta_key_error(): - # Make sure the return value doesn't act like a defaultdict. - with pytest.raises(KeyError): - core.RelatedPostings.group_by_meta([], 'metakey')['metavalue'] + assert not list(core.RelatedPostings.group_by_meta([], 'metacurrency')) def test_group_by_meta_one(credit_card_cycle): posting = next(post for post in data.Posting.from_entries(credit_card_cycle) if post.account.is_credit_card()) actual = core.RelatedPostings.group_by_meta([posting], 'metacurrency') - assert set(actual) == {'USD'} + assert set(key for key, _ in actual) == {'USD'} def test_group_by_meta_many(two_accruals_three_payments): postings = [post for post in data.Posting.from_entries(two_accruals_three_payments) if post.account == 'Assets:Receivable:Accounts'] - actual = core.RelatedPostings.group_by_meta(postings, 'metacurrency') + actual = dict(core.RelatedPostings.group_by_meta(postings, 'metacurrency')) assert set(actual) == {'USD', 'EUR'} for key, group in actual.items(): assert 2 <= len(group) <= 3 @@ -314,6 +303,6 @@ def test_group_by_meta_many(two_accruals_three_payments): def test_group_by_meta_many_single_posts(two_accruals_three_payments): postings = [post for post in data.Posting.from_entries(two_accruals_three_payments) if post.account == 'Assets:Receivable:Accounts'] - actual = core.RelatedPostings.group_by_meta(postings, 'metanumber') + actual = dict(core.RelatedPostings.group_by_meta(postings, 'metanumber')) assert set(actual) == {post.units.number for post in postings} assert len(actual) == len(postings)