From 32b62df5405f4510be631345394573b46a43a58d 2020-05-30 03:39:27 From: Brett Smith Date: 2020-05-30 03:39:27 Subject: [PATCH] cliutil: Better implementation of is_main_script. The old one could return True if you called accrual.main() directly from one-off test scripts. --- diff --git a/conservancy_beancount/cliutil.py b/conservancy_beancount/cliutil.py index 9f960cb987082f217f8de5832f58a5fabb5d6748..5a6fde3279c3321829fde7d3223e5ffc28215f27 100644 --- a/conservancy_beancount/cliutil.py +++ b/conservancy_beancount/cliutil.py @@ -28,6 +28,8 @@ import sys import traceback import types +from pathlib import Path + from typing import ( Any, Iterable, @@ -134,10 +136,14 @@ def add_version_argument(parser: argparse.ArgumentParser) -> argparse.Action: help="Show program version and license information", ) -def is_main_script() -> bool: +def is_main_script(prog_name: str) -> bool: """Return true if the caller is the "main" program.""" - stack = inspect.stack(context=False) - return len(stack) <= 3 and stack[-1].function.startswith('<') + stack = iter(inspect.stack(context=False)) + next(stack) # Discard the frame for calling this function + caller_filename = next(stack).filename + return all(frame.filename == caller_filename + or Path(frame.filename).stem == prog_name + for frame in stack) def setup_logger(logger: Union[str, logging.Logger]='', loglevel: int=logging.INFO, diff --git a/conservancy_beancount/reports/accrual.py b/conservancy_beancount/reports/accrual.py index 215d8afd7e038d729e226ce4b02ba54ed4d2368e..02fd3419a2272d88b43fd0d88c449bbdb6da9bba 100644 --- a/conservancy_beancount/reports/accrual.py +++ b/conservancy_beancount/reports/accrual.py @@ -391,7 +391,7 @@ def main(arglist: Optional[Sequence[str]]=None, stderr: TextIO=sys.stderr, config: Optional[configmod.Config]=None, ) -> int: - if cliutil.is_main_script(): + if cliutil.is_main_script(PROGNAME): global logger logger = logging.getLogger(PROGNAME) sys.excepthook = cliutil.ExceptHook(logger) diff --git a/tests/test_cliutil.py b/tests/test_cliutil.py index 08cddf36c92ef980a362624b2451df9655fba289..c1314c8055ff18e08b5f6d92c6d7231af7cca107 100644 --- a/tests/test_cliutil.py +++ b/tests/test_cliutil.py @@ -27,6 +27,11 @@ import pytest from conservancy_beancount import cliutil +class AlwaysEqual: + def __eq__(self, other): + return True + + class MockTraceback: def __init__(self, stack=None, index=0): if stack is None: @@ -91,8 +96,12 @@ def test_excepthook_traceback(caplog): assert caplog.records assert caplog.records[-1].message == ''.join(traceback.format_exception(*args)) -def test_is_main_script(): - assert not cliutil.is_main_script() +@pytest.mark.parametrize('prog_name,expected', [ + ('', False), + (AlwaysEqual(), True), +]) +def test_is_main_script(prog_name, expected): + assert cliutil.is_main_script(prog_name) == expected @pytest.mark.parametrize('arg,expected', [ ('debug', logging.DEBUG),