Files @ 50376b2b78f0
Branch filter:

Location: NPO-Accounting/import2ledger/import2ledger/__main__.py

Brett Smith
README: Fix typo in Benevity Ledger entry key.
import collections
import contextlib
import decimal
import logging
import sys

from . import config, errors, hooks, importers

logger = logging.getLogger('import2ledger')

class FileImporter:
    def __init__(self, config, stdout):
        self.config = config
        self.importers = list(importers.load_all())
        self.hooks = [hook(config) for hook in hooks.load_all()]
        self.stdout = stdout

    def import_file(self, in_file, in_path=None):
        if in_path is None:
            in_path = pathlib.Path(in_file.name)
        importers = []
        for importer in self.importers:
            in_file.seek(0)
            if importer.can_import(in_file):
                importers.append(importer)
        if not importers:
            raise errors.UserInputFileError("no importers available", in_file.name)
        source_vars = {
            'source_abspath': in_path.absolute().as_posix(),
            'source_absdir': in_path.absolute().parent.as_posix(),
            'source_dir': in_path.parent.as_posix(),
            'source_name': in_path.name,
            'source_path': in_path.as_posix(),
            'source_stem': in_path.stem,
        }
        for importer in importers:
            source_vars['importer_class'] = importer.__name__
            source_vars['importer_module'] = importer.__module__
            in_file.seek(0)
            for entry_data in importer(in_file):
                entry_data = collections.ChainMap(entry_data, source_vars)
                for hook in self.hooks:
                    hook_retval = hook.run(entry_data)
                    if hook_retval is None:
                        pass
                    elif hook_retval is False:
                        break
                    else:
                        entry_data = hook_retval

    def import_path(self, in_path):
        if in_path is None:
            raise errors.UserInputFileError("only seekable files are supported", '<stdin>')
        with in_path.open(errors='replace') as in_file:
            if not in_file.seekable():
                raise errors.UserInputFileError("only seekable files are supported", in_path)
            return self.import_file(in_file, in_path)

    def import_paths(self, path_seq):
        for in_path in path_seq:
            try:
                retval = self.import_path(in_path)
            except (OSError, errors.UserInputError) as error:
                yield in_path, error
            else:
                yield in_path, retval


def setup_logger(logger, main_config, stream):
    formatter = logging.Formatter('%(name)s: %(levelname)s: %(message)s')
    handler = logging.StreamHandler(stream)
    handler.setFormatter(formatter)
    logger.addHandler(handler)

def decimal_context(base=decimal.BasicContext):
    context = base.copy()
    context.rounding = decimal.ROUND_HALF_EVEN
    context.traps = {
        decimal.Clamped: True,
        decimal.DivisionByZero: True,
        decimal.FloatOperation: True,
        decimal.Inexact: False,
        decimal.InvalidOperation: True,
        decimal.Overflow: True,
        decimal.Rounded: False,
        decimal.Subnormal: True,
        decimal.Underflow: True,
    }
    return context

def main(arglist=None, stdout=sys.stdout, stderr=sys.stderr):
    try:
        my_config = config.Configuration(arglist, stdout, stderr)
    except errors.UserInputError as error:
        my_config.error("{}: {!r}".format(error.strerror, error.user_input))
        return 3
    setup_logger(logger, my_config, stderr)
    with decimal.localcontext(decimal_context()):
        importer = FileImporter(my_config, stdout)
        failures = 0
        for input_path, error in importer.import_paths(my_config.args.input_paths):
            if error is None:
                logger.info("%s: imported", input_path)
            else:
                logger.warning("%s: failed to import: %s", input_path or error.path, error)
                failures += 1
    if failures == 0:
        return 0
    else:
        return min(10 + failures, 99)

if __name__ == '__main__':
    exit(main())