Changeset - 4188dc6a64a8
[Not reviewed]
0 3 0
Brett Smith - 3 years ago 2021-02-17 19:00:06
brettcsmith@brettcsmith.org
cliutil: Add EnumArgument.

This functionality already existed in the code three times, and it's about
to get more important for the ledger report, so now was the time to abstract
it.
3 files changed with 104 insertions and 22 deletions:
0 comments (0 inline, 0 general)
conservancy_beancount/cliutil.py
Show inline comments
...
 
@@ -39,6 +39,8 @@ from typing import (
 
    BinaryIO,
 
    Callable,
 
    Container,
 
    Generic,
 
    Hashable,
 
    IO,
 
    Iterable,
 
    NamedTuple,
...
 
@@ -47,18 +49,75 @@ from typing import (
 
    Sequence,
 
    TextIO,
 
    Type,
 
    TypeVar,
 
    Union,
 
)
 
from .beancount_types import (
 
    MetaKey,
 
)
 

	
 
ET = TypeVar('ET', bound=enum.Enum)
 
OutputFile = Union[int, IO]
 

	
 
CPU_COUNT = len(os.sched_getaffinity(0))
 
STDSTREAM_PATH = Path('-')
 
VERSION = pkg_resources.require(PKGNAME)[0].version
 

	
 
class EnumArgument(Generic[ET]):
 
    """Wrapper class to use an enum as argument values
 

	
 
    Use this class when the user can choose one of some arbitrary enum names
 
    as an argument. It will let user abbreviate and use any case, and will
 
    return the correct value as long as it's unambiguous. Typical usage
 
    looks like::
 

	
 
        enum_arg = EnumArgument(Enum)
 
        arg_parser.add_argument(
 
          '--choice',
 
          type=enum_arg.enum_type,  # or .value_type
 
          help=f"Choices are {enum_arg.choices_str()}",
 
 
        )
 
    """
 
    # I originally wrote this as a mixin class, to eliminate the need for the
 
    # explicit wrapping in the example above. But Python 3.6 doesn't really
 
    # support mixins with Enums; see <https://bugs.python.org/issue29577>.
 
    # This functionality could be moved to a mixin when we drop support for
 
    # Python 3.6.
 

	
 
    def __init__(self, base: Type[ET]) -> None:
 
        self.base = base
 

	
 
    def enum_type(self, arg: str) -> ET:
 
        """Return a single enum whose name matches the user argument"""
 
        regexp = re.compile(re.escape(arg), re.IGNORECASE)
 
        matches = frozenset(choice for choice in self.base if regexp.match(choice.name))
 
        count = len(matches)
 
        if count == 1:
 
            return next(iter(matches))
 
        elif count:
 
            names = ', '.join(repr(choice.name) for choice in matches)
 
            raise ValueError(f"ambiguous argument {arg!r}: matches {names}")
 
        else:
 
            raise ValueError(f"unknown argument {arg!r}")
 

	
 
    def value_type(self, arg: str) -> Any:
 
        return self.enum_type(arg).value
 

	
 
    def choices_str(self, sep: str=', ', fmt: str='{!r}') -> str:
 
        """Return a user-formatted string of enum names"""
 
        sortkey: Callable[[ET], Hashable] = getattr(
 
            self.base, '_choices_sortkey', self._choices_sortkey,
 
        )
 
        return sep.join(
 
            fmt.format(choice.name.lower())
 
            for choice in sorted(self.base, key=sortkey)
 
        )
 

	
 
    def _choices_sortkey(self, choice: ET) -> Hashable:
 
        return choice.name
 

	
 

	
 
class ExceptHook:
 
    def __init__(self, logger: Optional[logging.Logger]=None) -> None:
 
        if logger is None:
...
 
@@ -148,17 +207,8 @@ class LogLevel(enum.IntEnum):
 
    ERR = ERROR
 
    CRIT = CRITICAL
 

	
 
    @classmethod
 
    def from_arg(cls, arg: str) -> int:
 
        try:
 
            return cls[arg.upper()].value
 
        except KeyError:
 
            raise ValueError(f"unknown loglevel {arg!r}") from None
 

	
 
    @classmethod
 
    def choices(cls) -> Iterable[str]:
 
        for level in sorted(cls, key=operator.attrgetter('value')):
 
            yield level.name.lower()
 
    def _choices_sortkey(self) -> Hashable:
 
        return self.value
 

	
 

	
 
class SearchTerm(NamedTuple):
...
 
@@ -250,14 +300,15 @@ Can specify a positive integer or a percentage of CPU cores. Default all cores.
 

	
 
def add_loglevel_argument(parser: argparse.ArgumentParser,
 
                          default: LogLevel=LogLevel.INFO) -> argparse.Action:
 
    arg_enum = EnumArgument(LogLevel)
 
    return parser.add_argument(
 
        '--loglevel',
 
        metavar='LEVEL',
 
        default=default.value,
 
        type=LogLevel.from_arg,
 
        type=arg_enum.value_type,
 
        help="Show logs at this level and above."
 
        f" Specify one of {', '.join(LogLevel.choices())}."
 
        f" Default {default.name.lower()}.",
 
        f" Specify one of {arg_enum.choices_str()}."
 
        f" Default {default.name.lower()!r}.",
 
    )
 

	
 
def add_rewrite_rules_argument(parser: argparse.ArgumentParser) -> argparse.Action:
conservancy_beancount/reports/fund.py
Show inline comments
...
 
@@ -306,13 +306,6 @@ class ReportType(enum.Enum):
 
    TXT = TEXT
 
    SPREADSHEET = ODS
 

	
 
    @classmethod
 
    def from_arg(cls, s: str) -> 'ReportType':
 
        try:
 
            return cls[s.upper()]
 
        except KeyError:
 
            raise ValueError(f"no report type matches {s!r}") from None
 

	
 

	
 
def parse_arguments(arglist: Optional[Sequence[str]]=None) -> argparse.Namespace:
 
    parser = argparse.ArgumentParser(prog=PROGNAME)
...
 
@@ -337,7 +330,7 @@ The default is a year after the start date.
 
    parser.add_argument(
 
        '--report-type', '-t',
 
        metavar='TYPE',
 
        type=ReportType.from_arg,
 
        type=cliutil.EnumArgument(ReportType).enum_type,
 
        help="""Type of report to generate. `text` gives a plain two-column text
 
report listing accounts and balances over the period, and is the default when
 
you search for a specific project/fund. `ods` produces a higher-level
tests/test_cliutil.py
Show inline comments
...
 
@@ -7,6 +7,7 @@
 

	
 
import argparse
 
import datetime
 
import enum
 
import errno
 
import io
 
import inspect
...
 
@@ -27,6 +28,12 @@ from conservancy_beancount import cliutil
 
FILE_NAMES = ['-foobar', '-foo.bin']
 
STREAM_PATHS = [None, Path('-')]
 

	
 
class ArgChoices(enum.Enum):
 
    AA = 'aa'
 
    AB = 'ab'
 
    BB = 'bb'
 

	
 

	
 
class MockTraceback:
 
    def __init__(self, stack=None, index=0):
 
        if stack is None:
...
 
@@ -45,6 +52,10 @@ class MockTraceback:
 
            return None
 

	
 

	
 
@pytest.fixture(scope='module')
 
def arg_enum():
 
    return cliutil.EnumArgument(ArgChoices)
 

	
 
@pytest.fixture(scope='module')
 
def argparser():
 
    parser = argparse.ArgumentParser(prog='test_cliutil')
...
 
@@ -239,3 +250,30 @@ def test_diff_year(date, diff, expected):
 
])
 
def test_can_run(cmd, expected):
 
    assert cliutil.can_run(cmd) == expected
 

	
 
@pytest.mark.parametrize('choice', ArgChoices)
 
def test_enum_arg_enum_type(arg_enum, choice):
 
    assert arg_enum.enum_type(choice.name) is choice
 
    assert arg_enum.enum_type(choice.value) is choice
 

	
 
@pytest.mark.parametrize('arg', 'az\0')
 
def test_enum_arg_no_enum_match(arg_enum, arg):
 
    with pytest.raises(ValueError):
 
        arg_enum.enum_type(arg)
 

	
 
@pytest.mark.parametrize('choice', ArgChoices)
 
def test_enum_arg_value_type(arg_enum, choice):
 
    assert arg_enum.value_type(choice.name) == choice.value
 
    assert arg_enum.value_type(choice.value) == choice.value
 

	
 
@pytest.mark.parametrize('arg', 'az\0')
 
def test_enum_arg_no_value_match(arg_enum, arg):
 
    with pytest.raises(ValueError):
 
        arg_enum.value_type(arg)
 

	
 
def test_enum_arg_choices_str_defaults(arg_enum):
 
    assert arg_enum.choices_str() == ', '.join(repr(c.value) for c in ArgChoices)
 

	
 
def test_enum_arg_choices_str_args(arg_enum):
 
    sep = '/'
 
    assert arg_enum.choices_str(sep, '{}') == sep.join(c.value for c in ArgChoices)
0 comments (0 inline, 0 general)