diff --git a/conservancy_beancount/cliutil.py b/conservancy_beancount/cliutil.py index f2ba4acd5f023ead87f9c0d3143c8744d2fc5b7b..ceb3c4391dc110c79dc3a7b1b81f5b3662c42ffb 100644 --- a/conservancy_beancount/cliutil.py +++ b/conservancy_beancount/cliutil.py @@ -18,7 +18,6 @@ along with this program. If not, see .""" import argparse import enum -import inspect import io import logging import operator @@ -232,29 +231,50 @@ def add_version_argument(parser: argparse.ArgumentParser) -> argparse.Action: help="Show program version and license information", ) -def is_main_script(prog_name: str) -> bool: - """Return true if the caller is the "main" program.""" - stack = iter(inspect.stack(context=False)) - next(stack) # Discard the frame for calling this function - caller_filename = next(stack).filename - return all(frame.filename == caller_filename - or Path(frame.filename).stem == prog_name - for frame in stack) +def make_entry_point(mod_name: str, prog_name: str=sys.argv[0]) -> Callable[[], int]: + """Create an entry_point function for a tool + + The returned function is suitable for use as an entry_point in setup.py. + It sets up the root logger and excepthook, then calls the module's main + function. + """ + def entry_point(): # type:ignore + prog_mod = sys.modules[mod_name] + setup_logger() + prog_mod.logger = logging.getLogger(prog_name) + sys.excepthook = ExceptHook(prog_mod.logger) + return prog_mod.main() + return entry_point def setup_logger(logger: Union[str, logging.Logger]='', - loglevel: int=logging.INFO, stream: TextIO=sys.stderr, fmt: str='%(name)s: %(levelname)s: %(message)s', ) -> logging.Logger: + """Set up a logger with a StreamHandler with the given format""" if isinstance(logger, str): logger = logging.getLogger(logger) formatter = logging.Formatter(fmt) handler = logging.StreamHandler(stream) handler.setFormatter(formatter) logger.addHandler(handler) - logger.setLevel(loglevel) return logger +def set_loglevel(logger: logging.Logger, loglevel: int=logging.INFO) -> None: + """Set the loglevel for a tool or module + + If the given logger is not under a hierarchy, this function sets the + loglevel for the root logger, along with some specific levels for libraries + used by reporting tools. Otherwise, it's the same as + ``logger.setLevel(loglevel)``. + """ + if '.' not in logger.name: + logger = logging.getLogger() + if loglevel <= logging.DEBUG: + # At the debug level, the rt module logs the full body of every + # request and response. That's too much. + logging.getLogger('rt.rt').setLevel(logging.INFO) + logger.setLevel(loglevel) + def bytes_output(path: Optional[Path]=None, default: OutputFile=sys.stdout, mode: str='w', diff --git a/conservancy_beancount/reports/accrual.py b/conservancy_beancount/reports/accrual.py index ec2530679cc74cca5dea8d7e5f00e01a9ce71393..404ce31c5521d1a3743fb52c3c96caa23fb01c8c 100644 --- a/conservancy_beancount/reports/accrual.py +++ b/conservancy_beancount/reports/accrual.py @@ -663,12 +663,8 @@ def main(arglist: Optional[Sequence[str]]=None, stderr: TextIO=sys.stderr, config: Optional[configmod.Config]=None, ) -> int: - if cliutil.is_main_script(PROGNAME): - global logger - logger = logging.getLogger(PROGNAME) - sys.excepthook = cliutil.ExceptHook(logger) args = parse_arguments(arglist) - cliutil.setup_logger(logger, args.loglevel, stderr) + cliutil.set_loglevel(logger, args.loglevel) if config is None: config = configmod.Config() config.load_file() @@ -753,5 +749,7 @@ def main(arglist: Optional[Sequence[str]]=None, report.run(groups) return 0 if returncode == 0 else 16 + returncode +entry_point = cliutil.make_entry_point(__name__, PROGNAME) + if __name__ == '__main__': - exit(main()) + exit(entry_point()) diff --git a/setup.py b/setup.py index 94576a1b99c79d1d47a4257ba72471945610da3e..6996aa305bfc1428cb7f64c8d71f64b2ee56ca1d 100755 --- a/setup.py +++ b/setup.py @@ -35,7 +35,7 @@ setup( ], entry_points={ 'console_scripts': [ - 'accrual-report = conservancy_beancount.reports.accrual:main', + 'accrual-report = conservancy_beancount.reports.accrual:entry_point', ], }, ) diff --git a/tests/test_cliutil.py b/tests/test_cliutil.py index 19dc17010d4d3b59bf7f4dc83449b72fee0e2a94..8623461e99dd233870451202adf2c26a1f8f4aec 100644 --- a/tests/test_cliutil.py +++ b/tests/test_cliutil.py @@ -33,11 +33,6 @@ from conservancy_beancount import cliutil FILE_NAMES = ['-foobar', '-foo.bin'] STREAM_PATHS = [None, Path('-')] -class AlwaysEqual: - def __eq__(self, other): - return True - - class MockTraceback: def __init__(self, stack=None, index=0): if stack is None: @@ -141,13 +136,6 @@ def test_excepthook_traceback(caplog): assert caplog.records assert caplog.records[-1].message == ''.join(traceback.format_exception(*args)) -@pytest.mark.parametrize('prog_name,expected', [ - ('', False), - (AlwaysEqual(), True), -]) -def test_is_main_script(prog_name, expected): - assert cliutil.is_main_script(prog_name) == expected - @pytest.mark.parametrize('arg,expected', [ ('debug', logging.DEBUG), ('info', logging.INFO), @@ -166,11 +154,10 @@ def test_loglevel_argument(argparser, arg, expected): def test_setup_logger(): stream = io.StringIO() logger = cliutil.setup_logger( - 'test_cliutil', logging.INFO, stream, '%(name)s %(levelname)s: %(message)s', + 'test_cliutil', stream, '%(name)s %(levelname)s: %(message)s', ) - logger.debug("test debug") - logger.info("test info") - assert stream.getvalue() == "test_cliutil INFO: test info\n" + logger.critical("test crit") + assert stream.getvalue() == "test_cliutil CRITICAL: test crit\n" @pytest.mark.parametrize('arg', [ '--license', diff --git a/tests/test_reports_accrual.py b/tests/test_reports_accrual.py index fec4b0e736bb0bc0fae2a9e9e58688af7a4a890e..a6c1c78d9845f1dc3c54db97cbea166fc7cb880b 100644 --- a/tests/test_reports_accrual.py +++ b/tests/test_reports_accrual.py @@ -590,12 +590,13 @@ def run_main(arglist, config=None): retcode = accrual.main(arglist, output, errors, config) return retcode, output, errors -def check_main_fails(arglist, config, error_flags, error_patterns): +def check_main_fails(arglist, config, error_flags): retcode, output, errors = run_main(arglist, config) assert retcode > 16 assert (retcode - 16) & error_flags - check_output(errors, error_patterns) assert not output.getvalue() + errors.seek(0) + return errors @pytest.mark.parametrize('arglist', [ ['--report-type=balance', 'entity=EarlyBird'], @@ -686,7 +687,8 @@ def test_main_aging_report(tmp_path, arglist): check_aging_ods(ods_file, None, recv_rows, pay_rows) def test_main_no_books(): - check_main_fails([], testutil.TestConfig(), 1 | 8, [ + errors = check_main_fails([], testutil.TestConfig(), 1 | 8) + testutil.check_lines_match(iter(errors), [ r':[01]: +no books to load in configuration\b', ]) @@ -695,15 +697,17 @@ def test_main_no_books(): ['505/99999'], ['entity=NonExistent'], ]) -def test_main_no_matches(arglist): - check_main_fails(arglist, None, 8, [ - r': WARNING: no matching entries found to report$', +def test_main_no_matches(arglist, caplog): + check_main_fails(arglist, None, 8) + testutil.check_logs_match(caplog, [ + ('WARNING', 'no matching entries found to report'), ]) -def test_main_no_rt(): +def test_main_no_rt(caplog): config = testutil.TestConfig( books_path=testutil.test_path('books/accruals.beancount'), ) - check_main_fails(['-t', 'out'], config, 4, [ - r': ERROR: unable to generate outgoing report: RT client is required\b', + check_main_fails(['-t', 'out'], config, 4) + testutil.check_logs_match(caplog, [ + ('ERROR', 'unable to generate outgoing report: RT client is required'), ]) diff --git a/tests/testutil.py b/tests/testutil.py index a18ef2065e8d1420e28e832a264bc8238259a242..110369be208696a7a1ee547fc9f8c7625084124d 100644 --- a/tests/testutil.py +++ b/tests/testutil.py @@ -69,6 +69,14 @@ def check_lines_match(lines, expect_patterns, source='output'): assert any(re.search(pattern, line) for line in lines), \ f"{pattern!r} not found in {source}" +def check_logs_match(caplog, expected): + records = iter(caplog.records) + for exp_level, exp_msg in expected: + exp_level = exp_level.upper() + assert any( + log.levelname == exp_level and log.message == exp_msg for log in records + ), f"{exp_level} log {exp_msg!r} not found" + def check_post_meta(txn, *expected_meta, default=None): assert len(txn.postings) == len(expected_meta) for post, expected in zip(txn.postings, expected_meta):