Changeset - b038ec827ccc
[Not reviewed]
0 2 0
Brett Smith - 4 years ago 2020-06-25 12:43:28
brettcsmith@brettcsmith.org
cliutil: Add year_or_date_arg() function.
2 files changed with 45 insertions and 2 deletions:
0 comments (0 inline, 0 general)
conservancy_beancount/cliutil.py
Show inline comments
...
 
@@ -190,96 +190,113 @@ class SearchTerm(NamedTuple):
 
            else:
 
                ticket_id, attachment_id = rt_ids
 
                if key is None:
 
                    if attachment_id is None:
 
                        key = ticket_default_key
 
                    else:
 
                        key = default_key
 
                pattern = rtutil.RT.metadata_regexp(
 
                    ticket_id,
 
                    attachment_id,
 
                    first_link_only=key == 'rt-id' and attachment_id is None,
 
                )
 
            if key is None:
 
                raise ValueError(f"invalid search term {arg!r}: no metadata key")
 
            return cls(key, pattern)
 
        return parse_search_term
 

	
 
    def filter_postings(self, postings: Iterable[data.Posting]) -> Iterable[data.Posting]:
 
        return filters.filter_meta_match(
 
            postings, self.meta_key, re.compile(self.pattern),
 
        )
 

	
 

	
 
def add_loglevel_argument(parser: argparse.ArgumentParser,
 
                          default: LogLevel=LogLevel.INFO) -> argparse.Action:
 
    return parser.add_argument(
 
        '--loglevel',
 
        metavar='LEVEL',
 
        default=default.value,
 
        type=LogLevel.from_arg,
 
        help="Show logs at this level and above."
 
        f" Specify one of {', '.join(LogLevel.choices())}."
 
        f" Default {default.name.lower()}.",
 
    )
 

	
 
def add_version_argument(parser: argparse.ArgumentParser) -> argparse.Action:
 
    progname = parser.prog or sys.argv[0]
 
    return parser.add_argument(
 
        '--version', '--copyright', '--license',
 
        action=InfoAction,
 
        nargs=0,
 
        const=f"{progname} version {VERSION}\n{LICENSE}",
 
        help="Show program version and license information",
 
    )
 

	
 
def date_arg(arg: str) -> datetime.date:
 
    return datetime.datetime.strptime(arg, '%Y-%m-%d').date()
 

	
 
def year_or_date_arg(arg: str) -> Union[int, datetime.date]:
 
    """Get either a date or a year (int) from an argument string
 

	
 
    This is a useful argument type for arguments that will be passed into
 
    Books loader methods which can accept either a fiscal year or a full date.
 
    """
 
    try:
 
        year = int(arg, 10)
 
    except ValueError:
 
        ok = False
 
    else:
 
        ok = datetime.MINYEAR <= year <= datetime.MAXYEAR
 
    if ok:
 
        return year
 
    else:
 
        return date_arg(arg)
 

	
 
def make_entry_point(mod_name: str, prog_name: str=sys.argv[0]) -> Callable[[], int]:
 
    """Create an entry_point function for a tool
 

	
 
    The returned function is suitable for use as an entry_point in setup.py.
 
    It sets up the root logger and excepthook, then calls the module's main
 
    function.
 
    """
 
    def entry_point():  # type:ignore
 
        prog_mod = sys.modules[mod_name]
 
        setup_logger()
 
        prog_mod.logger = logging.getLogger(prog_name)
 
        sys.excepthook = ExceptHook(prog_mod.logger)
 
        return prog_mod.main()
 
    return entry_point
 

	
 
def setup_logger(logger: Union[str, logging.Logger]='',
 
                 stream: TextIO=sys.stderr,
 
                 fmt: str='%(name)s: %(levelname)s: %(message)s',
 
) -> logging.Logger:
 
    """Set up a logger with a StreamHandler with the given format"""
 
    if isinstance(logger, str):
 
        logger = logging.getLogger(logger)
 
    formatter = logging.Formatter(fmt)
 
    handler = logging.StreamHandler(stream)
 
    handler.setFormatter(formatter)
 
    logger.addHandler(handler)
 
    return logger
 

	
 
def set_loglevel(logger: logging.Logger, loglevel: int=logging.INFO) -> None:
 
    """Set the loglevel for a tool or module
 

	
 
    If the given logger is not under a hierarchy, this function sets the
 
    loglevel for the root logger, along with some specific levels for libraries
 
    used by reporting tools. Otherwise, it's the same as
 
    ``logger.setLevel(loglevel)``.
 
    """
 
    if '.' not in logger.name:
 
        logger = logging.getLogger()
 
        if loglevel <= logging.DEBUG:
 
            # At the debug level, the rt module logs the full body of every
 
            # request and response. That's too much.
 
            logging.getLogger('rt.rt').setLevel(logging.INFO)
 
    logger.setLevel(loglevel)
 

	
 
def bytes_output(path: Optional[Path]=None,
 
                 default: OutputFile=sys.stdout,
 
                 mode: str='w',
 
) -> BinaryIO:
tests/test_cliutil.py
Show inline comments
...
 
@@ -36,110 +36,136 @@ STREAM_PATHS = [None, Path('-')]
 

	
 
class MockTraceback:
 
    def __init__(self, stack=None, index=0):
 
        if stack is None:
 
            stack = inspect.stack(context=False)
 
        self._stack = stack
 
        self._index = index
 
        frame_record = self._stack[self._index]
 
        self.tb_frame = frame_record.frame
 
        self.tb_lineno = frame_record.lineno
 

	
 
    @property
 
    def tb_next(self):
 
        try:
 
            return type(self)(self._stack, self._index + 1)
 
        except IndexError:
 
            return None
 

	
 

	
 
@pytest.fixture(scope='module')
 
def argparser():
 
    parser = argparse.ArgumentParser(prog='test_cliutil')
 
    cliutil.add_loglevel_argument(parser)
 
    cliutil.add_version_argument(parser)
 
    return parser
 

	
 
@pytest.mark.parametrize('path_name', FILE_NAMES)
 
def test_bytes_output_path(path_name, tmp_path):
 
    path = tmp_path / path_name
 
    stream = io.BytesIO()
 
    actual = cliutil.bytes_output(path, stream)
 
    assert actual is not stream
 
    assert str(actual.name) == str(path)
 
    assert 'w' in actual.mode
 
    assert 'b' in actual.mode
 

	
 
@pytest.mark.parametrize('path', STREAM_PATHS)
 
def test_bytes_output_stream(path):
 
    stream = io.BytesIO()
 
    actual = cliutil.bytes_output(path, stream)
 
    assert actual is stream
 

	
 
@pytest.mark.parametrize('year,month,day', [
 
    (2000, 1, 1),
 
    (2016, 2, 29),
 
    (2020, 12, 31),
 
])
 
def test_date_arg_valid(year, month, day):
 
    arg = f'{year}-{month}-{day}'
 
    expected = datetime.date(year, month, day)
 
    assert cliutil.date_arg(arg) == expected
 
    assert cliutil.date_arg(expected.isoformat()) == expected
 

	
 
@pytest.mark.parametrize('arg', [
 
    '2000',
 
    '20-02-12',
 
    '2019-02-29',
 
    'two thousand',
 
])
 
def test_date_arg_invalid(arg):
 
    with pytest.raises(ValueError):
 
        cliutil.date_arg(arg)
 

	
 
@pytest.mark.parametrize('year', [
 
    1990,
 
    2000,
 
    2009,
 
])
 
def test_year_or_date_arg_year(year):
 
    assert cliutil.year_or_date_arg(str(year)) == year
 

	
 
@pytest.mark.parametrize('year,month,day', [
 
    (2000, 1, 1),
 
    (2016, 2, 29),
 
    (2020, 12, 31),
 
])
 
def test_year_or_date_arg_date(year, month, day):
 
    expected = datetime.date(year, month, day)
 
    assert cliutil.year_or_date_arg(expected.isoformat()) == expected
 

	
 
@pytest.mark.parametrize('arg', [
 
    '-1',
 
    str(sys.maxsize),
 
    'MMDVIII',
 
    '2019-02-29',
 
])
 
def test_year_or_date_arg_invalid(arg):
 
    with pytest.raises(ValueError):
 
        cliutil.year_or_date_arg(arg)
 

	
 
@pytest.mark.parametrize('func_name', [
 
    'bytes_output',
 
    'text_output',
 
])
 
def test_default_output(func_name):
 
    actual = getattr(cliutil, func_name)()
 
    assert actual.fileno() == sys.stdout.fileno()
 

	
 
@pytest.mark.parametrize('path_name', FILE_NAMES)
 
def test_text_output_path(path_name, tmp_path):
 
    path = tmp_path / path_name
 
    stream = io.StringIO()
 
    actual = cliutil.text_output(path, stream)
 
    assert actual is not stream
 
    assert str(actual.name) == str(path)
 
    assert 'w' in actual.mode
 

	
 
@pytest.mark.parametrize('path', STREAM_PATHS)
 
def test_text_output_stream(path):
 
    stream = io.StringIO()
 
    actual = cliutil.text_output(path, stream)
 
    assert actual is stream
 

	
 
@pytest.mark.parametrize('errnum', [
 
    errno.EACCES,
 
    errno.EPERM,
 
    errno.ENOENT,
 
])
 
def test_excepthook_oserror(errnum, caplog):
 
    error = OSError(errnum, os.strerror(errnum), 'TestFilename')
 
    with pytest.raises(SystemExit) as exc_check:
 
        cliutil.ExceptHook()(type(error), error, None)
 
    assert exc_check.value.args[0] == 4
 
    assert caplog.records
 
    for log in caplog.records:
 
        assert log.levelname == 'CRITICAL'
 
        assert log.message == f"I/O error: {error.filename}: {error.strerror}"
 

	
 
@pytest.mark.parametrize('exc_type', [
 
    AttributeError,
 
    RuntimeError,
 
    ValueError,
 
])
 
def test_excepthook_bug(exc_type, caplog):
 
    error = exc_type("test message")
 
    with pytest.raises(SystemExit) as exc_check:
 
        cliutil.ExceptHook()(exc_type, error, None)
 
    assert exc_check.value.args[0] == 3
0 comments (0 inline, 0 general)