Changeset - 32b62df5405f
[Not reviewed]
0 3 0
Brett Smith - 4 years ago 2020-05-30 03:39:27
brettcsmith@brettcsmith.org
cliutil: Better implementation of is_main_script.

The old one could return True if you called accrual.main()
directly from one-off test scripts.
3 files changed with 21 insertions and 6 deletions:
0 comments (0 inline, 0 general)
conservancy_beancount/cliutil.py
Show inline comments
...
 
@@ -7,48 +7,50 @@ This program is free software: you can redistribute it and/or modify
 
it under the terms of the GNU Affero General Public License as published by
 
the Free Software Foundation, either version 3 of the License, or
 
(at your option) any later version.
 

	
 
This program is distributed in the hope that it will be useful,
 
but WITHOUT ANY WARRANTY; without even the implied warranty of
 
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
GNU Affero General Public License for more details.
 

	
 
You should have received a copy of the GNU Affero General Public License
 
along with this program.  If not, see <https://www.gnu.org/licenses/>."""
 

	
 
import argparse
 
import enum
 
import inspect
 
import logging
 
import operator
 
import os
 
import pkg_resources
 
import signal
 
import sys
 
import traceback
 
import types
 

	
 
from pathlib import Path
 

	
 
from typing import (
 
    Any,
 
    Iterable,
 
    NoReturn,
 
    Optional,
 
    Sequence,
 
    TextIO,
 
    Type,
 
    Union,
 
)
 

	
 
VERSION = pkg_resources.require(PKGNAME)[0].version
 

	
 
class ExceptHook:
 
    def __init__(self,
 
                 logger: Optional[logging.Logger]=None,
 
                 default_exitcode: int=3,
 
    ) -> None:
 
        if logger is None:
 
            logger = logging.getLogger()
 
        self.logger = logger
 
        self.default_exitcode = default_exitcode
 

	
 
    def __call__(self,
...
 
@@ -113,42 +115,46 @@ class LogLevel(enum.IntEnum):
 
            yield level.name.lower()
 

	
 
def add_loglevel_argument(parser: argparse.ArgumentParser,
 
                          default: LogLevel=LogLevel.INFO) -> argparse.Action:
 
    return parser.add_argument(
 
        '--loglevel',
 
        metavar='LEVEL',
 
        default=default.value,
 
        type=LogLevel.from_arg,
 
        help="Show logs at this level and above."
 
        f" Specify one of {', '.join(LogLevel.choices())}."
 
        f" Default {default.name.lower()}.",
 
    )
 

	
 
def add_version_argument(parser: argparse.ArgumentParser) -> argparse.Action:
 
    progname = parser.prog or sys.argv[0]
 
    return parser.add_argument(
 
        '--version', '--copyright', '--license',
 
        action=InfoAction,
 
        nargs=0,
 
        const=f"{progname} version {VERSION}\n{LICENSE}",
 
        help="Show program version and license information",
 
    )
 

	
 
def is_main_script() -> bool:
 
def is_main_script(prog_name: str) -> bool:
 
    """Return true if the caller is the "main" program."""
 
    stack = inspect.stack(context=False)
 
    return len(stack) <= 3 and stack[-1].function.startswith('<')
 
    stack = iter(inspect.stack(context=False))
 
    next(stack)  # Discard the frame for calling this function
 
    caller_filename = next(stack).filename
 
    return all(frame.filename == caller_filename
 
               or Path(frame.filename).stem == prog_name
 
               for frame in stack)
 

	
 
def setup_logger(logger: Union[str, logging.Logger]='',
 
                 loglevel: int=logging.INFO,
 
                 stream: TextIO=sys.stderr,
 
                 fmt: str='%(name)s: %(levelname)s: %(message)s',
 
) -> logging.Logger:
 
    if isinstance(logger, str):
 
        logger = logging.getLogger(logger)
 
    formatter = logging.Formatter(fmt)
 
    handler = logging.StreamHandler(stream)
 
    handler.setFormatter(formatter)
 
    logger.addHandler(handler)
 
    logger.setLevel(loglevel)
 
    return logger
conservancy_beancount/reports/accrual.py
Show inline comments
...
 
@@ -370,49 +370,49 @@ single outstanding payable, and `balance` any other time.
 
        help="""How far back to search the books for related transactions.
 
You can either specify a fiscal year, or a negative offset from the current
 
fiscal year, to start loading entries from. The default is -1 (start from the
 
previous fiscal year).
 
""")
 
    cliutil.add_loglevel_argument(parser)
 
    parser.add_argument(
 
        'search',
 
        nargs=argparse.ZERO_OR_MORE,
 
        help="""Report on accruals that match this criteria. The format is
 
NAME=TERM. TERM is a link or word that must exist in a posting's NAME
 
metadata to match. A single ticket number is a shortcut for
 
`rt-id=rt:NUMBER`. Any other link, including an RT attachment link in
 
`TIK/ATT` format, is a shortcut for `invoice=LINK`.
 
""")
 
    args = parser.parse_args(arglist)
 
    args.search_terms = [SearchTerm.parse(s) for s in args.search]
 
    return args
 

	
 
def main(arglist: Optional[Sequence[str]]=None,
 
         stdout: TextIO=sys.stdout,
 
         stderr: TextIO=sys.stderr,
 
         config: Optional[configmod.Config]=None,
 
) -> int:
 
    if cliutil.is_main_script():
 
    if cliutil.is_main_script(PROGNAME):
 
        global logger
 
        logger = logging.getLogger(PROGNAME)
 
        sys.excepthook = cliutil.ExceptHook(logger)
 
    args = parse_arguments(arglist)
 
    cliutil.setup_logger(logger, args.loglevel, stderr)
 
    if config is None:
 
        config = configmod.Config()
 
        config.load_file()
 
    books_loader = config.books_loader()
 
    if books_loader is not None:
 
        entries, load_errors, _ = books_loader.load_fy_range(args.since)
 
    else:
 
        entries = []
 
        source = {
 
            'filename': str(config.config_file_path()),
 
            'lineno': 1,
 
        }
 
        load_errors = [Error(source, "no books to load in configuration", None)]
 
    postings = filter_search(data.Posting.from_entries(entries), args.search_terms)
 
    groups = core.RelatedPostings.group_by_meta(postings, 'invoice')
 
    groups = AccrualAccount.filter_paid_accruals(groups) or groups
 
    meta_errors = consistency_check(groups)
 
    returncode = 0
 
    for error in load_errors:
tests/test_cliutil.py
Show inline comments
...
 
@@ -6,48 +6,53 @@
 
# the Free Software Foundation, either version 3 of the License, or
 
# (at your option) any later version.
 
#
 
# This program is distributed in the hope that it will be useful,
 
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
# GNU Affero General Public License for more details.
 
#
 
# You should have received a copy of the GNU Affero General Public License
 
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
 

	
 
import argparse
 
import errno
 
import io
 
import inspect
 
import logging
 
import os
 
import re
 
import traceback
 

	
 
import pytest
 

	
 
from conservancy_beancount import cliutil
 

	
 
class AlwaysEqual:
 
    def __eq__(self, other):
 
        return True
 

	
 

	
 
class MockTraceback:
 
    def __init__(self, stack=None, index=0):
 
        if stack is None:
 
            stack = inspect.stack(context=False)
 
        self._stack = stack
 
        self._index = index
 
        frame_record = self._stack[self._index]
 
        self.tb_frame = frame_record.frame
 
        self.tb_lineno = frame_record.lineno
 

	
 
    @property
 
    def tb_next(self):
 
        try:
 
            return type(self)(self._stack, self._index + 1)
 
        except IndexError:
 
            return None
 

	
 

	
 
@pytest.fixture(scope='module')
 
def argparser():
 
    parser = argparse.ArgumentParser(prog='test_cliutil')
 
    cliutil.add_loglevel_argument(parser)
 
    cliutil.add_version_argument(parser)
 
    return parser
...
 
@@ -70,50 +75,54 @@ def test_excepthook_oserror(errnum, caplog):
 
@pytest.mark.parametrize('exc_type', [
 
    AttributeError,
 
    RuntimeError,
 
    ValueError,
 
])
 
def test_excepthook_bug(exc_type, caplog):
 
    error = exc_type("test message")
 
    with pytest.raises(SystemExit) as exc_check:
 
        cliutil.ExceptHook()(exc_type, error, None)
 
    assert exc_check.value.args[0] == 3
 
    assert caplog.records
 
    for log in caplog.records:
 
        assert log.levelname == 'CRITICAL'
 
        assert log.message == f"internal {exc_type.__name__}: {error.args[0]}"
 

	
 
def test_excepthook_traceback(caplog):
 
    error = KeyError('test')
 
    args = (type(error), error, MockTraceback())
 
    caplog.set_level(logging.DEBUG)
 
    with pytest.raises(SystemExit) as exc_check:
 
        cliutil.ExceptHook()(*args)
 
    assert caplog.records
 
    assert caplog.records[-1].message == ''.join(traceback.format_exception(*args))
 

	
 
def test_is_main_script():
 
    assert not cliutil.is_main_script()
 
@pytest.mark.parametrize('prog_name,expected', [
 
    ('', False),
 
    (AlwaysEqual(), True),
 
])
 
def test_is_main_script(prog_name, expected):
 
    assert cliutil.is_main_script(prog_name) == expected
 

	
 
@pytest.mark.parametrize('arg,expected', [
 
    ('debug', logging.DEBUG),
 
    ('info', logging.INFO),
 
    ('warning', logging.WARNING),
 
    ('warn', logging.WARNING),
 
    ('error', logging.ERROR),
 
    ('err', logging.ERROR),
 
    ('critical', logging.CRITICAL),
 
    ('crit', logging.CRITICAL),
 
])
 
def test_loglevel_argument(argparser, arg, expected):
 
    for method in ['lower', 'title', 'upper']:
 
        args = argparser.parse_args(['--loglevel', getattr(arg, method)()])
 
        assert args.loglevel is expected
 

	
 
def test_setup_logger():
 
    stream = io.StringIO()
 
    logger = cliutil.setup_logger(
 
        'test_cliutil', logging.INFO, stream, '%(name)s %(levelname)s: %(message)s',
 
    )
 
    logger.debug("test debug")
 
    logger.info("test info")
 
    assert stream.getvalue() == "test_cliutil INFO: test info\n"
0 comments (0 inline, 0 general)