File diff 5231a1784f87 → ccbc447a353d
tests/test_reports_query.py
Show inline comments
 
new file 100644
 
"""test_reports_query.py - Unit tests for query report"""
 
# Copyright © 2021  Brett Smith
 
# License: AGPLv3-or-later WITH Beancount-Plugin-Additional-Permission-1.0
 
#
 
# Full copyright and licensing details can be found at toplevel file
 
# LICENSE.txt in the repository.
 

	
 
import argparse
 
import collections
 
import copy
 
import csv
 
import datetime
 
import io
 
import itertools
 
import re
 

	
 
import pytest
 

	
 
from . import testutil
 

	
 
from conservancy_beancount.books import FiscalYear
 
from conservancy_beancount.reports import query as qmod
 

	
 
from decimal import Decimal
 

	
 
class MockRewriteRuleset:
 
    def __init__(self, multiplier=2):
 
        self.multiplier = multiplier
 

	
 
    def rewrite(self, posts):
 
        for post in posts:
 
            number, currency = post.units
 
            number *= self.multiplier
 
            yield post._replace(units=testutil.Amount(number, currency))
 

	
 

	
 
@pytest.fixture(scope='module')
 
def fy():
 
    return FiscalYear(3, 1)
 

	
 
def pipe_main(arglist, config):
 
    stdout = io.StringIO()
 
    stderr = io.StringIO()
 
    returncode = qmod.main(arglist, stdout, stderr, config)
 
    return returncode, stdout, stderr
 

	
 
def query_args(query=None, start_date=None, stop_date=None, join='AND'):
 
    join = qmod.JoinOperator[join]
 
    return argparse.Namespace(**locals())
 

	
 
def test_books_loader_empty():
 
    result = qmod.BooksLoader(None)()
 
    assert not result.entries
 
    assert len(result.errors) == 1
 

	
 
def test_books_loader_plain():
 
    books_path = testutil.test_path(f'books/books/2018.beancount')
 
    loader = testutil.TestBooksLoader(books_path)
 
    result = qmod.BooksLoader(loader)()
 
    assert not result.errors
 
    assert result.entries
 
    min_date = datetime.date(2018, 3, 1)
 
    assert all(ent.date >= min_date for ent in result.entries)
 

	
 
def test_books_loader_rewrites():
 
    rewrites = [MockRewriteRuleset()]
 
    books_path = testutil.test_path(f'books/books/2018.beancount')
 
    loader = testutil.TestBooksLoader(books_path)
 
    result = qmod.BooksLoader(loader, None, None, rewrites)()
 
    assert not result.errors
 
    assert result.entries
 
    numbers = frozenset(
 
        abs(post.units.number)
 
        for entry in result.entries
 
        for post in getattr(entry, 'postings', ())
 
    )
 
    assert numbers
 
    assert all(abs(number) >= 40 for number in numbers)
 

	
 
@pytest.mark.parametrize('file_s', [None, '', ' \n \n\n'])
 
def test_build_query_empty(fy, file_s):
 
    args = query_args()
 
    if file_s is None:
 
        query = qmod.build_query(args, fy)
 
    else:
 
        with io.StringIO(file_s) as qfile:
 
            query = qmod.build_query(args, fy, qfile)
 
    assert query is None
 

	
 
@pytest.mark.parametrize('query_str', [
 
    'SELECT * WHERE date >= 2018-03-01',
 
    'select *',
 
    'JOURNAL "Income:Donations"',
 
    'journal',
 
    'BALANCES FROM year=2018',
 
    'balances',
 
])
 
def test_build_query_in_arglist(fy, query_str):
 
    args = query_args(query_str.split(), testutil.PAST_DATE, testutil.FUTURE_DATE)
 
    assert qmod.build_query(args, fy) == query_str
 

	
 
@pytest.mark.parametrize('count,join_op', enumerate(qmod.JoinOperator, 1))
 
def test_build_query_where_arglist_conditions(fy, count, join_op):
 
    conds = ['account ~ "^Income:"', 'year >= 2018'][:count]
 
    args = query_args(conds, join=join_op.name)
 
    query = qmod.build_query(args, fy)
 
    assert query.startswith('SELECT ')
 
    cond_index = query.index(' WHERE ') + 7
 
    assert query[cond_index:] == '({})'.format(join_op.join(conds))
 

	
 
@pytest.mark.parametrize('argname,date_arg', itertools.product(
 
    ['start_date', 'stop_date'],
 
    [testutil.FY_START_DATE, testutil.FY_START_DATE.year],
 
))
 
def test_build_query_one_date_arg(fy, argname, date_arg):
 
    query_kwargs = {
 
        argname: date_arg,
 
        'query': ['flag = "*"', 'flag = "!"'],
 
        'join': 'OR',
 
    }
 
    args = query_args(**query_kwargs)
 
    query = qmod.build_query(args, fy)
 
    assert query.startswith('SELECT ')
 
    cond_index = query.index(' WHERE ') + 7
 
    if argname == 'start_date':
 
        expect_op = '>='
 
        year_to_date = fy.first_date
 
    else:
 
        expect_op = '<'
 
        year_to_date = fy.next_fy_date
 
    if not isinstance(date_arg, datetime.date):
 
        date_arg = year_to_date(date_arg)
 
    assert query[cond_index:] == '({}) AND date {} {}'.format(
 
        ' OR '.join(query_kwargs['query']), expect_op, date_arg.isoformat(),
 
    )
 

	
 
@pytest.mark.parametrize('start_date,stop_date', itertools.product(
 
    [testutil.PAST_DATE, testutil.PAST_DATE.year],
 
    [testutil.FUTURE_DATE, testutil.FUTURE_DATE.year],
 
))
 
def test_build_query_two_date_args(fy, start_date, stop_date):
 
    args = query_args(['account ~ "^Equity:"'], start_date, stop_date, 'AND')
 
    query = qmod.build_query(args, fy)
 
    assert query.startswith('SELECT ')
 
    cond_index = query.index(' WHERE ') + 7
 
    if isinstance(start_date, int):
 
        start_date = fy.first_date(start_date)
 
    if isinstance(stop_date, int):
 
        stop_date = fy.next_fy_date(stop_date)
 
    assert query[cond_index:] == '({}) AND date >= {} AND date < {}'.format(
 
        args.query[0], start_date.isoformat(), stop_date.isoformat(),
 
    )
 

	
 
def test_build_query_plain_from_file(fy):
 
    with io.StringIO("SELECT *\n WHERE account ~ '^Assets:';\n") as qfile:
 
        query = qmod.build_query(query_args(), fy, qfile)
 
    assert re.fullmatch(r"SELECT \*\s+WHERE account ~ '\^Assets:';\s*", query)
 

	
 
def test_build_query_from_file_where_clauses(fy):
 
    conds = ["account ~ '^Income:'", "account ~ '^Expenses:'"]
 
    args = query_args(None, testutil.PAST_DATE, testutil.FUTURE_DATE, 'OR')
 
    with io.StringIO(''.join(f'{s}\n' for s in conds)) as qfile:
 
        query = qmod.build_query(args, fy, qfile)
 
    assert query.startswith('SELECT ')
 
    cond_index = query.index(' WHERE ') + 7
 
    assert query[cond_index:] == '({}) AND date >= {} AND date < {}'.format(
 
        ' OR '.join(conds),
 
        testutil.PAST_DATE.isoformat(),
 
        testutil.FUTURE_DATE.isoformat(),
 
    )
 

	
 
@pytest.mark.parametrize('arglist,fy', testutil.combine_values(
 
    [['--report-type', 'text'], ['--format=text'], ['-f', 'txt']],
 
    range(2018, 2021),
 
))
 
def test_text_query(arglist, fy):
 
    books_path = testutil.test_path(f'books/books/{fy}.beancount')
 
    config = testutil.TestConfig(books_path=books_path)
 
    arglist += ['select', 'date,', 'narration,', 'account,', 'position']
 
    returncode, stdout, stderr = pipe_main(arglist, config)
 
    assert returncode == 0
 
    stdout.seek(0)
 
    lines = iter(stdout)
 
    next(lines); next(lines)  # Skip header
 
    for count, line in enumerate(lines, 1):
 
        assert re.match(rf'^{fy}-\d\d-\d\d\s+{fy} donation\b', line)
 
    assert count >= 2
 

	
 
@pytest.mark.parametrize('arglist,fy', testutil.combine_values(
 
    [['--format=csv'], ['-f', 'csv'], ['-t', 'csv']],
 
    range(2018, 2021),
 
))
 
def test_csv_query(arglist, fy):
 
    books_path = testutil.test_path(f'books/books/{fy}.beancount')
 
    config = testutil.TestConfig(books_path=books_path)
 
    arglist += ['select', 'date,', 'narration,', 'account,', 'position']
 
    returncode, stdout, stderr = pipe_main(arglist, config)
 
    assert returncode == 0
 
    stdout.seek(0)
 
    for count, row in enumerate(csv.DictReader(stdout), 1):
 
        assert re.fullmatch(rf'{fy}-\d\d-\d\d', row['date'])
 
        assert row['narration'] == f'{fy} donation'
 
    assert count >= 2
 

	
 
@pytest.mark.parametrize('end_index', range(3))
 
def test_rewrite_query(end_index):
 
    books_path = testutil.test_path(f'books/books/2018.beancount')
 
    config = testutil.TestConfig(books_path=books_path)
 
    accounts = ['Assets', 'Income']
 
    expected = frozenset(accounts[:end_index])
 
    rewrite_paths = [
 
        testutil.test_path(f'userconfig/Rewrite{s}.yml')
 
        for s in expected
 
    ]
 
    arglist = [f'--rewrite-rules={path}' for path in rewrite_paths]
 
    arglist.append('--format=txt')
 
    arglist.append('select any_meta("root") as root')
 
    returncode, stdout, stderr = pipe_main(arglist, config)
 
    assert returncode == 0
 
    stdout.seek(0)
 
    actual = frozenset(line.rstrip('\n') for line in stdout)
 
    assert expected.issubset(actual)
 
    assert frozenset(accounts).difference(expected).isdisjoint(actual)