Changeset - ef03893bfed3
[Not reviewed]
0 2 0
Brett Smith - 3 years ago 2021-03-12 15:56:43
brettcsmith@brettcsmith.org
query: Convert query functions that return List to Set.

Beancount's built-in renderers expect this and are better equipped for it.
2 files changed with 27 insertions and 34 deletions:
0 comments (0 inline, 0 general)
conservancy_beancount/reports/query.py
Show inline comments
...
 
@@ -158,14 +158,12 @@ class MetaDocs(bc_query_env.AnyMeta):
 
    """Return a list of document links from metadata."""
 
    def __init__(self, operands: List[bc_query_compile.EvalNode]) -> None:
 
        super(bc_query_env.AnyMeta, self).__init__(operands, list)
 
        super(bc_query_env.AnyMeta, self).__init__(operands, set)
 
        # The second argument is our return type.
 
        # It should match the annotated return type of __call__.
 

	
 
    def __call__(self, context: PostingContext) -> List[str]:
 
    def __call__(self, context: PostingContext) -> Set[str]:
 
        raw_value = super().__call__(context)
 
        if isinstance(raw_value, str):
 
            return raw_value.split()
 
        else:
 
            return []
 
        seq = raw_value.split() if isinstance(raw_value, str) else ''
 
        return set(seq)
 

	
 

	
...
 
@@ -248,5 +246,5 @@ class RTTicket(bc_query_compile.EvalFunction):
 
        if not rest:
 
            operands.append(bc_query_compile.EvalConstant(sys.maxsize))
 
        super().__init__(operands, list)
 
        super().__init__(operands, set)
 

	
 
    def _rt_key(self, key: str) -> RTField:
...
 
@@ -262,5 +260,5 @@ class RTTicket(bc_query_compile.EvalFunction):
 
            raise ValueError(f"metadata key {key!r} does not contain documentation links")
 

	
 
    def __call__(self, context: PostingContext) -> list:
 
    def __call__(self, context: PostingContext) -> Set[object]:
 
        rt_key: str
 
        meta_key: str
...
 
@@ -284,5 +282,5 @@ class RTTicket(bc_query_compile.EvalFunction):
 
                if len(ticket_ids) >= limit:
 
                    break
 
        retval: List[object] = []
 
        retval: Set[object] = set()
 
        for ticket_id in ticket_ids:
 
            try:
...
 
@@ -295,7 +293,7 @@ class RTTicket(bc_query_compile.EvalFunction):
 
                pass
 
            elif isinstance(field_value, list):
 
                retval.extend(field_value)
 
                retval.update(field_value)
 
            else:
 
                retval.append(field_value)
 
                retval.add(field_value)
 
        return retval
 

	
tests/test_reports_query.py
Show inline comments
...
 
@@ -87,8 +87,8 @@ def test_rt_ticket_bad_metadata(ticket_query, meta_name):
 

	
 
@pytest.mark.parametrize('field_name,meta_name,expected', [
 
    ('id', 'rt-id', 1),
 
    ('Queue', 'approval', 'general'),
 
    ('Requestors', 'invoice', ['mx1@example.org', 'requestor2@example.org']),
 
    ('Due', 'tax-reporting', datetime.datetime(2017, 1, 14, 12, 1, 0, tzinfo=UTC)),
 
    ('id', 'rt-id', {1}),
 
    ('Queue', 'approval', {'general'}),
 
    ('Requestors', 'invoice', {'mx1@example.org', 'requestor2@example.org'}),
 
    ('Due', 'tax-reporting', {datetime.datetime(2017, 1, 14, 12, 1, 0, tzinfo=UTC)}),
 
])
 
def test_rt_ticket_from_txn(ticket_query, field_name, meta_name, expected):
...
 
@@ -98,13 +98,11 @@ def test_rt_ticket_from_txn(ticket_query, field_name, meta_name, expected):
 
    ])
 
    context = RowContext(txn, txn.postings[0])
 
    if not isinstance(expected, list):
 
        expected = [expected]
 
    assert func(context) == expected
 

	
 
@pytest.mark.parametrize('field_name,meta_name,expected', [
 
    ('id', 'rt-id', 2),
 
    ('Queue', 'approval', 'general'),
 
    ('Requestors', 'invoice', ['mx2@example.org', 'requestor2@example.org']),
 
    ('Due', 'tax-reporting', datetime.datetime(2017, 1, 14, 12, 2, 0, tzinfo=UTC)),
 
    ('id', 'rt-id', {2}),
 
    ('Queue', 'approval', {'general'}),
 
    ('Requestors', 'invoice', {'mx2@example.org', 'requestor2@example.org'}),
 
    ('Due', 'tax-reporting', {datetime.datetime(2017, 1, 14, 12, 2, 0, tzinfo=UTC)}),
 
])
 
def test_rt_ticket_from_post(ticket_query, field_name, meta_name, expected):
...
 
@@ -114,17 +112,14 @@ def test_rt_ticket_from_post(ticket_query, field_name, meta_name, expected):
 
    ])
 
    context = RowContext(txn, txn.postings[0])
 
    if not isinstance(expected, list):
 
        expected = [expected]
 
    assert func(context) == expected
 

	
 
@pytest.mark.parametrize('field_name,meta_name,expected,on_txn', [
 
    ('id', 'approval', [1, 2], True),
 
    ('Queue', 'check', ['general', 'general'], False),
 
    ('Requestors', 'invoice', [
 
    ('id', 'approval', {1, 2}, True),
 
    ('Queue', 'check', {'general'}, False),
 
    ('Requestors', 'invoice', {
 
        'mx1@example.org',
 
        'mx2@example.org',
 
        'requestor2@example.org',
 
        'requestor2@example.org',
 
    ], False),
 
    }, False),
 
])
 
def test_rt_ticket_multi_results(ticket_query, field_name, meta_name, expected, on_txn):
...
 
@@ -137,5 +132,5 @@ def test_rt_ticket_multi_results(ticket_query, field_name, meta_name, expected,
 
    meta[meta_name] = 'rt:1/2 Docs/12.pdf rt:2/8'
 
    context = RowContext(txn, post)
 
    assert sorted(func(context)) == expected
 
    assert func(context) == expected
 

	
 
@pytest.mark.parametrize('meta_value,on_txn', testutil.combine_values(
...
 
@@ -152,5 +147,5 @@ def test_rt_ticket_no_results(ticket_query, meta_value, on_txn):
 
    meta['check'] = meta_value
 
    context = RowContext(txn, post)
 
    assert func(context) == []
 
    assert func(context) == set()
 

	
 
def test_rt_ticket_caches_tickets():
...
 
@@ -163,7 +158,7 @@ def test_rt_ticket_caches_tickets():
 
    ])
 
    context = RowContext(txn, txn.postings[0])
 
    assert func(context) == [3]
 
    assert func(context) == {3}
 
    del rt_client.TICKET_DATA['3']
 
    assert func(context) == [3]
 
    assert func(context) == {3}
 

	
 
def test_rt_ticket_caches_tickets_not_found():
...
 
@@ -177,7 +172,7 @@ def test_rt_ticket_caches_tickets_not_found():
 
    ])
 
    context = RowContext(txn, txn.postings[0])
 
    assert func(context) == []
 
    assert func(context) == set()
 
    rt_client.TICKET_DATA['3'] = rt3
 
    assert func(context) == []
 
    assert func(context) == set()
 

	
 
def test_books_loader_empty():
0 comments (0 inline, 0 general)