diff --git a/tests/test_main.py b/tests/test_main.py index cb1264678ad7d54da02ca06d24c4ec4315f6577b..d515aedf35657834295dd71fb9d493321c2668e3 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -30,35 +30,50 @@ def iter_entries(in_file): if lines: yield ''.join(lines) -def entries2set(in_file): - return set(normalize_whitespace(e) for e in iter_entries(in_file)) +def format_entry(entry_s, format_vars): + return normalize_whitespace(entry_s).format_map(format_vars) -def expected_entries(path): +def format_entries(source, format_vars=None): + if format_vars is None: + format_vars = {} + return (format_entry(e, format_vars) for e in iter_entries(source)) + +def expected_entries(path, format_vars=None): path = pathlib.Path(path) if not path.is_absolute(): path = DATA_DIR / path with path.open() as in_file: - return entries2set(in_file) + return list(format_entries(in_file, format_vars)) + +def path_vars(path): + return { + 'source_abspath': str(path), + 'source_name': path.name, + 'source_path': str(path), + } def test_fees_import(): + source_path = pathlib.Path(DATA_DIR, 'PatreonEarnings.csv') arglist = ARGLIST + [ '-c', 'One', - pathlib.Path(DATA_DIR, 'PatreonEarnings.csv').as_posix(), + source_path.as_posix(), ] exitcode, stdout, _ = run_main(arglist) assert exitcode == 0 - actual = entries2set(stdout) - assert actual == expected_entries('test_main_fees_import.ledger') + actual = list(format_entries(stdout)) + expected = expected_entries('test_main_fees_import.ledger', path_vars(source_path)) + assert actual == expected def test_date_range_import(): + source_path = pathlib.Path(DATA_DIR, 'PatreonEarnings.csv') arglist = ARGLIST + [ '-c', 'One', '--date-range', '2017/10/01-', - pathlib.Path(DATA_DIR, 'PatreonEarnings.csv').as_posix(), + source_path.as_posix(), ] exitcode, stdout, _ = run_main(arglist) assert exitcode == 0 - actual = entries2set(stdout) - expected = {entry for entry in expected_entries('test_main_fees_import.ledger') - if entry.startswith('2017/10/')} + actual = list(format_entries(stdout)) + valid = expected_entries('test_main_fees_import.ledger', path_vars(source_path)) + expected = [entry for entry in valid if entry.startswith('2017/10/')] assert actual == expected