Changeset - 2b5cb0eca6e0
[Not reviewed]
0 3 0
Brett Smith - 4 years ago 2020-06-09 13:04:27
brettcsmith@brettcsmith.org
cliutil: Add bytes_output() and text_output() functions.
3 files changed with 106 insertions and 32 deletions:
0 comments (0 inline, 0 general)
conservancy_beancount/cliutil.py
Show inline comments
...
 
@@ -10,56 +10,63 @@ the Free Software Foundation, either version 3 of the License, or
 

	
 
This program is distributed in the hope that it will be useful,
 
but WITHOUT ANY WARRANTY; without even the implied warranty of
 
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
GNU Affero General Public License for more details.
 

	
 
You should have received a copy of the GNU Affero General Public License
 
along with this program.  If not, see <https://www.gnu.org/licenses/>."""
 

	
 
import argparse
 
import enum
 
import inspect
 
import io
 
import logging
 
import operator
 
import os
 
import pkg_resources
 
import re
 
import signal
 
import sys
 
import traceback
 
import types
 

	
 
from pathlib import Path
 

	
 
from . import data
 
from . import filters
 
from . import rtutil
 

	
 
from typing import (
 
    cast,
 
    Any,
 
    BinaryIO,
 
    Callable,
 
    IO,
 
    Iterable,
 
    NamedTuple,
 
    NoReturn,
 
    Optional,
 
    Sequence,
 
    TextIO,
 
    Type,
 
    Union,
 
)
 
from .beancount_types import (
 
    MetaKey,
 
)
 

	
 
OutputFile = Union[int, IO]
 

	
 
STDSTREAM_PATH = Path('-')
 
VERSION = pkg_resources.require(PKGNAME)[0].version
 

	
 
class ExceptHook:
 
    def __init__(self,
 
                 logger: Optional[logging.Logger]=None,
 
                 default_exitcode: int=3,
 
    ) -> None:
 
        if logger is None:
 
            logger = logging.getLogger()
 
        self.logger = logger
 
        self.default_exitcode = default_exitcode
 

	
...
 
@@ -238,12 +245,60 @@ def setup_logger(logger: Union[str, logging.Logger]='',
 
                 loglevel: int=logging.INFO,
 
                 stream: TextIO=sys.stderr,
 
                 fmt: str='%(name)s: %(levelname)s: %(message)s',
 
) -> logging.Logger:
 
    if isinstance(logger, str):
 
        logger = logging.getLogger(logger)
 
    formatter = logging.Formatter(fmt)
 
    handler = logging.StreamHandler(stream)
 
    handler.setFormatter(formatter)
 
    logger.addHandler(handler)
 
    logger.setLevel(loglevel)
 
    return logger
 

	
 
def bytes_output(path: Optional[Path]=None,
 
                 default: OutputFile=sys.stdout,
 
                 mode: str='w',
 
) -> BinaryIO:
 
    """Get a file-like object suitable for binary output
 

	
 
    If ``path`` is ``None`` or ``-``, returns a file-like object backed by
 
    ``default``. If ``default`` is a file descriptor or text IO object, this
 
    method returns a file-like object that writes to the same place.
 

	
 
    Otherwise, returns ``path.open(mode)``.
 
    """
 
    mode = f'{mode}b'
 
    if path is None or path == STDSTREAM_PATH:
 
        if isinstance(default, int):
 
            retval = open(default, mode)
 
        elif isinstance(default, TextIO):
 
            retval = default.buffer
 
        else:
 
            retval = default
 
    else:
 
        retval = path.open(mode)
 
    return cast(BinaryIO, retval)
 

	
 
def text_output(path: Optional[Path]=None,
 
                default: OutputFile=sys.stdout,
 
                mode: str='w',
 
                encoding: Optional[str]=None,
 
) -> TextIO:
 
    """Get a file-like object suitable for text output
 

	
 
    If ``path`` is ``None`` or ``-``, returns a file-like object backed by
 
    ``default``. If ``default`` is a file descriptor or binary IO object, this
 
    method returns a file-like object that writes to the same place.
 

	
 
    Otherwise, returns ``path.open(mode)``.
 
    """
 
    if path is None or path == STDSTREAM_PATH:
 
        if isinstance(default, int):
 
            retval = open(default, mode, encoding=encoding)
 
        elif isinstance(default, BinaryIO):
 
            retval = io.TextIOWrapper(default, encoding=encoding)
 
        else:
 
            retval = default
 
    else:
 
        retval = path.open(mode, encoding=encoding)
 
    return cast(TextIO, retval)
conservancy_beancount/reports/accrual.py
Show inline comments
...
 
@@ -111,25 +111,24 @@ import odf.table  # type:ignore[import]
 
import rt
 

	
 
from beancount.parser import printer as bc_printer
 

	
 
from . import core
 
from .. import cliutil
 
from .. import config as configmod
 
from .. import data
 
from .. import filters
 
from .. import rtutil
 

	
 
PROGNAME = 'accrual-report'
 
STANDARD_PATH = Path('-')
 

	
 
CompoundAmount = TypeVar('CompoundAmount', data.Amount, core.Balance)
 
PostGroups = Mapping[Optional[MetaValue], 'AccrualPostings']
 
RTObject = Mapping[str, str]
 
T = TypeVar('T')
 

	
 
logger = logging.getLogger('conservancy_beancount.reports.accrual')
 

	
 
class Sentinel:
 
    pass
 

	
 

	
...
 
@@ -668,46 +667,24 @@ filename for other reports.
 
        nargs=argparse.ZERO_OR_MORE,
 
        help="""Report on accruals that match this criteria. The format is
 
NAME=TERM. TERM is a link or word that must exist in a posting's NAME
 
metadata to match. A single ticket number is a shortcut for
 
`rt-id=rt:NUMBER`. Any other link, including an RT attachment link in
 
`TIK/ATT` format, is a shortcut for `invoice=LINK`.
 
""")
 
    args = parser.parse_args(arglist)
 
    if args.report_type is None and not args.search_terms:
 
        args.report_type = ReportType.AGING
 
    return args
 

	
 
def get_output_path(output_path: Optional[Path],
 
                    default_path: Path=STANDARD_PATH,
 
) -> Optional[Path]:
 
    if output_path is None:
 
        output_path = default_path
 
    if output_path == STANDARD_PATH:
 
        return None
 
    else:
 
        return output_path
 

	
 
def get_output_bin(path: Optional[Path], stdout: TextIO) -> BinaryIO:
 
    if path is None:
 
        return open(stdout.fileno(), 'wb')
 
    else:
 
        return path.open('wb')
 

	
 
def get_output_text(path: Optional[Path], stdout: TextIO) -> TextIO:
 
    if path is None:
 
        return stdout
 
    else:
 
        return path.open('w')
 

	
 
def main(arglist: Optional[Sequence[str]]=None,
 
         stdout: TextIO=sys.stdout,
 
         stderr: TextIO=sys.stderr,
 
         config: Optional[configmod.Config]=None,
 
) -> int:
 
    if cliutil.is_main_script(PROGNAME):
 
        global logger
 
        logger = logging.getLogger(PROGNAME)
 
        sys.excepthook = cliutil.ExceptHook(logger)
 
    args = parse_arguments(arglist)
 
    cliutil.setup_logger(logger, args.loglevel, stderr)
 
    if config is None:
...
 
@@ -753,39 +730,36 @@ def main(arglist: Optional[Sequence[str]]=None,
 
        } or groups
 

	
 
    if args.report_type is None:
 
        args.report_type = ReportType.default_for(groups)
 
    report: Optional[BaseReport] = None
 
    output_path: Optional[Path] = None
 
    if args.report_type is ReportType.AGING:
 
        rt_client = config.rt_client()
 
        if rt_client is None:
 
            logger.error("unable to generate aging report: RT client is required")
 
        else:
 
            now = datetime.datetime.now()
 
            default_path = Path(now.strftime('AgingReport_%Y-%m-%d_%H:%M.ods'))
 
            output_path = get_output_path(args.output_file, default_path)
 
            out_bin = get_output_bin(output_path, stdout)
 
            if args.output_file is None:
 
                args.output_file = Path(now.strftime('AgingReport_%Y-%m-%d_%H:%M.ods'))
 
                logger.info("Writing report to %s", args.output_file)
 
            out_bin = cliutil.bytes_output(args.output_file, stdout)
 
            report = AgingReport(rt_client, out_bin)
 
    elif args.report_type is ReportType.OUTGOING:
 
        rt_client = config.rt_client()
 
        if rt_client is None:
 
            logger.error("unable to generate outgoing report: RT client is required")
 
        else:
 
            output_path = get_output_path(args.output_file)
 
            out_file = get_output_text(output_path, stdout)
 
            out_file = cliutil.text_output(args.output_file, stdout)
 
            report = OutgoingReport(rt_client, out_file)
 
    else:
 
        output_path = get_output_path(args.output_file)
 
        out_file = get_output_text(output_path, stdout)
 
        out_file = cliutil.text_output(args.output_file, stdout)
 
        report = args.report_type.value(out_file)
 

	
 
    if report is None:
 
        returncode |= ReturnFlag.REPORT_ERRORS
 
    else:
 
        report.run(groups)
 
        if args.output_file != output_path:
 
            logger.info("Report saved to %s", output_path)
 
    return 0 if returncode == 0 else 16 + returncode
 

	
 
if __name__ == '__main__':
 
    exit(main())
tests/test_cliutil.py
Show inline comments
...
 
@@ -12,30 +12,36 @@
 
# GNU Affero General Public License for more details.
 
#
 
# You should have received a copy of the GNU Affero General Public License
 
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
 

	
 
import argparse
 
import errno
 
import io
 
import inspect
 
import logging
 
import os
 
import re
 
import sys
 
import traceback
 

	
 
import pytest
 

	
 
from pathlib import Path
 

	
 
from conservancy_beancount import cliutil
 

	
 
FILE_NAMES = ['-foobar', '-foo.bin']
 
STREAM_PATHS = [None, Path('-')]
 

	
 
class AlwaysEqual:
 
    def __eq__(self, other):
 
        return True
 

	
 

	
 
class MockTraceback:
 
    def __init__(self, stack=None, index=0):
 
        if stack is None:
 
            stack = inspect.stack(context=False)
 
        self._stack = stack
 
        self._index = index
 
        frame_record = self._stack[self._index]
...
 
@@ -48,24 +54,63 @@ class MockTraceback:
 
            return type(self)(self._stack, self._index + 1)
 
        except IndexError:
 
            return None
 

	
 

	
 
@pytest.fixture(scope='module')
 
def argparser():
 
    parser = argparse.ArgumentParser(prog='test_cliutil')
 
    cliutil.add_loglevel_argument(parser)
 
    cliutil.add_version_argument(parser)
 
    return parser
 

	
 
@pytest.mark.parametrize('path_name', FILE_NAMES)
 
def test_bytes_output_path(path_name, tmp_path):
 
    path = tmp_path / path_name
 
    stream = io.BytesIO()
 
    actual = cliutil.bytes_output(path, stream)
 
    assert actual is not stream
 
    assert str(actual.name) == str(path)
 
    assert 'w' in actual.mode
 
    assert 'b' in actual.mode
 

	
 
@pytest.mark.parametrize('path', STREAM_PATHS)
 
def test_bytes_output_stream(path):
 
    stream = io.BytesIO()
 
    actual = cliutil.bytes_output(path, stream)
 
    assert actual is stream
 

	
 
@pytest.mark.parametrize('func_name', [
 
    'bytes_output',
 
    'text_output',
 
])
 
def test_default_output(func_name):
 
    actual = getattr(cliutil, func_name)()
 
    assert actual.fileno() == sys.stdout.fileno()
 

	
 
@pytest.mark.parametrize('path_name', FILE_NAMES)
 
def test_text_output_path(path_name, tmp_path):
 
    path = tmp_path / path_name
 
    stream = io.StringIO()
 
    actual = cliutil.text_output(path, stream)
 
    assert actual is not stream
 
    assert str(actual.name) == str(path)
 
    assert 'w' in actual.mode
 

	
 
@pytest.mark.parametrize('path', STREAM_PATHS)
 
def test_text_output_stream(path):
 
    stream = io.StringIO()
 
    actual = cliutil.text_output(path, stream)
 
    assert actual is stream
 

	
 
@pytest.mark.parametrize('errnum', [
 
    errno.EACCES,
 
    errno.EPERM,
 
    errno.ENOENT,
 
])
 
def test_excepthook_oserror(errnum, caplog):
 
    error = OSError(errnum, os.strerror(errnum), 'TestFilename')
 
    with pytest.raises(SystemExit) as exc_check:
 
        cliutil.ExceptHook()(type(error), error, None)
 
    assert exc_check.value.args[0] == 4
 
    assert caplog.records
 
    for log in caplog.records:
0 comments (0 inline, 0 general)