Changeset - e26dffa21428
[Not reviewed]
0 2 1
Brett Smith - 4 years ago 2020-06-09 13:04:27
brettcsmith@brettcsmith.org
reports: Add normalize_amount_func() function.
3 files changed with 94 insertions and 10 deletions:
0 comments (0 inline, 0 general)
conservancy_beancount/reports/accrual.py
Show inline comments
...
 
@@ -32,290 +32,289 @@ a negative number of how many years back to search::
 

	
 
If you want to further limit what accruals are reported, you can match on
 
other metadata by passing additional arguments in ``name=value`` format.
 
You can pass any number of search terms. For example::
 

	
 
    # Report accruals associated with RT#1230 and Jane Doe
 
    accrual-report 1230 entity=Doe-Jane
 

	
 
accrual-report will automatically decide what kind of report to generate
 
from the search terms you provide and the results they return. If you pass
 
no search terms, it generates an aging report. If your search terms match a
 
single outstanding payable, it writes an outgoing approval report.
 
Otherwise, it writes a basic balance report. You can specify what report
 
type you want with the ``--report-type`` option::
 

	
 
    # Write an outgoing approval report for all outstanding accruals for
 
    # Jane Doe, even if there's more than one
 
    accrual-report --report-type outgoing entity=Doe-Jane
 
    # Write an aging report for a specific project
 
    accrual-report --report-type aging project=ProjectName
 
"""
 
# Copyright © 2020  Brett Smith
 
#
 
# This program is free software: you can redistribute it and/or modify
 
# it under the terms of the GNU Affero General Public License as published by
 
# the Free Software Foundation, either version 3 of the License, or
 
# (at your option) any later version.
 
#
 
# This program is distributed in the hope that it will be useful,
 
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
# GNU Affero General Public License for more details.
 
#
 
# You should have received a copy of the GNU Affero General Public License
 
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
 

	
 
import argparse
 
import collections
 
import datetime
 
import enum
 
import logging
 
import operator
 
import re
 
import sys
 
import urllib.parse as urlparse
 

	
 
from pathlib import Path
 

	
 
from typing import (
 
    cast,
 
    Any,
 
    BinaryIO,
 
    Callable,
 
    Dict,
 
    Iterable,
 
    Iterator,
 
    FrozenSet,
 
    List,
 
    Mapping,
 
    NamedTuple,
 
    Optional,
 
    Sequence,
 
    Set,
 
    TextIO,
 
    Tuple,
 
    TypeVar,
 
    Union,
 
)
 
from ..beancount_types import (
 
    Entries,
 
    Error,
 
    Errors,
 
    MetaKey,
 
    MetaValue,
 
    Transaction,
 
)
 

	
 
import odf.style  # type:ignore[import]
 
import odf.table  # type:ignore[import]
 
import rt
 

	
 
from beancount.parser import printer as bc_printer
 

	
 
from . import core
 
from .. import cliutil
 
from .. import config as configmod
 
from .. import data
 
from .. import filters
 
from .. import rtutil
 

	
 
PROGNAME = 'accrual-report'
 
STANDARD_PATH = Path('-')
 

	
 
CompoundAmount = TypeVar('CompoundAmount', data.Amount, core.Balance)
 
PostGroups = Mapping[Optional[MetaValue], 'AccrualPostings']
 
RTObject = Mapping[str, str]
 
T = TypeVar('T')
 

	
 
logger = logging.getLogger('conservancy_beancount.reports.accrual')
 

	
 
class Sentinel:
 
    pass
 

	
 

	
 
class Account(NamedTuple):
 
    name: str
 
    norm_func: Callable[[CompoundAmount], CompoundAmount]
 
    aging_thresholds: Sequence[int]
 

	
 

	
 
class AccrualAccount(enum.Enum):
 
    # Note the aging report uses the same order accounts are defined here.
 
    # See AgingODS.start_spreadsheet().
 
    RECEIVABLE = Account(
 
        'Assets:Receivable', lambda bal: bal, [365, 120, 90, 60],
 
    )
 
    PAYABLE = Account(
 
        'Liabilities:Payable', operator.neg, [365, 90, 60, 30],
 
    )
 
    RECEIVABLE = Account('Assets:Receivable', [365, 120, 90, 60])
 
    PAYABLE = Account('Liabilities:Payable', [365, 90, 60, 30])
 

	
 
    @classmethod
 
    def account_names(cls) -> Iterator[str]:
 
        return (acct.value.name for acct in cls)
 

	
 
    @classmethod
 
    def by_account(cls, name: data.Account) -> 'AccrualAccount':
 
        for account in cls:
 
            if name.is_under(account.value.name):
 
                return account
 
        raise ValueError(f"unrecognized account {name!r}")
 

	
 
    @classmethod
 
    def classify(cls, related: core.RelatedPostings) -> 'AccrualAccount':
 
        for account in cls:
 
            account_name = account.value.name
 
            if all(post.account.is_under(account_name) for post in related):
 
                return account
 
        raise ValueError("unrecognized account set in related postings")
 

	
 
    @property
 
    def normalize_amount(self) -> Callable[[T], T]:
 
        return core.normalize_amount_func(self.value.name)
 

	
 

	
 
class AccrualPostings(core.RelatedPostings):
 
    def _meta_getter(key: MetaKey) -> Callable[[data.Posting], MetaValue]:  # type:ignore[misc]
 
        def meta_getter(post: data.Posting) -> MetaValue:
 
            return post.meta.get(key)
 
        return meta_getter
 

	
 
    _FIELDS: Dict[str, Callable[[data.Posting], MetaValue]] = {
 
        'account': operator.attrgetter('account'),
 
        'contract': _meta_getter('contract'),
 
        'invoice': _meta_getter('invoice'),
 
        'purchase_order': _meta_getter('purchase-order'),
 
    }
 
    INCONSISTENT = Sentinel()
 
    __slots__ = (
 
        'accrual_type',
 
        'accrued_entities',
 
        'end_balance',
 
        'paid_entities',
 
        'account',
 
        'accounts',
 
        'contract',
 
        'contracts',
 
        'invoice',
 
        'invoices',
 
        'purchase_order',
 
        'purchase_orders',
 
    )
 

	
 
    def __init__(self,
 
                 source: Iterable[data.Posting]=(),
 
                 *,
 
                 _can_own: bool=False,
 
    ) -> None:
 
        super().__init__(source, _can_own=_can_own)
 
        # The following type declarations tell mypy about values set in the for
 
        # loop that are important enough to be referenced directly elsewhere.
 
        self.account: Union[data.Account, Sentinel]
 
        self.invoice: Union[MetaValue, Sentinel]
 
        for name, get_func in self._FIELDS.items():
 
            values = frozenset(get_func(post) for post in self)
 
            setattr(self, f'{name}s', values)
 
            if len(values) == 1:
 
                one_value = next(iter(values))
 
            else:
 
                one_value = self.INCONSISTENT
 
            setattr(self, name, one_value)
 
        if self.account is self.INCONSISTENT:
 
            self.accrual_type: Optional[AccrualAccount] = None
 
            self.end_balance = self.balance_at_cost()
 
            self.accrued_entities = self._collect_entities()
 
            self.paid_entities = self.accrued_entities
 
        else:
 
            self.accrual_type = AccrualAccount.classify(self)
 
            accrual_acct: Account = self.accrual_type.value
 
            norm_func = accrual_acct.norm_func
 
            norm_func = self.accrual_type.normalize_amount
 
            self.end_balance = norm_func(self.balance_at_cost())
 
            self.accrued_entities = self._collect_entities(
 
                lambda post: norm_func(post.units).number > 0,
 
            )
 
            self.paid_entities = self._collect_entities(
 
                lambda post: norm_func(post.units).number < 0,
 
            )
 

	
 
    def _collect_entities(self,
 
                          pred: Callable[[data.Posting], bool]=bool,
 
                          default: str='<empty>',
 
    ) -> FrozenSet[MetaValue]:
 
        return frozenset(
 
            post.meta.get('entity') or default
 
            for post in self if pred(post)
 
        )
 

	
 
    def entities(self) -> Iterator[MetaValue]:
 
        yield from self.accrued_entities
 
        yield from self.paid_entities.difference(self.accrued_entities)
 

	
 
    def make_consistent(self) -> Iterator[Tuple[MetaValue, 'AccrualPostings']]:
 
        account_ok = isinstance(self.account, str)
 
        if len(self.accrued_entities) == 1:
 
            entity = next(iter(self.accrued_entities))
 
        else:
 
            entity = None
 
        # `'/' in self.invoice` is just our heuristic to ensure that the
 
        # invoice metadata is "unique enough," and not just a placeholder
 
        # value like "FIXME". It can be refined if needed.
 
        invoice_ok = isinstance(self.invoice, str) and '/' in self.invoice
 
        if account_ok and entity is not None and invoice_ok:
 
            yield (self.invoice, self)
 
            return
 
        groups = collections.defaultdict(list)
 
        for post in self:
 
            post_invoice = self.invoice if invoice_ok else (
 
                post.meta.get('invoice') or 'BlankInvoice'
 
            )
 
            post_entity = entity if entity is not None else (
 
                post.meta.get('entity') or 'BlankEntity'
 
            )
 
            groups[f'{post.account} {post_invoice} {post_entity}'].append(post)
 
        type_self = type(self)
 
        for group_key, posts in groups.items():
 
            yield group_key, type_self(posts, _can_own=True)
 

	
 
    def report_inconsistencies(self) -> Iterable[Error]:
 
        for field_name, get_func in self._FIELDS.items():
 
            if getattr(self, field_name) is self.INCONSISTENT:
 
                for post in self:
 
                    errmsg = 'inconsistent {} for invoice {}: {}'.format(
 
                        field_name.replace('_', '-'),
 
                        self.invoice or "<none>",
 
                        get_func(post),
 
                    )
 
                    yield Error(post.meta, errmsg, post.meta.txn)
 
        costs = collections.defaultdict(set)
 
        for post in self:
 
            costs[post.units.currency].add(post.cost)
 
        for code, currency_costs in costs.items():
 
            if len(currency_costs) > 1:
 
                for post in self:
 
                    if post.units.currency == code:
 
                        errmsg = 'inconsistent cost for invoice {}: {}'.format(
 
                            self.invoice or "<none>", post.cost,
 
                        )
 
                        yield Error(post.meta, errmsg, post.meta.txn)
 

	
 
    def is_paid(self, default: Optional[bool]=None) -> Optional[bool]:
 
        if self.accrual_type is None:
 
            return default
 
        else:
 
            return self.end_balance.le_zero()
 

	
 
    def is_zero(self, default: Optional[bool]=None) -> Optional[bool]:
 
        if self.accrual_type is None:
 
            return default
 
        else:
 
            return self.end_balance.is_zero()
 

	
 
    def since_last_nonzero(self) -> 'AccrualPostings':
 
        for index, (post, balance) in enumerate(self.iter_with_balance()):
 
            if balance.is_zero():
 
                start_index = index
 
        try:
 
            empty = start_index == index
 
        except NameError:
 
            empty = True
 
        return self if empty else self[start_index + 1:]
 

	
 

	
 
class BaseReport:
 
    def __init__(self, out_file: TextIO) -> None:
 
        self.out_file = out_file
 
        self.logger = logger.getChild(type(self).__name__)
...
 
@@ -360,193 +359,193 @@ class AgingODS(core.BaseODS[AccrualPostings, Optional[data.Account]]):
 
        )
 
        self.style_widecol.setAttribute('family', 'table-column')
 
        self.style_widecol.addElement(odf.style.TableColumnProperties(
 
            columnwidth='1.25in',
 
        ))
 

	
 
    def section_key(self, row: AccrualPostings) -> Optional[data.Account]:
 
        if isinstance(row.account, str):
 
            return row.account
 
        else:
 
            return None
 

	
 
    def start_spreadsheet(self) -> None:
 
        for accrual_type in AccrualAccount:
 
            self.use_sheet(accrual_type.name.title())
 
            for index in range(self.COL_COUNT):
 
                stylename = self.style_widecol if index else ''
 
                self.sheet.addElement(odf.table.TableColumn(stylename=stylename))
 
            self.add_row(*(
 
                self.string_cell(name, stylename=self.style_bold)
 
                for name in self.COLUMNS
 
            ))
 
            self.lock_first_row()
 

	
 
    def start_section(self, key: Optional[data.Account]) -> None:
 
        if key is None:
 
            return
 
        self.age_thresholds = list(AccrualAccount.by_account(key).value.aging_thresholds)
 
        self.age_balances = [core.MutableBalance() for _ in self.age_thresholds]
 
        accrual_date = self.date - datetime.timedelta(days=self.age_thresholds[-1])
 
        acct_parts = key.split(':')
 
        self.use_sheet(acct_parts[1])
 
        self.add_row()
 
        self.add_row(self.string_cell(
 
            f"{' '.join(acct_parts[2:])} {acct_parts[1]} Aging Report"
 
            f" Accrued by {accrual_date.isoformat()} Unpaid by {self.date.isoformat()}",
 
            stylename=self.merge_styles(self.style_bold, self.style_centertext),
 
            numbercolumnsspanned=self.COL_COUNT,
 
        ))
 
        self.add_row()
 

	
 
    def end_section(self, key: Optional[data.Account]) -> None:
 
        if key is None:
 
            return
 
        self.add_row()
 
        text_style = self.merge_styles(self.style_bold, self.style_endtext)
 
        text_span = self.COL_COUNT - 1
 
        for threshold, balance in zip(self.age_thresholds, self.age_balances):
 
            years, days = divmod(threshold, 365)
 
            years_text = f"{years} {'Year' if years == 1 else 'Years'}"
 
            days_text = f"{days} Days"
 
            if years and days:
 
                age_text = f"{years_text} {days_text}"
 
            elif years:
 
                age_text = years_text
 
            else:
 
                age_text = days_text
 
            self.add_row(
 
                self.string_cell(
 
                    f"Total Aged Over {age_text}: ",
 
                    stylename=text_style,
 
                    numbercolumnsspanned=text_span,
 
                ),
 
                *(odf.table.TableCell() for _ in range(1, text_span)),
 
                self.balance_cell(balance),
 
            )
 

	
 
    def _link_seq(self, row: AccrualPostings, key: MetaKey) -> Iterator[Tuple[str, str]]:
 
        for href in row.all_meta_links(key):
 
            text: Optional[str] = None
 
            rt_ids = self.rt_wrapper.parse(href)
 
            if rt_ids is not None:
 
                ticket_id, attachment_id = rt_ids
 
                if attachment_id is None:
 
                    text = f'RT#{ticket_id}'
 
                href = self.rt_wrapper.url(ticket_id, attachment_id) or href
 
            else:
 
                # '..' pops the ODS filename off the link path. In other words,
 
                # make the link relative to the directory the ODS is in.
 
                href = f'../{href}'
 
            if text is None:
 
                href_path = Path(urlparse.urlparse(href).path)
 
                text = urlparse.unquote(href_path.name)
 
            yield (href, text)
 

	
 
    def write_row(self, row: AccrualPostings) -> None:
 
        age = (self.date - row[0].meta.date).days
 
        if row.end_balance.ge_zero():
 
            for index, threshold in enumerate(self.age_thresholds):
 
                if age >= threshold:
 
                    self.age_balances[index] += row.end_balance
 
                    break
 
            else:
 
                return
 
        raw_balance = row.balance()
 
        if row.accrual_type is not None:
 
            raw_balance = row.accrual_type.value.norm_func(raw_balance)
 
            raw_balance = row.accrual_type.normalize_amount(raw_balance)
 
        if raw_balance == row.end_balance:
 
            amount_cell = odf.table.TableCell()
 
        else:
 
            amount_cell = self.balance_cell(raw_balance)
 
        self.add_row(
 
            self.date_cell(row[0].meta.date),
 
            self.multiline_cell(row.entities()),
 
            amount_cell,
 
            self.balance_cell(row.end_balance),
 
            self.multilink_cell(self._link_seq(row, 'rt-id')),
 
            self.multilink_cell(self._link_seq(row, 'invoice')),
 
        )
 

	
 

	
 
class AgingReport(BaseReport):
 
    def __init__(self,
 
                 rt_client: rt.Rt,
 
                 out_file: BinaryIO,
 
                 date: Optional[datetime.date]=None,
 
    ) -> None:
 
        if date is None:
 
            date = datetime.date.today()
 
        self.out_bin = out_file
 
        self.logger = logger.getChild(type(self).__name__)
 
        self.ods = AgingODS(rt_client, date, self.logger)
 

	
 
    def run(self, groups: PostGroups) -> None:
 
        rows = list(group for group in groups.values() if not group.is_zero())
 
        rows.sort(key=lambda related: (
 
            related.account,
 
            related[0].meta.date,
 
            min(related.entities()) if related.accrued_entities else '',
 
        ))
 
        self.ods.write(rows)
 
        self.ods.save_file(self.out_bin)
 

	
 

	
 
class BalanceReport(BaseReport):
 
    def _report(self, posts: AccrualPostings, index: int) -> Iterable[str]:
 
        posts = posts.since_last_nonzero()
 
        date_s = posts[0].meta.date.strftime('%Y-%m-%d')
 
        if index:
 
            yield ""
 
        yield f"{posts.invoice}:"
 
        yield f"  {posts.balance_at_cost()} outstanding since {date_s}"
 

	
 

	
 
class OutgoingReport(BaseReport):
 
    def __init__(self, rt_client: rt.Rt, out_file: TextIO) -> None:
 
        super().__init__(out_file)
 
        self.rt_client = rt_client
 
        self.rt_wrapper = rtutil.RT(rt_client)
 

	
 
    def _primary_rt_id(self, posts: AccrualPostings) -> rtutil.TicketAttachmentIds:
 
        rt_ids: Set[str] = set()
 
        for post in posts:
 
            try:
 
                rt_ids.add(post.meta.get_links('rt-id')[0])
 
            except (IndexError, TypeError):
 
                pass
 
        rt_ids_count = len(rt_ids)
 
        if rt_ids_count != 1:
 
            raise ValueError(f"{rt_ids_count} rt-id links found")
 
        parsed = rtutil.RT.parse(rt_ids.pop())
 
        if parsed is None:
 
            raise ValueError("rt-id is not a valid RT reference")
 
        else:
 
            return parsed
 

	
 
    def _report(self, posts: AccrualPostings, index: int) -> Iterable[str]:
 
        posts = posts.since_last_nonzero()
 
        try:
 
            ticket_id, _ = self._primary_rt_id(posts)
 
            ticket = self.rt_client.get_ticket(ticket_id)
 
            # Note we only use this when ticket is None.
 
            errmsg = f"ticket {ticket_id} not found"
 
        except (ValueError, rt.RtError) as error:
 
            ticket = None
 
            errmsg = error.args[0]
 
        if ticket is None:
 
            self.logger.error(
 
                "can't generate outgoings report for %s because no RT ticket available: %s",
 
                posts.invoice, errmsg,
 
            )
 
            return
 

	
 
        try:
 
            rt_requestor = self.rt_client.get_user(ticket['Requestors'][0])
 
        except (IndexError, rt.RtError):
 
            rt_requestor = None
 
        if rt_requestor is None:
 
            requestor = ''
 
            requestor_name = ''
 
        else:
 
            requestor_name = (
 
                rt_requestor.get('RealName')
conservancy_beancount/reports/core.py
Show inline comments
 
"""core.py - Common data classes for reporting functionality"""
 
# Copyright © 2020  Brett Smith
 
#
 
# This program is free software: you can redistribute it and/or modify
 
# it under the terms of the GNU Affero General Public License as published by
 
# the Free Software Foundation, either version 3 of the License, or
 
# (at your option) any later version.
 
#
 
# This program is distributed in the hope that it will be useful,
 
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
# GNU Affero General Public License for more details.
 
#
 
# You should have received a copy of the GNU Affero General Public License
 
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
 

	
 
import abc
 
import collections
 
import datetime
 
import itertools
 
import operator
 
import re
 

	
 
import babel.core  # type:ignore[import]
 
import babel.numbers  # type:ignore[import]
 

	
 
import odf.config  # type:ignore[import]
 
import odf.element  # type:ignore[import]
 
import odf.number  # type:ignore[import]
 
import odf.opendocument  # type:ignore[import]
 
import odf.style  # type:ignore[import]
 
import odf.table  # type:ignore[import]
 
import odf.text  # type:ignore[import]
 

	
 
from decimal import Decimal
 
from pathlib import Path
 

	
 
from beancount.core import amount as bc_amount
 

	
 
from .. import data
 

	
 
from typing import (
 
    cast,
 
    overload,
 
    Any,
 
    BinaryIO,
 
    Callable,
 
    DefaultDict,
 
    Dict,
 
    Generic,
 
    Iterable,
 
    Iterator,
 
    List,
 
    Mapping,
 
    MutableMapping,
 
    Optional,
 
    Sequence,
 
    Set,
 
    Tuple,
 
    Type,
 
    TypeVar,
 
    Union,
 
)
 
from ..beancount_types import (
 
    MetaKey,
 
    MetaValue,
 
)
 

	
 
DecimalCompat = data.DecimalCompat
 
BalanceType = TypeVar('BalanceType', bound='Balance')
 
ElementType = Callable[..., odf.element.Element]
 
LinkType = Union[str, Tuple[str, Optional[str]]]
 
RelatedType = TypeVar('RelatedType', bound='RelatedPostings')
 
RT = TypeVar('RT', bound=Sequence)
 
ST = TypeVar('ST')
 
T = TypeVar('T')
 

	
 
class Balance(Mapping[str, data.Amount]):
 
    """A collection of amounts mapped by currency
 

	
 
    Each key is a Beancount currency string, and each value represents the
 
    balance in that currency.
 
    """
 
    __slots__ = ('_currency_map', 'tolerance')
 
    TOLERANCE = Decimal('0.01')
 

	
 
    def __init__(self,
 
                 source: Iterable[data.Amount]=(),
 
                 tolerance: Optional[Decimal]=None,
 
    ) -> None:
 
        if tolerance is None:
 
            tolerance = self.TOLERANCE
 
        self._currency_map = {amount.currency: amount for amount in source}
 
        self.tolerance = tolerance
 

	
 
    def _add_amount(self,
 
                    currency_map: MutableMapping[str, data.Amount],
 
                    amount: data.Amount,
 
    ) -> None:
 
        code = amount.currency
 
        try:
 
            current_number = currency_map[code].number
 
        except KeyError:
 
            current_number = Decimal(0)
 
        currency_map[code] = data.Amount(current_number + amount.number, code)
 

	
 
    def _add_other(self,
 
                   currency_map: MutableMapping[str, data.Amount],
 
                   other: Union[data.Amount, 'Balance'],
 
    ) -> None:
 
        if isinstance(other, Balance):
 
            for amount in other.values():
 
                self._add_amount(currency_map, amount)
 
        else:
 
            self._add_amount(currency_map, other)
 

	
 
    def __repr__(self) -> str:
 
        values = [repr(amt) for amt in self.values()]
 
        return f"{type(self).__name__}({values!r})"
 

	
 
    def __str__(self) -> str:
 
        return self.format()
 

	
 
    def __abs__(self: BalanceType) -> BalanceType:
 
        return type(self)(bc_amount.abs(amt) for amt in self.values())
 

	
 
    def __add__(self: BalanceType, other: Union[data.Amount, 'Balance']) -> BalanceType:
 
        retval_map = self._currency_map.copy()
 
        self._add_other(retval_map, other)
 
        return type(self)(retval_map.values())
 

	
 
    def __eq__(self, other: Any) -> bool:
 
        if (self.is_zero()
 
            and isinstance(other, Balance)
 
            and other.is_zero()):
 
            return True
 
        else:
 
            return super().__eq__(other)
 

	
 
    def __neg__(self: BalanceType) -> BalanceType:
 
        return type(self)(-amt for amt in self.values())
 

	
 
    def __getitem__(self, key: str) -> data.Amount:
 
        return self._currency_map[key]
 

	
 
    def __iter__(self) -> Iterator[str]:
 
        return iter(self._currency_map)
 

	
 
    def __len__(self) -> int:
 
        return len(self._currency_map)
 

	
 
    def _all_amounts(self,
 
                     op_func: Callable[[DecimalCompat, DecimalCompat], bool],
 
                     operand: DecimalCompat,
 
    ) -> bool:
 
        return all(op_func(amt.number, operand) for amt in self.values())
 

	
 
    @staticmethod
 
    def within_tolerance(dec: DecimalCompat, tolerance: DecimalCompat) -> bool:
 
        dec = cast(Decimal, dec)
 
        return abs(dec) < tolerance
 

	
 
    def eq_zero(self) -> bool:
 
        """Returns true if all amounts in the balance == 0, within tolerance."""
 
        return self._all_amounts(self.within_tolerance, self.tolerance)
 

	
 
    is_zero = eq_zero
 

	
 
    def ge_zero(self) -> bool:
 
        """Returns true if all amounts in the balance >= 0, within tolerance."""
 
        op_func = operator.gt if self.tolerance else operator.ge
 
        return self._all_amounts(op_func, -self.tolerance)
...
 
@@ -805,96 +806,114 @@ class BaseODS(BaseSpreadsheet[RT, ST], metaclass=abc.ABCMeta):
 

	
 
    def add_row(self, *cells: odf.table.TableCell, **attrs: Any) -> odf.table.TableRow:
 
        row = odf.table.TableRow(**attrs)
 
        for cell in cells:
 
            row.addElement(cell)
 
        self.sheet.addElement(row)
 
        return row
 

	
 
    def balance_cell(self, balance: Balance, **attrs: Any) -> odf.table.TableCell:
 
        if balance.is_zero():
 
            return self.float_cell(0, **attrs)
 
        elif len(balance) == 1:
 
            amount = next(iter(balance.values()))
 
            attrs['stylename'] = self.merge_styles(
 
                attrs.get('stylename'), self.currency_style(amount.currency),
 
            )
 
            return self.currency_cell(amount, **attrs)
 
        else:
 
            lines = [babel.numbers.format_currency(
 
                number, currency, locale=self.locale, format_type=self.currency_fmt_key,
 
            ) for number, currency in balance.values()]
 
            attrs['stylename'] = self.merge_styles(
 
                attrs.get('stylename'), self.style_endtext,
 
            )
 
            return self.multiline_cell(lines, **attrs)
 

	
 
    def currency_cell(self, amount: data.Amount, **attrs: Any) -> odf.table.TableCell:
 
        number, currency = amount
 
        cell = odf.table.TableCell(valuetype='currency', value=number, **attrs)
 
        cell.addElement(odf.text.P(text=babel.numbers.format_currency(
 
            number, currency, locale=self.locale, format_type=self.currency_fmt_key,
 
        )))
 
        return cell
 

	
 
    def date_cell(self, date: datetime.date, **attrs: Any) -> odf.table.TableCell:
 
        attrs.setdefault('stylename', self.style_date)
 
        cell = odf.table.TableCell(valuetype='date', datevalue=date, **attrs)
 
        cell.addElement(odf.text.P(text=date.isoformat()))
 
        return cell
 

	
 
    def float_cell(self, value: Union[int, float, Decimal], **attrs: Any) -> odf.table.TableCell:
 
        cell = odf.table.TableCell(valuetype='float', value=value, **attrs)
 
        cell.addElement(odf.text.P(text=str(value)))
 
        return cell
 

	
 
    def multiline_cell(self, lines: Iterable[Any], **attrs: Any) -> odf.table.TableCell:
 
        cell = odf.table.TableCell(valuetype='string', **attrs)
 
        for line in lines:
 
            cell.addElement(odf.text.P(text=str(line)))
 
        return cell
 

	
 
    def multilink_cell(self, links: Iterable[LinkType], **attrs: Any) -> odf.table.TableCell:
 
        cell = odf.table.TableCell(valuetype='string', **attrs)
 
        for link in links:
 
            if isinstance(link, tuple):
 
                href, text = link
 
            else:
 
                href = link
 
                text = None
 
            cell.addElement(odf.text.P())
 
            cell.lastChild.addElement(odf.text.A(
 
                type='simple', href=href, text=text,
 
            ))
 
        return cell
 

	
 
    def string_cell(self, text: str, **attrs: Any) -> odf.table.TableCell:
 
        cell = odf.table.TableCell(valuetype='string', **attrs)
 
        cell.addElement(odf.text.P(text=text))
 
        return cell
 

	
 
    def write_row(self, row: RT) -> None:
 
        """Write a single row of input data to the spreadsheet
 

	
 
        This default implementation adds a single row to the spreadsheet,
 
        with one cell per element of the row. The type of each element
 
        determines what kind of cell is created.
 

	
 
        This implementation will help get you started, but you'll probably
 
        want to override it to specify styles.
 
        """
 
        out_row = odf.table.TableRow()
 
        for cell_source in row:
 
            if isinstance(cell_source, (int, float, Decimal)):
 
                cell = self.float_cell(cell_source)
 
            else:
 
                cell = self.string_cell(cell_source)
 
            out_row.addElement(cell)
 
        self.sheet.addElement(out_row)
 

	
 
    def save_file(self, out_file: BinaryIO) -> None:
 
        self.document.write(out_file)
 

	
 
    def save_path(self, path: Path, mode: str='w') -> None:
 
        with path.open(f'{mode}b') as out_file:
 
            out_file = cast(BinaryIO, out_file)
 
            self.save_file(out_file)
 

	
 

	
 
def normalize_amount_func(account_name: str) -> Callable[[T], T]:
 
    """Get a function to normalize amounts for reporting
 

	
 
    Given an account name, return a function that can be used on "amounts"
 
    under that account (including numbers, Amount objects, and Balance objects)
 
    to normalize them for reporting. Right now that means make flipping the
 
    sign for accounts where "normal" postings are negative.
 
    """
 
    if account_name.startswith(('Assets:', 'Expenses:')):
 
        # We can't just return operator.pos because Beancount's Amount class
 
        # doesn't implement __pos__.
 
        return lambda amt: amt
 
    elif account_name.startswith(('Equity:', 'Income:', 'Liabilities:')):
 
        return operator.neg
 
    else:
 
        raise ValueError(f"unrecognized account name {account_name!r}")
tests/test_reports_core.py
Show inline comments
 
new file 100644
 
"""test_reports_core - Unit tests for basic reports functions"""
 
# Copyright © 2020  Brett Smith
 
#
 
# This program is free software: you can redistribute it and/or modify
 
# it under the terms of the GNU Affero General Public License as published by
 
# the Free Software Foundation, either version 3 of the License, or
 
# (at your option) any later version.
 
#
 
# This program is distributed in the hope that it will be useful,
 
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
# GNU Affero General Public License for more details.
 
#
 
# You should have received a copy of the GNU Affero General Public License
 
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
 

	
 
import pytest
 

	
 
from decimal import Decimal
 

	
 
from . import testutil
 

	
 
from conservancy_beancount.reports import core
 

	
 
AMOUNTS = [
 
    2,
 
    Decimal('4.40'),
 
    testutil.Amount('6.60', 'CHF'),
 
    core.Balance([testutil.Amount('8.80')]),
 
]
 

	
 
@pytest.mark.parametrize('acct_name', [
 
    'Assets:Checking',
 
    'Assets:Receivable:Accounts',
 
    'Expenses:Other',
 
    'Expenses:FilingFees',
 
])
 
def test_normalize_amount_func_pos(acct_name):
 
    actual = core.normalize_amount_func(acct_name)
 
    for amount in AMOUNTS:
 
        assert actual(amount) == amount
 

	
 
@pytest.mark.parametrize('acct_name', [
 
    'Equity:Funds:Restricted',
 
    'Equity:Realized:CurrencyConversion',
 
    'Income:Donations',
 
    'Income:Other',
 
    'Liabilities:CreditCard',
 
    'Liabilities:Payable:Accounts',
 
])
 
def test_normalize_amount_func_neg(acct_name):
 
    actual = core.normalize_amount_func(acct_name)
 
    for amount in AMOUNTS:
 
        assert actual(amount) == -amount
 

	
 
@pytest.mark.parametrize('acct_name', [
 
    '',
 
    'Assets',
 
    'Equity',
 
    'Expenses',
 
    'Income',
 
    'Liabilities',
 
])
 
def test_normalize_amount_func_bad_acct_name(acct_name):
 
    with pytest.raises(ValueError):
 
        core.normalize_amount_func(acct_name)
0 comments (0 inline, 0 general)