diff --git a/tests/test_cliutil.py b/tests/test_cliutil.py index 08cddf36c92ef980a362624b2451df9655fba289..c1314c8055ff18e08b5f6d92c6d7231af7cca107 100644 --- a/tests/test_cliutil.py +++ b/tests/test_cliutil.py @@ -27,6 +27,11 @@ import pytest from conservancy_beancount import cliutil +class AlwaysEqual: + def __eq__(self, other): + return True + + class MockTraceback: def __init__(self, stack=None, index=0): if stack is None: @@ -91,8 +96,12 @@ def test_excepthook_traceback(caplog): assert caplog.records assert caplog.records[-1].message == ''.join(traceback.format_exception(*args)) -def test_is_main_script(): - assert not cliutil.is_main_script() +@pytest.mark.parametrize('prog_name,expected', [ + ('', False), + (AlwaysEqual(), True), +]) +def test_is_main_script(prog_name, expected): + assert cliutil.is_main_script(prog_name) == expected @pytest.mark.parametrize('arg,expected', [ ('debug', logging.DEBUG),