From 21f6d078a285a938a30e579c72543f751c7d412c 2018-04-03 21:32:54 From: Brett Smith Date: 2018-04-03 21:32:54 Subject: [PATCH] outfile: New module to take over functionality from config. With the concept of the general "output path" going away, it doesn't make sense for Config to have an open_output_file method. --- diff --git a/import2ledger/config.py b/import2ledger/config.py index acc924dbdd81e0b2d5f3f22f3df5ad0e5889e4a1..220485813bfbf8dd6ee01332d64da115a3af4f16 100644 --- a/import2ledger/config.py +++ b/import2ledger/config.py @@ -176,14 +176,6 @@ class Configuration: default_section['output_path'] = self.args.output_path self._read_dates_args() - @contextlib.contextmanager - def _open_path(self, path, fallback_file, *args, **kwargs): - if path is None: - yield fallback_file - else: - with path.open(*args, **kwargs) as open_file: - yield open_file - def get_section(self, section_name): if section_name is None: section_name = self.args.use_config @@ -215,14 +207,6 @@ class Configuration: except AttributeError: raise errors.UserInputConfigurationError("not a valid loglevel", level_name) - def get_output_path(self, section_name=None): - section_config = self.get_section(section_name) - return self._s_to_path(section_config['output_path']) - - def open_output_file(self, section_name=None): - path = self.get_output_path(section_name) - return self._open_path(path, self.stdout, 'a') - def setup_logger(self, logger, section_name=None): logger.setLevel(self.get_loglevel(section_name)) diff --git a/import2ledger/hooks/ledger_entry.py b/import2ledger/hooks/ledger_entry.py index ef03f704561ecd70cb98438161d8148851c88e93..cd3a01b1247759a4d4c4f2ac0a3dd4f7062041e1 100644 --- a/import2ledger/hooks/ledger_entry.py +++ b/import2ledger/hooks/ledger_entry.py @@ -11,7 +11,7 @@ import tokenize import babel.numbers from . import HOOK_KINDS -from .. import errors, strparse +from .. import errors, outfile, strparse logger = logging.getLogger('import2ledger.hooks.ledger_entry') @@ -324,5 +324,8 @@ class LedgerEntryHook: if template.is_empty(): logger.warning("no Ledger template defined as %r", template_name) else: - with self.config.open_output_file() as out_file: + with outfile.open( + self.config_section.get('output_path', '-'), + self.config.stdout, + ) as out_file: print(template.render(entry_data), file=out_file, end='') diff --git a/import2ledger/outfile.py b/import2ledger/outfile.py new file mode 100644 index 0000000000000000000000000000000000000000..0647d2fcbc6cad0a2c78f682889bd22a7db3813c --- /dev/null +++ b/import2ledger/outfile.py @@ -0,0 +1,10 @@ +import contextlib +import pathlib + +@contextlib.contextmanager +def open(path, stdpipe, mode='a', **kwargs): + if path == '-': + yield stdpipe + else: + with pathlib.Path(path).open(mode, **kwargs) as retval: + yield retval diff --git a/tests/test_config.py b/tests/test_config.py index a71825b672c5bf0664846230d08d827cfb86cfc7..f068a6048979ab6e54c6d6290525a199c80ff463 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -34,16 +34,6 @@ def test_get_section_falls_back_to_default(): assert section.get('output_path') != 'Template.output' assert section['signed_currencies'] == 'EUR' -@pytest.mark.parametrize('arg_s', [None, '-', 'output.ledger']) -def test_output_path(arg_s): - arglist = [] if arg_s is None else ['-O', arg_s] - config = config_from_file(os.devnull, arglist) - output_path = config.get_output_path() - if (arg_s is None) or (arg_s == '-'): - assert output_path is None - else: - assert output_path == pathlib.Path(arg_s) - def _fix_date_s(date_s, new_sep, old_sep='/'): return date_s.replace(old_sep, new_sep) diff --git a/tests/test_outfile.py b/tests/test_outfile.py new file mode 100644 index 0000000000000000000000000000000000000000..37f5c9722c9a1a6134d1b3ba2c89c10abead6bf4 --- /dev/null +++ b/tests/test_outfile.py @@ -0,0 +1,23 @@ +import pathlib +import tempfile + +import pytest + +from import2ledger import outfile + +@pytest.mark.parametrize('to_path', [ + lambda s: s, + pathlib.Path, +]) +def test_output_path(to_path): + with tempfile.NamedTemporaryFile() as source_file: + path_arg = to_path(source_file.name) + with outfile.open(path_arg, None) as actual: + assert actual.name == source_file.name + assert 'a' in actual.mode + +def test_fallback(): + with tempfile.NamedTemporaryFile() as source_file: + with outfile.open('-', source_file) as actual: + assert actual is source_file + assert not source_file.closed