diff --git a/tests/test_reports_accrual.py b/tests/test_reports_accrual.py index 2ca92b44de0b591fd9e7af08a1ba2bbc080f54d8..f99706c6276d5216e1ab5b36ff3462a80249ef89 100644 --- a/tests/test_reports_accrual.py +++ b/tests/test_reports_accrual.py @@ -577,15 +577,19 @@ def test_aging_report_does_not_include_too_recent_postings(accrual_postings): project='Development Grant'), ], []) -def run_main(arglist, config=None): +def run_main(arglist, config=None, out_type=io.StringIO): if config is None: config = testutil.TestConfig( books_path=testutil.test_path('books/accruals.beancount'), rt_client=RTClient(), ) - output = io.StringIO() + if out_type is io.BytesIO: + arglist.insert(0, '--output-file=-') + output = out_type() errors = io.StringIO() retcode = accrual.main(arglist, output, errors, config) + output.seek(0) + errors.seek(0) return retcode, output, errors def check_main_fails(arglist, config, error_flags): @@ -593,7 +597,6 @@ def check_main_fails(arglist, config, error_flags): assert retcode > 16 assert (retcode - 16) & error_flags assert not output.getvalue() - errors.seek(0) return errors @pytest.mark.parametrize('arglist', [ @@ -624,7 +627,7 @@ def test_output_payments_when_only_match(arglist, expect_invoice): @pytest.mark.parametrize('arglist,expect_amount', [ (['310'], 420), (['310/3120'], 220), - (['entity=Vendor'], 420), + (['-t', 'out', 'entity=Vendor'], 420), ]) def test_main_outgoing_report(arglist, expect_amount): retcode, output, errors = run_main(arglist) @@ -643,7 +646,6 @@ def test_main_outgoing_report(arglist, expect_amount): @pytest.mark.parametrize('arglist', [ ['-t', 'balance'], ['515/5150'], - ['entity=MatchingProgram'], ]) def test_main_balance_report(arglist): retcode, output, errors = run_main(arglist) @@ -666,23 +668,19 @@ def test_main_balance_report_because_no_rt_id(): @pytest.mark.parametrize('arglist', [ [], - ['-t', 'aging', 'entity=Lawyer'], + ['entity=Lawyer'], ]) -def test_main_aging_report(tmp_path, arglist): +def test_main_aging_report(arglist): if arglist: recv_rows = [row for row in AGING_AR if 'Lawyer' in row.entity] pay_rows = [row for row in AGING_AP if 'Lawyer' in row.entity] else: recv_rows = AGING_AR pay_rows = AGING_AP - output_path = tmp_path / 'AgingReport.ods' - arglist.insert(0, f'--output-file={output_path}') - retcode, output, errors = run_main(arglist) + retcode, output, errors = run_main(arglist, out_type=io.BytesIO) assert not errors.getvalue() assert retcode == 0 - assert not output.getvalue() - with output_path.open('rb') as ods_file: - check_aging_ods(ods_file, None, recv_rows, pay_rows) + check_aging_ods(output, None, recv_rows, pay_rows) def test_main_no_books(): errors = check_main_fails([], testutil.TestConfig(), 1 | 8) @@ -693,7 +691,7 @@ def test_main_no_books(): @pytest.mark.parametrize('arglist', [ ['499'], ['505/99999'], - ['entity=NonExistent'], + ['-t', 'balance', 'entity=NonExistent'], ]) def test_main_no_matches(arglist, caplog): check_main_fails(arglist, None, 8)