diff --git a/conservancy_beancount/cliutil.py b/conservancy_beancount/cliutil.py new file mode 100644 index 0000000000000000000000000000000000000000..58340e72de2ebc23f38cf892ae5b085f543af305 --- /dev/null +++ b/conservancy_beancount/cliutil.py @@ -0,0 +1,148 @@ +"""cliutil - Utilities for CLI tools""" +PKGNAME = 'conservancy_beancount' +LICENSE = """ +Copyright © 2020 Brett Smith + +This program is free software: you can redistribute it and/or modify +it under the terms of the GNU Affero General Public License as published by +the Free Software Foundation, either version 3 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU Affero General Public License for more details. + +You should have received a copy of the GNU Affero General Public License +along with this program. If not, see .""" + +import argparse +import enum +import logging +import operator +import os +import pkg_resources +import signal +import sys +import traceback +import types + +from typing import ( + Any, + Iterable, + NoReturn, + Optional, + Sequence, + TextIO, + Type, + Union, +) + +VERSION = pkg_resources.require(PKGNAME)[0].version + +class ExceptHook: + def __init__(self, + logger: Optional[logging.Logger]=None, + default_exitcode: int=3, + ) -> None: + if logger is None: + logger = logging.getLogger() + self.logger = logger + self.default_exitcode = default_exitcode + + def __call__(self, + exc_type: Type[BaseException], + exc_value: BaseException, + exc_tb: types.TracebackType, + ) -> NoReturn: + exitcode = self.default_exitcode + if isinstance(exc_value, KeyboardInterrupt): + signal.signal(signal.SIGINT, signal.SIG_DFL) + os.kill(0, signal.SIGINT) + signal.pause() + elif isinstance(exc_value, OSError): + exitcode += 1 + msg = "I/O error: {e.filename}: {e.strerror}".format(e=exc_value) + else: + parts = [type(exc_value).__name__, *exc_value.args] + msg = "internal " + ": ".join(parts) + self.logger.critical(msg) + self.logger.debug( + ''.join(traceback.format_exception(exc_type, exc_value, exc_tb)), + ) + raise SystemExit(exitcode) + + +class InfoAction(argparse.Action): + def __call__(self, + parser: argparse.ArgumentParser, + namespace: argparse.Namespace, + values: Union[Sequence[Any], str, None]=None, + option_string: Optional[str]=None, + ) -> NoReturn: + if isinstance(self.const, str): + info = self.const + exitcode = 0 + else: + info, exitcode = self.const + print(info) + raise SystemExit(exitcode) + + +class LogLevel(enum.IntEnum): + DEBUG = logging.DEBUG + INFO = logging.INFO + WARNING = logging.WARNING + ERROR = logging.ERROR + CRITICAL = logging.CRITICAL + WARN = WARNING + ERR = ERROR + CRIT = CRITICAL + + @classmethod + def from_arg(cls, arg: str) -> int: + try: + return cls[arg.upper()].value + except KeyError: + raise ValueError(f"unknown loglevel {arg!r}") from None + + @classmethod + def choices(cls) -> Iterable[str]: + for level in sorted(cls, key=operator.attrgetter('value')): + yield level.name.lower() + +def add_loglevel_argument(parser: argparse.ArgumentParser, + default: LogLevel=LogLevel.INFO) -> argparse.Action: + return parser.add_argument( + '--loglevel', + metavar='LEVEL', + default=default.value, + type=LogLevel.from_arg, + help="Show logs at this level and above." + f" Specify one of {', '.join(LogLevel.choices())}." + f" Default {default.name.lower()}.", + ) + +def add_version_argument(parser: argparse.ArgumentParser) -> argparse.Action: + progname = parser.prog or sys.argv[0] + return parser.add_argument( + '--version', '--copyright', '--license', + action=InfoAction, + nargs=0, + const=f"{progname} version {VERSION}\n{LICENSE}", + help="Show program version and license information", + ) + +def setup_logger(logger: Union[str, logging.Logger]='', + loglevel: int=logging.INFO, + stream: TextIO=sys.stderr, + fmt: str='%(name)s: %(levelname)s: %(message)s', +) -> logging.Logger: + if isinstance(logger, str): + logger = logging.getLogger(logger) + formatter = logging.Formatter(fmt) + handler = logging.StreamHandler(stream) + handler.setFormatter(formatter) + logger.addHandler(handler) + logger.setLevel(loglevel) + return logger