diff --git a/README.rst b/README.rst index 38956ac5e6c0509bc3f2f4bf16935b0291e3d813..e4270156a210d89e8b28c15387ac88b6c9ad261b 100644 --- a/README.rst +++ b/README.rst @@ -76,6 +76,13 @@ date The date of the transaction, in your configured output format ------------------ ---------------------------------------------------------- payee The name of the transaction payee +------------------ ---------------------------------------------------------- +source_abspath The absolute path of the file being imported +------------------ ---------------------------------------------------------- +source_name The filename of the file being imported +------------------ ---------------------------------------------------------- +source_path The path of the file being imported, as specified on the + command line ================== ========================================================== Specific importers and hooks may provide additional variables. diff --git a/import2ledger/__main__.py b/import2ledger/__main__.py index 76fb9210209f8002c30addbf68e00bfb4d45d5f1..068460118add55a3249c7ad1b695a7c9504c93ad 100644 --- a/import2ledger/__main__.py +++ b/import2ledger/__main__.py @@ -1,3 +1,4 @@ +import collections import contextlib import logging import sys @@ -13,7 +14,9 @@ class FileImporter: self.hooks = [hook(config) for hook in hooks.load_all()] self.stdout = stdout - def import_file(self, in_file): + def import_file(self, in_file, in_path=None): + if in_path is None: + in_path = pathlib.Path(in_file.name) importers = [] for importer in self.importers: in_file.seek(0) @@ -31,6 +34,11 @@ class FileImporter: importers.append((importer, template)) if not importers: raise errors.UserInputFileError("no importers available", in_file.name) + source_vars = { + 'source_abspath': in_path.absolute().as_posix(), + 'source_name': in_path.name, + 'source_path': in_path.as_posix(), + } with contextlib.ExitStack() as exit_stack: output_path = self.config.get_output_path() if output_path is None: @@ -48,7 +56,8 @@ class FileImporter: break else: del entry_data['_hook_cancel'] - print(template.render(**entry_data), file=out_file, end='') + render_vars = collections.ChainMap(entry_data, source_vars) + print(template.render(render_vars), file=out_file, end='') def import_path(self, in_path): if in_path is None: @@ -56,7 +65,7 @@ class FileImporter: with in_path.open(errors='replace') as in_file: if not in_file.seekable(): raise errors.UserInputFileError("only seekable files are supported", in_path) - return self.import_file(in_file) + return self.import_file(in_file, in_path) def import_paths(self, path_seq): for in_path in path_seq: diff --git a/tests/data/test_main.ini b/tests/data/test_main.ini index 32c3779990d016160cb5a3f077629fad3050ce57..5babf5c7773bf383f8db8598979e902ac43bf6ec 100644 --- a/tests/data/test_main.ini +++ b/tests/data/test_main.ini @@ -8,5 +8,7 @@ template patreon cardfees = Accrued:Accounts Receivable -{amount} Expenses:Fees:Credit Card {amount} template patreon svcfees = + ;SourcePath: {source_abspath} + ;SourceName: {source_name} Accrued:Accounts Receivable -{amount} Expenses:Fundraising {amount} diff --git a/tests/data/test_main_fees_import.ledger b/tests/data/test_main_fees_import.ledger index b7acce49a4e3892b7988c46d37a3a6551e64f3fd..a2ec9b4b4dc128288dcd213b1e4e50ae1db04ac7 100644 --- a/tests/data/test_main_fees_import.ledger +++ b/tests/data/test_main_fees_import.ledger @@ -2,14 +2,18 @@ Accrued:Accounts Receivable $-52.47 Expenses:Fees:Credit Card $52.47 -2017/09/01 Patreon - Accrued:Accounts Receivable $-61.73 - Expenses:Fundraising $61.73 - 2017/10/01 Patreon Accrued:Accounts Receivable $-99.47 Expenses:Fees:Credit Card $99.47 +2017/09/01 Patreon + ;SourcePath: {source_abspath} + ;SourceName: {source_name} + Accrued:Accounts Receivable $-61.73 + Expenses:Fundraising $61.73 + 2017/10/01 Patreon + ;SourcePath: {source_abspath} + ;SourceName: {source_name} Accrued:Accounts Receivable $-117.03 Expenses:Fundraising $117.03 diff --git a/tests/test_main.py b/tests/test_main.py index cb1264678ad7d54da02ca06d24c4ec4315f6577b..d515aedf35657834295dd71fb9d493321c2668e3 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -30,35 +30,50 @@ def iter_entries(in_file): if lines: yield ''.join(lines) -def entries2set(in_file): - return set(normalize_whitespace(e) for e in iter_entries(in_file)) +def format_entry(entry_s, format_vars): + return normalize_whitespace(entry_s).format_map(format_vars) -def expected_entries(path): +def format_entries(source, format_vars=None): + if format_vars is None: + format_vars = {} + return (format_entry(e, format_vars) for e in iter_entries(source)) + +def expected_entries(path, format_vars=None): path = pathlib.Path(path) if not path.is_absolute(): path = DATA_DIR / path with path.open() as in_file: - return entries2set(in_file) + return list(format_entries(in_file, format_vars)) + +def path_vars(path): + return { + 'source_abspath': str(path), + 'source_name': path.name, + 'source_path': str(path), + } def test_fees_import(): + source_path = pathlib.Path(DATA_DIR, 'PatreonEarnings.csv') arglist = ARGLIST + [ '-c', 'One', - pathlib.Path(DATA_DIR, 'PatreonEarnings.csv').as_posix(), + source_path.as_posix(), ] exitcode, stdout, _ = run_main(arglist) assert exitcode == 0 - actual = entries2set(stdout) - assert actual == expected_entries('test_main_fees_import.ledger') + actual = list(format_entries(stdout)) + expected = expected_entries('test_main_fees_import.ledger', path_vars(source_path)) + assert actual == expected def test_date_range_import(): + source_path = pathlib.Path(DATA_DIR, 'PatreonEarnings.csv') arglist = ARGLIST + [ '-c', 'One', '--date-range', '2017/10/01-', - pathlib.Path(DATA_DIR, 'PatreonEarnings.csv').as_posix(), + source_path.as_posix(), ] exitcode, stdout, _ = run_main(arglist) assert exitcode == 0 - actual = entries2set(stdout) - expected = {entry for entry in expected_entries('test_main_fees_import.ledger') - if entry.startswith('2017/10/')} + actual = list(format_entries(stdout)) + valid = expected_entries('test_main_fees_import.ledger', path_vars(source_path)) + expected = [entry for entry in valid if entry.startswith('2017/10/')] assert actual == expected