Files @ 18eebbc0ed28
Branch filter:

Location: NPO-Accounting/import2ledger/import2ledger/__main__.py

Brett Smith
hooks: run() return value controls processing of entry data.

Instead of using in-band signaling with the entry_data dict.
I don't know why I didn't think of this in the first place.
import collections
import contextlib
import logging
import sys

from . import config, errors, hooks, importers

logger = logging.getLogger('import2ledger')

class FileImporter:
    def __init__(self, config, stdout):
        self.config = config
        self.importers = list(importers.load_all())
        self.hooks = [hook(config) for hook in hooks.load_all()]
        self.stdout = stdout

    def import_file(self, in_file, in_path=None):
        if in_path is None:
            in_path = pathlib.Path(in_file.name)
        importers = []
        for importer in self.importers:
            in_file.seek(0)
            if importer.can_import(in_file):
                try:
                    template = self.config.get_template(importer.TEMPLATE_KEY)
                except errors.UserInputConfigurationError as error:
                    if error.strerror.startswith('template not defined '):
                        have_template = False
                    else:
                        raise
                else:
                    have_template = not template.is_empty()
                if have_template:
                    importers.append((importer, template))
        if not importers:
            raise errors.UserInputFileError("no importers available", in_file.name)
        source_vars = {
            'source_abspath': in_path.absolute().as_posix(),
            'source_name': in_path.name,
            'source_path': in_path.as_posix(),
        }
        with contextlib.ExitStack() as exit_stack:
            output_path = self.config.get_output_path()
            if output_path is None:
                out_file = self.stdout
            else:
                out_file = exit_stack.enter_context(output_path.open('a'))
            for importer, template in importers:
                default_date = self.config.get_default_date()
                in_file.seek(0)
                for entry_data in importer(in_file):
                    for hook in self.hooks:
                        hook_retval = hook.run(entry_data)
                        if hook_retval is None:
                            pass
                        elif hook_retval is False:
                            break
                        else:
                            entry_data = hook_retval
                    else:
                        render_vars = collections.ChainMap(entry_data, source_vars)
                        print(template.render(render_vars), file=out_file, end='')

    def import_path(self, in_path):
        if in_path is None:
            raise errors.UserInputFileError("only seekable files are supported", '<stdin>')
        with in_path.open(errors='replace') as in_file:
            if not in_file.seekable():
                raise errors.UserInputFileError("only seekable files are supported", in_path)
            return self.import_file(in_file, in_path)

    def import_paths(self, path_seq):
        for in_path in path_seq:
            try:
                retval = self.import_path(in_path)
            except (OSError, errors.UserInputError) as error:
                yield in_path, error
            else:
                yield in_path, retval


def setup_logger(logger, main_config, stream):
    formatter = logging.Formatter('%(name)s: %(levelname)s: %(message)s')
    handler = logging.StreamHandler(stream)
    handler.setFormatter(formatter)
    logger.addHandler(handler)

def main(arglist=None, stdout=sys.stdout, stderr=sys.stderr):
    try:
        my_config = config.Configuration(arglist)
    except errors.UserInputError as error:
        my_config.error("{}: {!r}".format(error.strerror, error.user_input))
        return 3
    setup_logger(logger, my_config, stderr)
    importer = FileImporter(my_config, stdout)
    failures = 0
    for input_path, error in importer.import_paths(my_config.args.input_paths):
        if error is None:
            logger.info("%s: imported", input_path)
        else:
            logger.warning("%s: failed to import: %s", input_path or error.path, error)
            failures += 1
    if failures == 0:
        return 0
    else:
        return min(10 + failures, 99)

if __name__ == '__main__':
    exit(main())