Changeset - 902c313b4dfb
[Not reviewed]
0 2 0
Brett Smith - 4 years ago 2021-01-09 15:09:08
brettcsmith@brettcsmith.org
cliutil: New function can_run.
2 files changed with 30 insertions and 0 deletions:
0 comments (0 inline, 0 general)
conservancy_beancount/cliutil.py
Show inline comments
 
"""cliutil - Utilities for CLI tools"""
 
PKGNAME = 'conservancy_beancount'
 
LICENSE = """
 
Copyright © 2020  Brett Smith and other contributors
 

	
 
This program is free software: you can redistribute it and/or modify it.
 
Refer to the LICENSE.txt that came with the software for details.
 

	
 
This program is distributed in the hope that it will be useful,
 
but WITHOUT ANY WARRANTY; without even the implied warranty of
 
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE."""
 

	
 
import argparse
 
import datetime
 
import enum
 
import io
 
import logging
 
import operator
 
import os
 
import pkg_resources
 
import re
 
import signal
 
import subprocess
 
import sys
 
import traceback
 
import types
 

	
 
from pathlib import Path
 

	
 
import rt.exceptions as rt_error
 

	
 
from . import data
 
from . import filters
 
from . import rtutil
 

	
 
from typing import (
 
    cast,
 
    Any,
 
    BinaryIO,
 
    Callable,
 
    Container,
 
    IO,
 
    Iterable,
 
    NamedTuple,
 
    NoReturn,
 
    Optional,
 
    Sequence,
 
    TextIO,
 
    Type,
 
    Union,
 
)
 
from .beancount_types import (
 
    MetaKey,
 
)
 

	
 
OutputFile = Union[int, IO]
 

	
 
CPU_COUNT = len(os.sched_getaffinity(0))
 
STDSTREAM_PATH = Path('-')
 
VERSION = pkg_resources.require(PKGNAME)[0].version
 

	
 
class ExceptHook:
 
    def __init__(self, logger: Optional[logging.Logger]=None) -> None:
 
        if logger is None:
 
            logger = logging.getLogger()
 
        self.logger = logger
 

	
 
    def __call__(self,
 
                 exc_type: Type[BaseException],
 
                 exc_value: BaseException,
 
                 exc_tb: types.TracebackType,
 
    ) -> NoReturn:
 
        error_type = type(exc_value).__name__
 
        msg = ": ".join(str(arg) for arg in exc_value.args)
 
        if isinstance(exc_value, KeyboardInterrupt):
 
            signal.signal(signal.SIGINT, signal.SIG_DFL)
 
            os.kill(0, signal.SIGINT)
 
            signal.pause()
 
        elif isinstance(exc_value, (
 
                rt_error.AuthorizationError,
 
                rt_error.NotAllowed,
 
        )):
 
            exitcode = os.EX_NOPERM
 
            error_type = "RT access denied"
 
        elif isinstance(exc_value, rt_error.ConnectionError):
 
            exitcode = os.EX_TEMPFAIL
 
            error_type = "RT connection error"
 
        elif isinstance(exc_value, rt_error.RtError):
 
            exitcode = os.EX_UNAVAILABLE
 
            error_type = f"RT {error_type}"
 
        elif isinstance(exc_value, OSError):
 
            if exc_value.filename is None:
 
                exitcode = os.EX_OSERR
 
                error_type = "OS error"
 
                msg = exc_value.strerror
 
            else:
 
                # There are more specific exit codes for input problems vs.
 
                # output problems, but without knowing how the file was
 
                # intended to be used, we can't use them.
 
                exitcode = os.EX_IOERR
 
                error_type = "I/O error"
 
                msg = f"{exc_value.filename}: {exc_value.strerror}"
 
        else:
 
            exitcode = os.EX_SOFTWARE
 
            error_type = f"internal {error_type}"
 
        self.logger.critical("%s%s%s", error_type, ": " if msg else "", msg)
 
        self.logger.debug(
 
            ''.join(traceback.format_exception(exc_type, exc_value, exc_tb)),
 
        )
 
        raise SystemExit(exitcode)
 

	
 

	
 
class ExitCode(enum.IntEnum):
 
    # BSD exit codes commonly used
 
    NoConfiguration = os.EX_CONFIG
 
    NoConfig = NoConfiguration
 
    NoDataFiltered = os.EX_DATAERR
 
    NoDataLoaded = os.EX_NOINPUT
 
    RewriteRulesError = os.EX_DATAERR
 

	
 
    # Our own exit codes, working down from that range
 
    BeancountErrors = 63
 

	
 

	
 
class InfoAction(argparse.Action):
 
    def __call__(self,
 
                 parser: argparse.ArgumentParser,
 
                 namespace: argparse.Namespace,
 
                 values: Union[Sequence[Any], str, None]=None,
 
                 option_string: Optional[str]=None,
 
    ) -> NoReturn:
 
        if isinstance(self.const, str):
 
            info = self.const
 
            exitcode = 0
 
        else:
 
            info, exitcode = self.const
 
        print(info)
 
        raise SystemExit(exitcode)
 

	
 

	
 
class LogLevel(enum.IntEnum):
 
    DEBUG = logging.DEBUG
 
    INFO = logging.INFO
 
    WARNING = logging.WARNING
 
    ERROR = logging.ERROR
 
    CRITICAL = logging.CRITICAL
 
    WARN = WARNING
 
    ERR = ERROR
 
    CRIT = CRITICAL
 

	
 
    @classmethod
 
    def from_arg(cls, arg: str) -> int:
 
        try:
 
            return cls[arg.upper()].value
 
        except KeyError:
 
            raise ValueError(f"unknown loglevel {arg!r}") from None
 

	
 
    @classmethod
 
    def choices(cls) -> Iterable[str]:
 
        for level in sorted(cls, key=operator.attrgetter('value')):
 
            yield level.name.lower()
 

	
 

	
 
class SearchTerm(NamedTuple):
 
    """NamedTuple representing a user's metadata filter
 

	
 
    SearchTerm knows how to parse and store posting metadata filters provided
 
    by the user in `key=value` format. Reporting tools can use this to filter
 
    postings that match the user's criteria, to report on subsets of the books.
 

	
 
    Typical usage looks like::
 

	
 
      argument_parser.add_argument(
 
        'search_terms',
 
        type=SearchTerm.arg_parser(),
 
        …,
 
      )
 

	
 
      args = argument_parser.parse_args(…)
 
      for query in args.search_terms:
 
        postings = query.filter_postings(postings)
 
    """
 
    meta_key: MetaKey
 
    pattern: str
 

	
 
    @classmethod
 
    def arg_parser(cls,
 
                   default_key: Optional[str]=None,
 
                   ticket_default_key: Optional[str]=None,
 
    ) -> Callable[[str], 'SearchTerm']:
 
        """Build a SearchTerm parser
 

	
 
        This method returns a function that can parse strings in ``key=value``
 
        format and return a corresponding SearchTerm.
 

	
 
        If you specify a default key, then strings that just specify a ``value``
 
        will be parsed as if they said ``default_key=value``. Otherwise,
 
        parsing strings without a metadata key will raise a ValueError.
 

	
 
        If you specify a default key ticket links, then values in the format
 
        ``number``, ``rt:number``, or ``rt://ticket/number`` will be parsed as
 
        if they said ``ticket_default_key=value``.
 
        """
 
        if ticket_default_key is None:
 
            ticket_default_key = default_key
 
        def parse_search_term(arg: str) -> 'SearchTerm':
 
            key: Optional[str] = None
 
            if re.match(r'^[a-z][-\w]*=', arg):
 
                key, _, raw_link = arg.partition('=')
 
            else:
 
                raw_link = arg
 
            rt_ids = rtutil.RT.parse(raw_link)
 
            if rt_ids is None:
 
                rt_ids = rtutil.RT.parse('rt:' + raw_link)
 
            if rt_ids is None:
 
                if key is None:
 
                    key = default_key
 
                pattern = r'(?:^|\s){}(?:\s|$)'.format(re.escape(raw_link))
 
            else:
 
                ticket_id, attachment_id = rt_ids
 
                if key is None:
 
                    if attachment_id is None:
 
                        key = ticket_default_key
 
                    else:
 
                        key = default_key
 
                pattern = rtutil.RT.metadata_regexp(
 
                    ticket_id,
 
                    attachment_id,
 
                    first_link_only=key == 'rt-id' and attachment_id is None,
 
                )
 
            if key is None:
 
                raise ValueError(f"invalid search term {arg!r}: no metadata key")
 
            return cls(key, pattern)
 
        return parse_search_term
 

	
 
    def filter_postings(self, postings: Iterable[data.Posting]) -> Iterable[data.Posting]:
 
        return filters.filter_meta_match(
 
            postings, self.meta_key, re.compile(self.pattern),
 
        )
 

	
 
def add_jobs_argument(parser: argparse.ArgumentParser) -> argparse.Action:
 
    return parser.add_argument(
 
        '--jobs', '-j',
 
        metavar='NUM',
 
        type=jobs_arg,
 
        default=CPU_COUNT,
 
        help="""Maximum number of processes to run concurrently.
 
Can specify a positive integer or a percentage of CPU cores. Default all cores.
 
""")
 

	
 
def add_loglevel_argument(parser: argparse.ArgumentParser,
 
                          default: LogLevel=LogLevel.INFO) -> argparse.Action:
 
    return parser.add_argument(
 
        '--loglevel',
 
        metavar='LEVEL',
 
        default=default.value,
 
        type=LogLevel.from_arg,
 
        help="Show logs at this level and above."
 
        f" Specify one of {', '.join(LogLevel.choices())}."
 
        f" Default {default.name.lower()}.",
 
    )
 

	
 
def add_rewrite_rules_argument(parser: argparse.ArgumentParser) -> argparse.Action:
 
    return parser.add_argument(
 
        '--rewrite-rules', '--rewrites', '-r',
 
        action='append',
 
        default=[],
 
        metavar='PATH',
 
        type=Path,
 
        help="""Use rewrite rules from the given YAML file. You can specify
 
this option multiple times to load multiple sets of rewrite rules in order.
 
""")
 

	
 
def add_version_argument(parser: argparse.ArgumentParser) -> argparse.Action:
 
    progname = parser.prog or sys.argv[0]
 
    return parser.add_argument(
 
        '--version', '--copyright', '--license',
 
        action=InfoAction,
 
        nargs=0,
 
        const=f"{progname} version {VERSION}\n{LICENSE}",
 
        help="Show program version and license information",
 
    )
 

	
 
def can_run(
 
        cmd: Sequence[str],
 
        stdout: Optional[int]=subprocess.DEVNULL,
 
        stderr: Optional[int]=None,
 
        ok_returncodes: Container[int]=frozenset([0]),
 
) -> bool:
 
    try:
 
        with subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=stdout, stderr=stderr) as proc:
 
            # Typing says this can be None, but I don't think that's true
 
            # given that we passed stdin=PIPE.
 
            proc.stdin.close()  # type:ignore[union-attr]
 
    except (OSError, subprocess.SubprocessError):
 
        return False
 
    else:
 
        return proc.returncode in ok_returncodes
 

	
 
def date_arg(arg: str) -> datetime.date:
 
    return datetime.datetime.strptime(arg, '%Y-%m-%d').date()
 

	
 
def diff_year(date: datetime.date, diff: int) -> datetime.date:
 
    new_year = date.year + diff
 
    try:
 
        return date.replace(year=new_year)
 
    except ValueError:
 
        # The original date is Feb 29, which doesn't exist in the new year.
 
        if diff < 0:
 
            return datetime.date(new_year, 2, 28)
 
        else:
 
            return datetime.date(new_year, 3, 1)
 

	
 
def year_or_date_arg(arg: str) -> Union[int, datetime.date]:
 
    """Get either a date or a year (int) from an argument string
 

	
 
    This is a useful argument type for arguments that will be passed into
 
    Books loader methods which can accept either a fiscal year or a full date.
 
    """
 
    try:
 
        year = int(arg, 10)
 
    except ValueError:
 
        ok = False
 
    else:
 
        ok = datetime.MINYEAR <= year <= datetime.MAXYEAR
 
    if ok:
 
        return year
 
    else:
 
        return date_arg(arg)
 

	
 
def jobs_arg(arg: str) -> int:
 
    if arg.endswith('%'):
 
        arg_n = round(CPU_COUNT * 100 / int(arg[:-1]))
 
    else:
 
        arg_n = int(arg)
 
    if arg_n < 1:
 
        raise ValueError("--jobs argument must be a positive integer or percentage")
 
    else:
 
        return arg_n
 

	
 
def make_entry_point(mod_name: str, prog_name: str=sys.argv[0]) -> Callable[[], int]:
 
    """Create an entry_point function for a tool
 

	
 
    The returned function is suitable for use as an entry_point in setup.py.
 
    It sets up the root logger and excepthook, then calls the module's main
 
    function.
 
    """
 
    def entry_point():  # type:ignore
 
        prog_mod = sys.modules[mod_name]
 
        setup_logger()
 
        prog_mod.logger = logging.getLogger(prog_name)
 
        sys.excepthook = ExceptHook(prog_mod.logger)
 
        return prog_mod.main()
 
    return entry_point
 

	
 
def setup_logger(logger: Union[str, logging.Logger]='',
 
                 stream: TextIO=sys.stderr,
 
                 fmt: str='%(name)s: %(levelname)s: %(message)s',
 
) -> logging.Logger:
 
    """Set up a logger with a StreamHandler with the given format"""
 
    if isinstance(logger, str):
 
        logger = logging.getLogger(logger)
 
    formatter = logging.Formatter(fmt)
 
    handler = logging.StreamHandler(stream)
 
    handler.setFormatter(formatter)
 
    logger.addHandler(handler)
 
    return logger
 

	
 
def set_loglevel(logger: logging.Logger, loglevel: int=logging.INFO) -> None:
 
    """Set the loglevel for a tool or module
 

	
 
    If the given logger is not under a hierarchy, this function sets the
 
    loglevel for the root logger, along with some specific levels for libraries
 
    used by reporting tools. Otherwise, it's the same as
 
    ``logger.setLevel(loglevel)``.
 
    """
 
    if '.' not in logger.name:
 
        logger = logging.getLogger()
 
        if loglevel <= logging.DEBUG:
 
            # At the debug level, the rt module logs the full body of every
 
            # request and response. That's too much.
 
            logging.getLogger('rt.rt').setLevel(logging.INFO)
 
    logger.setLevel(loglevel)
 

	
 
def bytes_output(path: Optional[Path]=None,
 
                 default: OutputFile=sys.stdout,
 
                 mode: str='w',
 
) -> BinaryIO:
 
    """Get a file-like object suitable for binary output
 

	
 
    If ``path`` is ``None`` or ``-``, returns a file-like object backed by
 
    ``default``. If ``default`` is a file descriptor or text IO object, this
 
    method returns a file-like object that writes to the same place.
 

	
 
    Otherwise, returns ``path.open(mode)``.
 
    """
 
    mode = f'{mode}b'
 
    if path is None or path == STDSTREAM_PATH:
 
        if isinstance(default, int):
 
            retval = open(default, mode)
 
        elif isinstance(default, TextIO):
 
            retval = default.buffer
 
        else:
 
            retval = default
 
    else:
 
        retval = path.open(mode)
 
    return cast(BinaryIO, retval)
 

	
 
def text_output(path: Optional[Path]=None,
 
                default: OutputFile=sys.stdout,
 
                mode: str='w',
 
                encoding: Optional[str]=None,
 
) -> TextIO:
 
    """Get a file-like object suitable for text output
 

	
 
    If ``path`` is ``None`` or ``-``, returns a file-like object backed by
 
    ``default``. If ``default`` is a file descriptor or binary IO object, this
 
    method returns a file-like object that writes to the same place.
 

	
 
    Otherwise, returns ``path.open(mode)``.
 
    """
 
    if path is None or path == STDSTREAM_PATH:
 
        if isinstance(default, int):
 
            retval = open(default, mode, encoding=encoding)
 
        elif isinstance(default, BinaryIO):
 
            retval = io.TextIOWrapper(default, encoding=encoding)
 
        else:
 
            retval = default
 
    else:
 
        retval = path.open(mode, encoding=encoding)
 
    return cast(TextIO, retval)
tests/test_cliutil.py
Show inline comments
 
"""Test CLI utilities"""
 
# Copyright © 2020  Brett Smith
 
# License: AGPLv3-or-later WITH Beancount-Plugin-Additional-Permission-1.0
 
#
 
# Full copyright and licensing details can be found at toplevel file
 
# LICENSE.txt in the repository.
 

	
 
import argparse
 
import datetime
 
import errno
 
import io
 
import inspect
 
import logging
 
import os
 
import re
 
import sys
 
import traceback
 

	
 
import pytest
 

	
 
from pathlib import Path
 

	
 
from . import testutil
 

	
 
from conservancy_beancount import cliutil
 

	
 
FILE_NAMES = ['-foobar', '-foo.bin']
 
STREAM_PATHS = [None, Path('-')]
 

	
 
class MockTraceback:
 
    def __init__(self, stack=None, index=0):
 
        if stack is None:
 
            stack = inspect.stack(context=False)
 
        self._stack = stack
 
        self._index = index
 
        frame_record = self._stack[self._index]
 
        self.tb_frame = frame_record.frame
 
        self.tb_lineno = frame_record.lineno
 

	
 
    @property
 
    def tb_next(self):
 
        try:
 
            return type(self)(self._stack, self._index + 1)
 
        except IndexError:
 
            return None
 

	
 

	
 
@pytest.fixture(scope='module')
 
def argparser():
 
    parser = argparse.ArgumentParser(prog='test_cliutil')
 
    cliutil.add_loglevel_argument(parser)
 
    cliutil.add_version_argument(parser)
 
    return parser
 

	
 
@pytest.mark.parametrize('path_name', FILE_NAMES)
 
def test_bytes_output_path(path_name, tmp_path):
 
    path = tmp_path / path_name
 
    stream = io.BytesIO()
 
    actual = cliutil.bytes_output(path, stream)
 
    assert actual is not stream
 
    assert str(actual.name) == str(path)
 
    assert 'w' in actual.mode
 
    assert 'b' in actual.mode
 

	
 
@pytest.mark.parametrize('path', STREAM_PATHS)
 
def test_bytes_output_stream(path):
 
    stream = io.BytesIO()
 
    actual = cliutil.bytes_output(path, stream)
 
    assert actual is stream
 

	
 
@pytest.mark.parametrize('year,month,day', [
 
    (2000, 1, 1),
 
    (2016, 2, 29),
 
    (2020, 12, 31),
 
])
 
def test_date_arg_valid(year, month, day):
 
    expected = datetime.date(year, month, day)
 
    assert cliutil.date_arg(expected.isoformat()) == expected
 

	
 
@pytest.mark.parametrize('arg', [
 
    '2000',
 
    '20-02-12',
 
    '2019-02-29',
 
    'two thousand',
 
])
 
def test_date_arg_invalid(arg):
 
    with pytest.raises(ValueError):
 
        cliutil.date_arg(arg)
 

	
 
@pytest.mark.parametrize('year', [
 
    1990,
 
    2000,
 
    2009,
 
])
 
def test_year_or_date_arg_year(year):
 
    assert cliutil.year_or_date_arg(str(year)) == year
 

	
 
@pytest.mark.parametrize('year,month,day', [
 
    (2000, 1, 1),
 
    (2016, 2, 29),
 
    (2020, 12, 31),
 
])
 
def test_year_or_date_arg_date(year, month, day):
 
    expected = datetime.date(year, month, day)
 
    assert cliutil.year_or_date_arg(expected.isoformat()) == expected
 

	
 
@pytest.mark.parametrize('arg', [
 
    '-1',
 
    str(sys.maxsize),
 
    'MMDVIII',
 
    '2019-02-29',
 
])
 
def test_year_or_date_arg_invalid(arg):
 
    with pytest.raises(ValueError):
 
        cliutil.year_or_date_arg(arg)
 

	
 
@pytest.mark.parametrize('func_name', [
 
    'bytes_output',
 
    'text_output',
 
])
 
def test_default_output(func_name):
 
    actual = getattr(cliutil, func_name)()
 
    assert actual.fileno() == sys.stdout.fileno()
 

	
 
@pytest.mark.parametrize('path_name', FILE_NAMES)
 
def test_text_output_path(path_name, tmp_path):
 
    path = tmp_path / path_name
 
    stream = io.StringIO()
 
    actual = cliutil.text_output(path, stream)
 
    assert actual is not stream
 
    assert str(actual.name) == str(path)
 
    assert 'w' in actual.mode
 

	
 
@pytest.mark.parametrize('path', STREAM_PATHS)
 
def test_text_output_stream(path):
 
    stream = io.StringIO()
 
    actual = cliutil.text_output(path, stream)
 
    assert actual is stream
 

	
 
@pytest.mark.parametrize('errnum', [
 
    errno.EACCES,
 
    errno.EPERM,
 
    errno.ENOENT,
 
])
 
def test_excepthook_oserror(errnum, caplog):
 
    error = OSError(errnum, os.strerror(errnum), 'TestFilename')
 
    with pytest.raises(SystemExit) as exc_check:
 
        cliutil.ExceptHook()(type(error), error, None)
 
    assert exc_check.value.args[0] == os.EX_IOERR
 
    assert caplog.records
 
    for log in caplog.records:
 
        assert log.levelname == 'CRITICAL'
 
        assert log.message == f"I/O error: {error.filename}: {error.strerror}"
 

	
 
@pytest.mark.parametrize('exc_type', [
 
    AttributeError,
 
    RuntimeError,
 
    ValueError,
 
])
 
def test_excepthook_bug(exc_type, caplog):
 
    error = exc_type("test message")
 
    with pytest.raises(SystemExit) as exc_check:
 
        cliutil.ExceptHook()(exc_type, error, None)
 
    assert exc_check.value.args[0] == os.EX_SOFTWARE
 
    assert caplog.records
 
    for log in caplog.records:
 
        assert log.levelname == 'CRITICAL'
 
        assert log.message == f"internal {exc_type.__name__}: {error.args[0]}"
 

	
 
def test_excepthook_traceback(caplog):
 
    error = KeyError('test')
 
    args = (type(error), error, MockTraceback())
 
    caplog.set_level(logging.DEBUG)
 
    with pytest.raises(SystemExit) as exc_check:
 
        cliutil.ExceptHook()(*args)
 
    assert caplog.records
 
    assert caplog.records[-1].message == ''.join(traceback.format_exception(*args))
 

	
 
@pytest.mark.parametrize('arg,expected', [
 
    ('debug', logging.DEBUG),
 
    ('info', logging.INFO),
 
    ('warning', logging.WARNING),
 
    ('warn', logging.WARNING),
 
    ('error', logging.ERROR),
 
    ('err', logging.ERROR),
 
    ('critical', logging.CRITICAL),
 
    ('crit', logging.CRITICAL),
 
])
 
def test_loglevel_argument(argparser, arg, expected):
 
    for method in ['lower', 'title', 'upper']:
 
        args = argparser.parse_args(['--loglevel', getattr(arg, method)()])
 
        assert args.loglevel is expected
 

	
 
def test_setup_logger():
 
    stream = io.StringIO()
 
    logger = cliutil.setup_logger(
 
        'test_cliutil', stream, '%(name)s %(levelname)s: %(message)s',
 
    )
 
    logger.critical("test crit")
 
    assert stream.getvalue() == "test_cliutil CRITICAL: test crit\n"
 

	
 
@pytest.mark.parametrize('arg', [
 
    '--license',
 
    '--version',
 
    '--copyright',
 
])
 
def test_version_argument(argparser, capsys, arg):
 
    with pytest.raises(SystemExit) as exc_check:
 
        args = argparser.parse_args(['--version'])
 
    assert exc_check.value.args[0] == 0
 
    stdout, _ = capsys.readouterr()
 
    lines = iter(stdout.splitlines())
 
    assert re.match(r'^test_cliutil version \d+\.\d+\.\d+', next(lines, "<EOF>"))
 

	
 
@pytest.mark.parametrize('date,diff,expected', [
 
    (datetime.date(2010, 2, 28), 0, datetime.date(2010, 2, 28)),
 
    (datetime.date(2010, 2, 28), 1, datetime.date(2011, 2, 28)),
 
    (datetime.date(2010, 2, 28), 2, datetime.date(2012, 2, 28)),
 
    (datetime.date(2010, 2, 28), -1, datetime.date(2009, 2, 28)),
 
    (datetime.date(2010, 2, 28), -2, datetime.date(2008, 2, 28)),
 
    (datetime.date(2012, 2, 29), 2, datetime.date(2014, 3, 1)),
 
    (datetime.date(2012, 2, 29), 4, datetime.date(2016, 2, 29)),
 
    (datetime.date(2012, 2, 29), -2, datetime.date(2010, 2, 28)),
 
    (datetime.date(2012, 2, 29), -4, datetime.date(2008, 2, 29)),
 
    (datetime.date(2010, 3, 1), 1, datetime.date(2011, 3, 1)),
 
    (datetime.date(2010, 3, 1), 2, datetime.date(2012, 3, 1)),
 
    (datetime.date(2010, 3, 1), -1, datetime.date(2009, 3, 1)),
 
    (datetime.date(2010, 3, 1), -2, datetime.date(2008, 3, 1)),
 
])
 
def test_diff_year(date, diff, expected):
 
    assert cliutil.diff_year(date, diff) == expected
 

	
 
@pytest.mark.parametrize('cmd,expected', [
 
    (['true'], True),
 
    (['true', '--version'], True),
 
    (['false'], False),
 
    (['false', '--version'], False),
 
    ([str(testutil.TESTS_DIR)], False),
 
])
 
def test_can_run(cmd, expected):
 
    assert cliutil.can_run(cmd) == expected
0 comments (0 inline, 0 general)