Changeset - b038ec827ccc
[Not reviewed]
0 2 0
Brett Smith - 4 years ago 2020-06-25 12:43:28
brettcsmith@brettcsmith.org
cliutil: Add year_or_date_arg() function.
2 files changed with 45 insertions and 2 deletions:
0 comments (0 inline, 0 general)
conservancy_beancount/cliutil.py
Show inline comments
...
 
@@ -142,187 +142,204 @@ class SearchTerm(NamedTuple):
 
    Typical usage looks like::
 

	
 
      argument_parser.add_argument(
 
        'search_terms',
 
        type=SearchTerm.arg_parser(),
 
        …,
 
      )
 

	
 
      args = argument_parser.parse_args(…)
 
      for query in args.search_terms:
 
        postings = query.filter_postings(postings)
 
    """
 
    meta_key: MetaKey
 
    pattern: str
 

	
 
    @classmethod
 
    def arg_parser(cls,
 
                   default_key: Optional[str]=None,
 
                   ticket_default_key: Optional[str]=None,
 
    ) -> Callable[[str], 'SearchTerm']:
 
        """Build a SearchTerm parser
 

	
 
        This method returns a function that can parse strings in ``key=value``
 
        format and return a corresponding SearchTerm.
 

	
 
        If you specify a default key, then strings that just specify a ``value``
 
        will be parsed as if they said ``default_key=value``. Otherwise,
 
        parsing strings without a metadata key will raise a ValueError.
 

	
 
        If you specify a default key ticket links, then values in the format
 
        ``number``, ``rt:number``, or ``rt://ticket/number`` will be parsed as
 
        if they said ``ticket_default_key=value``.
 
        """
 
        if ticket_default_key is None:
 
            ticket_default_key = default_key
 
        def parse_search_term(arg: str) -> 'SearchTerm':
 
            key: Optional[str] = None
 
            if re.match(r'^[a-z][-\w]*=', arg):
 
                key, _, raw_link = arg.partition('=')
 
            else:
 
                raw_link = arg
 
            rt_ids = rtutil.RT.parse(raw_link)
 
            if rt_ids is None:
 
                rt_ids = rtutil.RT.parse('rt:' + raw_link)
 
            if rt_ids is None:
 
                if key is None:
 
                    key = default_key
 
                pattern = r'(?:^|\s){}(?:\s|$)'.format(re.escape(raw_link))
 
            else:
 
                ticket_id, attachment_id = rt_ids
 
                if key is None:
 
                    if attachment_id is None:
 
                        key = ticket_default_key
 
                    else:
 
                        key = default_key
 
                pattern = rtutil.RT.metadata_regexp(
 
                    ticket_id,
 
                    attachment_id,
 
                    first_link_only=key == 'rt-id' and attachment_id is None,
 
                )
 
            if key is None:
 
                raise ValueError(f"invalid search term {arg!r}: no metadata key")
 
            return cls(key, pattern)
 
        return parse_search_term
 

	
 
    def filter_postings(self, postings: Iterable[data.Posting]) -> Iterable[data.Posting]:
 
        return filters.filter_meta_match(
 
            postings, self.meta_key, re.compile(self.pattern),
 
        )
 

	
 

	
 
def add_loglevel_argument(parser: argparse.ArgumentParser,
 
                          default: LogLevel=LogLevel.INFO) -> argparse.Action:
 
    return parser.add_argument(
 
        '--loglevel',
 
        metavar='LEVEL',
 
        default=default.value,
 
        type=LogLevel.from_arg,
 
        help="Show logs at this level and above."
 
        f" Specify one of {', '.join(LogLevel.choices())}."
 
        f" Default {default.name.lower()}.",
 
    )
 

	
 
def add_version_argument(parser: argparse.ArgumentParser) -> argparse.Action:
 
    progname = parser.prog or sys.argv[0]
 
    return parser.add_argument(
 
        '--version', '--copyright', '--license',
 
        action=InfoAction,
 
        nargs=0,
 
        const=f"{progname} version {VERSION}\n{LICENSE}",
 
        help="Show program version and license information",
 
    )
 

	
 
def date_arg(arg: str) -> datetime.date:
 
    return datetime.datetime.strptime(arg, '%Y-%m-%d').date()
 

	
 
def year_or_date_arg(arg: str) -> Union[int, datetime.date]:
 
    """Get either a date or a year (int) from an argument string
 

	
 
    This is a useful argument type for arguments that will be passed into
 
    Books loader methods which can accept either a fiscal year or a full date.
 
    """
 
    try:
 
        year = int(arg, 10)
 
    except ValueError:
 
        ok = False
 
    else:
 
        ok = datetime.MINYEAR <= year <= datetime.MAXYEAR
 
    if ok:
 
        return year
 
    else:
 
        return date_arg(arg)
 

	
 
def make_entry_point(mod_name: str, prog_name: str=sys.argv[0]) -> Callable[[], int]:
 
    """Create an entry_point function for a tool
 

	
 
    The returned function is suitable for use as an entry_point in setup.py.
 
    It sets up the root logger and excepthook, then calls the module's main
 
    function.
 
    """
 
    def entry_point():  # type:ignore
 
        prog_mod = sys.modules[mod_name]
 
        setup_logger()
 
        prog_mod.logger = logging.getLogger(prog_name)
 
        sys.excepthook = ExceptHook(prog_mod.logger)
 
        return prog_mod.main()
 
    return entry_point
 

	
 
def setup_logger(logger: Union[str, logging.Logger]='',
 
                 stream: TextIO=sys.stderr,
 
                 fmt: str='%(name)s: %(levelname)s: %(message)s',
 
) -> logging.Logger:
 
    """Set up a logger with a StreamHandler with the given format"""
 
    if isinstance(logger, str):
 
        logger = logging.getLogger(logger)
 
    formatter = logging.Formatter(fmt)
 
    handler = logging.StreamHandler(stream)
 
    handler.setFormatter(formatter)
 
    logger.addHandler(handler)
 
    return logger
 

	
 
def set_loglevel(logger: logging.Logger, loglevel: int=logging.INFO) -> None:
 
    """Set the loglevel for a tool or module
 

	
 
    If the given logger is not under a hierarchy, this function sets the
 
    loglevel for the root logger, along with some specific levels for libraries
 
    used by reporting tools. Otherwise, it's the same as
 
    ``logger.setLevel(loglevel)``.
 
    """
 
    if '.' not in logger.name:
 
        logger = logging.getLogger()
 
        if loglevel <= logging.DEBUG:
 
            # At the debug level, the rt module logs the full body of every
 
            # request and response. That's too much.
 
            logging.getLogger('rt.rt').setLevel(logging.INFO)
 
    logger.setLevel(loglevel)
 

	
 
def bytes_output(path: Optional[Path]=None,
 
                 default: OutputFile=sys.stdout,
 
                 mode: str='w',
 
) -> BinaryIO:
 
    """Get a file-like object suitable for binary output
 

	
 
    If ``path`` is ``None`` or ``-``, returns a file-like object backed by
 
    ``default``. If ``default`` is a file descriptor or text IO object, this
 
    method returns a file-like object that writes to the same place.
 

	
 
    Otherwise, returns ``path.open(mode)``.
 
    """
 
    mode = f'{mode}b'
 
    if path is None or path == STDSTREAM_PATH:
 
        if isinstance(default, int):
 
            retval = open(default, mode)
 
        elif isinstance(default, TextIO):
 
            retval = default.buffer
 
        else:
 
            retval = default
 
    else:
 
        retval = path.open(mode)
 
    return cast(BinaryIO, retval)
 

	
 
def text_output(path: Optional[Path]=None,
 
                default: OutputFile=sys.stdout,
 
                mode: str='w',
 
                encoding: Optional[str]=None,
 
) -> TextIO:
 
    """Get a file-like object suitable for text output
 

	
 
    If ``path`` is ``None`` or ``-``, returns a file-like object backed by
 
    ``default``. If ``default`` is a file descriptor or binary IO object, this
 
    method returns a file-like object that writes to the same place.
 

	
 
    Otherwise, returns ``path.open(mode)``.
 
    """
 
    if path is None or path == STDSTREAM_PATH:
 
        if isinstance(default, int):
 
            retval = open(default, mode, encoding=encoding)
 
        elif isinstance(default, BinaryIO):
 
            retval = io.TextIOWrapper(default, encoding=encoding)
 
        else:
 
            retval = default
 
    else:
 
        retval = path.open(mode, encoding=encoding)
 
    return cast(TextIO, retval)
tests/test_cliutil.py
Show inline comments
 
"""Test CLI utilities"""
 
# Copyright © 2020  Brett Smith
 
#
 
# This program is free software: you can redistribute it and/or modify
 
# it under the terms of the GNU Affero General Public License as published by
 
# the Free Software Foundation, either version 3 of the License, or
 
# (at your option) any later version.
 
#
 
# This program is distributed in the hope that it will be useful,
 
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
# GNU Affero General Public License for more details.
 
#
 
# You should have received a copy of the GNU Affero General Public License
 
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
 

	
 
import argparse
 
import datetime
 
import errno
 
import io
 
import inspect
 
import logging
 
import os
 
import re
 
import sys
 
import traceback
 

	
 
import pytest
 

	
 
from pathlib import Path
 

	
 
from conservancy_beancount import cliutil
 

	
 
FILE_NAMES = ['-foobar', '-foo.bin']
 
STREAM_PATHS = [None, Path('-')]
 

	
 
class MockTraceback:
 
    def __init__(self, stack=None, index=0):
 
        if stack is None:
 
            stack = inspect.stack(context=False)
 
        self._stack = stack
 
        self._index = index
 
        frame_record = self._stack[self._index]
 
        self.tb_frame = frame_record.frame
 
        self.tb_lineno = frame_record.lineno
 

	
 
    @property
 
    def tb_next(self):
 
        try:
 
            return type(self)(self._stack, self._index + 1)
 
        except IndexError:
 
            return None
 

	
 

	
 
@pytest.fixture(scope='module')
 
def argparser():
 
    parser = argparse.ArgumentParser(prog='test_cliutil')
 
    cliutil.add_loglevel_argument(parser)
 
    cliutil.add_version_argument(parser)
 
    return parser
 

	
 
@pytest.mark.parametrize('path_name', FILE_NAMES)
 
def test_bytes_output_path(path_name, tmp_path):
 
    path = tmp_path / path_name
 
    stream = io.BytesIO()
 
    actual = cliutil.bytes_output(path, stream)
 
    assert actual is not stream
 
    assert str(actual.name) == str(path)
 
    assert 'w' in actual.mode
 
    assert 'b' in actual.mode
 

	
 
@pytest.mark.parametrize('path', STREAM_PATHS)
 
def test_bytes_output_stream(path):
 
    stream = io.BytesIO()
 
    actual = cliutil.bytes_output(path, stream)
 
    assert actual is stream
 

	
 
@pytest.mark.parametrize('year,month,day', [
 
    (2000, 1, 1),
 
    (2016, 2, 29),
 
    (2020, 12, 31),
 
])
 
def test_date_arg_valid(year, month, day):
 
    arg = f'{year}-{month}-{day}'
 
    expected = datetime.date(year, month, day)
 
    assert cliutil.date_arg(arg) == expected
 
    assert cliutil.date_arg(expected.isoformat()) == expected
 

	
 
@pytest.mark.parametrize('arg', [
 
    '2000',
 
    '20-02-12',
 
    '2019-02-29',
 
    'two thousand',
 
])
 
def test_date_arg_invalid(arg):
 
    with pytest.raises(ValueError):
 
        cliutil.date_arg(arg)
 

	
 
@pytest.mark.parametrize('year', [
 
    1990,
 
    2000,
 
    2009,
 
])
 
def test_year_or_date_arg_year(year):
 
    assert cliutil.year_or_date_arg(str(year)) == year
 

	
 
@pytest.mark.parametrize('year,month,day', [
 
    (2000, 1, 1),
 
    (2016, 2, 29),
 
    (2020, 12, 31),
 
])
 
def test_year_or_date_arg_date(year, month, day):
 
    expected = datetime.date(year, month, day)
 
    assert cliutil.year_or_date_arg(expected.isoformat()) == expected
 

	
 
@pytest.mark.parametrize('arg', [
 
    '-1',
 
    str(sys.maxsize),
 
    'MMDVIII',
 
    '2019-02-29',
 
])
 
def test_year_or_date_arg_invalid(arg):
 
    with pytest.raises(ValueError):
 
        cliutil.year_or_date_arg(arg)
 

	
 
@pytest.mark.parametrize('func_name', [
 
    'bytes_output',
 
    'text_output',
 
])
 
def test_default_output(func_name):
 
    actual = getattr(cliutil, func_name)()
 
    assert actual.fileno() == sys.stdout.fileno()
 

	
 
@pytest.mark.parametrize('path_name', FILE_NAMES)
 
def test_text_output_path(path_name, tmp_path):
 
    path = tmp_path / path_name
 
    stream = io.StringIO()
 
    actual = cliutil.text_output(path, stream)
 
    assert actual is not stream
 
    assert str(actual.name) == str(path)
 
    assert 'w' in actual.mode
 

	
 
@pytest.mark.parametrize('path', STREAM_PATHS)
 
def test_text_output_stream(path):
 
    stream = io.StringIO()
 
    actual = cliutil.text_output(path, stream)
 
    assert actual is stream
 

	
 
@pytest.mark.parametrize('errnum', [
 
    errno.EACCES,
 
    errno.EPERM,
 
    errno.ENOENT,
 
])
 
def test_excepthook_oserror(errnum, caplog):
 
    error = OSError(errnum, os.strerror(errnum), 'TestFilename')
 
    with pytest.raises(SystemExit) as exc_check:
 
        cliutil.ExceptHook()(type(error), error, None)
 
    assert exc_check.value.args[0] == 4
 
    assert caplog.records
 
    for log in caplog.records:
 
        assert log.levelname == 'CRITICAL'
 
        assert log.message == f"I/O error: {error.filename}: {error.strerror}"
 

	
 
@pytest.mark.parametrize('exc_type', [
 
    AttributeError,
 
    RuntimeError,
 
    ValueError,
 
])
 
def test_excepthook_bug(exc_type, caplog):
 
    error = exc_type("test message")
 
    with pytest.raises(SystemExit) as exc_check:
 
        cliutil.ExceptHook()(exc_type, error, None)
 
    assert exc_check.value.args[0] == 3
 
    assert caplog.records
 
    for log in caplog.records:
 
        assert log.levelname == 'CRITICAL'
 
        assert log.message == f"internal {exc_type.__name__}: {error.args[0]}"
 

	
 
def test_excepthook_traceback(caplog):
 
    error = KeyError('test')
 
    args = (type(error), error, MockTraceback())
 
    caplog.set_level(logging.DEBUG)
 
    with pytest.raises(SystemExit) as exc_check:
 
        cliutil.ExceptHook()(*args)
 
    assert caplog.records
 
    assert caplog.records[-1].message == ''.join(traceback.format_exception(*args))
 

	
 
@pytest.mark.parametrize('arg,expected', [
 
    ('debug', logging.DEBUG),
 
    ('info', logging.INFO),
 
    ('warning', logging.WARNING),
 
    ('warn', logging.WARNING),
 
    ('error', logging.ERROR),
 
    ('err', logging.ERROR),
 
    ('critical', logging.CRITICAL),
 
    ('crit', logging.CRITICAL),
 
])
 
def test_loglevel_argument(argparser, arg, expected):
 
    for method in ['lower', 'title', 'upper']:
 
        args = argparser.parse_args(['--loglevel', getattr(arg, method)()])
 
        assert args.loglevel is expected
 

	
 
def test_setup_logger():
 
    stream = io.StringIO()
 
    logger = cliutil.setup_logger(
 
        'test_cliutil', stream, '%(name)s %(levelname)s: %(message)s',
 
    )
 
    logger.critical("test crit")
 
    assert stream.getvalue() == "test_cliutil CRITICAL: test crit\n"
 

	
 
@pytest.mark.parametrize('arg', [
 
    '--license',
 
    '--version',
 
    '--copyright',
 
])
 
def test_version_argument(argparser, capsys, arg):
 
    with pytest.raises(SystemExit) as exc_check:
 
        args = argparser.parse_args(['--version'])
 
    assert exc_check.value.args[0] == 0
 
    stdout, _ = capsys.readouterr()
 
    lines = iter(stdout.splitlines())
0 comments (0 inline, 0 general)