diff --git a/tests/test_cliutil.py b/tests/test_cliutil.py
new file mode 100644
index 0000000000000000000000000000000000000000..e8a386775a8dbdfb8f5295df827ba3951b56581c
--- /dev/null
+++ b/tests/test_cliutil.py
@@ -0,0 +1,129 @@
+"""Test CLI utilities"""
+# Copyright © 2020 Brett Smith
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+import argparse
+import errno
+import io
+import inspect
+import logging
+import os
+import re
+import traceback
+
+import pytest
+
+from conservancy_beancount import cliutil
+
+class MockTraceback:
+ def __init__(self, stack=None, index=0):
+ if stack is None:
+ stack = inspect.stack(context=False)
+ self._stack = stack
+ self._index = index
+ frame_record = self._stack[self._index]
+ self.tb_frame = frame_record.frame
+ self.tb_lineno = frame_record.lineno
+
+ @property
+ def tb_next(self):
+ try:
+ return type(self)(self._stack, self._index + 1)
+ except IndexError:
+ return None
+
+
+@pytest.fixture(scope='module')
+def argparser():
+ parser = argparse.ArgumentParser(prog='test_cliutil')
+ cliutil.add_loglevel_argument(parser)
+ cliutil.add_version_argument(parser)
+ return parser
+
+@pytest.mark.parametrize('errnum', [
+ errno.EACCES,
+ errno.EPERM,
+ errno.ENOENT,
+])
+def test_excepthook_oserror(errnum, caplog):
+ error = OSError(errnum, os.strerror(errnum), 'TestFilename')
+ with pytest.raises(SystemExit) as exc_check:
+ cliutil.ExceptHook()(type(error), error, None)
+ assert exc_check.value.args[0] == 4
+ assert caplog.records
+ for log in caplog.records:
+ assert log.levelname == 'CRITICAL'
+ assert log.message == f"I/O error: {error.filename}: {error.strerror}"
+
+@pytest.mark.parametrize('exc_type', [
+ AttributeError,
+ RuntimeError,
+ ValueError,
+])
+def test_excepthook_bug(exc_type, caplog):
+ error = exc_type("test message")
+ with pytest.raises(SystemExit) as exc_check:
+ cliutil.ExceptHook()(exc_type, error, None)
+ assert exc_check.value.args[0] == 3
+ assert caplog.records
+ for log in caplog.records:
+ assert log.levelname == 'CRITICAL'
+ assert log.message == f"internal {exc_type.__name__}: {error.args[0]}"
+
+def test_excepthook_traceback(caplog):
+ error = KeyError('test')
+ args = (type(error), error, MockTraceback())
+ caplog.set_level(logging.DEBUG)
+ with pytest.raises(SystemExit) as exc_check:
+ cliutil.ExceptHook()(*args)
+ assert caplog.records
+ assert caplog.records[-1].message == ''.join(traceback.format_exception(*args))
+
+@pytest.mark.parametrize('arg,expected', [
+ ('debug', logging.DEBUG),
+ ('info', logging.INFO),
+ ('warning', logging.WARNING),
+ ('warn', logging.WARNING),
+ ('error', logging.ERROR),
+ ('err', logging.ERROR),
+ ('critical', logging.CRITICAL),
+ ('crit', logging.CRITICAL),
+])
+def test_loglevel_argument(argparser, arg, expected):
+ for method in ['lower', 'title', 'upper']:
+ args = argparser.parse_args(['--loglevel', getattr(arg, method)()])
+ assert args.loglevel is expected
+
+def test_setup_logger():
+ stream = io.StringIO()
+ logger = cliutil.setup_logger(
+ 'test_cliutil', logging.INFO, stream, '%(name)s %(levelname)s: %(message)s',
+ )
+ logger.debug("test debug")
+ logger.info("test info")
+ assert stream.getvalue() == "test_cliutil INFO: test info\n"
+
+@pytest.mark.parametrize('arg', [
+ '--license',
+ '--version',
+ '--copyright',
+])
+def test_version_argument(argparser, capsys, arg):
+ with pytest.raises(SystemExit) as exc_check:
+ args = argparser.parse_args(['--version'])
+ assert exc_check.value.args[0] == 0
+ stdout, _ = capsys.readouterr()
+ lines = iter(stdout.splitlines())
+ assert re.match(r'^test_cliutil version \d+\.\d+\.\d+', next(lines, ""))