From 2b5cb0eca6e0b4700ef72f09829299460afab140 2020-06-09 13:04:27 From: Brett Smith Date: 2020-06-09 13:04:27 Subject: [PATCH] cliutil: Add bytes_output() and text_output() functions. --- diff --git a/conservancy_beancount/cliutil.py b/conservancy_beancount/cliutil.py index e12652bc768a61459d6e7fcb96f78800cbf493e3..f2ba4acd5f023ead87f9c0d3143c8744d2fc5b7b 100644 --- a/conservancy_beancount/cliutil.py +++ b/conservancy_beancount/cliutil.py @@ -19,6 +19,7 @@ along with this program. If not, see .""" import argparse import enum import inspect +import io import logging import operator import os @@ -36,8 +37,11 @@ from . import filters from . import rtutil from typing import ( + cast, Any, + BinaryIO, Callable, + IO, Iterable, NamedTuple, NoReturn, @@ -51,6 +55,9 @@ from .beancount_types import ( MetaKey, ) +OutputFile = Union[int, IO] + +STDSTREAM_PATH = Path('-') VERSION = pkg_resources.require(PKGNAME)[0].version class ExceptHook: @@ -247,3 +254,51 @@ def setup_logger(logger: Union[str, logging.Logger]='', logger.addHandler(handler) logger.setLevel(loglevel) return logger + +def bytes_output(path: Optional[Path]=None, + default: OutputFile=sys.stdout, + mode: str='w', +) -> BinaryIO: + """Get a file-like object suitable for binary output + + If ``path`` is ``None`` or ``-``, returns a file-like object backed by + ``default``. If ``default`` is a file descriptor or text IO object, this + method returns a file-like object that writes to the same place. + + Otherwise, returns ``path.open(mode)``. + """ + mode = f'{mode}b' + if path is None or path == STDSTREAM_PATH: + if isinstance(default, int): + retval = open(default, mode) + elif isinstance(default, TextIO): + retval = default.buffer + else: + retval = default + else: + retval = path.open(mode) + return cast(BinaryIO, retval) + +def text_output(path: Optional[Path]=None, + default: OutputFile=sys.stdout, + mode: str='w', + encoding: Optional[str]=None, +) -> TextIO: + """Get a file-like object suitable for text output + + If ``path`` is ``None`` or ``-``, returns a file-like object backed by + ``default``. If ``default`` is a file descriptor or binary IO object, this + method returns a file-like object that writes to the same place. + + Otherwise, returns ``path.open(mode)``. + """ + if path is None or path == STDSTREAM_PATH: + if isinstance(default, int): + retval = open(default, mode, encoding=encoding) + elif isinstance(default, BinaryIO): + retval = io.TextIOWrapper(default, encoding=encoding) + else: + retval = default + else: + retval = path.open(mode, encoding=encoding) + return cast(TextIO, retval) diff --git a/conservancy_beancount/reports/accrual.py b/conservancy_beancount/reports/accrual.py index 6f1e713459011d77e151ad374062c3513f3a3fdf..f8c2e9a8d39b098bde0ab8f4881f6af6e1901366 100644 --- a/conservancy_beancount/reports/accrual.py +++ b/conservancy_beancount/reports/accrual.py @@ -120,7 +120,6 @@ from .. import filters from .. import rtutil PROGNAME = 'accrual-report' -STANDARD_PATH = Path('-') CompoundAmount = TypeVar('CompoundAmount', data.Amount, core.Balance) PostGroups = Mapping[Optional[MetaValue], 'AccrualPostings'] @@ -677,28 +676,6 @@ metadata to match. A single ticket number is a shortcut for args.report_type = ReportType.AGING return args -def get_output_path(output_path: Optional[Path], - default_path: Path=STANDARD_PATH, -) -> Optional[Path]: - if output_path is None: - output_path = default_path - if output_path == STANDARD_PATH: - return None - else: - return output_path - -def get_output_bin(path: Optional[Path], stdout: TextIO) -> BinaryIO: - if path is None: - return open(stdout.fileno(), 'wb') - else: - return path.open('wb') - -def get_output_text(path: Optional[Path], stdout: TextIO) -> TextIO: - if path is None: - return stdout - else: - return path.open('w') - def main(arglist: Optional[Sequence[str]]=None, stdout: TextIO=sys.stdout, stderr: TextIO=sys.stderr, @@ -762,29 +739,26 @@ def main(arglist: Optional[Sequence[str]]=None, logger.error("unable to generate aging report: RT client is required") else: now = datetime.datetime.now() - default_path = Path(now.strftime('AgingReport_%Y-%m-%d_%H:%M.ods')) - output_path = get_output_path(args.output_file, default_path) - out_bin = get_output_bin(output_path, stdout) + if args.output_file is None: + args.output_file = Path(now.strftime('AgingReport_%Y-%m-%d_%H:%M.ods')) + logger.info("Writing report to %s", args.output_file) + out_bin = cliutil.bytes_output(args.output_file, stdout) report = AgingReport(rt_client, out_bin) elif args.report_type is ReportType.OUTGOING: rt_client = config.rt_client() if rt_client is None: logger.error("unable to generate outgoing report: RT client is required") else: - output_path = get_output_path(args.output_file) - out_file = get_output_text(output_path, stdout) + out_file = cliutil.text_output(args.output_file, stdout) report = OutgoingReport(rt_client, out_file) else: - output_path = get_output_path(args.output_file) - out_file = get_output_text(output_path, stdout) + out_file = cliutil.text_output(args.output_file, stdout) report = args.report_type.value(out_file) if report is None: returncode |= ReturnFlag.REPORT_ERRORS else: report.run(groups) - if args.output_file != output_path: - logger.info("Report saved to %s", output_path) return 0 if returncode == 0 else 16 + returncode if __name__ == '__main__': diff --git a/tests/test_cliutil.py b/tests/test_cliutil.py index c1314c8055ff18e08b5f6d92c6d7231af7cca107..19dc17010d4d3b59bf7f4dc83449b72fee0e2a94 100644 --- a/tests/test_cliutil.py +++ b/tests/test_cliutil.py @@ -21,12 +21,18 @@ import inspect import logging import os import re +import sys import traceback import pytest +from pathlib import Path + from conservancy_beancount import cliutil +FILE_NAMES = ['-foobar', '-foo.bin'] +STREAM_PATHS = [None, Path('-')] + class AlwaysEqual: def __eq__(self, other): return True @@ -57,6 +63,45 @@ def argparser(): cliutil.add_version_argument(parser) return parser +@pytest.mark.parametrize('path_name', FILE_NAMES) +def test_bytes_output_path(path_name, tmp_path): + path = tmp_path / path_name + stream = io.BytesIO() + actual = cliutil.bytes_output(path, stream) + assert actual is not stream + assert str(actual.name) == str(path) + assert 'w' in actual.mode + assert 'b' in actual.mode + +@pytest.mark.parametrize('path', STREAM_PATHS) +def test_bytes_output_stream(path): + stream = io.BytesIO() + actual = cliutil.bytes_output(path, stream) + assert actual is stream + +@pytest.mark.parametrize('func_name', [ + 'bytes_output', + 'text_output', +]) +def test_default_output(func_name): + actual = getattr(cliutil, func_name)() + assert actual.fileno() == sys.stdout.fileno() + +@pytest.mark.parametrize('path_name', FILE_NAMES) +def test_text_output_path(path_name, tmp_path): + path = tmp_path / path_name + stream = io.StringIO() + actual = cliutil.text_output(path, stream) + assert actual is not stream + assert str(actual.name) == str(path) + assert 'w' in actual.mode + +@pytest.mark.parametrize('path', STREAM_PATHS) +def test_text_output_stream(path): + stream = io.StringIO() + actual = cliutil.text_output(path, stream) + assert actual is stream + @pytest.mark.parametrize('errnum', [ errno.EACCES, errno.EPERM,