diff --git a/conservancy_beancount/cliutil.py b/conservancy_beancount/cliutil.py index 40d8ab9c8a8cd76b3c16e16d3f96892819dc3680..358c56876bc16a121012ed7365d8d9720c601ec6 100644 --- a/conservancy_beancount/cliutil.py +++ b/conservancy_beancount/cliutil.py @@ -39,6 +39,8 @@ from typing import ( BinaryIO, Callable, Container, + Generic, + Hashable, IO, Iterable, NamedTuple, @@ -47,18 +49,75 @@ from typing import ( Sequence, TextIO, Type, + TypeVar, Union, ) from .beancount_types import ( MetaKey, ) +ET = TypeVar('ET', bound=enum.Enum) OutputFile = Union[int, IO] CPU_COUNT = len(os.sched_getaffinity(0)) STDSTREAM_PATH = Path('-') VERSION = pkg_resources.require(PKGNAME)[0].version +class EnumArgument(Generic[ET]): + """Wrapper class to use an enum as argument values + + Use this class when the user can choose one of some arbitrary enum names + as an argument. It will let user abbreviate and use any case, and will + return the correct value as long as it's unambiguous. Typical usage + looks like:: + + enum_arg = EnumArgument(Enum) + arg_parser.add_argument( + '--choice', + type=enum_arg.enum_type, # or .value_type + help=f"Choices are {enum_arg.choices_str()}", + … + ) + """ + # I originally wrote this as a mixin class, to eliminate the need for the + # explicit wrapping in the example above. But Python 3.6 doesn't really + # support mixins with Enums; see . + # This functionality could be moved to a mixin when we drop support for + # Python 3.6. + + def __init__(self, base: Type[ET]) -> None: + self.base = base + + def enum_type(self, arg: str) -> ET: + """Return a single enum whose name matches the user argument""" + regexp = re.compile(re.escape(arg), re.IGNORECASE) + matches = frozenset(choice for choice in self.base if regexp.match(choice.name)) + count = len(matches) + if count == 1: + return next(iter(matches)) + elif count: + names = ', '.join(repr(choice.name) for choice in matches) + raise ValueError(f"ambiguous argument {arg!r}: matches {names}") + else: + raise ValueError(f"unknown argument {arg!r}") + + def value_type(self, arg: str) -> Any: + return self.enum_type(arg).value + + def choices_str(self, sep: str=', ', fmt: str='{!r}') -> str: + """Return a user-formatted string of enum names""" + sortkey: Callable[[ET], Hashable] = getattr( + self.base, '_choices_sortkey', self._choices_sortkey, + ) + return sep.join( + fmt.format(choice.name.lower()) + for choice in sorted(self.base, key=sortkey) + ) + + def _choices_sortkey(self, choice: ET) -> Hashable: + return choice.name + + class ExceptHook: def __init__(self, logger: Optional[logging.Logger]=None) -> None: if logger is None: @@ -148,17 +207,8 @@ class LogLevel(enum.IntEnum): ERR = ERROR CRIT = CRITICAL - @classmethod - def from_arg(cls, arg: str) -> int: - try: - return cls[arg.upper()].value - except KeyError: - raise ValueError(f"unknown loglevel {arg!r}") from None - - @classmethod - def choices(cls) -> Iterable[str]: - for level in sorted(cls, key=operator.attrgetter('value')): - yield level.name.lower() + def _choices_sortkey(self) -> Hashable: + return self.value class SearchTerm(NamedTuple): @@ -250,14 +300,15 @@ Can specify a positive integer or a percentage of CPU cores. Default all cores. def add_loglevel_argument(parser: argparse.ArgumentParser, default: LogLevel=LogLevel.INFO) -> argparse.Action: + arg_enum = EnumArgument(LogLevel) return parser.add_argument( '--loglevel', metavar='LEVEL', default=default.value, - type=LogLevel.from_arg, + type=arg_enum.value_type, help="Show logs at this level and above." - f" Specify one of {', '.join(LogLevel.choices())}." - f" Default {default.name.lower()}.", + f" Specify one of {arg_enum.choices_str()}." + f" Default {default.name.lower()!r}.", ) def add_rewrite_rules_argument(parser: argparse.ArgumentParser) -> argparse.Action: