Files @ cc2b3978011a
Branch filter:

Location: NPO-Accounting/import2ledger/import2ledger/hooks/ledger_entry.py

Brett Smith
tests: Pass amount to Template.render as string.

render is required to convert the string to Decimal for historical design
reasons. Passing the amount as a string verifies this behavior.
import collections
import datetime
import decimal
import functools
import io
import logging
import operator
import re
import tokenize

import babel.numbers

from . import HOOK_KINDS
from .. import errors, strparse

logger = logging.getLogger('import2ledger.hooks.ledger_entry')

class TokenTransformer:
    def __init__(self, source):
        try:
            source = source.readline
        except AttributeError:
            pass
        self.in_tokens = tokenize.tokenize(source)

    @classmethod
    def from_bytes(cls, b):
        return cls(io.BytesIO(b).readline)

    @classmethod
    def from_str(cls, s, encoding='utf-8'):
        return cls.from_bytes(s.encode(encoding))

    def __iter__(self):
        for ttype, tvalue, _, _, _ in self.in_tokens:
            try:
                transformer = getattr(self, 'transform_' + tokenize.tok_name[ttype])
            except AttributeError:
                raise ValueError("{} token {!r} not supported".format(ttype, tvalue))
            yield from transformer(ttype, tvalue)

    def _noop_transformer(self, ttype, tvalue):
        yield (ttype, tvalue)

    transform_ENDMARKER = _noop_transformer
    transform_NEWLINE = _noop_transformer

    def transform_ENCODING(self, ttype, tvalue):
        self.in_encoding = tvalue
        return self._noop_transformer(ttype, tvalue)

    def transform(self):
        out_bytes = tokenize.untokenize(self)
        return out_bytes.decode(self.in_encoding)


class AmountTokenTransformer(TokenTransformer):
    SUPPORTED_NAMES = frozenset([
        'if',
        'else',
        'and',
        'or',
        'not',
        'in',
    ])
    SUPPORTED_OPS = frozenset([
        '(',
        ')',
        '+',
        '-',
        '*',
        '/',
        '==',
        '!=',
        '<',
        '<=',
        '>',
        '>=',
    ])

    def __iter__(self):
        tokens = super().__iter__()
        for token in tokens:
            yield token
            if token[0] == tokenize.NAME:
                break
        else:
            raise ValueError("no amount in expression")
        yield from tokens

    def transform_NAME(self, ttype, tvalue):
        if tvalue in self.SUPPORTED_NAMES:
            yield from self._noop_transformer(ttype, tvalue)
        else:
            raise ValueError("unsupported bare word {!r}".format(tvalue))

    def transform_NUMBER(self, ttype, tvalue):
        yield (tokenize.NAME, 'Decimal')
        yield (tokenize.OP, '(')
        yield (tokenize.STRING, repr(tvalue))
        yield (tokenize.OP, ')')

    def transform_OP(self, ttype, tvalue):
        if tvalue == '{':
            try:
                name_type, name_value, _, _, _ = next(self.in_tokens)
                close_type, close_value, _, _, _ = next(self.in_tokens)
                if (name_type != tokenize.NAME
                    or name_value != name_value.lower()
                    or close_type != tokenize.OP
                    or close_value != '}'):
                    raise ValueError()
            except (StopIteration, ValueError):
                raise ValueError("opening { does not name variable")
            yield (tokenize.NAME, name_value)
        elif tvalue in self.SUPPORTED_OPS:
            yield from self._noop_transformer(ttype, tvalue)
        else:
            raise ValueError("unsupported operator {!r}".format(tvalue))

    transform_STRING = TokenTransformer._noop_transformer


class AccountSplitter:
    EVAL_GLOBALS = {
        'Decimal': decimal.Decimal,
    }
    TARGET_LINE_LEN = 78
    # -4 because that's how many spaces prefix an account line.
    TARGET_ACCTLINE_LEN = TARGET_LINE_LEN - 4

    def __init__(self, signed_currencies, signed_currency_fmt, unsigned_currency_fmt,
                 template_name):
        self.splits = []
        self.metadata = []
        self.signed_currency_fmt = signed_currency_fmt
        self.unsigned_currency_fmt = unsigned_currency_fmt
        self.signed_currencies = set(signed_currencies)
        self.template_name = template_name
        self._last_template_vars = object()

    def is_empty(self):
        return not self.splits

    def add(self, account, amount_expr):
        try:
            clean_expr = AmountTokenTransformer.from_str(amount_expr).transform()
            compiled_expr = compile(clean_expr, self.template_name, 'eval')
        except (SyntaxError, tokenize.TokenError, ValueError) as error:
            raise errors.UserInputConfigurationError(error.args[0], amount_expr)
        else:
            self.splits.append((account, compiled_expr))
            self.metadata.append('')

    def set_metadata(self, metadata_s):
        self.metadata[-1] = metadata_s

    def _currency_decimal(self, amount, currency):
        return decimal.Decimal(babel.numbers.format_currency(amount, currency, '###0.###'))

    def _balance_amounts(self, amounts, to_amount):
        cmp_func = operator.lt if to_amount > 0 else operator.gt
        should_balance = functools.partial(cmp_func, 0)
        remainder = to_amount
        balance_index = None
        for index, (_, amount) in enumerate(amounts):
            if should_balance(amount):
                remainder -= amount
                balance_index = index
        if balance_index is None:
            pass
        elif (abs(remainder) / abs(to_amount)) >= decimal.Decimal('.1'):
            raise errors.UserInputConfigurationError(
                "template can't balance amounts to {}".format(to_amount),
                self.template_name,
            )
        else:
            account_name, start_amount = amounts[balance_index]
            amounts[balance_index] = (account_name, start_amount + remainder)

    def _build_amounts(self, template_vars):
        try:
            amounts = [
                (account,
                 self._currency_decimal(eval(amount_expr, self.EVAL_GLOBALS, template_vars),
                                        template_vars['currency']),
                ) for account, amount_expr in self.splits
            ]
        except (ArithmeticError, NameError, TypeError, ValueError) as error:
            raise errors.UserInputConfigurationError(
                "{}: {}".format(type(error).__name__, error),
                "template {!r}".format(self.template_name)
            ) from error
        if sum(amt for _, amt in amounts) != 0:
            self._balance_amounts(amounts, template_vars['amount'])
            self._balance_amounts(amounts, -template_vars['amount'])
        return amounts

    def _iter_splits(self, template_vars):
        amounts = self._build_amounts(template_vars)
        if template_vars['currency'] in self.signed_currencies:
            amt_fmt = self.signed_currency_fmt
        else:
            amt_fmt = self.unsigned_currency_fmt
        for (account, amount), metadata in zip(amounts, self.metadata):
            if amount == 0:
                yield ''
            else:
                account_s = account.format_map(template_vars)
                amount_s = babel.numbers.format_currency(amount, template_vars['currency'], amt_fmt)
                sep_len = max(2, self.TARGET_ACCTLINE_LEN - len(account_s) - len(amount_s))
                yield '\n    {}{}{}{}'.format(
                    account_s, ' ' * sep_len, amount_s,
                    metadata.format_map(template_vars),
                )

    def render_next(self, template_vars):
        if template_vars is not self._last_template_vars:
            self._split_iter = self._iter_splits(template_vars)
            self._last_template_vars = template_vars
        return next(self._split_iter)


class Template:
    ACCOUNT_SPLIT_RE = re.compile(r'(?:\t|  )\s*')
    DATE_FMT = '%Y/%m/%d'
    PAYEE_LINE_RE = re.compile(r'^\{(\w*_)*date\}\s')
    SIGNED_CURRENCY_FMT = '¤#,##0.###;¤-#,##0.###'
    UNSIGNED_CURRENCY_FMT = '#,##0.### ¤¤'

    def __init__(self, template_s, signed_currencies=frozenset(),
                 date_fmt=DATE_FMT,
                 signed_currency_fmt=SIGNED_CURRENCY_FMT,
                 unsigned_currency_fmt=UNSIGNED_CURRENCY_FMT,
                 template_name='<template>'):
        self.date_fmt = date_fmt
        self.date_field = 'date'
        self.splitter = AccountSplitter(
            signed_currencies, signed_currency_fmt, unsigned_currency_fmt, template_name)

        lines = self._template_lines(template_s)
        self.format_funcs = []
        try:
            self.format_funcs.append(next(lines).format_map)
        except StopIteration:
            return
        metadata = []
        for line in lines:
            if line.startswith(';'):
                metadata.append(line)
            else:
                self._add_str_func(metadata)
                metadata = []
                line = line.strip()
                match = self.ACCOUNT_SPLIT_RE.search(line)
                if match is None:
                    raise errors.UserInputError("no amount expression found", line)
                account = line[:match.start()]
                amount_expr = line[match.end():]
                self.splitter.add(account, amount_expr)
                self.format_funcs.append(self.splitter.render_next)
        self._add_str_func(metadata)
        self.format_funcs.append('\n'.format_map)

    def _nonblank_lines(self, s):
        for line in s.splitlines(True):
            line = line.strip()
            if line:
                yield line

    def _template_lines(self, template_s):
        lines = self._nonblank_lines(template_s)
        try:
            line1 = next(lines)
        except StopIteration:
            return
        match = self.PAYEE_LINE_RE.match(line1)
        if match:
            self.date_field = match.group(0)[1:-2]
            yield '\n' + line1
        else:
            yield '\n{date} {payee}'
            yield line1
        yield from lines

    def _add_str_func(self, str_seq):
        str_flat = ''.join('\n      ' + s for s in str_seq)
        if not str_flat:
            pass
        elif self.splitter.is_empty():
            self.format_funcs.append(str_flat.format_map)
        else:
            self.splitter.set_metadata(str_flat)

    def render(self, template_vars):
        # template_vars must have these keys.  Raise a KeyError if not.
        template_vars['currency']
        template_vars['payee']
        if template_vars.get(self.date_field) is None:
            raise errors.UserInputConfigurationError(
                "entry needs {} field but that's not set by the importer".format(
                    self.date_field,
                ), self.splitter.template_name)
        render_vars = {
            'amount': strparse.currency_decimal(template_vars['amount']),
        }
        for key, value in template_vars.items():
            if value is not None and (key == 'date' or key.endswith('_date')):
                render_vars[key] = value.strftime(self.date_fmt)
        all_vars = collections.ChainMap(render_vars, template_vars)
        return ''.join(f(all_vars) for f in self.format_funcs)

    def is_empty(self):
        return not self.format_funcs


class LedgerEntryHook:
    KIND = HOOK_KINDS.OUTPUT

    def __init__(self, config):
        self.config = config

    @staticmethod
    @functools.lru_cache()
    def _load_template(config, section_name, config_key):
        section_config = config.get_section(section_name)
        try:
            template_s = section_config[config_key]
        except KeyError:
            raise errors.UserInputConfigurationError(
                "Ledger template not defined in [{}]".format(section_name),
                config_key,
            )
        return Template(
            template_s,
            date_fmt=section_config['date_format'],
            signed_currencies=[code.strip().upper() for code in section_config['signed_currencies'].split(',')],
            signed_currency_fmt=section_config['signed_currency_format'],
            unsigned_currency_fmt=section_config['unsigned_currency_format'],
            template_name=config_key,
        )

    def run(self, entry_data):
        try:
            template_key = entry_data['ledger template']
        except KeyError:
            template_key = '{} {} ledger entry'.format(
                strparse.rslice_words(entry_data['importer_module'], -1, '.', 1),
                entry_data['importer_class'][:-8].lower(),
            )
        try:
            template = self._load_template(self.config, None, template_key)
        except errors.UserInputConfigurationError as error:
            if error.strerror.startswith('Ledger template not defined '):
                have_template = False
            else:
                raise
        else:
            have_template = not template.is_empty()
        if not have_template:
            logger.warning("no Ledger template defined as %r", template_key)
        else:
            with self.config.open_output_file() as out_file:
                print(template.render(entry_data), file=out_file, end='')