diff --git a/tests/test_cliutil.py b/tests/test_cliutil.py index c1314c8055ff18e08b5f6d92c6d7231af7cca107..19dc17010d4d3b59bf7f4dc83449b72fee0e2a94 100644 --- a/tests/test_cliutil.py +++ b/tests/test_cliutil.py @@ -21,12 +21,18 @@ import inspect import logging import os import re +import sys import traceback import pytest +from pathlib import Path + from conservancy_beancount import cliutil +FILE_NAMES = ['-foobar', '-foo.bin'] +STREAM_PATHS = [None, Path('-')] + class AlwaysEqual: def __eq__(self, other): return True @@ -57,6 +63,45 @@ def argparser(): cliutil.add_version_argument(parser) return parser +@pytest.mark.parametrize('path_name', FILE_NAMES) +def test_bytes_output_path(path_name, tmp_path): + path = tmp_path / path_name + stream = io.BytesIO() + actual = cliutil.bytes_output(path, stream) + assert actual is not stream + assert str(actual.name) == str(path) + assert 'w' in actual.mode + assert 'b' in actual.mode + +@pytest.mark.parametrize('path', STREAM_PATHS) +def test_bytes_output_stream(path): + stream = io.BytesIO() + actual = cliutil.bytes_output(path, stream) + assert actual is stream + +@pytest.mark.parametrize('func_name', [ + 'bytes_output', + 'text_output', +]) +def test_default_output(func_name): + actual = getattr(cliutil, func_name)() + assert actual.fileno() == sys.stdout.fileno() + +@pytest.mark.parametrize('path_name', FILE_NAMES) +def test_text_output_path(path_name, tmp_path): + path = tmp_path / path_name + stream = io.StringIO() + actual = cliutil.text_output(path, stream) + assert actual is not stream + assert str(actual.name) == str(path) + assert 'w' in actual.mode + +@pytest.mark.parametrize('path', STREAM_PATHS) +def test_text_output_stream(path): + stream = io.StringIO() + actual = cliutil.text_output(path, stream) + assert actual is stream + @pytest.mark.parametrize('errnum', [ errno.EACCES, errno.EPERM,