diff --git a/conservancy_beancount/reports/query.py b/conservancy_beancount/reports/query.py index 75881069cd22aaf1fdd3105b9a39a32f03166cc5..086ac09b019ba7a00c3e3cdee0a3511222403d87 100644 --- a/conservancy_beancount/reports/query.py +++ b/conservancy_beancount/reports/query.py @@ -157,16 +157,14 @@ class PostingContext: class MetaDocs(bc_query_env.AnyMeta): """Return a list of document links from metadata.""" def __init__(self, operands: List[bc_query_compile.EvalNode]) -> None: - super(bc_query_env.AnyMeta, self).__init__(operands, list) + super(bc_query_env.AnyMeta, self).__init__(operands, set) # The second argument is our return type. # It should match the annotated return type of __call__. - def __call__(self, context: PostingContext) -> List[str]: + def __call__(self, context: PostingContext) -> Set[str]: raw_value = super().__call__(context) - if isinstance(raw_value, str): - return raw_value.split() - else: - return [] + seq = raw_value.split() if isinstance(raw_value, str) else '' + return set(seq) class RTField(NamedTuple): @@ -247,7 +245,7 @@ class RTTicket(bc_query_compile.EvalFunction): self._meta_key(meta_op.value) if not rest: operands.append(bc_query_compile.EvalConstant(sys.maxsize)) - super().__init__(operands, list) + super().__init__(operands, set) def _rt_key(self, key: str) -> RTField: try: @@ -261,7 +259,7 @@ class RTTicket(bc_query_compile.EvalFunction): else: raise ValueError(f"metadata key {key!r} does not contain documentation links") - def __call__(self, context: PostingContext) -> list: + def __call__(self, context: PostingContext) -> Set[object]: rt_key: str meta_key: str limit: int @@ -283,7 +281,7 @@ class RTTicket(bc_query_compile.EvalFunction): ticket_ids.add(rt_id[0]) if len(ticket_ids) >= limit: break - retval: List[object] = [] + retval: Set[object] = set() for ticket_id in ticket_ids: try: rt_ticket = self._rt_cache[ticket_id] @@ -294,9 +292,9 @@ class RTTicket(bc_query_compile.EvalFunction): if field_value is None: pass elif isinstance(field_value, list): - retval.extend(field_value) + retval.update(field_value) else: - retval.append(field_value) + retval.add(field_value) return retval diff --git a/tests/test_reports_query.py b/tests/test_reports_query.py index fa5270b27bfa19e7a39a0d6b6c113a8b2c0f9ea2..f3ba5947de28d183f993e3317cc397da6091fdfd 100644 --- a/tests/test_reports_query.py +++ b/tests/test_reports_query.py @@ -86,10 +86,10 @@ def test_rt_ticket_bad_metadata(ticket_query, meta_name): ticket_query(const_operands('id', meta_name)) @pytest.mark.parametrize('field_name,meta_name,expected', [ - ('id', 'rt-id', 1), - ('Queue', 'approval', 'general'), - ('Requestors', 'invoice', ['mx1@example.org', 'requestor2@example.org']), - ('Due', 'tax-reporting', datetime.datetime(2017, 1, 14, 12, 1, 0, tzinfo=UTC)), + ('id', 'rt-id', {1}), + ('Queue', 'approval', {'general'}), + ('Requestors', 'invoice', {'mx1@example.org', 'requestor2@example.org'}), + ('Due', 'tax-reporting', {datetime.datetime(2017, 1, 14, 12, 1, 0, tzinfo=UTC)}), ]) def test_rt_ticket_from_txn(ticket_query, field_name, meta_name, expected): func = ticket_query(const_operands(field_name, meta_name)) @@ -97,15 +97,13 @@ def test_rt_ticket_from_txn(ticket_query, field_name, meta_name, expected): ('Assets:Cash', 80), ]) context = RowContext(txn, txn.postings[0]) - if not isinstance(expected, list): - expected = [expected] assert func(context) == expected @pytest.mark.parametrize('field_name,meta_name,expected', [ - ('id', 'rt-id', 2), - ('Queue', 'approval', 'general'), - ('Requestors', 'invoice', ['mx2@example.org', 'requestor2@example.org']), - ('Due', 'tax-reporting', datetime.datetime(2017, 1, 14, 12, 2, 0, tzinfo=UTC)), + ('id', 'rt-id', {2}), + ('Queue', 'approval', {'general'}), + ('Requestors', 'invoice', {'mx2@example.org', 'requestor2@example.org'}), + ('Due', 'tax-reporting', {datetime.datetime(2017, 1, 14, 12, 2, 0, tzinfo=UTC)}), ]) def test_rt_ticket_from_post(ticket_query, field_name, meta_name, expected): func = ticket_query(const_operands(field_name, meta_name)) @@ -113,19 +111,16 @@ def test_rt_ticket_from_post(ticket_query, field_name, meta_name, expected): ('Assets:Cash', 110, {meta_name: 'rt:2/8'}), ]) context = RowContext(txn, txn.postings[0]) - if not isinstance(expected, list): - expected = [expected] assert func(context) == expected @pytest.mark.parametrize('field_name,meta_name,expected,on_txn', [ - ('id', 'approval', [1, 2], True), - ('Queue', 'check', ['general', 'general'], False), - ('Requestors', 'invoice', [ + ('id', 'approval', {1, 2}, True), + ('Queue', 'check', {'general'}, False), + ('Requestors', 'invoice', { 'mx1@example.org', 'mx2@example.org', 'requestor2@example.org', - 'requestor2@example.org', - ], False), + }, False), ]) def test_rt_ticket_multi_results(ticket_query, field_name, meta_name, expected, on_txn): func = ticket_query(const_operands(field_name, meta_name)) @@ -136,7 +131,7 @@ def test_rt_ticket_multi_results(ticket_query, field_name, meta_name, expected, meta = txn.meta if on_txn else post.meta meta[meta_name] = 'rt:1/2 Docs/12.pdf rt:2/8' context = RowContext(txn, post) - assert sorted(func(context)) == expected + assert func(context) == expected @pytest.mark.parametrize('meta_value,on_txn', testutil.combine_values( ['', 'Docs/34.pdf', 'Docs/100.pdf Docs/120.pdf'], @@ -151,7 +146,7 @@ def test_rt_ticket_no_results(ticket_query, meta_value, on_txn): meta = txn.meta if on_txn else post.meta meta['check'] = meta_value context = RowContext(txn, post) - assert func(context) == [] + assert func(context) == set() def test_rt_ticket_caches_tickets(): rt_client = testutil.RTClient() @@ -162,9 +157,9 @@ def test_rt_ticket_caches_tickets(): ('Assets:Cash', 160, {'rt-id': 'rt:3'}), ]) context = RowContext(txn, txn.postings[0]) - assert func(context) == [3] + assert func(context) == {3} del rt_client.TICKET_DATA['3'] - assert func(context) == [3] + assert func(context) == {3} def test_rt_ticket_caches_tickets_not_found(): rt_client = testutil.RTClient() @@ -176,9 +171,9 @@ def test_rt_ticket_caches_tickets_not_found(): ('Assets:Cash', 160, {'rt-id': 'rt:3'}), ]) context = RowContext(txn, txn.postings[0]) - assert func(context) == [] + assert func(context) == set() rt_client.TICKET_DATA['3'] = rt3 - assert func(context) == [] + assert func(context) == set() def test_books_loader_empty(): result = qmod.BooksLoader(None)()