Changeset - f76fa35fad2a
[Not reviewed]
0 3 0
Brett Smith - 4 years ago 2020-06-11 14:46:06
brettcsmith@brettcsmith.org
reports: RelatedPostings.all_meta_links() returns an iterator.

This preserves order.
3 files changed with 21 insertions and 8 deletions:
0 comments (0 inline, 0 general)
conservancy_beancount/reports/accrual.py
Show inline comments
 
#!/usr/bin/env python3
 
"""accrual-report - Status reports for accruals
 

	
 
accrual-report checks accruals (postings under Assets:Receivable and
 
Liabilities:Payable) for errors and metadata consistency, and reports any
 
problems on stderr. Then it writes a report about the status of those
 
accruals.
 

	
 
If you run it with no arguments, it will generate an aging report in ODS format
 
in the current directory.
 

	
 
Otherwise, the typical way to run it is to pass an RT ticket number or
 
invoice link as an argument, to report about accruals that match those
 
criteria::
 

	
 
    # Report all accruals associated with RT#1230:
 
    accrual-report 1230
 
    # Report all accruals with the invoice link rt:45/670.
 
    accrual-report 45/670
 
    # Report all accruals with the invoice link Invoice980.pdf.
 
    accrual-report Invoice980.pdf
 

	
 
By default, to stay fast, accrual-report only looks for postings from the
 
beginning of the last fiscal year. You can search further back in history
 
by passing the ``--since`` argument. The argument can be a fiscal year, or
 
a negative number of how many years back to search::
 

	
 
    # Search for accruals since 2016
 
    accrual-report --since 2016 [search terms …]
 
    # Search for accruals from the beginning of three fiscal years ago
 
    accrual-report --since -3 [search terms …]
 

	
 
If you want to further limit what accruals are reported, you can match on
 
other metadata by passing additional arguments in ``name=value`` format.
 
You can pass any number of search terms. For example::
 

	
 
    # Report accruals associated with RT#1230 and Jane Doe
 
    accrual-report 1230 entity=Doe-Jane
 

	
 
accrual-report will automatically decide what kind of report to generate
 
from the search terms you provide and the results they return. If you pass
 
no search terms, it generates an aging report. If your search terms match a
 
single outstanding payable, it writes an outgoing approval report.
 
Otherwise, it writes a basic balance report. You can specify what report
 
type you want with the ``--report-type`` option::
 

	
 
    # Write an outgoing approval report for all outstanding accruals for
 
    # Jane Doe, even if there's more than one
 
    accrual-report --report-type outgoing entity=Doe-Jane
 
    # Write an aging report for a specific project
 
    accrual-report --report-type aging project=ProjectName
 
"""
 
# Copyright © 2020  Brett Smith
 
#
 
# This program is free software: you can redistribute it and/or modify
 
# it under the terms of the GNU Affero General Public License as published by
 
# the Free Software Foundation, either version 3 of the License, or
 
# (at your option) any later version.
 
#
 
# This program is distributed in the hope that it will be useful,
 
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
# GNU Affero General Public License for more details.
 
#
 
# You should have received a copy of the GNU Affero General Public License
 
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
 

	
 
import argparse
 
import collections
 
import datetime
 
import enum
 
import logging
 
import re
 
import sys
 
import urllib.parse as urlparse
 

	
 
from pathlib import Path
 

	
 
from typing import (
 
    cast,
 
    Any,
 
    BinaryIO,
 
    Callable,
 
    Iterable,
 
    Iterator,
 
    List,
 
    Mapping,
 
    NamedTuple,
 
    Optional,
 
    Sequence,
 
    Set,
 
    TextIO,
 
    Tuple,
 
    TypeVar,
 
    Union,
 
)
 
from ..beancount_types import (
 
    Entries,
 
    Error,
 
    Errors,
 
    MetaKey,
 
    MetaValue,
 
    Transaction,
 
)
 

	
 
import odf.style  # type:ignore[import]
 
import odf.table  # type:ignore[import]
 
import rt
 

	
 
from beancount.parser import printer as bc_printer
 

	
 
from . import core
 
from .. import books
 
from .. import cliutil
 
from .. import config as configmod
 
from .. import data
 
from .. import filters
 
from .. import rtutil
 

	
 
PROGNAME = 'accrual-report'
 

	
 
CompoundAmount = TypeVar('CompoundAmount', data.Amount, core.Balance)
 
PostGroups = Mapping[Optional[MetaValue], 'AccrualPostings']
 
RTObject = Mapping[str, str]
 
T = TypeVar('T')
 

	
 
logger = logging.getLogger('conservancy_beancount.reports.accrual')
 

	
 
class Sentinel:
 
    pass
 

	
 

	
 
class Account(NamedTuple):
 
    name: str
 
    aging_thresholds: Sequence[int]
 

	
 

	
 
class AccrualAccount(enum.Enum):
 
    # Note the aging report uses the same order accounts are defined here.
 
    # See AgingODS.start_spreadsheet().
 
    RECEIVABLE = Account('Assets:Receivable', [365, 120, 90, 60])
 
    PAYABLE = Account('Liabilities:Payable', [365, 90, 60, 30])
 

	
 
    @classmethod
 
    def account_names(cls) -> Iterator[str]:
 
        return (acct.value.name for acct in cls)
 

	
 
    @classmethod
 
    def by_account(cls, name: data.Account) -> 'AccrualAccount':
 
        for account in cls:
 
            if name.is_under(account.value.name):
 
                return account
 
        raise ValueError(f"unrecognized account {name!r}")
 

	
 
    @classmethod
 
    def classify(cls, related: core.RelatedPostings) -> 'AccrualAccount':
 
        for account in cls:
 
            account_name = account.value.name
 
            if all(post.account.is_under(account_name) for post in related):
 
                return account
 
        raise ValueError("unrecognized account set in related postings")
 

	
 
    @property
 
    def normalize_amount(self) -> Callable[[T], T]:
 
        return core.normalize_amount_func(self.value.name)
 

	
 

	
 
class AccrualPostings(core.RelatedPostings):
 
    __slots__ = (
 
        'accrual_type',
 
        'end_balance',
 
        'account',
 
        'entity',
 
        'invoice',
 
    )
 
    INCONSISTENT = Sentinel()
 

	
 
    def __init__(self,
 
                 source: Iterable[data.Posting]=(),
 
                 *,
 
                 _can_own: bool=False,
 
    ) -> None:
 
        super().__init__(source, _can_own=_can_own)
 
        # The following type declarations tell mypy about values set in the for
 
        # loop that are important enough to be referenced directly elsewhere.
 
        self.account = self._single_item(post.account for post in self)
 
        if isinstance(self.account, Sentinel):
 
            self.accrual_type: Optional[AccrualAccount] = None
 
            norm_func: Callable[[T], T] = lambda x: x
 
            entity_pred: Callable[[data.Posting], bool] = bool
 
        else:
 
            self.accrual_type = AccrualAccount.by_account(self.account)
 
            norm_func = self.accrual_type.normalize_amount
 
            entity_pred = lambda post: norm_func(post.units).number > 0
 
        self.entity = self._single_item(self.entities(entity_pred))
 
        self.invoice = self._single_item(self.first_links('invoice'))
 
        self.end_balance = norm_func(self.balance_at_cost())
 

	
 
    def _single_item(self, seq: Iterable[T]) -> Union[T, Sentinel]:
 
        items = iter(seq)
 
        try:
 
            item1 = next(items)
 
        except StopIteration:
 
            all_same = False
 
        else:
 
            all_same = all(item == item1 for item in items)
 
        return item1 if all_same else self.INCONSISTENT
 

	
 
    def entities(self, pred: Callable[[data.Posting], bool]=bool) -> Iterator[MetaValue]:
 
        return filters.iter_unique(
 
            post.meta['entity']
 
            for post in self
 
            if pred(post) and 'entity' in post.meta
 
        )
 

	
 
    def first_links(self, key: MetaKey, default: Optional[str]=None) -> Iterator[Optional[str]]:
 
        return (post.meta.first_link(key, default) for post in self)
 

	
 
    def make_consistent(self) -> Iterator[Tuple[MetaValue, 'AccrualPostings']]:
 
        account_ok = isinstance(self.account, str)
 
        entity_ok = isinstance(self.entity, str)
 
        # `'/' in self.invoice` is just our heuristic to ensure that the
 
        # invoice metadata is "unique enough," and not just a placeholder
 
        # value like "FIXME". It can be refined if needed.
 
        invoice_ok = isinstance(self.invoice, str) and '/' in self.invoice
 
        if account_ok and entity_ok and invoice_ok:
 
            yield (self.invoice, self)
 
            return
 
        groups = collections.defaultdict(list)
 
        for post in self:
 
            post_invoice = self.invoice if invoice_ok else (
 
                post.meta.get('invoice') or 'BlankInvoice'
 
            )
 
            post_entity = self.entity if entity_ok else (
 
                post.meta.get('entity') or 'BlankEntity'
 
            )
 
            groups[f'{post.account} {post_invoice} {post_entity}'].append(post)
 
        type_self = type(self)
 
        for group_key, posts in groups.items():
 
            yield group_key, type_self(posts, _can_own=True)
 

	
 
    def is_paid(self, default: Optional[bool]=None) -> Optional[bool]:
 
        if self.accrual_type is None:
 
            return default
 
        else:
 
            return self.end_balance.le_zero()
 

	
 
    def is_zero(self, default: Optional[bool]=None) -> Optional[bool]:
 
        if self.accrual_type is None:
 
            return default
 
        else:
 
            return self.end_balance.is_zero()
 

	
 
    def since_last_nonzero(self) -> 'AccrualPostings':
 
        for index, (post, balance) in enumerate(self.iter_with_balance()):
 
            if balance.is_zero():
 
                start_index = index
 
        try:
 
            empty = start_index == index
 
        except NameError:
 
            empty = True
 
        return self if empty else self[start_index + 1:]
 

	
 

	
 
class BaseReport:
 
    def __init__(self, out_file: TextIO) -> None:
 
        self.out_file = out_file
 
        self.logger = logger.getChild(type(self).__name__)
 

	
 
    def _report(self, posts: AccrualPostings, index: int) -> Iterable[str]:
 
        raise NotImplementedError("BaseReport._report")
 

	
 
    def run(self, groups: PostGroups) -> None:
 
        for index, invoice in enumerate(groups):
 
            for line in self._report(groups[invoice], index):
 
                print(line, file=self.out_file)
 

	
 

	
 
class AgingODS(core.BaseODS[AccrualPostings, Optional[data.Account]]):
 
    COLUMNS = [
 
        'Date',
 
        'Entity',
 
        'Invoice Amount',
 
        'Booked Amount',
 
        'Project',
 
        'Ticket',
 
        'Invoice',
 
        'Approval',
 
        'Contract',
 
        'Purchase Order',
 
    ]
 
    COL_COUNT = len(COLUMNS)
 

	
 
    def __init__(self,
 
                 rt_client: rt.Rt,
 
                 date: datetime.date,
 
                 logger: logging.Logger,
 
    ) -> None:
 
        super().__init__()
 
        self.rt_client = rt_client
 
        self.rt_wrapper = rtutil.RT(self.rt_client)
 
        self.date = date
 
        self.logger = logger
 

	
 
    def init_styles(self) -> None:
 
        super().init_styles()
 
        self.style_widecol = self.replace_child(
 
            self.document.automaticstyles,
 
            odf.style.Style,
 
            name='WideCol',
 
        )
 
        self.style_widecol.setAttribute('family', 'table-column')
 
        self.style_widecol.addElement(odf.style.TableColumnProperties(
 
            columnwidth='1.25in',
 
        ))
 

	
 
    def section_key(self, row: AccrualPostings) -> Optional[data.Account]:
 
        if isinstance(row.account, str):
 
            return row.account
 
        else:
 
            return None
 

	
 
    def start_spreadsheet(self) -> None:
 
        for accrual_type in AccrualAccount:
 
            self.use_sheet(accrual_type.name.title())
 
            for index in range(self.COL_COUNT):
 
                stylename = self.style_widecol if index else ''
 
                self.sheet.addElement(odf.table.TableColumn(stylename=stylename))
 
            self.add_row(*(
 
                self.string_cell(name, stylename=self.style_bold)
 
                for name in self.COLUMNS
 
            ))
 
            self.lock_first_row()
 

	
 
    def start_section(self, key: Optional[data.Account]) -> None:
 
        if key is None:
 
            return
 
        self.age_thresholds = list(AccrualAccount.by_account(key).value.aging_thresholds)
 
        self.age_balances = [core.MutableBalance() for _ in self.age_thresholds]
 
        accrual_date = self.date - datetime.timedelta(days=self.age_thresholds[-1])
 
        acct_parts = key.slice_parts()
 
        self.use_sheet(acct_parts[1])
 
        self.add_row()
 
        self.add_row(self.string_cell(
 
            f"{' '.join(acct_parts[2:])} {acct_parts[1]} Aging Report"
 
            f" Accrued by {accrual_date.isoformat()} Unpaid by {self.date.isoformat()}",
 
            stylename=self.merge_styles(self.style_bold, self.style_centertext),
 
            numbercolumnsspanned=self.COL_COUNT,
 
        ))
 
        self.add_row()
 

	
 
    def end_section(self, key: Optional[data.Account]) -> None:
 
        if key is None:
 
            return
 
        total_balance = core.MutableBalance()
 
        text_style = self.merge_styles(self.style_bold, self.style_endtext)
 
        text_span = 4
 
        last_age_text: Optional[str] = None
 
        self.add_row()
 
        for threshold, balance in zip(self.age_thresholds, self.age_balances):
 
            years, days = divmod(threshold, 365)
 
            years_text = f"{years} {'Year' if years == 1 else 'Years'}"
 
            days_text = f"{days} Days"
 
            if years and days:
 
                age_text = f"{years_text} {days_text}"
 
            elif years:
 
                age_text = years_text
 
            else:
 
                age_text = days_text
 
            if last_age_text is None:
 
                age_range = f"Over {age_text}"
 
            else:
 
                age_range = f"{age_text}–{last_age_text}"
 
            self.add_row(
 
                self.string_cell(
 
                    f"Total Aged {age_range}: ",
 
                    stylename=text_style,
 
                    numbercolumnsspanned=text_span,
 
                ),
 
                *(odf.table.TableCell() for _ in range(1, text_span)),
 
                self.balance_cell(balance),
 
            )
 
            last_age_text = age_text
 
            total_balance += balance
 
        self.add_row(
 
            self.string_cell(
 
                "Total Unpaid: ",
 
                stylename=text_style,
 
                numbercolumnsspanned=text_span,
 
            ),
 
            *(odf.table.TableCell() for _ in range(1, text_span)),
 
            self.balance_cell(total_balance),
 
        )
 

	
 
    def _link_seq(self, row: AccrualPostings, key: MetaKey) -> Iterator[Tuple[str, str]]:
 
        for href in row.all_meta_links(key):
 
            text: Optional[str] = None
 
            rt_ids = self.rt_wrapper.parse(href)
 
            if rt_ids is not None:
 
                ticket_id, attachment_id = rt_ids
 
                if attachment_id is None:
 
                    text = f'RT#{ticket_id}'
 
                href = self.rt_wrapper.url(ticket_id, attachment_id) or href
 
            else:
 
                # '..' pops the ODS filename off the link path. In other words,
 
                # make the link relative to the directory the ODS is in.
 
                href = f'../{href}'
 
            if text is None:
 
                href_path = Path(urlparse.urlparse(href).path)
 
                text = urlparse.unquote(href_path.name)
 
            yield (href, text)
 

	
 
    def write_row(self, row: AccrualPostings) -> None:
 
        age = (self.date - row[0].meta.date).days
 
        if row.end_balance.ge_zero():
 
            for index, threshold in enumerate(self.age_thresholds):
 
                if age >= threshold:
 
                    self.age_balances[index] += row.end_balance
 
                    break
 
            else:
 
                return
 
        raw_balance = row.balance()
 
        if row.accrual_type is not None:
 
            raw_balance = row.accrual_type.normalize_amount(raw_balance)
 
        if raw_balance == row.end_balance:
 
            amount_cell = odf.table.TableCell()
 
        else:
 
            amount_cell = self.balance_cell(raw_balance)
 
        projects = {post.meta.get('project') or None for post in row}
 
        projects.discard(None)
 
        self.add_row(
 
            self.date_cell(row[0].meta.date),
 
            self.multiline_cell(row.entities()),
 
            amount_cell,
 
            self.balance_cell(row.end_balance),
 
            self.multiline_cell(sorted(projects)),
 
            self.multilink_cell(self._link_seq(row, 'rt-id')),
 
            self.multilink_cell(self._link_seq(row, 'invoice')),
 
            self.multilink_cell(self._link_seq(row, 'approval')),
 
            self.multilink_cell(self._link_seq(row, 'contract')),
 
            self.multilink_cell(self._link_seq(row, 'purchase-order')),
 
        )
 

	
 

	
 
class AgingReport(BaseReport):
 
    def __init__(self,
 
                 rt_client: rt.Rt,
 
                 out_file: BinaryIO,
 
                 date: Optional[datetime.date]=None,
 
    ) -> None:
 
        if date is None:
 
            date = datetime.date.today()
 
        self.out_bin = out_file
 
        self.logger = logger.getChild(type(self).__name__)
 
        self.ods = AgingODS(rt_client, date, self.logger)
 

	
 
    def run(self, groups: PostGroups) -> None:
 
        rows: List[AccrualPostings] = []
 
        for group in groups.values():
 
            if group.is_zero():
 
                # Cheap optimization: don't slice and dice groups we're not
 
                # going to report anyway.
 
                continue
 
            elif group.accrual_type is None:
 
                group = group.since_last_nonzero()
 
            else:
 
                # Filter out new accruals after the report date.
 
                # e.g., cover the case that the same invoices has multiple
 
                # postings over time, and we don't want to report too-recent
 
                # ones.
 
                cutoff_date = self.ods.date - datetime.timedelta(
 
                    days=group.accrual_type.value.aging_thresholds[-1],
 
                )
 
                group = AccrualPostings(
 
                    post for post in group.since_last_nonzero()
 
                    if post.meta.date <= cutoff_date
 
                    or group.accrual_type.normalize_amount(post.units.number) < 0
 
                )
 
            if group and not group.is_zero():
 
                rows.append(group)
 
        rows.sort(key=lambda related: (
 
            related.account,
 
            related[0].meta.date,
 
            ('\0'.join(related.entities())
 
             if related.entity is related.INCONSISTENT
 
             else related.entity),
 
        ))
 
        self.ods.write(rows)
 
        self.ods.save_file(self.out_bin)
 

	
 

	
 
class BalanceReport(BaseReport):
 
    def _report(self, posts: AccrualPostings, index: int) -> Iterable[str]:
 
        posts = posts.since_last_nonzero()
 
        date_s = posts[0].meta.date.strftime('%Y-%m-%d')
 
        if index:
 
            yield ""
 
        yield f"{posts.invoice}:"
 
        yield f"  {posts.balance_at_cost()} outstanding since {date_s}"
 

	
 

	
 
class OutgoingReport(BaseReport):
 
    def __init__(self, rt_client: rt.Rt, out_file: TextIO) -> None:
 
        super().__init__(out_file)
 
        self.rt_client = rt_client
 
        self.rt_wrapper = rtutil.RT(rt_client)
 

	
 
    def _primary_rt_id(self, posts: AccrualPostings) -> rtutil.TicketAttachmentIds:
 
        rt_ids = {url for url in posts.first_links('rt-id') if url is not None}
 
        rt_ids_count = len(rt_ids)
 
        if rt_ids_count != 1:
 
            raise ValueError(f"{rt_ids_count} rt-id links found")
 
        parsed = rtutil.RT.parse(rt_ids.pop())
 
        if parsed is None:
 
            raise ValueError("rt-id is not a valid RT reference")
 
        else:
 
            return parsed
 

	
 
    def _report(self, posts: AccrualPostings, index: int) -> Iterable[str]:
 
        posts = posts.since_last_nonzero()
 
        try:
 
            ticket_id, _ = self._primary_rt_id(posts)
 
            ticket = self.rt_client.get_ticket(ticket_id)
 
            # Note we only use this when ticket is None.
 
            errmsg = f"ticket {ticket_id} not found"
 
        except (ValueError, rt.RtError) as error:
 
            ticket = None
 
            errmsg = error.args[0]
 
        if ticket is None:
 
            self.logger.error(
 
                "can't generate outgoings report for %s because no RT ticket available: %s",
 
                posts.invoice, errmsg,
 
            )
 
            return
 

	
 
        try:
 
            rt_requestor = self.rt_client.get_user(ticket['Requestors'][0])
 
        except (IndexError, rt.RtError):
 
            rt_requestor = None
 
        if rt_requestor is None:
 
            requestor = ''
 
            requestor_name = ''
 
        else:
 
            requestor_name = (
 
                rt_requestor.get('RealName')
 
                or ticket.get('CF.{payment-to}')
 
                or ''
 
            )
 
            requestor = f'{requestor_name} <{rt_requestor["EmailAddress"]}>'.strip()
 

	
 
        balance_s = posts.end_balance.format(None)
 
        raw_balance = -posts.balance()
 
        if raw_balance != posts.end_balance:
 
            balance_s = f'{raw_balance} ({balance_s})'
 

	
 
        contract_links = posts.all_meta_links('contract')
 
        contract_links = list(posts.all_meta_links('contract'))
 
        if contract_links:
 
            contract_s = ' , '.join(self.rt_wrapper.iter_urls(
 
                contract_links, missing_fmt='<BROKEN RT LINK: {}>',
 
            ))
 
        else:
 
            contract_s = "NO CONTRACT GOVERNS THIS TRANSACTION"
 
        projects = [v for v in posts.meta_values('project')
 
                    if isinstance(v, str)]
 

	
 
        yield "PAYMENT FOR APPROVAL:"
 
        yield f"REQUESTOR: {requestor}"
 
        yield f"TOTAL TO PAY: {balance_s}"
 
        yield f"AGREEMENT: {contract_s}"
 
        yield f"PAYMENT TO: {ticket.get('CF.{payment-to}') or requestor_name}"
 
        yield f"PAYMENT METHOD: {ticket.get('CF.{payment-method}', '')}"
 
        yield f"PROJECT: {', '.join(projects)}"
 
        yield "\nBEANCOUNT ENTRIES:\n"
 

	
 
        last_txn: Optional[Transaction] = None
 
        for post in posts:
 
            txn = post.meta.txn
 
            if txn is not last_txn:
 
                last_txn = txn
 
                txn = self.rt_wrapper.txn_with_urls(txn, '{}')
 
                yield bc_printer.format_entry(txn)
 

	
 

	
 
class ReportType(enum.Enum):
 
    AGING = AgingReport
 
    BALANCE = BalanceReport
 
    OUTGOING = OutgoingReport
 
    AGE = AGING
 
    BAL = BALANCE
 
    OUT = OUTGOING
 
    OUTGOINGS = OUTGOING
 

	
 
    @classmethod
 
    def by_name(cls, name: str) -> 'ReportType':
 
        try:
 
            return cls[name.upper()]
 
        except KeyError:
 
            raise ValueError(f"unknown report type {name!r}") from None
 

	
 
    @classmethod
 
    def default_for(cls, groups: PostGroups) -> 'ReportType':
 
        if len(groups) == 1 and all(
 
                group.accrual_type is AccrualAccount.PAYABLE
 
                and not group.is_paid()
 
                for group in groups.values()
 
        ):
 
            return cls.OUTGOING
 
        else:
 
            return cls.BALANCE
 

	
 

	
 
class ReturnFlag(enum.IntFlag):
 
    LOAD_ERRORS = 1
 
    REPORT_ERRORS = 4
 
    NOTHING_TO_REPORT = 8
 

	
 

	
 
def filter_search(postings: Iterable[data.Posting],
 
                  search_terms: Iterable[cliutil.SearchTerm],
 
) -> Iterable[data.Posting]:
 
    accounts = tuple(AccrualAccount.account_names())
 
    postings = (post for post in postings if post.account.is_under(*accounts))
 
    for query in search_terms:
 
        postings = query.filter_postings(postings)
 
    return postings
 

	
 
def parse_arguments(arglist: Optional[Sequence[str]]=None) -> argparse.Namespace:
 
    parser = argparse.ArgumentParser(prog=PROGNAME)
 
    cliutil.add_version_argument(parser)
 
    parser.add_argument(
 
        '--report-type', '-t',
 
        metavar='NAME',
 
        type=ReportType.by_name,
 
        help="""The type of report to generate, one of `aging`, `balance`, or
 
`outgoing`. If not specified, the default is `aging` when no search terms are
 
given, `outgoing` for search terms that return a single outstanding payable,
 
and `balance` any other time.
 
""")
 
    parser.add_argument(
 
        '--since',
 
        metavar='YEAR',
 
        type=int,
 
        default=-1,
 
        help="""How far back to search the books for related transactions.
 
You can either specify a fiscal year, or a negative offset from the current
 
fiscal year, to start loading entries from. The default is -1 (start from the
 
previous fiscal year).
 
""")
 
    parser.add_argument(
 
        '--output-file', '-O',
 
        metavar='PATH',
 
        type=Path,
 
        help="""Write the report to this file, or stdout when PATH is `-`.
 
The default is stdout for the balance and outgoing reports, and a generated
 
filename for other reports.
 
""")
 
    cliutil.add_loglevel_argument(parser)
 
    parser.add_argument(
 
        'search_terms',
 
        metavar='FILTER',
 
        type=cliutil.SearchTerm.arg_parser('invoice', 'rt-id'),
 
        nargs=argparse.ZERO_OR_MORE,
 
        help="""Report on accruals that match this criteria. The format is
 
NAME=TERM. TERM is a link or word that must exist in a posting's NAME
 
metadata to match. A single ticket number is a shortcut for
 
`rt-id=rt:NUMBER`. Any other link, including an RT attachment link in
 
`TIK/ATT` format, is a shortcut for `invoice=LINK`.
 
""")
 
    args = parser.parse_args(arglist)
 
    if args.report_type is None and not args.search_terms:
 
        args.report_type = ReportType.AGING
 
    return args
 

	
 
def main(arglist: Optional[Sequence[str]]=None,
 
         stdout: TextIO=sys.stdout,
 
         stderr: TextIO=sys.stderr,
 
         config: Optional[configmod.Config]=None,
 
) -> int:
 
    if cliutil.is_main_script(PROGNAME):
 
        global logger
 
        logger = logging.getLogger(PROGNAME)
 
        sys.excepthook = cliutil.ExceptHook(logger)
 
    args = parse_arguments(arglist)
 
    cliutil.setup_logger(logger, args.loglevel, stderr)
 
    if config is None:
 
        config = configmod.Config()
 
        config.load_file()
 

	
 
    books_loader = config.books_loader()
 
    if books_loader is None:
 
        entries, load_errors, _ = books.Loader.load_none(config.config_file_path())
 
    elif args.report_type is ReportType.AGING:
 
        entries, load_errors, _ = books_loader.load_all()
 
    else:
 
        entries, load_errors, _ = books_loader.load_all(args.since)
 
    filters.remove_opening_balance_txn(entries)
 

	
 
    returncode = 0
 
    postings = filter_search(data.Posting.from_entries(entries), args.search_terms)
 
    groups: PostGroups = dict(AccrualPostings.group_by_meta(postings, 'invoice'))
 
    for error in load_errors:
 
        bc_printer.print_error(error, file=stderr)
 
        returncode |= ReturnFlag.LOAD_ERRORS
 
    if not groups:
 
        logger.warning("no matching entries found to report")
 
        returncode |= ReturnFlag.NOTHING_TO_REPORT
 

	
 
    groups = {
 
        key: posts
 
        for source_posts in groups.values()
 
        for key, posts in source_posts.make_consistent()
 
    }
 
    if args.report_type is not ReportType.AGING:
 
        groups = {
 
            key: posts for key, posts in groups.items() if not posts.is_paid()
 
        } or groups
 

	
 
    if args.report_type is None:
 
        args.report_type = ReportType.default_for(groups)
 
    report: Optional[BaseReport] = None
 
    output_path: Optional[Path] = None
 
    if args.report_type is ReportType.AGING:
 
        rt_client = config.rt_client()
 
        if rt_client is None:
 
            logger.error("unable to generate aging report: RT client is required")
 
        else:
 
            now = datetime.datetime.now()
 
            if args.output_file is None:
 
                args.output_file = Path(now.strftime('AgingReport_%Y-%m-%d_%H:%M.ods'))
 
                logger.info("Writing report to %s", args.output_file)
 
            out_bin = cliutil.bytes_output(args.output_file, stdout)
 
            report = AgingReport(rt_client, out_bin)
 
    elif args.report_type is ReportType.OUTGOING:
 
        rt_client = config.rt_client()
 
        if rt_client is None:
 
            logger.error("unable to generate outgoing report: RT client is required")
 
        else:
 
            out_file = cliutil.text_output(args.output_file, stdout)
 
            report = OutgoingReport(rt_client, out_file)
 
    else:
 
        out_file = cliutil.text_output(args.output_file, stdout)
 
        report = args.report_type.value(out_file)
 

	
 
    if report is None:
 
        returncode |= ReturnFlag.REPORT_ERRORS
 
    else:
 
        report.run(groups)
 
    return 0 if returncode == 0 else 16 + returncode
 

	
 
if __name__ == '__main__':
 
    exit(main())
conservancy_beancount/reports/core.py
Show inline comments
 
"""core.py - Common data classes for reporting functionality"""
 
# Copyright © 2020  Brett Smith
 
#
 
# This program is free software: you can redistribute it and/or modify
 
# it under the terms of the GNU Affero General Public License as published by
 
# the Free Software Foundation, either version 3 of the License, or
 
# (at your option) any later version.
 
#
 
# This program is distributed in the hope that it will be useful,
 
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
# GNU Affero General Public License for more details.
 
#
 
# You should have received a copy of the GNU Affero General Public License
 
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
 

	
 
import abc
 
import collections
 
import datetime
 
import itertools
 
import operator
 
import re
 

	
 
import babel.core  # type:ignore[import]
 
import babel.numbers  # type:ignore[import]
 

	
 
import odf.config  # type:ignore[import]
 
import odf.element  # type:ignore[import]
 
import odf.number  # type:ignore[import]
 
import odf.opendocument  # type:ignore[import]
 
import odf.style  # type:ignore[import]
 
import odf.table  # type:ignore[import]
 
import odf.text  # type:ignore[import]
 

	
 
from decimal import Decimal
 
from pathlib import Path
 

	
 
from beancount.core import amount as bc_amount
 

	
 
from .. import data
 
from .. import filters
 

	
 
from typing import (
 
    cast,
 
    overload,
 
    Any,
 
    BinaryIO,
 
    Callable,
 
    DefaultDict,
 
    Dict,
 
    Generic,
 
    Iterable,
 
    Iterator,
 
    List,
 
    Mapping,
 
    MutableMapping,
 
    Optional,
 
    Sequence,
 
    Set,
 
    Tuple,
 
    Type,
 
    TypeVar,
 
    Union,
 
)
 
from ..beancount_types import (
 
    MetaKey,
 
    MetaValue,
 
)
 

	
 
DecimalCompat = data.DecimalCompat
 
BalanceType = TypeVar('BalanceType', bound='Balance')
 
ElementType = Callable[..., odf.element.Element]
 
LinkType = Union[str, Tuple[str, Optional[str]]]
 
RelatedType = TypeVar('RelatedType', bound='RelatedPostings')
 
RT = TypeVar('RT', bound=Sequence)
 
ST = TypeVar('ST')
 
T = TypeVar('T')
 

	
 
class Balance(Mapping[str, data.Amount]):
 
    """A collection of amounts mapped by currency
 

	
 
    Each key is a Beancount currency string, and each value represents the
 
    balance in that currency.
 
    """
 
    __slots__ = ('_currency_map', 'tolerance')
 
    TOLERANCE = Decimal('0.01')
 

	
 
    def __init__(self,
 
                 source: Iterable[data.Amount]=(),
 
                 tolerance: Optional[Decimal]=None,
 
    ) -> None:
 
        if tolerance is None:
 
            tolerance = self.TOLERANCE
 
        self.tolerance = tolerance
 
        self._currency_map: Dict[str, data.Amount] = {}
 
        for amount in source:
 
            self._add_amount(self._currency_map, amount)
 

	
 
    def _add_amount(self,
 
                    currency_map: MutableMapping[str, data.Amount],
 
                    amount: data.Amount,
 
    ) -> None:
 
        code = amount.currency
 
        try:
 
            current_number = currency_map[code].number
 
        except KeyError:
 
            current_number = Decimal(0)
 
        currency_map[code] = data.Amount(current_number + amount.number, code)
 

	
 
    def _add_other(self,
 
                   currency_map: MutableMapping[str, data.Amount],
 
                   other: Union[data.Amount, 'Balance'],
 
    ) -> None:
 
        if isinstance(other, Balance):
 
            for amount in other.values():
 
                self._add_amount(currency_map, amount)
 
        else:
 
            self._add_amount(currency_map, other)
 

	
 
    def __repr__(self) -> str:
 
        values = [repr(amt) for amt in self.values()]
 
        return f"{type(self).__name__}({values!r})"
 

	
 
    def __str__(self) -> str:
 
        return self.format()
 

	
 
    def __abs__(self: BalanceType) -> BalanceType:
 
        return type(self)(bc_amount.abs(amt) for amt in self.values())
 

	
 
    def __add__(self: BalanceType, other: Union[data.Amount, 'Balance']) -> BalanceType:
 
        retval_map = self._currency_map.copy()
 
        self._add_other(retval_map, other)
 
        return type(self)(retval_map.values())
 

	
 
    def __eq__(self, other: Any) -> bool:
 
        if isinstance(other, Balance):
 
            clean_self = self.clean_copy()
 
            clean_other = other.clean_copy()
 
            return len(clean_self) == len(clean_other) and all(
 
                clean_self[key] == clean_other.get(key) for key in clean_self
 
            )
 
        else:
 
            return super().__eq__(other)
 

	
 
    def __neg__(self: BalanceType) -> BalanceType:
 
        return type(self)(-amt for amt in self.values())
 

	
 
    def __pos__(self: BalanceType) -> BalanceType:
 
        return self
 

	
 
    def __getitem__(self, key: str) -> data.Amount:
 
        return self._currency_map[key]
 

	
 
    def __iter__(self) -> Iterator[str]:
 
        return iter(self._currency_map)
 

	
 
    def __len__(self) -> int:
 
        return len(self._currency_map)
 

	
 
    def _all_amounts(self,
 
                     op_func: Callable[[DecimalCompat, DecimalCompat], bool],
 
                     operand: DecimalCompat,
 
    ) -> bool:
 
        return all(op_func(amt.number, operand) for amt in self.values())
 

	
 
    def copy(self: BalanceType) -> BalanceType:
 
        return type(self)(self.values())
 

	
 
    def clean_copy(self: BalanceType, tolerance: Optional[Decimal]=None) -> BalanceType:
 
        if tolerance is None:
 
            tolerance = self.tolerance
 
        return type(self)(
 
            amount for amount in self.values()
 
            if abs(amount.number) >= tolerance
 
        )
 

	
 
    @staticmethod
 
    def within_tolerance(dec: DecimalCompat, tolerance: DecimalCompat) -> bool:
 
        dec = cast(Decimal, dec)
 
        return abs(dec) < tolerance
 

	
 
    def eq_zero(self) -> bool:
 
        """Returns true if all amounts in the balance == 0, within tolerance."""
 
        return self._all_amounts(self.within_tolerance, self.tolerance)
 

	
 
    is_zero = eq_zero
 

	
 
    def ge_zero(self) -> bool:
 
        """Returns true if all amounts in the balance >= 0, within tolerance."""
 
        op_func = operator.gt if self.tolerance else operator.ge
 
        return self._all_amounts(op_func, -self.tolerance)
 

	
 
    def le_zero(self) -> bool:
 
        """Returns true if all amounts in the balance <= 0, within tolerance."""
 
        op_func = operator.lt if self.tolerance else operator.le
 
        return self._all_amounts(op_func, self.tolerance)
 

	
 
    def format(self,
 
               fmt: Optional[str]='#,#00.00 ¤¤',
 
               sep: str=', ',
 
               empty: str="Zero balance",
 
               tolerance: Optional[Decimal]=None,
 
    ) -> str:
 
        """Formats the balance as a string with the given parameters
 

	
 
        If the balance is zero (within tolerance), returns ``empty``.
 
        Otherwise, returns a string with each amount in the balance formatted
 
        as ``fmt``, separated by ``sep``.
 

	
 
        If you set ``fmt`` to None, amounts will be formatted according to the
 
        user's locale. The default format is Beancount's input format.
 
        """
 
        amounts = list(self.clean_copy(tolerance).values())
 
        if not amounts:
 
            return empty
 
        amounts.sort(key=lambda amt: abs(amt.number), reverse=True)
 
        return sep.join(
 
            babel.numbers.format_currency(amt.number, amt.currency, fmt)
 
            for amt in amounts
 
        )
 

	
 

	
 
class MutableBalance(Balance):
 
    __slots__ = ()
 

	
 
    def __iadd__(self: BalanceType, other: Union[data.Amount, Balance]) -> BalanceType:
 
        self._add_other(self._currency_map, other)
 
        return self
 

	
 

	
 
class RelatedPostings(Sequence[data.Posting]):
 
    """Collect and query related postings
 

	
 
    This class provides common functionality for collecting related postings
 
    and running queries on them: iterating over them, tallying their balance,
 
    etc.
 

	
 
    This class doesn't know anything about how the postings are related. That's
 
    entirely up to the caller.
 

	
 
    A common pattern is to use this class with collections.defaultdict
 
    to organize postings based on some key. See the group_by_meta classmethod
 
    for an example.
 
    """
 
    __slots__ = ('_postings',)
 

	
 
    def __init__(self,
 
                 source: Iterable[data.Posting]=(),
 
                 *,
 
                 _can_own: bool=False,
 
    ) -> None:
 
        self._postings: List[data.Posting]
 
        if _can_own and isinstance(source, list):
 
            self._postings = source
 
        else:
 
            self._postings = list(source)
 

	
 
    @classmethod
 
    def group_by_meta(cls: Type[RelatedType],
 
                      postings: Iterable[data.Posting],
 
                      key: MetaKey,
 
                      default: Optional[MetaValue]=None,
 
    ) -> Iterator[Tuple[Optional[MetaValue], RelatedType]]:
 
        """Relate postings by metadata value
 

	
 
        This method takes an iterable of postings and returns a mapping.
 
        The keys of the mapping are the values of post.meta.get(key, default).
 
        The values are RelatedPostings instances that contain all the postings
 
        that had that same metadata value.
 
        """
 
        mapping: DefaultDict[Optional[MetaValue], List[data.Posting]] = collections.defaultdict(list)
 
        for post in postings:
 
            mapping[post.meta.get(key, default)].append(post)
 
        for value, posts in mapping.items():
 
            yield value, cls(posts, _can_own=True)
 

	
 
    def __repr__(self) -> str:
 
        return f'<{type(self).__name__} {self._postings!r}>'
 

	
 
    @overload
 
    def __getitem__(self: RelatedType, index: int) -> data.Posting: ...
 

	
 
    @overload
 
    def __getitem__(self: RelatedType, s: slice) -> RelatedType: ...
 

	
 
    def __getitem__(self: RelatedType,
 
                    index: Union[int, slice],
 
    ) -> Union[data.Posting, RelatedType]:
 
        if isinstance(index, slice):
 
            return type(self)(self._postings[index], _can_own=True)
 
        else:
 
            return self._postings[index]
 

	
 
    def __len__(self) -> int:
 
        return len(self._postings)
 

	
 
    def all_meta_links(self, key: MetaKey) -> Set[str]:
 
        retval: Set[str] = set()
 
    def _all_meta_links(self, key: MetaKey) -> Iterator[str]:
 
        for post in self:
 
            try:
 
                retval.update(post.meta.get_links(key))
 
                yield from post.meta.get_links(key)
 
            except TypeError:
 
                pass
 
        return retval
 

	
 
    def all_meta_links(self, key: MetaKey) -> Iterator[str]:
 
        return filters.iter_unique(self._all_meta_links(key))
 

	
 
    def iter_with_balance(self) -> Iterator[Tuple[data.Posting, Balance]]:
 
        balance = MutableBalance()
 
        for post in self:
 
            balance += post.units
 
            yield post, balance
 

	
 
    def balance(self) -> Balance:
 
        for _, balance in self.iter_with_balance():
 
            pass
 
        try:
 
            return balance
 
        except NameError:
 
            return Balance()
 

	
 
    def balance_at_cost(self) -> Balance:
 
        balance = MutableBalance()
 
        for post in self:
 
            if post.cost is None:
 
                balance += post.units
 
            else:
 
                number = post.units.number * post.cost.number
 
                balance += data.Amount(number, post.cost.currency)
 
        return balance
 

	
 
    def meta_values(self,
 
                    key: MetaKey,
 
                    default: Optional[MetaValue]=None,
 
    ) -> Set[Optional[MetaValue]]:
 
        return {post.meta.get(key, default) for post in self}
 

	
 

	
 
class BaseSpreadsheet(Generic[RT, ST], metaclass=abc.ABCMeta):
 
    """Abstract base class to help write spreadsheets
 

	
 
    This class provides the very core logic to write an arbitrary set of data
 
    rows to arbitrary output. It calls hooks when it starts writing the
 
    spreadsheet, starts a new "section" of rows, ends a section, and ends the
 
    spreadsheet.
 

	
 
    RT is the type of the input data rows. ST is the type of the section
 
    identifier that you create from each row. If you don't want to use the
 
    section logic at all, set ST to None and define section_key to return None.
 
    """
 

	
 
    @abc.abstractmethod
 
    def section_key(self, row: RT) -> ST:
 
        """Return the section a row belongs to
 

	
 
        Given a data row, this method should return some identifier for the
 
        "section" the row belongs to. The write method uses this to
 
        determine when to call start_section and end_section.
 

	
 
        If your spreadsheet doesn't need sections, define this to return None.
 
        """
 
        ...
 

	
 
    @abc.abstractmethod
 
    def write_row(self, row: RT) -> None:
 
        """Write a data row to the output spreadsheet
 

	
 
        This method is called once for each data row in the input.
 
        """
 
        ...
 

	
 
    # The next four methods are all called by the write method when the name
 
    # says. You may override them to output headers or sums, record
 
    # state, etc. The default implementations are all noops.
 

	
 
    def start_spreadsheet(self) -> None:
 
        pass
 

	
 
    def start_section(self, key: ST) -> None:
 
        pass
 

	
 
    def end_section(self, key: ST) -> None:
 
        pass
 

	
 
    def end_spreadsheet(self) -> None:
 
        pass
 

	
 
    def write(self, rows: Iterable[RT]) -> None:
 
        prev_section: Optional[ST] = None
 
        self.start_spreadsheet()
 
        for row in rows:
 
            section = self.section_key(row)
 
            if section != prev_section:
 
                if prev_section is not None:
 
                    self.end_section(prev_section)
 
                self.start_section(section)
 
                prev_section = section
 
            self.write_row(row)
 
        try:
 
            should_end = section is not None
 
        except NameError:
 
            should_end = False
 
        if should_end:
 
            self.end_section(section)
 
        self.end_spreadsheet()
 

	
 

	
 
class BaseODS(BaseSpreadsheet[RT, ST], metaclass=abc.ABCMeta):
 
    """Abstract base class to help write OpenDocument spreadsheets
 

	
 
    This class provides the very core logic to write an arbitrary set of data
 
    rows to an OpenDocument spreadsheet. It provides helper methods for
 
    building sheets, rows, and cells.
 

	
 
    See also the BaseSpreadsheet base class for additional documentation about
 
    methods you must and can define, the definition of RT and ST, etc.
 
    """
 
    def __init__(self) -> None:
 
        self.locale = babel.core.Locale.default('LC_MONETARY')
 
        self.currency_fmt_key = 'accounting'
 
        self._name_counter = itertools.count(1)
 
        self._currency_style_cache: MutableMapping[str, odf.style.Style] = {}
 
        self.document = odf.opendocument.OpenDocumentSpreadsheet()
 
        self.init_settings()
 
        self.init_styles()
 
        self.sheet = self.use_sheet("Report")
 

	
 
    ### Low-level document tree manipulation
 
    # The *intent* is that you only need to use these if you're adding new
 
    # methods to manipulate document settings or styles.
 

	
 
    def copy_element(self, elem: odf.element.Element) -> odf.element.Element:
 
        qattrs = dict(self.iter_qattributes(elem))
 
        retval = odf.element.Element(qname=elem.qname, qattributes=qattrs)
 
        try:
 
            orig_name = retval.getAttribute('name')
 
        except ValueError:
 
            orig_name = None
 
        if orig_name is not None:
 
            retval.setAttribute('name', f'{orig_name}{next(self._name_counter)}')
 
        return retval
 

	
 
    def ensure_child(self,
 
                     parent: odf.element.Element,
 
                     child_type: ElementType,
 
                     **kwargs: Any,
 
    ) -> odf.element.Element:
 
        new_child = child_type(**kwargs)
 
        found_child = self.find_child(parent, new_child)
 
        if found_child is None:
 
            parent.addElement(new_child)
 
            return parent.lastChild
 
        else:
 
            return found_child
 

	
 
    def ensure_config_map_entry(self,
 
                                root: odf.element.Element,
 
                                map_name: str,
 
                                entry_name: str,
 
    ) -> odf.element.Element:
 
        """Return a ``ConfigItemMapEntry`` under ``root``
 

	
 
        This method ensures there's a ``ConfigItemMapNamed`` named ``map_name``
 
        under ``root``, and a ``ConfigItemMapEntry`` named ``entry_name`` under
 
        that. Return the ``ConfigItemMapEntry`` element.
 
        """
 
        config_map = self.ensure_child(root, odf.config.ConfigItemMapNamed, name=map_name)
 
        return self.ensure_child(config_map, odf.config.ConfigItemMapEntry, name=entry_name)
 

	
 
    def find_child(self,
 
                   parent: odf.element.Element,
 
                   child: odf.element.Element,
 
    ) -> Optional[odf.element.Element]:
 
        attrs = {k: v for k, v in self.iter_attributes(child)}
 
        if not attrs:
 
            return None
 
        for elem in parent.childNodes:
 
            if (elem.qname == child.qname
 
                and all(elem.getAttribute(k) == v for k, v in attrs.items())):
 
                return elem
 
        return None
 

	
 
    def iter_attributes(self, elem: odf.element.Element) -> Iterator[Tuple[str, str]]:
 
        for (_, key), value in self.iter_qattributes(elem):
 
            yield key.lower().replace('-', ''), value
 

	
 
    def iter_qattributes(self, elem: odf.element.Element) -> Iterator[Tuple[Tuple[str, str], str]]:
 
        if elem.attributes:
 
            yield from elem.attributes.items()
 

	
 
    def replace_child(self,
 
                     parent: odf.element.Element,
 
                     child_type: ElementType,
 
                     **kwargs: Any,
 
    ) -> odf.element.Element:
 
        new_child = child_type(**kwargs)
 
        found_child = self.find_child(parent, new_child)
 
        parent.insertBefore(new_child, found_child)
 
        if found_child is not None:
 
            parent.removeChild(found_child)
 
        return new_child
 

	
 
    def set_config(self,
 
                   root: odf.element.Element,
 
                   name: str,
 
                   value: Union[bool, int, str],
 
                   config_type: Optional[str]=None,
 
    ) -> None:
 
        """Ensure ``root`` has a ``ConfigItem`` with the given name, type, and value"""
 
        value_s = str(value)
 
        if isinstance(value, bool):
 
            value_s = str(value).lower()
 
            default_type = 'boolean'
 
        elif isinstance(value, str):
 
            default_type = 'string'
 
        if config_type is None:
 
            try:
 
                config_type = default_type
 
            except NameError:
 
                raise ValueError(
 
                    f"need config_type for {type(value).__name__} value",
 
                ) from None
 
        item = self.replace_child(
 
            root, odf.config.ConfigItem, name=name, type=config_type,
 
        )
 
        item.addText(value_s)
 

	
 
    ### Styles
 

	
 
    def _build_currency_style(
 
            self,
 
            root: odf.element.Element,
 
            locale: babel.core.Locale,
 
            code: str,
 
            fmt_index: int,
 
            properties: Optional[odf.style.TextProperties]=None,
 
            *,
 
            fmt_key: Optional[str]=None,
 
            volatile: bool=False,
 
            minintegerdigits: int=1,
 
    ) -> odf.element.Element:
 
        if fmt_key is None:
 
            fmt_key = self.currency_fmt_key
 
        pattern = locale.currency_formats[fmt_key]
 
        fmts = pattern.pattern.split(';')
 
        try:
 
            fmt = fmts[fmt_index]
 
        except IndexError:
 
            fmt = fmts[0]
 
            grouping = pattern.grouping[0]
 
        else:
 
            grouping = pattern.grouping[fmt_index]
 
        zero_s = babel.numbers.format_currency(0, code, '##0.0', locale)
 
        try:
 
            decimal_index = zero_s.rindex('.') + 1
 
        except ValueError:
 
            decimalplaces = 0
 
        else:
 
            decimalplaces = len(zero_s) - decimal_index
 
        style = self.replace_child(
 
            root,
 
            odf.number.CurrencyStyle,
 
            name=f'{code}{next(self._name_counter)}',
 
        )
 
        style.setAttribute('volatile', 'true' if volatile else 'false')
 
        if properties is not None:
 
            style.addElement(properties)
 
        for part in re.split(r"(¤+|[#0,.]+|'[^']+')", fmt):
 
            if not part:
 
                pass
 
            elif not part.strip('#0,.'):
 
                style.addElement(odf.number.Number(
 
                    decimalplaces=str(decimalplaces),
 
                    grouping='true' if grouping else 'false',
 
                    minintegerdigits=str(minintegerdigits),
 
                ))
 
            elif part == '¤':
 
                style.addElement(odf.number.CurrencySymbol(
 
                    country=locale.territory,
 
                    language=locale.language,
 
                    text=babel.numbers.get_currency_symbol(code, locale),
 
                ))
 
            elif part == '¤¤':
 
                style.addElement(odf.number.Text(text=code))
 
            else:
 
                style.addElement(odf.number.Text(text=part.strip("'")))
 
        return style
 

	
 
    def currency_style(
 
            self,
 
            code: str,
 
            locale: Optional[babel.core.Locale]=None,
 
            negative_properties: Optional[odf.style.TextProperties]=None,
 
            positive_properties: Optional[odf.style.TextProperties]=None,
 
            root: odf.element.Element=None,
 
    ) -> odf.style.Style:
 
        """Create and return a spreadsheet style to format currency data
 

	
 
        Given a currency code and a locale, this method will create all the
 
        styles necessary to format the currency according to the locale's
 
        rules, including rendering of decimal points and negative values.
 

	
 
        You may optionally pass in TextProperties to use for negative and
 
        positive amounts, respectively. If you don't, negative values will
 
        automatically be rendered in red (text color #f00).
 

	
 
        Results are cached. If you repeatedly call this method with the same
 
        arguments, you'll keep getting the same style returned, which will
 
        only be added to the document once.
 
        """
 
        if locale is None:
 
            locale = self.locale
 
        if negative_properties is None:
 
            negative_properties = odf.style.TextProperties(color='#ff0000')
 
        if root is None:
 
            root = self.document.styles
 
        cache_parts = [str(id(root)), code, str(locale)]
 
        for key, value in self.iter_attributes(negative_properties):
 
            cache_parts.append(f'{key}={value}')
 
        if positive_properties is not None:
 
            cache_parts.append('')
 
            for key, value in self.iter_attributes(positive_properties):
 
                cache_parts.append(f'{key}={value}')
 
        cache_key = '\0'.join(cache_parts)
 
        try:
 
            style = self._currency_style_cache[cache_key]
 
        except KeyError:
 
            pos_style = self._build_currency_style(
 
                root, locale, code, 0, positive_properties, volatile=True,
 
            )
 
            curr_style = self._build_currency_style(
 
                root, locale, code, 1, negative_properties,
 
            )
 
            curr_style.addElement(odf.style.Map(
 
                condition='value()>=0', applystylename=pos_style,
 
            ))
 
            style = self.ensure_child(
 
                self.document.styles,
 
                odf.style.Style,
 
                name=f'{curr_style.getAttribute("name")}Cell',
 
                family='table-cell',
 
                datastylename=curr_style,
 
            )
 
            self._currency_style_cache[cache_key] = style
 
        return style
 

	
 
    def _merge_style_iter_names(
 
            self,
 
            styles: Sequence[Union[str, odf.style.Style, None]],
 
    ) -> Iterator[str]:
 
        for source in styles:
 
            if source is None:
 
                continue
 
            elif not isinstance(source, str):
 
                source = source.getAttribute('name')
 
            if source.startswith('Merge_'):
 
                orig_names = iter(source.split('_'))
 
                next(orig_names)
 
                yield from orig_names
 
            else:
 
                yield source
 

	
 
    def _merge_styles(self,
 
                      new_style: odf.style.Style,
 
                      sources: Iterable[odf.style.Style],
 
    ) -> None:
 
        for elem in sources:
 
            for key, new_value in self.iter_attributes(elem):
 
                old_value = new_style.getAttribute(key)
 
                if (key == 'name'
 
                    or key == 'displayname'
 
                    or old_value == new_value):
 
                    pass
 
                elif old_value is None:
 
                    new_style.setAttribute(key, new_value)
 
                else:
 
                    raise ValueError(f"cannot merge styles with conflicting {key}")
 
            for child in elem.childNodes:
 
                new_style.addElement(self.copy_element(child))
 

	
 
    def merge_styles(self,
 
                     *styles: Union[str, odf.style.Style, None],
 
    ) -> Optional[odf.style.Style]:
 
        """Create a new style from multiple existing styles
 

	
 
        Given any number of existing styles, create a new style that combines
 
        all of those styles' attributes and properties, add it to the document
 
        styles, and return it.
 

	
 
        Styles can be specified by name, or by passing in their Style element.
 
        For convenience, you can also pass in None as an argument; None will
 
        simply be skipped.
 

	
 
        Results are cached. If you repeatedly call this method with the same
 
        arguments, you'll keep getting the same style returned, which will
 
        only be added to the document once.
 

	
 
        If you pass in zero real style arguments, returns None.
 
        If you pass in one style argument, returns that style unchanged.
 
        If you pass in a style that doesn't already exist in the document,
 
        or if you pass in styles that can't be merged (because they have
 
        conflicting attributes), raises ValueError.
 
        """
 
        name_map: Dict[str, odf.style.Style] = {}
 
        for name in self._merge_style_iter_names(styles):
 
            source = odf.style.Style(name=name)
 
            found = self.find_child(self.document.styles, source)
 
            if found is None:
 
                raise ValueError(f"no style named {name!r}")
 
            name_map[name] = found
 
        if not name_map:
 
            retval = None
 
        elif len(name_map) == 1:
 
            _, retval = name_map.popitem()
 
        else:
 
            new_name = f'Merge_{"_".join(sorted(name_map))}'
 
            retval = self.ensure_child(
 
                self.document.styles, odf.style.Style, name=new_name,
 
            )
 
            if retval.firstChild is None:
 
                self._merge_styles(retval, name_map.values())
 
        return retval
 

	
 
    ### Sheets
 

	
 
    def lock_first_row(self, sheet: Optional[odf.table.Table]=None) -> None:
 
        """Lock the first row of cells under the given sheet
 

	
 
        This method sets all the appropriate settings to "lock" the first row
 
        of cells in a sheet, so it stays in view even as the viewer scrolls
 
        through rows. If a sheet is not given, works on ``self.sheet``.
 
        """
 
        if sheet is None:
 
            sheet = self.sheet
 
        config_map = self.ensure_config_map_entry(
 
            self.view, 'Tables', sheet.getAttribute('name'),
 
        )
 
        self.set_config(config_map, 'PositionBottom', 1, 'int')
 
        self.set_config(config_map, 'VerticalSplitMode', 2, 'short')
 
        self.set_config(config_map, 'VerticalSplitPosition', 1, 'short')
 

	
 
    def use_sheet(self, name: str) -> odf.table.Table:
 
        """Switch the active sheet ``self.sheet`` to the one with the given name
 

	
 
        If there is no sheet with the given name, create it and append it to
 
        the spreadsheet first.
 

	
 
        If the current active sheet is empty when this method is called, it
 
        will be removed from the spreadsheet.
 
        """
 
        try:
 
            empty_sheet = not self.sheet.hasChildNodes()
 
        except AttributeError:
 
            empty_sheet = False
 
        if empty_sheet:
 
            self.document.spreadsheet.removeChild(self.sheet)
 
        self.sheet = self.ensure_child(
 
            self.document.spreadsheet, odf.table.Table, name=name,
 
        )
 
        return self.sheet
 

	
 
    ### Initialization hooks
 

	
 
    def init_settings(self) -> None:
 
        """Hook called to initialize settings
 

	
 
        This method is called by __init__ to populate
 
        ``self.document.settings``. This implementation creates the barest
 
        skeleton structure necessary to support other methods, in particular
 
        ``lock_first_row``.
 
        """
 
        view_settings = self.ensure_child(
 
            self.document.settings, odf.config.ConfigItemSet, name='ooo:view-settings',
 
        )
 
        views = self.ensure_child(
 
            view_settings, odf.config.ConfigItemMapIndexed, name='Views',
 
        )
 
        self.view = self.ensure_child(views, odf.config.ConfigItemMapEntry)
 
        self.set_config(self.view, 'ViewId', 'view1')
 

	
 
    def init_styles(self) -> None:
 
        """Hook called to initialize settings
 

	
 
        This method is called by __init__ to populate
 
        ``self.document.styles``. This implementation creates basic building
 
        block cell styles often used in financial reports.
 
        """
 
        styles = self.document.styles
 
        self.style_bold = self.ensure_child(
 
            styles, odf.style.Style, name='Bold', family='table-cell',
 
        )
 
        self.ensure_child(
 
            self.style_bold, odf.style.TextProperties, fontweight='bold',
 
        )
 
        self.style_starttext: odf.style.Style
 
        self.style_centertext: odf.style.Style
 
        self.style_endtext: odf.style.Style
 
        for textalign in ['start', 'center', 'end']:
 
            aligned_style = self.replace_child(
 
                styles, odf.style.Style, name=f'{textalign.title()}Text',
 
            )
 
            aligned_style.setAttribute('family', 'table-cell')
 
            aligned_style.addElement(odf.style.ParagraphProperties(textalign=textalign))
 
            setattr(self, f'style_{textalign}text', aligned_style)
 
        date_style = self.replace_child(styles, odf.number.DateStyle, name='ISODate')
 
        date_style.addElement(odf.number.Year(style='long'))
 
        date_style.addElement(odf.number.Text(text='-'))
 
        date_style.addElement(odf.number.Month(style='long'))
 
        date_style.addElement(odf.number.Text(text='-'))
 
        date_style.addElement(odf.number.Day(style='long'))
 
        self.style_date = self.ensure_child(
 
            styles,
 
            odf.style.Style,
 
            name=f'{date_style.getAttribute("name")}Cell',
 
            family='table-cell',
 
            datastylename=date_style,
 
        )
 
        self.style_dividerline = self.ensure_child(
 
            styles, odf.style.Style, name='DividerLine', family='table-cell',
 
        )
 
        self.ensure_child(
 
            self.style_dividerline,
 
            odf.style.TableCellProperties,
 
            borderbottom='1pt solid #0000ff',
 
        )
 

	
 
    ### Rows and cells
 

	
 
    def add_row(self, *cells: odf.table.TableCell, **attrs: Any) -> odf.table.TableRow:
 
        row = odf.table.TableRow(**attrs)
 
        for cell in cells:
 
            row.addElement(cell)
 
        self.sheet.addElement(row)
 
        return row
 

	
 
    def balance_cell(self, balance: Balance, **attrs: Any) -> odf.table.TableCell:
 
        if balance.is_zero():
 
            return self.float_cell(0, **attrs)
 
        elif len(balance) == 1:
 
            amount = next(iter(balance.values()))
 
            attrs['stylename'] = self.merge_styles(
 
                attrs.get('stylename'), self.currency_style(amount.currency),
 
            )
 
            return self.currency_cell(amount, **attrs)
 
        else:
 
            lines = [babel.numbers.format_currency(
 
                number, currency, locale=self.locale, format_type=self.currency_fmt_key,
 
            ) for number, currency in balance.values()]
 
            attrs['stylename'] = self.merge_styles(
 
                attrs.get('stylename'), self.style_endtext,
 
            )
 
            return self.multiline_cell(lines, **attrs)
 

	
 
    def currency_cell(self, amount: data.Amount, **attrs: Any) -> odf.table.TableCell:
 
        number, currency = amount
 
        cell = odf.table.TableCell(valuetype='currency', value=number, **attrs)
 
        cell.addElement(odf.text.P(text=babel.numbers.format_currency(
 
            number, currency, locale=self.locale, format_type=self.currency_fmt_key,
 
        )))
 
        return cell
 

	
 
    def date_cell(self, date: datetime.date, **attrs: Any) -> odf.table.TableCell:
 
        attrs.setdefault('stylename', self.style_date)
 
        cell = odf.table.TableCell(valuetype='date', datevalue=date, **attrs)
 
        cell.addElement(odf.text.P(text=date.isoformat()))
 
        return cell
 

	
 
    def float_cell(self, value: Union[int, float, Decimal], **attrs: Any) -> odf.table.TableCell:
 
        cell = odf.table.TableCell(valuetype='float', value=value, **attrs)
 
        cell.addElement(odf.text.P(text=str(value)))
 
        return cell
 

	
 
    def multiline_cell(self, lines: Iterable[Any], **attrs: Any) -> odf.table.TableCell:
 
        cell = odf.table.TableCell(valuetype='string', **attrs)
 
        for line in lines:
 
            cell.addElement(odf.text.P(text=str(line)))
 
        return cell
 

	
 
    def multilink_cell(self, links: Iterable[LinkType], **attrs: Any) -> odf.table.TableCell:
 
        cell = odf.table.TableCell(valuetype='string', **attrs)
 
        for link in links:
 
            if isinstance(link, tuple):
 
                href, text = link
 
            else:
 
                href = link
 
                text = None
 
            cell.addElement(odf.text.P())
 
            cell.lastChild.addElement(odf.text.A(
 
                type='simple', href=href, text=text,
 
            ))
 
        return cell
 

	
 
    def string_cell(self, text: str, **attrs: Any) -> odf.table.TableCell:
 
        cell = odf.table.TableCell(valuetype='string', **attrs)
 
        cell.addElement(odf.text.P(text=text))
 
        return cell
 

	
 
    def write_row(self, row: RT) -> None:
 
        """Write a single row of input data to the spreadsheet
 

	
 
        This default implementation adds a single row to the spreadsheet,
 
        with one cell per element of the row. The type of each element
 
        determines what kind of cell is created.
 

	
 
        This implementation will help get you started, but you'll probably
 
        want to override it to specify styles.
 
        """
 
        out_row = odf.table.TableRow()
 
        for cell_source in row:
 
            if isinstance(cell_source, (int, float, Decimal)):
 
                cell = self.float_cell(cell_source)
 
            else:
 
                cell = self.string_cell(cell_source)
 
            out_row.addElement(cell)
 
        self.sheet.addElement(out_row)
 

	
 
    def save_file(self, out_file: BinaryIO) -> None:
 
        self.document.write(out_file)
 

	
 
    def save_path(self, path: Path, mode: str='w') -> None:
 
        with path.open(f'{mode}b') as out_file:
 
            out_file = cast(BinaryIO, out_file)
 
            self.save_file(out_file)
 

	
 

	
 
def normalize_amount_func(account_name: str) -> Callable[[T], T]:
 
    """Get a function to normalize amounts for reporting
 

	
 
    Given an account name, return a function that can be used on "amounts"
 
    under that account (including numbers, Amount objects, and Balance objects)
 
    to normalize them for reporting. Right now that means make flipping the
 
    sign for accounts where "normal" postings are negative.
 
    """
 
    if account_name.startswith(('Assets:', 'Expenses:')):
 
        # We can't just return operator.pos because Beancount's Amount class
 
        # doesn't implement __pos__.
 
        return lambda amt: amt
 
    elif account_name.startswith(('Equity:', 'Income:', 'Liabilities:')):
 
        return operator.neg
 
    else:
 
        raise ValueError(f"unrecognized account name {account_name!r}")
tests/test_reports_related_postings.py
Show inline comments
 
"""test_reports_related_postings - Unit tests for RelatedPostings"""
 
# Copyright © 2020  Brett Smith
 
#
 
# This program is free software: you can redistribute it and/or modify
 
# it under the terms of the GNU Affero General Public License as published by
 
# the Free Software Foundation, either version 3 of the License, or
 
# (at your option) any later version.
 
#
 
# This program is distributed in the hope that it will be useful,
 
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
# GNU Affero General Public License for more details.
 
#
 
# You should have received a copy of the GNU Affero General Public License
 
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
 

	
 
import collections
 
import datetime
 
import itertools
 

	
 
from decimal import Decimal
 

	
 
import pytest
 

	
 
from . import testutil
 

	
 
from conservancy_beancount import data
 
from conservancy_beancount.reports import core
 

	
 
def accruals_and_payments(acct, src_acct, dst_acct, start_date, *amounts):
 
    dates = testutil.date_seq(start_date)
 
    for amt, currency in amounts:
 
        post_meta = {'metanumber': amt, 'metacurrency': currency}
 
        yield testutil.Transaction(date=next(dates), postings=[
 
            (acct, amt, currency, post_meta),
 
            (dst_acct if amt < 0 else src_acct, -amt, currency, post_meta),
 
        ])
 

	
 
@pytest.fixture
 
def credit_card_cycle():
 
    return list(accruals_and_payments(
 
        'Liabilities:CreditCard',
 
        'Assets:Checking',
 
        'Expenses:Other',
 
        datetime.date(2020, 4, 1),
 
        (-110, 'USD'),
 
        (110, 'USD'),
 
        (-120, 'USD'),
 
        (120, 'USD'),
 
    ))
 

	
 
@pytest.fixture
 
def two_accruals_three_payments():
 
    return list(accruals_and_payments(
 
        'Assets:Receivable:Accounts',
 
        'Income:Donations',
 
        'Assets:Checking',
 
        datetime.date(2020, 4, 10),
 
        (440, 'USD'),
 
        (-230, 'USD'),
 
        (550, 'EUR'),
 
        (-210, 'USD'),
 
        (-550, 'EUR'),
 
    ))
 

	
 
def test_initialize_with_list(credit_card_cycle):
 
    related = core.RelatedPostings(credit_card_cycle[0].postings)
 
    assert len(related) == 2
 

	
 
def test_initialize_with_iterable(two_accruals_three_payments):
 
    related = core.RelatedPostings(
 
        post for txn in two_accruals_three_payments
 
        for post in txn.postings
 
        if post.account == 'Assets:Receivable:Accounts'
 
    )
 
    assert len(related) == 5
 

	
 
def test_balance_empty():
 
    balance = core.RelatedPostings().balance()
 
    assert not balance
 
    assert balance.is_zero()
 

	
 
@pytest.mark.parametrize('index,expected', enumerate([
 
    -110,
 
    0,
 
    -120,
 
    0,
 
]))
 
def test_balance_credit_card(credit_card_cycle, index, expected):
 
    related = core.RelatedPostings(
 
        txn.postings[0] for txn in credit_card_cycle[:index + 1]
 
    )
 
    assert related.balance() == {'USD': testutil.Amount(expected, 'USD')}
 

	
 
def check_iter_with_balance(entries):
 
    expect_posts = [txn.postings[0] for txn in entries]
 
    expect_balances = []
 
    balance_tally = collections.defaultdict(Decimal)
 
    for post in expect_posts:
 
        number, currency = post.units
 
        balance_tally[currency] += number
 
        expect_balances.append({code: testutil.Amount(number, code)
 
                                for code, number in balance_tally.items()})
 
    related = core.RelatedPostings(expect_posts)
 
    for (post, balance), exp_post, exp_balance in zip(
 
            related.iter_with_balance(),
 
            expect_posts,
 
            expect_balances,
 
    ):
 
        assert post is exp_post
 
        assert balance == exp_balance
 
    assert post is expect_posts[-1]
 
    assert related.balance() == expect_balances[-1]
 

	
 
def test_iter_with_balance_empty():
 
    assert not list(core.RelatedPostings().iter_with_balance())
 

	
 
def test_iter_with_balance_credit_card(credit_card_cycle):
 
    check_iter_with_balance(credit_card_cycle)
 

	
 
def test_iter_with_balance_two_acccruals(two_accruals_three_payments):
 
    check_iter_with_balance(two_accruals_three_payments)
 

	
 
def test_balance_at_cost_mixed():
 
    txn = testutil.Transaction(postings=[
 
        ('Expenses:Other', '22'),
 
        ('Expenses:Other', '30', 'EUR', ('1.1',)),
 
        ('Expenses:Other', '40', 'EUR'),
 
        ('Expenses:Other', '50', 'USD', ('1.1', 'EUR')),
 
    ])
 
    related = core.RelatedPostings(data.Posting.from_txn(txn))
 
    balance = related.balance_at_cost()
 
    amounts = set(balance.values())
 
    assert amounts == {testutil.Amount(55, 'USD'), testutil.Amount(95, 'EUR')}
 

	
 
def test_balance_at_single_currency_cost():
 
    txn = testutil.Transaction(postings=[
 
        ('Expenses:Other', '22'),
 
        ('Expenses:Other', '30', 'EUR', ('1.1',)),
 
        ('Expenses:Other', '40', 'GBP', ('1.1',)),
 
    ])
 
    related = core.RelatedPostings(data.Posting.from_txn(txn))
 
    balance = related.balance_at_cost()
 
    amounts = set(balance.values())
 
    assert amounts == {testutil.Amount(99)}
 

	
 
def test_balance_at_cost_zeroed_out():
 
    txn = testutil.Transaction(postings=[
 
        ('Income:Other', '-22'),
 
        ('Assets:Receivable:Accounts', '20', 'EUR', ('1.1',)),
 
    ])
 
    related = core.RelatedPostings(data.Posting.from_txn(txn))
 
    balance = related.balance_at_cost()
 
    assert balance.is_zero()
 

	
 
def test_balance_at_cost_singleton():
 
    txn = testutil.Transaction(postings=[
 
        ('Assets:Receivable:Accounts', '20', 'EUR', ('1.1',)),
 
    ])
 
    related = core.RelatedPostings(data.Posting.from_txn(txn))
 
    balance = related.balance_at_cost()
 
    amounts = set(balance.values())
 
    assert amounts == {testutil.Amount(22)}
 

	
 
def test_balance_at_cost_singleton_without_cost():
 
    txn = testutil.Transaction(postings=[
 
        ('Assets:Receivable:Accounts', '20'),
 
    ])
 
    related = core.RelatedPostings(data.Posting.from_txn(txn))
 
    balance = related.balance_at_cost()
 
    amounts = set(balance.values())
 
    assert amounts == {testutil.Amount(20)}
 

	
 
def test_balance_at_cost_empty():
 
    related = core.RelatedPostings()
 
    balance = related.balance_at_cost()
 
    assert balance.is_zero()
 

	
 
def test_meta_values_empty():
 
    related = core.RelatedPostings()
 
    assert related.meta_values('key') == set()
 

	
 
def test_meta_values_no_match():
 
    related = core.RelatedPostings([
 
        testutil.Posting('Income:Donations', -1, metakey='metavalue'),
 
    ])
 
    assert related.meta_values('key') == {None}
 

	
 
def test_meta_values_no_match_default_given():
 
    related = core.RelatedPostings([
 
        testutil.Posting('Income:Donations', -1, metakey='metavalue'),
 
    ])
 
    assert related.meta_values('key', '') == {''}
 

	
 
def test_meta_values_one_match():
 
    related = core.RelatedPostings([
 
        testutil.Posting('Income:Donations', -1, key='metavalue'),
 
    ])
 
    assert related.meta_values('key') == {'metavalue'}
 

	
 
def test_meta_values_some_match():
 
    related = core.RelatedPostings([
 
        testutil.Posting('Income:Donations', -1, key='1'),
 
        testutil.Posting('Income:Donations', -2, metakey='2'),
 
    ])
 
    assert related.meta_values('key') == {'1', None}
 

	
 
def test_meta_values_some_match_default_given():
 
    related = core.RelatedPostings([
 
        testutil.Posting('Income:Donations', -1, key='1'),
 
        testutil.Posting('Income:Donations', -2, metakey='2'),
 
    ])
 
    assert related.meta_values('key', '') == {'1', ''}
 

	
 
def test_meta_values_all_match():
 
    related = core.RelatedPostings([
 
        testutil.Posting('Income:Donations', -1, key='1'),
 
        testutil.Posting('Income:Donations', -2, key='2'),
 
    ])
 
    assert related.meta_values('key') == {'1', '2'}
 

	
 
def test_meta_values_all_match_one_value():
 
    related = core.RelatedPostings([
 
        testutil.Posting('Income:Donations', -1, key='1'),
 
        testutil.Posting('Income:Donations', -2, key='1'),
 
    ])
 
    assert related.meta_values('key') == {'1'}
 

	
 
def test_meta_values_all_match_default_given():
 
    related = core.RelatedPostings([
 
        testutil.Posting('Income:Donations', -1, key='1'),
 
        testutil.Posting('Income:Donations', -2, key='2'),
 
    ])
 
    assert related.meta_values('key', '') == {'1', '2'}
 

	
 
def test_meta_values_many_types():
 
    expected = {
 
        datetime.date(2020, 4, 1),
 
        Decimal(42),
 
        testutil.Amount(5),
 
        'rt:42',
 
    }
 
    related = core.RelatedPostings(
 
        testutil.Posting('Income:Donations', -index, key=value)
 
        for index, value in enumerate(expected)
 
    )
 
    assert related.meta_values('key') == expected
 

	
 
@pytest.mark.parametrize('count', range(3))
 
def test_all_meta_links_zero(count):
 
    postings = (
 
        testutil.Posting('Income:Donations', -n, testkey=str(n))
 
        for n in range(count)
 
    )
 
    related = core.RelatedPostings(
 
        post._replace(meta=data.Metadata(post.meta))
 
        for post in postings
 
    )
 
    assert related.all_meta_links('approval') == set()
 
    assert next(related.all_meta_links('approval'), None) is None
 

	
 
def test_all_meta_links_singletons():
 
    postings = (
 
        testutil.Posting('Income:Donations', -10, statement=value)
 
        for value in itertools.chain(
 
            testutil.NON_LINK_METADATA_STRINGS,
 
            testutil.LINK_METADATA_STRINGS,
 
            testutil.NON_STRING_METADATA_VALUES,
 
        ))
 
    related = core.RelatedPostings(
 
        post._replace(meta=data.Metadata(post.meta))
 
        for post in postings
 
    )
 
    assert related.all_meta_links('statement') == testutil.LINK_METADATA_STRINGS
 
    assert set(related.all_meta_links('statement')) == testutil.LINK_METADATA_STRINGS
 

	
 
def test_all_meta_links_multiples():
 
    postings = (
 
        testutil.Posting('Income:Donations', -10, approval=' '.join(value))
 
        for value in itertools.permutations(testutil.LINK_METADATA_STRINGS, 2)
 
    )
 
    related = core.RelatedPostings(
 
        post._replace(meta=data.Metadata(post.meta))
 
        for post in postings
 
    )
 
    assert related.all_meta_links('approval') == testutil.LINK_METADATA_STRINGS
 
    assert set(related.all_meta_links('approval')) == testutil.LINK_METADATA_STRINGS
 

	
 
def test_all_meta_links_preserves_order():
 
    postings = (
 
        testutil.Posting('Income:Donations', -10, approval=c)
 
        for c in '121323'
 
    )
 
    related = core.RelatedPostings(
 
        post._replace(meta=data.Metadata(post.meta))
 
        for post in postings
 
    )
 
    assert list(related.all_meta_links('approval')) == list('123')
 

	
 
def test_group_by_meta_zero():
 
    assert not list(core.RelatedPostings.group_by_meta([], 'metacurrency'))
 

	
 
def test_group_by_meta_one(credit_card_cycle):
 
    posting = next(post for post in data.Posting.from_entries(credit_card_cycle)
 
                   if post.account.is_credit_card())
 
    actual = core.RelatedPostings.group_by_meta([posting], 'metacurrency')
 
    assert set(key for key, _ in actual) == {'USD'}
 

	
 
def test_group_by_meta_many(two_accruals_three_payments):
 
    postings = [post for post in data.Posting.from_entries(two_accruals_three_payments)
 
                if post.account == 'Assets:Receivable:Accounts']
 
    actual = dict(core.RelatedPostings.group_by_meta(postings, 'metacurrency'))
 
    assert set(actual) == {'USD', 'EUR'}
 
    for key, group in actual.items():
 
        assert 2 <= len(group) <= 3
 
        assert group.balance().is_zero()
 

	
 
def test_group_by_meta_many_single_posts(two_accruals_three_payments):
 
    postings = [post for post in data.Posting.from_entries(two_accruals_three_payments)
 
                if post.account == 'Assets:Receivable:Accounts']
 
    actual = dict(core.RelatedPostings.group_by_meta(postings, 'metanumber'))
 
    assert set(actual) == {post.units.number for post in postings}
 
    assert len(actual) == len(postings)
0 comments (0 inline, 0 general)