Changeset - f76fa35fad2a
[Not reviewed]
0 3 0
Brett Smith - 4 years ago 2020-06-11 14:46:06
brettcsmith@brettcsmith.org
reports: RelatedPostings.all_meta_links() returns an iterator.

This preserves order.
3 files changed with 21 insertions and 8 deletions:
0 comments (0 inline, 0 general)
conservancy_beancount/reports/accrual.py
Show inline comments
...
 
@@ -364,385 +364,385 @@ class AgingODS(core.BaseODS[AccrualPostings, Optional[data.Account]]):
 
            if years and days:
 
                age_text = f"{years_text} {days_text}"
 
            elif years:
 
                age_text = years_text
 
            else:
 
                age_text = days_text
 
            if last_age_text is None:
 
                age_range = f"Over {age_text}"
 
            else:
 
                age_range = f"{age_text}–{last_age_text}"
 
            self.add_row(
 
                self.string_cell(
 
                    f"Total Aged {age_range}: ",
 
                    stylename=text_style,
 
                    numbercolumnsspanned=text_span,
 
                ),
 
                *(odf.table.TableCell() for _ in range(1, text_span)),
 
                self.balance_cell(balance),
 
            )
 
            last_age_text = age_text
 
            total_balance += balance
 
        self.add_row(
 
            self.string_cell(
 
                "Total Unpaid: ",
 
                stylename=text_style,
 
                numbercolumnsspanned=text_span,
 
            ),
 
            *(odf.table.TableCell() for _ in range(1, text_span)),
 
            self.balance_cell(total_balance),
 
        )
 

	
 
    def _link_seq(self, row: AccrualPostings, key: MetaKey) -> Iterator[Tuple[str, str]]:
 
        for href in row.all_meta_links(key):
 
            text: Optional[str] = None
 
            rt_ids = self.rt_wrapper.parse(href)
 
            if rt_ids is not None:
 
                ticket_id, attachment_id = rt_ids
 
                if attachment_id is None:
 
                    text = f'RT#{ticket_id}'
 
                href = self.rt_wrapper.url(ticket_id, attachment_id) or href
 
            else:
 
                # '..' pops the ODS filename off the link path. In other words,
 
                # make the link relative to the directory the ODS is in.
 
                href = f'../{href}'
 
            if text is None:
 
                href_path = Path(urlparse.urlparse(href).path)
 
                text = urlparse.unquote(href_path.name)
 
            yield (href, text)
 

	
 
    def write_row(self, row: AccrualPostings) -> None:
 
        age = (self.date - row[0].meta.date).days
 
        if row.end_balance.ge_zero():
 
            for index, threshold in enumerate(self.age_thresholds):
 
                if age >= threshold:
 
                    self.age_balances[index] += row.end_balance
 
                    break
 
            else:
 
                return
 
        raw_balance = row.balance()
 
        if row.accrual_type is not None:
 
            raw_balance = row.accrual_type.normalize_amount(raw_balance)
 
        if raw_balance == row.end_balance:
 
            amount_cell = odf.table.TableCell()
 
        else:
 
            amount_cell = self.balance_cell(raw_balance)
 
        projects = {post.meta.get('project') or None for post in row}
 
        projects.discard(None)
 
        self.add_row(
 
            self.date_cell(row[0].meta.date),
 
            self.multiline_cell(row.entities()),
 
            amount_cell,
 
            self.balance_cell(row.end_balance),
 
            self.multiline_cell(sorted(projects)),
 
            self.multilink_cell(self._link_seq(row, 'rt-id')),
 
            self.multilink_cell(self._link_seq(row, 'invoice')),
 
            self.multilink_cell(self._link_seq(row, 'approval')),
 
            self.multilink_cell(self._link_seq(row, 'contract')),
 
            self.multilink_cell(self._link_seq(row, 'purchase-order')),
 
        )
 

	
 

	
 
class AgingReport(BaseReport):
 
    def __init__(self,
 
                 rt_client: rt.Rt,
 
                 out_file: BinaryIO,
 
                 date: Optional[datetime.date]=None,
 
    ) -> None:
 
        if date is None:
 
            date = datetime.date.today()
 
        self.out_bin = out_file
 
        self.logger = logger.getChild(type(self).__name__)
 
        self.ods = AgingODS(rt_client, date, self.logger)
 

	
 
    def run(self, groups: PostGroups) -> None:
 
        rows: List[AccrualPostings] = []
 
        for group in groups.values():
 
            if group.is_zero():
 
                # Cheap optimization: don't slice and dice groups we're not
 
                # going to report anyway.
 
                continue
 
            elif group.accrual_type is None:
 
                group = group.since_last_nonzero()
 
            else:
 
                # Filter out new accruals after the report date.
 
                # e.g., cover the case that the same invoices has multiple
 
                # postings over time, and we don't want to report too-recent
 
                # ones.
 
                cutoff_date = self.ods.date - datetime.timedelta(
 
                    days=group.accrual_type.value.aging_thresholds[-1],
 
                )
 
                group = AccrualPostings(
 
                    post for post in group.since_last_nonzero()
 
                    if post.meta.date <= cutoff_date
 
                    or group.accrual_type.normalize_amount(post.units.number) < 0
 
                )
 
            if group and not group.is_zero():
 
                rows.append(group)
 
        rows.sort(key=lambda related: (
 
            related.account,
 
            related[0].meta.date,
 
            ('\0'.join(related.entities())
 
             if related.entity is related.INCONSISTENT
 
             else related.entity),
 
        ))
 
        self.ods.write(rows)
 
        self.ods.save_file(self.out_bin)
 

	
 

	
 
class BalanceReport(BaseReport):
 
    def _report(self, posts: AccrualPostings, index: int) -> Iterable[str]:
 
        posts = posts.since_last_nonzero()
 
        date_s = posts[0].meta.date.strftime('%Y-%m-%d')
 
        if index:
 
            yield ""
 
        yield f"{posts.invoice}:"
 
        yield f"  {posts.balance_at_cost()} outstanding since {date_s}"
 

	
 

	
 
class OutgoingReport(BaseReport):
 
    def __init__(self, rt_client: rt.Rt, out_file: TextIO) -> None:
 
        super().__init__(out_file)
 
        self.rt_client = rt_client
 
        self.rt_wrapper = rtutil.RT(rt_client)
 

	
 
    def _primary_rt_id(self, posts: AccrualPostings) -> rtutil.TicketAttachmentIds:
 
        rt_ids = {url for url in posts.first_links('rt-id') if url is not None}
 
        rt_ids_count = len(rt_ids)
 
        if rt_ids_count != 1:
 
            raise ValueError(f"{rt_ids_count} rt-id links found")
 
        parsed = rtutil.RT.parse(rt_ids.pop())
 
        if parsed is None:
 
            raise ValueError("rt-id is not a valid RT reference")
 
        else:
 
            return parsed
 

	
 
    def _report(self, posts: AccrualPostings, index: int) -> Iterable[str]:
 
        posts = posts.since_last_nonzero()
 
        try:
 
            ticket_id, _ = self._primary_rt_id(posts)
 
            ticket = self.rt_client.get_ticket(ticket_id)
 
            # Note we only use this when ticket is None.
 
            errmsg = f"ticket {ticket_id} not found"
 
        except (ValueError, rt.RtError) as error:
 
            ticket = None
 
            errmsg = error.args[0]
 
        if ticket is None:
 
            self.logger.error(
 
                "can't generate outgoings report for %s because no RT ticket available: %s",
 
                posts.invoice, errmsg,
 
            )
 
            return
 

	
 
        try:
 
            rt_requestor = self.rt_client.get_user(ticket['Requestors'][0])
 
        except (IndexError, rt.RtError):
 
            rt_requestor = None
 
        if rt_requestor is None:
 
            requestor = ''
 
            requestor_name = ''
 
        else:
 
            requestor_name = (
 
                rt_requestor.get('RealName')
 
                or ticket.get('CF.{payment-to}')
 
                or ''
 
            )
 
            requestor = f'{requestor_name} <{rt_requestor["EmailAddress"]}>'.strip()
 

	
 
        balance_s = posts.end_balance.format(None)
 
        raw_balance = -posts.balance()
 
        if raw_balance != posts.end_balance:
 
            balance_s = f'{raw_balance} ({balance_s})'
 

	
 
        contract_links = posts.all_meta_links('contract')
 
        contract_links = list(posts.all_meta_links('contract'))
 
        if contract_links:
 
            contract_s = ' , '.join(self.rt_wrapper.iter_urls(
 
                contract_links, missing_fmt='<BROKEN RT LINK: {}>',
 
            ))
 
        else:
 
            contract_s = "NO CONTRACT GOVERNS THIS TRANSACTION"
 
        projects = [v for v in posts.meta_values('project')
 
                    if isinstance(v, str)]
 

	
 
        yield "PAYMENT FOR APPROVAL:"
 
        yield f"REQUESTOR: {requestor}"
 
        yield f"TOTAL TO PAY: {balance_s}"
 
        yield f"AGREEMENT: {contract_s}"
 
        yield f"PAYMENT TO: {ticket.get('CF.{payment-to}') or requestor_name}"
 
        yield f"PAYMENT METHOD: {ticket.get('CF.{payment-method}', '')}"
 
        yield f"PROJECT: {', '.join(projects)}"
 
        yield "\nBEANCOUNT ENTRIES:\n"
 

	
 
        last_txn: Optional[Transaction] = None
 
        for post in posts:
 
            txn = post.meta.txn
 
            if txn is not last_txn:
 
                last_txn = txn
 
                txn = self.rt_wrapper.txn_with_urls(txn, '{}')
 
                yield bc_printer.format_entry(txn)
 

	
 

	
 
class ReportType(enum.Enum):
 
    AGING = AgingReport
 
    BALANCE = BalanceReport
 
    OUTGOING = OutgoingReport
 
    AGE = AGING
 
    BAL = BALANCE
 
    OUT = OUTGOING
 
    OUTGOINGS = OUTGOING
 

	
 
    @classmethod
 
    def by_name(cls, name: str) -> 'ReportType':
 
        try:
 
            return cls[name.upper()]
 
        except KeyError:
 
            raise ValueError(f"unknown report type {name!r}") from None
 

	
 
    @classmethod
 
    def default_for(cls, groups: PostGroups) -> 'ReportType':
 
        if len(groups) == 1 and all(
 
                group.accrual_type is AccrualAccount.PAYABLE
 
                and not group.is_paid()
 
                for group in groups.values()
 
        ):
 
            return cls.OUTGOING
 
        else:
 
            return cls.BALANCE
 

	
 

	
 
class ReturnFlag(enum.IntFlag):
 
    LOAD_ERRORS = 1
 
    REPORT_ERRORS = 4
 
    NOTHING_TO_REPORT = 8
 

	
 

	
 
def filter_search(postings: Iterable[data.Posting],
 
                  search_terms: Iterable[cliutil.SearchTerm],
 
) -> Iterable[data.Posting]:
 
    accounts = tuple(AccrualAccount.account_names())
 
    postings = (post for post in postings if post.account.is_under(*accounts))
 
    for query in search_terms:
 
        postings = query.filter_postings(postings)
 
    return postings
 

	
 
def parse_arguments(arglist: Optional[Sequence[str]]=None) -> argparse.Namespace:
 
    parser = argparse.ArgumentParser(prog=PROGNAME)
 
    cliutil.add_version_argument(parser)
 
    parser.add_argument(
 
        '--report-type', '-t',
 
        metavar='NAME',
 
        type=ReportType.by_name,
 
        help="""The type of report to generate, one of `aging`, `balance`, or
 
`outgoing`. If not specified, the default is `aging` when no search terms are
 
given, `outgoing` for search terms that return a single outstanding payable,
 
and `balance` any other time.
 
""")
 
    parser.add_argument(
 
        '--since',
 
        metavar='YEAR',
 
        type=int,
 
        default=-1,
 
        help="""How far back to search the books for related transactions.
 
You can either specify a fiscal year, or a negative offset from the current
 
fiscal year, to start loading entries from. The default is -1 (start from the
 
previous fiscal year).
 
""")
 
    parser.add_argument(
 
        '--output-file', '-O',
 
        metavar='PATH',
 
        type=Path,
 
        help="""Write the report to this file, or stdout when PATH is `-`.
 
The default is stdout for the balance and outgoing reports, and a generated
 
filename for other reports.
 
""")
 
    cliutil.add_loglevel_argument(parser)
 
    parser.add_argument(
 
        'search_terms',
 
        metavar='FILTER',
 
        type=cliutil.SearchTerm.arg_parser('invoice', 'rt-id'),
 
        nargs=argparse.ZERO_OR_MORE,
 
        help="""Report on accruals that match this criteria. The format is
 
NAME=TERM. TERM is a link or word that must exist in a posting's NAME
 
metadata to match. A single ticket number is a shortcut for
 
`rt-id=rt:NUMBER`. Any other link, including an RT attachment link in
 
`TIK/ATT` format, is a shortcut for `invoice=LINK`.
 
""")
 
    args = parser.parse_args(arglist)
 
    if args.report_type is None and not args.search_terms:
 
        args.report_type = ReportType.AGING
 
    return args
 

	
 
def main(arglist: Optional[Sequence[str]]=None,
 
         stdout: TextIO=sys.stdout,
 
         stderr: TextIO=sys.stderr,
 
         config: Optional[configmod.Config]=None,
 
) -> int:
 
    if cliutil.is_main_script(PROGNAME):
 
        global logger
 
        logger = logging.getLogger(PROGNAME)
 
        sys.excepthook = cliutil.ExceptHook(logger)
 
    args = parse_arguments(arglist)
 
    cliutil.setup_logger(logger, args.loglevel, stderr)
 
    if config is None:
 
        config = configmod.Config()
 
        config.load_file()
 

	
 
    books_loader = config.books_loader()
 
    if books_loader is None:
 
        entries, load_errors, _ = books.Loader.load_none(config.config_file_path())
 
    elif args.report_type is ReportType.AGING:
 
        entries, load_errors, _ = books_loader.load_all()
 
    else:
 
        entries, load_errors, _ = books_loader.load_all(args.since)
 
    filters.remove_opening_balance_txn(entries)
 

	
 
    returncode = 0
 
    postings = filter_search(data.Posting.from_entries(entries), args.search_terms)
 
    groups: PostGroups = dict(AccrualPostings.group_by_meta(postings, 'invoice'))
 
    for error in load_errors:
 
        bc_printer.print_error(error, file=stderr)
 
        returncode |= ReturnFlag.LOAD_ERRORS
 
    if not groups:
 
        logger.warning("no matching entries found to report")
 
        returncode |= ReturnFlag.NOTHING_TO_REPORT
 

	
 
    groups = {
 
        key: posts
 
        for source_posts in groups.values()
 
        for key, posts in source_posts.make_consistent()
 
    }
 
    if args.report_type is not ReportType.AGING:
 
        groups = {
 
            key: posts for key, posts in groups.items() if not posts.is_paid()
 
        } or groups
 

	
 
    if args.report_type is None:
 
        args.report_type = ReportType.default_for(groups)
 
    report: Optional[BaseReport] = None
 
    output_path: Optional[Path] = None
 
    if args.report_type is ReportType.AGING:
 
        rt_client = config.rt_client()
 
        if rt_client is None:
 
            logger.error("unable to generate aging report: RT client is required")
 
        else:
 
            now = datetime.datetime.now()
 
            if args.output_file is None:
 
                args.output_file = Path(now.strftime('AgingReport_%Y-%m-%d_%H:%M.ods'))
 
                logger.info("Writing report to %s", args.output_file)
 
            out_bin = cliutil.bytes_output(args.output_file, stdout)
 
            report = AgingReport(rt_client, out_bin)
 
    elif args.report_type is ReportType.OUTGOING:
 
        rt_client = config.rt_client()
 
        if rt_client is None:
 
            logger.error("unable to generate outgoing report: RT client is required")
 
        else:
 
            out_file = cliutil.text_output(args.output_file, stdout)
 
            report = OutgoingReport(rt_client, out_file)
 
    else:
 
        out_file = cliutil.text_output(args.output_file, stdout)
 
        report = args.report_type.value(out_file)
 

	
 
    if report is None:
 
        returncode |= ReturnFlag.REPORT_ERRORS
 
    else:
 
        report.run(groups)
 
    return 0 if returncode == 0 else 16 + returncode
conservancy_beancount/reports/core.py
Show inline comments
 
"""core.py - Common data classes for reporting functionality"""
 
# Copyright © 2020  Brett Smith
 
#
 
# This program is free software: you can redistribute it and/or modify
 
# it under the terms of the GNU Affero General Public License as published by
 
# the Free Software Foundation, either version 3 of the License, or
 
# (at your option) any later version.
 
#
 
# This program is distributed in the hope that it will be useful,
 
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
# GNU Affero General Public License for more details.
 
#
 
# You should have received a copy of the GNU Affero General Public License
 
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
 

	
 
import abc
 
import collections
 
import datetime
 
import itertools
 
import operator
 
import re
 

	
 
import babel.core  # type:ignore[import]
 
import babel.numbers  # type:ignore[import]
 

	
 
import odf.config  # type:ignore[import]
 
import odf.element  # type:ignore[import]
 
import odf.number  # type:ignore[import]
 
import odf.opendocument  # type:ignore[import]
 
import odf.style  # type:ignore[import]
 
import odf.table  # type:ignore[import]
 
import odf.text  # type:ignore[import]
 

	
 
from decimal import Decimal
 
from pathlib import Path
 

	
 
from beancount.core import amount as bc_amount
 

	
 
from .. import data
 
from .. import filters
 

	
 
from typing import (
 
    cast,
 
    overload,
 
    Any,
 
    BinaryIO,
 
    Callable,
 
    DefaultDict,
 
    Dict,
 
    Generic,
 
    Iterable,
 
    Iterator,
 
    List,
 
    Mapping,
 
    MutableMapping,
 
    Optional,
 
    Sequence,
 
    Set,
 
    Tuple,
 
    Type,
 
    TypeVar,
 
    Union,
 
)
 
from ..beancount_types import (
 
    MetaKey,
 
    MetaValue,
 
)
 

	
 
DecimalCompat = data.DecimalCompat
 
BalanceType = TypeVar('BalanceType', bound='Balance')
 
ElementType = Callable[..., odf.element.Element]
 
LinkType = Union[str, Tuple[str, Optional[str]]]
 
RelatedType = TypeVar('RelatedType', bound='RelatedPostings')
 
RT = TypeVar('RT', bound=Sequence)
 
ST = TypeVar('ST')
 
T = TypeVar('T')
 

	
 
class Balance(Mapping[str, data.Amount]):
 
    """A collection of amounts mapped by currency
 

	
 
    Each key is a Beancount currency string, and each value represents the
 
    balance in that currency.
 
    """
 
    __slots__ = ('_currency_map', 'tolerance')
 
    TOLERANCE = Decimal('0.01')
 

	
 
    def __init__(self,
 
                 source: Iterable[data.Amount]=(),
 
                 tolerance: Optional[Decimal]=None,
 
    ) -> None:
 
        if tolerance is None:
 
            tolerance = self.TOLERANCE
 
        self.tolerance = tolerance
 
        self._currency_map: Dict[str, data.Amount] = {}
 
        for amount in source:
 
            self._add_amount(self._currency_map, amount)
 

	
 
    def _add_amount(self,
 
                    currency_map: MutableMapping[str, data.Amount],
 
                    amount: data.Amount,
 
    ) -> None:
 
        code = amount.currency
 
        try:
 
            current_number = currency_map[code].number
 
        except KeyError:
 
            current_number = Decimal(0)
 
        currency_map[code] = data.Amount(current_number + amount.number, code)
 

	
 
    def _add_other(self,
 
                   currency_map: MutableMapping[str, data.Amount],
 
                   other: Union[data.Amount, 'Balance'],
 
    ) -> None:
 
        if isinstance(other, Balance):
 
            for amount in other.values():
 
                self._add_amount(currency_map, amount)
 
        else:
 
            self._add_amount(currency_map, other)
 

	
 
    def __repr__(self) -> str:
 
        values = [repr(amt) for amt in self.values()]
 
        return f"{type(self).__name__}({values!r})"
 

	
 
    def __str__(self) -> str:
 
        return self.format()
 

	
 
    def __abs__(self: BalanceType) -> BalanceType:
 
        return type(self)(bc_amount.abs(amt) for amt in self.values())
 

	
 
    def __add__(self: BalanceType, other: Union[data.Amount, 'Balance']) -> BalanceType:
 
        retval_map = self._currency_map.copy()
 
        self._add_other(retval_map, other)
 
        return type(self)(retval_map.values())
 

	
 
    def __eq__(self, other: Any) -> bool:
 
        if isinstance(other, Balance):
 
            clean_self = self.clean_copy()
 
            clean_other = other.clean_copy()
 
            return len(clean_self) == len(clean_other) and all(
 
                clean_self[key] == clean_other.get(key) for key in clean_self
 
            )
 
        else:
 
            return super().__eq__(other)
 

	
 
    def __neg__(self: BalanceType) -> BalanceType:
 
        return type(self)(-amt for amt in self.values())
 

	
 
    def __pos__(self: BalanceType) -> BalanceType:
 
        return self
 

	
 
    def __getitem__(self, key: str) -> data.Amount:
 
        return self._currency_map[key]
 

	
 
    def __iter__(self) -> Iterator[str]:
 
        return iter(self._currency_map)
 

	
 
    def __len__(self) -> int:
 
        return len(self._currency_map)
 

	
 
    def _all_amounts(self,
 
                     op_func: Callable[[DecimalCompat, DecimalCompat], bool],
 
                     operand: DecimalCompat,
 
    ) -> bool:
 
        return all(op_func(amt.number, operand) for amt in self.values())
 

	
 
    def copy(self: BalanceType) -> BalanceType:
 
        return type(self)(self.values())
 

	
 
    def clean_copy(self: BalanceType, tolerance: Optional[Decimal]=None) -> BalanceType:
 
        if tolerance is None:
 
            tolerance = self.tolerance
 
        return type(self)(
 
            amount for amount in self.values()
 
            if abs(amount.number) >= tolerance
 
        )
 

	
 
    @staticmethod
 
    def within_tolerance(dec: DecimalCompat, tolerance: DecimalCompat) -> bool:
 
        dec = cast(Decimal, dec)
 
        return abs(dec) < tolerance
 

	
 
    def eq_zero(self) -> bool:
 
        """Returns true if all amounts in the balance == 0, within tolerance."""
 
        return self._all_amounts(self.within_tolerance, self.tolerance)
 

	
 
    is_zero = eq_zero
 

	
 
    def ge_zero(self) -> bool:
 
        """Returns true if all amounts in the balance >= 0, within tolerance."""
 
        op_func = operator.gt if self.tolerance else operator.ge
 
        return self._all_amounts(op_func, -self.tolerance)
 

	
 
    def le_zero(self) -> bool:
 
        """Returns true if all amounts in the balance <= 0, within tolerance."""
 
        op_func = operator.lt if self.tolerance else operator.le
 
        return self._all_amounts(op_func, self.tolerance)
 

	
 
    def format(self,
 
               fmt: Optional[str]='#,#00.00 ¤¤',
 
               sep: str=', ',
 
               empty: str="Zero balance",
 
               tolerance: Optional[Decimal]=None,
 
    ) -> str:
 
        """Formats the balance as a string with the given parameters
 

	
 
        If the balance is zero (within tolerance), returns ``empty``.
 
        Otherwise, returns a string with each amount in the balance formatted
 
        as ``fmt``, separated by ``sep``.
 

	
 
        If you set ``fmt`` to None, amounts will be formatted according to the
 
        user's locale. The default format is Beancount's input format.
 
        """
 
        amounts = list(self.clean_copy(tolerance).values())
 
        if not amounts:
 
            return empty
 
        amounts.sort(key=lambda amt: abs(amt.number), reverse=True)
 
        return sep.join(
 
            babel.numbers.format_currency(amt.number, amt.currency, fmt)
 
            for amt in amounts
 
        )
 

	
 

	
 
class MutableBalance(Balance):
 
    __slots__ = ()
 

	
 
    def __iadd__(self: BalanceType, other: Union[data.Amount, Balance]) -> BalanceType:
 
        self._add_other(self._currency_map, other)
 
        return self
 

	
 

	
 
class RelatedPostings(Sequence[data.Posting]):
 
    """Collect and query related postings
 

	
 
    This class provides common functionality for collecting related postings
 
    and running queries on them: iterating over them, tallying their balance,
 
    etc.
 

	
 
    This class doesn't know anything about how the postings are related. That's
 
    entirely up to the caller.
 

	
 
    A common pattern is to use this class with collections.defaultdict
 
    to organize postings based on some key. See the group_by_meta classmethod
 
    for an example.
 
    """
 
    __slots__ = ('_postings',)
 

	
 
    def __init__(self,
 
                 source: Iterable[data.Posting]=(),
 
                 *,
 
                 _can_own: bool=False,
 
    ) -> None:
 
        self._postings: List[data.Posting]
 
        if _can_own and isinstance(source, list):
 
            self._postings = source
 
        else:
 
            self._postings = list(source)
 

	
 
    @classmethod
 
    def group_by_meta(cls: Type[RelatedType],
 
                      postings: Iterable[data.Posting],
 
                      key: MetaKey,
 
                      default: Optional[MetaValue]=None,
 
    ) -> Iterator[Tuple[Optional[MetaValue], RelatedType]]:
 
        """Relate postings by metadata value
 

	
 
        This method takes an iterable of postings and returns a mapping.
 
        The keys of the mapping are the values of post.meta.get(key, default).
 
        The values are RelatedPostings instances that contain all the postings
 
        that had that same metadata value.
 
        """
 
        mapping: DefaultDict[Optional[MetaValue], List[data.Posting]] = collections.defaultdict(list)
 
        for post in postings:
 
            mapping[post.meta.get(key, default)].append(post)
 
        for value, posts in mapping.items():
 
            yield value, cls(posts, _can_own=True)
 

	
 
    def __repr__(self) -> str:
 
        return f'<{type(self).__name__} {self._postings!r}>'
 

	
 
    @overload
 
    def __getitem__(self: RelatedType, index: int) -> data.Posting: ...
 

	
 
    @overload
 
    def __getitem__(self: RelatedType, s: slice) -> RelatedType: ...
 

	
 
    def __getitem__(self: RelatedType,
 
                    index: Union[int, slice],
 
    ) -> Union[data.Posting, RelatedType]:
 
        if isinstance(index, slice):
 
            return type(self)(self._postings[index], _can_own=True)
 
        else:
 
            return self._postings[index]
 

	
 
    def __len__(self) -> int:
 
        return len(self._postings)
 

	
 
    def all_meta_links(self, key: MetaKey) -> Set[str]:
 
        retval: Set[str] = set()
 
    def _all_meta_links(self, key: MetaKey) -> Iterator[str]:
 
        for post in self:
 
            try:
 
                retval.update(post.meta.get_links(key))
 
                yield from post.meta.get_links(key)
 
            except TypeError:
 
                pass
 
        return retval
 

	
 
    def all_meta_links(self, key: MetaKey) -> Iterator[str]:
 
        return filters.iter_unique(self._all_meta_links(key))
 

	
 
    def iter_with_balance(self) -> Iterator[Tuple[data.Posting, Balance]]:
 
        balance = MutableBalance()
 
        for post in self:
 
            balance += post.units
 
            yield post, balance
 

	
 
    def balance(self) -> Balance:
 
        for _, balance in self.iter_with_balance():
 
            pass
 
        try:
 
            return balance
 
        except NameError:
 
            return Balance()
 

	
 
    def balance_at_cost(self) -> Balance:
 
        balance = MutableBalance()
 
        for post in self:
 
            if post.cost is None:
 
                balance += post.units
 
            else:
 
                number = post.units.number * post.cost.number
 
                balance += data.Amount(number, post.cost.currency)
 
        return balance
 

	
 
    def meta_values(self,
 
                    key: MetaKey,
 
                    default: Optional[MetaValue]=None,
 
    ) -> Set[Optional[MetaValue]]:
 
        return {post.meta.get(key, default) for post in self}
 

	
 

	
 
class BaseSpreadsheet(Generic[RT, ST], metaclass=abc.ABCMeta):
 
    """Abstract base class to help write spreadsheets
 

	
 
    This class provides the very core logic to write an arbitrary set of data
 
    rows to arbitrary output. It calls hooks when it starts writing the
 
    spreadsheet, starts a new "section" of rows, ends a section, and ends the
 
    spreadsheet.
 

	
 
    RT is the type of the input data rows. ST is the type of the section
 
    identifier that you create from each row. If you don't want to use the
 
    section logic at all, set ST to None and define section_key to return None.
 
    """
 

	
 
    @abc.abstractmethod
 
    def section_key(self, row: RT) -> ST:
 
        """Return the section a row belongs to
 

	
 
        Given a data row, this method should return some identifier for the
 
        "section" the row belongs to. The write method uses this to
 
        determine when to call start_section and end_section.
 

	
 
        If your spreadsheet doesn't need sections, define this to return None.
 
        """
 
        ...
 

	
 
    @abc.abstractmethod
 
    def write_row(self, row: RT) -> None:
 
        """Write a data row to the output spreadsheet
 

	
 
        This method is called once for each data row in the input.
 
        """
 
        ...
 

	
 
    # The next four methods are all called by the write method when the name
 
    # says. You may override them to output headers or sums, record
 
    # state, etc. The default implementations are all noops.
 

	
 
    def start_spreadsheet(self) -> None:
 
        pass
 

	
 
    def start_section(self, key: ST) -> None:
 
        pass
 

	
 
    def end_section(self, key: ST) -> None:
 
        pass
 

	
 
    def end_spreadsheet(self) -> None:
 
        pass
 

	
 
    def write(self, rows: Iterable[RT]) -> None:
 
        prev_section: Optional[ST] = None
 
        self.start_spreadsheet()
 
        for row in rows:
 
            section = self.section_key(row)
 
            if section != prev_section:
 
                if prev_section is not None:
 
                    self.end_section(prev_section)
 
                self.start_section(section)
 
                prev_section = section
 
            self.write_row(row)
 
        try:
 
            should_end = section is not None
 
        except NameError:
 
            should_end = False
 
        if should_end:
 
            self.end_section(section)
 
        self.end_spreadsheet()
 

	
 

	
 
class BaseODS(BaseSpreadsheet[RT, ST], metaclass=abc.ABCMeta):
 
    """Abstract base class to help write OpenDocument spreadsheets
 

	
 
    This class provides the very core logic to write an arbitrary set of data
 
    rows to an OpenDocument spreadsheet. It provides helper methods for
 
    building sheets, rows, and cells.
 

	
 
    See also the BaseSpreadsheet base class for additional documentation about
 
    methods you must and can define, the definition of RT and ST, etc.
 
    """
 
    def __init__(self) -> None:
 
        self.locale = babel.core.Locale.default('LC_MONETARY')
 
        self.currency_fmt_key = 'accounting'
 
        self._name_counter = itertools.count(1)
 
        self._currency_style_cache: MutableMapping[str, odf.style.Style] = {}
 
        self.document = odf.opendocument.OpenDocumentSpreadsheet()
 
        self.init_settings()
 
        self.init_styles()
 
        self.sheet = self.use_sheet("Report")
 

	
 
    ### Low-level document tree manipulation
 
    # The *intent* is that you only need to use these if you're adding new
 
    # methods to manipulate document settings or styles.
 

	
 
    def copy_element(self, elem: odf.element.Element) -> odf.element.Element:
 
        qattrs = dict(self.iter_qattributes(elem))
 
        retval = odf.element.Element(qname=elem.qname, qattributes=qattrs)
 
        try:
 
            orig_name = retval.getAttribute('name')
 
        except ValueError:
 
            orig_name = None
 
        if orig_name is not None:
 
            retval.setAttribute('name', f'{orig_name}{next(self._name_counter)}')
 
        return retval
 

	
 
    def ensure_child(self,
 
                     parent: odf.element.Element,
 
                     child_type: ElementType,
 
                     **kwargs: Any,
 
    ) -> odf.element.Element:
 
        new_child = child_type(**kwargs)
 
        found_child = self.find_child(parent, new_child)
 
        if found_child is None:
 
            parent.addElement(new_child)
 
            return parent.lastChild
 
        else:
 
            return found_child
 

	
 
    def ensure_config_map_entry(self,
 
                                root: odf.element.Element,
 
                                map_name: str,
 
                                entry_name: str,
 
    ) -> odf.element.Element:
 
        """Return a ``ConfigItemMapEntry`` under ``root``
 

	
 
        This method ensures there's a ``ConfigItemMapNamed`` named ``map_name``
 
        under ``root``, and a ``ConfigItemMapEntry`` named ``entry_name`` under
 
        that. Return the ``ConfigItemMapEntry`` element.
 
        """
 
        config_map = self.ensure_child(root, odf.config.ConfigItemMapNamed, name=map_name)
 
        return self.ensure_child(config_map, odf.config.ConfigItemMapEntry, name=entry_name)
 

	
 
    def find_child(self,
 
                   parent: odf.element.Element,
 
                   child: odf.element.Element,
 
    ) -> Optional[odf.element.Element]:
 
        attrs = {k: v for k, v in self.iter_attributes(child)}
 
        if not attrs:
 
            return None
 
        for elem in parent.childNodes:
 
            if (elem.qname == child.qname
 
                and all(elem.getAttribute(k) == v for k, v in attrs.items())):
 
                return elem
 
        return None
 

	
 
    def iter_attributes(self, elem: odf.element.Element) -> Iterator[Tuple[str, str]]:
 
        for (_, key), value in self.iter_qattributes(elem):
 
            yield key.lower().replace('-', ''), value
 

	
 
    def iter_qattributes(self, elem: odf.element.Element) -> Iterator[Tuple[Tuple[str, str], str]]:
 
        if elem.attributes:
 
            yield from elem.attributes.items()
 

	
 
    def replace_child(self,
 
                     parent: odf.element.Element,
 
                     child_type: ElementType,
 
                     **kwargs: Any,
 
    ) -> odf.element.Element:
 
        new_child = child_type(**kwargs)
 
        found_child = self.find_child(parent, new_child)
 
        parent.insertBefore(new_child, found_child)
tests/test_reports_related_postings.py
Show inline comments
...
 
@@ -67,243 +67,254 @@ def test_initialize_with_list(credit_card_cycle):
 
    related = core.RelatedPostings(credit_card_cycle[0].postings)
 
    assert len(related) == 2
 

	
 
def test_initialize_with_iterable(two_accruals_three_payments):
 
    related = core.RelatedPostings(
 
        post for txn in two_accruals_three_payments
 
        for post in txn.postings
 
        if post.account == 'Assets:Receivable:Accounts'
 
    )
 
    assert len(related) == 5
 

	
 
def test_balance_empty():
 
    balance = core.RelatedPostings().balance()
 
    assert not balance
 
    assert balance.is_zero()
 

	
 
@pytest.mark.parametrize('index,expected', enumerate([
 
    -110,
 
    0,
 
    -120,
 
    0,
 
]))
 
def test_balance_credit_card(credit_card_cycle, index, expected):
 
    related = core.RelatedPostings(
 
        txn.postings[0] for txn in credit_card_cycle[:index + 1]
 
    )
 
    assert related.balance() == {'USD': testutil.Amount(expected, 'USD')}
 

	
 
def check_iter_with_balance(entries):
 
    expect_posts = [txn.postings[0] for txn in entries]
 
    expect_balances = []
 
    balance_tally = collections.defaultdict(Decimal)
 
    for post in expect_posts:
 
        number, currency = post.units
 
        balance_tally[currency] += number
 
        expect_balances.append({code: testutil.Amount(number, code)
 
                                for code, number in balance_tally.items()})
 
    related = core.RelatedPostings(expect_posts)
 
    for (post, balance), exp_post, exp_balance in zip(
 
            related.iter_with_balance(),
 
            expect_posts,
 
            expect_balances,
 
    ):
 
        assert post is exp_post
 
        assert balance == exp_balance
 
    assert post is expect_posts[-1]
 
    assert related.balance() == expect_balances[-1]
 

	
 
def test_iter_with_balance_empty():
 
    assert not list(core.RelatedPostings().iter_with_balance())
 

	
 
def test_iter_with_balance_credit_card(credit_card_cycle):
 
    check_iter_with_balance(credit_card_cycle)
 

	
 
def test_iter_with_balance_two_acccruals(two_accruals_three_payments):
 
    check_iter_with_balance(two_accruals_three_payments)
 

	
 
def test_balance_at_cost_mixed():
 
    txn = testutil.Transaction(postings=[
 
        ('Expenses:Other', '22'),
 
        ('Expenses:Other', '30', 'EUR', ('1.1',)),
 
        ('Expenses:Other', '40', 'EUR'),
 
        ('Expenses:Other', '50', 'USD', ('1.1', 'EUR')),
 
    ])
 
    related = core.RelatedPostings(data.Posting.from_txn(txn))
 
    balance = related.balance_at_cost()
 
    amounts = set(balance.values())
 
    assert amounts == {testutil.Amount(55, 'USD'), testutil.Amount(95, 'EUR')}
 

	
 
def test_balance_at_single_currency_cost():
 
    txn = testutil.Transaction(postings=[
 
        ('Expenses:Other', '22'),
 
        ('Expenses:Other', '30', 'EUR', ('1.1',)),
 
        ('Expenses:Other', '40', 'GBP', ('1.1',)),
 
    ])
 
    related = core.RelatedPostings(data.Posting.from_txn(txn))
 
    balance = related.balance_at_cost()
 
    amounts = set(balance.values())
 
    assert amounts == {testutil.Amount(99)}
 

	
 
def test_balance_at_cost_zeroed_out():
 
    txn = testutil.Transaction(postings=[
 
        ('Income:Other', '-22'),
 
        ('Assets:Receivable:Accounts', '20', 'EUR', ('1.1',)),
 
    ])
 
    related = core.RelatedPostings(data.Posting.from_txn(txn))
 
    balance = related.balance_at_cost()
 
    assert balance.is_zero()
 

	
 
def test_balance_at_cost_singleton():
 
    txn = testutil.Transaction(postings=[
 
        ('Assets:Receivable:Accounts', '20', 'EUR', ('1.1',)),
 
    ])
 
    related = core.RelatedPostings(data.Posting.from_txn(txn))
 
    balance = related.balance_at_cost()
 
    amounts = set(balance.values())
 
    assert amounts == {testutil.Amount(22)}
 

	
 
def test_balance_at_cost_singleton_without_cost():
 
    txn = testutil.Transaction(postings=[
 
        ('Assets:Receivable:Accounts', '20'),
 
    ])
 
    related = core.RelatedPostings(data.Posting.from_txn(txn))
 
    balance = related.balance_at_cost()
 
    amounts = set(balance.values())
 
    assert amounts == {testutil.Amount(20)}
 

	
 
def test_balance_at_cost_empty():
 
    related = core.RelatedPostings()
 
    balance = related.balance_at_cost()
 
    assert balance.is_zero()
 

	
 
def test_meta_values_empty():
 
    related = core.RelatedPostings()
 
    assert related.meta_values('key') == set()
 

	
 
def test_meta_values_no_match():
 
    related = core.RelatedPostings([
 
        testutil.Posting('Income:Donations', -1, metakey='metavalue'),
 
    ])
 
    assert related.meta_values('key') == {None}
 

	
 
def test_meta_values_no_match_default_given():
 
    related = core.RelatedPostings([
 
        testutil.Posting('Income:Donations', -1, metakey='metavalue'),
 
    ])
 
    assert related.meta_values('key', '') == {''}
 

	
 
def test_meta_values_one_match():
 
    related = core.RelatedPostings([
 
        testutil.Posting('Income:Donations', -1, key='metavalue'),
 
    ])
 
    assert related.meta_values('key') == {'metavalue'}
 

	
 
def test_meta_values_some_match():
 
    related = core.RelatedPostings([
 
        testutil.Posting('Income:Donations', -1, key='1'),
 
        testutil.Posting('Income:Donations', -2, metakey='2'),
 
    ])
 
    assert related.meta_values('key') == {'1', None}
 

	
 
def test_meta_values_some_match_default_given():
 
    related = core.RelatedPostings([
 
        testutil.Posting('Income:Donations', -1, key='1'),
 
        testutil.Posting('Income:Donations', -2, metakey='2'),
 
    ])
 
    assert related.meta_values('key', '') == {'1', ''}
 

	
 
def test_meta_values_all_match():
 
    related = core.RelatedPostings([
 
        testutil.Posting('Income:Donations', -1, key='1'),
 
        testutil.Posting('Income:Donations', -2, key='2'),
 
    ])
 
    assert related.meta_values('key') == {'1', '2'}
 

	
 
def test_meta_values_all_match_one_value():
 
    related = core.RelatedPostings([
 
        testutil.Posting('Income:Donations', -1, key='1'),
 
        testutil.Posting('Income:Donations', -2, key='1'),
 
    ])
 
    assert related.meta_values('key') == {'1'}
 

	
 
def test_meta_values_all_match_default_given():
 
    related = core.RelatedPostings([
 
        testutil.Posting('Income:Donations', -1, key='1'),
 
        testutil.Posting('Income:Donations', -2, key='2'),
 
    ])
 
    assert related.meta_values('key', '') == {'1', '2'}
 

	
 
def test_meta_values_many_types():
 
    expected = {
 
        datetime.date(2020, 4, 1),
 
        Decimal(42),
 
        testutil.Amount(5),
 
        'rt:42',
 
    }
 
    related = core.RelatedPostings(
 
        testutil.Posting('Income:Donations', -index, key=value)
 
        for index, value in enumerate(expected)
 
    )
 
    assert related.meta_values('key') == expected
 

	
 
@pytest.mark.parametrize('count', range(3))
 
def test_all_meta_links_zero(count):
 
    postings = (
 
        testutil.Posting('Income:Donations', -n, testkey=str(n))
 
        for n in range(count)
 
    )
 
    related = core.RelatedPostings(
 
        post._replace(meta=data.Metadata(post.meta))
 
        for post in postings
 
    )
 
    assert related.all_meta_links('approval') == set()
 
    assert next(related.all_meta_links('approval'), None) is None
 

	
 
def test_all_meta_links_singletons():
 
    postings = (
 
        testutil.Posting('Income:Donations', -10, statement=value)
 
        for value in itertools.chain(
 
            testutil.NON_LINK_METADATA_STRINGS,
 
            testutil.LINK_METADATA_STRINGS,
 
            testutil.NON_STRING_METADATA_VALUES,
 
        ))
 
    related = core.RelatedPostings(
 
        post._replace(meta=data.Metadata(post.meta))
 
        for post in postings
 
    )
 
    assert related.all_meta_links('statement') == testutil.LINK_METADATA_STRINGS
 
    assert set(related.all_meta_links('statement')) == testutil.LINK_METADATA_STRINGS
 

	
 
def test_all_meta_links_multiples():
 
    postings = (
 
        testutil.Posting('Income:Donations', -10, approval=' '.join(value))
 
        for value in itertools.permutations(testutil.LINK_METADATA_STRINGS, 2)
 
    )
 
    related = core.RelatedPostings(
 
        post._replace(meta=data.Metadata(post.meta))
 
        for post in postings
 
    )
 
    assert related.all_meta_links('approval') == testutil.LINK_METADATA_STRINGS
 
    assert set(related.all_meta_links('approval')) == testutil.LINK_METADATA_STRINGS
 

	
 
def test_all_meta_links_preserves_order():
 
    postings = (
 
        testutil.Posting('Income:Donations', -10, approval=c)
 
        for c in '121323'
 
    )
 
    related = core.RelatedPostings(
 
        post._replace(meta=data.Metadata(post.meta))
 
        for post in postings
 
    )
 
    assert list(related.all_meta_links('approval')) == list('123')
 

	
 
def test_group_by_meta_zero():
 
    assert not list(core.RelatedPostings.group_by_meta([], 'metacurrency'))
 

	
 
def test_group_by_meta_one(credit_card_cycle):
 
    posting = next(post for post in data.Posting.from_entries(credit_card_cycle)
 
                   if post.account.is_credit_card())
 
    actual = core.RelatedPostings.group_by_meta([posting], 'metacurrency')
 
    assert set(key for key, _ in actual) == {'USD'}
 

	
 
def test_group_by_meta_many(two_accruals_three_payments):
 
    postings = [post for post in data.Posting.from_entries(two_accruals_three_payments)
 
                if post.account == 'Assets:Receivable:Accounts']
 
    actual = dict(core.RelatedPostings.group_by_meta(postings, 'metacurrency'))
 
    assert set(actual) == {'USD', 'EUR'}
 
    for key, group in actual.items():
 
        assert 2 <= len(group) <= 3
 
        assert group.balance().is_zero()
 

	
 
def test_group_by_meta_many_single_posts(two_accruals_three_payments):
 
    postings = [post for post in data.Posting.from_entries(two_accruals_three_payments)
 
                if post.account == 'Assets:Receivable:Accounts']
 
    actual = dict(core.RelatedPostings.group_by_meta(postings, 'metanumber'))
 
    assert set(actual) == {post.units.number for post in postings}
 
    assert len(actual) == len(postings)
0 comments (0 inline, 0 general)