diff --git a/tests/test_cliutil.py b/tests/test_cliutil.py new file mode 100644 index 0000000000000000000000000000000000000000..e8a386775a8dbdfb8f5295df827ba3951b56581c --- /dev/null +++ b/tests/test_cliutil.py @@ -0,0 +1,129 @@ +"""Test CLI utilities""" +# Copyright © 2020 Brett Smith +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +import argparse +import errno +import io +import inspect +import logging +import os +import re +import traceback + +import pytest + +from conservancy_beancount import cliutil + +class MockTraceback: + def __init__(self, stack=None, index=0): + if stack is None: + stack = inspect.stack(context=False) + self._stack = stack + self._index = index + frame_record = self._stack[self._index] + self.tb_frame = frame_record.frame + self.tb_lineno = frame_record.lineno + + @property + def tb_next(self): + try: + return type(self)(self._stack, self._index + 1) + except IndexError: + return None + + +@pytest.fixture(scope='module') +def argparser(): + parser = argparse.ArgumentParser(prog='test_cliutil') + cliutil.add_loglevel_argument(parser) + cliutil.add_version_argument(parser) + return parser + +@pytest.mark.parametrize('errnum', [ + errno.EACCES, + errno.EPERM, + errno.ENOENT, +]) +def test_excepthook_oserror(errnum, caplog): + error = OSError(errnum, os.strerror(errnum), 'TestFilename') + with pytest.raises(SystemExit) as exc_check: + cliutil.ExceptHook()(type(error), error, None) + assert exc_check.value.args[0] == 4 + assert caplog.records + for log in caplog.records: + assert log.levelname == 'CRITICAL' + assert log.message == f"I/O error: {error.filename}: {error.strerror}" + +@pytest.mark.parametrize('exc_type', [ + AttributeError, + RuntimeError, + ValueError, +]) +def test_excepthook_bug(exc_type, caplog): + error = exc_type("test message") + with pytest.raises(SystemExit) as exc_check: + cliutil.ExceptHook()(exc_type, error, None) + assert exc_check.value.args[0] == 3 + assert caplog.records + for log in caplog.records: + assert log.levelname == 'CRITICAL' + assert log.message == f"internal {exc_type.__name__}: {error.args[0]}" + +def test_excepthook_traceback(caplog): + error = KeyError('test') + args = (type(error), error, MockTraceback()) + caplog.set_level(logging.DEBUG) + with pytest.raises(SystemExit) as exc_check: + cliutil.ExceptHook()(*args) + assert caplog.records + assert caplog.records[-1].message == ''.join(traceback.format_exception(*args)) + +@pytest.mark.parametrize('arg,expected', [ + ('debug', logging.DEBUG), + ('info', logging.INFO), + ('warning', logging.WARNING), + ('warn', logging.WARNING), + ('error', logging.ERROR), + ('err', logging.ERROR), + ('critical', logging.CRITICAL), + ('crit', logging.CRITICAL), +]) +def test_loglevel_argument(argparser, arg, expected): + for method in ['lower', 'title', 'upper']: + args = argparser.parse_args(['--loglevel', getattr(arg, method)()]) + assert args.loglevel is expected + +def test_setup_logger(): + stream = io.StringIO() + logger = cliutil.setup_logger( + 'test_cliutil', logging.INFO, stream, '%(name)s %(levelname)s: %(message)s', + ) + logger.debug("test debug") + logger.info("test info") + assert stream.getvalue() == "test_cliutil INFO: test info\n" + +@pytest.mark.parametrize('arg', [ + '--license', + '--version', + '--copyright', +]) +def test_version_argument(argparser, capsys, arg): + with pytest.raises(SystemExit) as exc_check: + args = argparser.parse_args(['--version']) + assert exc_check.value.args[0] == 0 + stdout, _ = capsys.readouterr() + lines = iter(stdout.splitlines()) + assert re.match(r'^test_cliutil version \d+\.\d+\.\d+', next(lines, ""))