Changeset - 6fa12789661e
[Not reviewed]
0 3 0
Brett Smith - 3 years ago 2021-03-09 15:39:12
brettcsmith@brettcsmith.org
query: Improve formatting of ODS output.

* Provide dedicated formatting for more Beancount types.
* Improve code to determine when we're looking up link metadata
and should format output as links.
3 files changed with 194 insertions and 76 deletions:
0 comments (0 inline, 0 general)
conservancy_beancount/reports/core.py
Show inline comments
...
 
@@ -347,1270 +347,1270 @@ class Balances:
 
            classification = self._get_classification(post)
 
            post_meta = post.meta.get(post_meta_key)
 
            key = BalanceKey(post.account, classification, period, fund, post_meta)
 
            self.balances[key] += post.at_cost()
 

	
 
    def _get_classification(self, post: data.Posting) -> data.Account:
 
        try:
 
            return self._get_meta_account(post.account.meta, 'classification')
 
        except (KeyError, TypeError):
 
            return post.account
 

	
 
    def _get_meta_account(self, meta: Mapping[MetaKey, MetaValue], key: MetaKey) -> data.Account:
 
        value = meta[key]
 
        if isinstance(value, str):
 
            return data.Account(value)
 
        else:
 
            raise TypeError(f"{key!r} is not a string but a {type(value).__name__}")
 

	
 
    def total(self,
 
              account: Union[None, str, Collection[str]]=None,
 
              classification: Optional[str]=None,
 
              period: int=Period.ANY,
 
              fund: int=Fund.ANY,
 
              post_meta: Optional[str]=None,
 
              *,
 
              account_exact: bool=False,
 
    ) -> Balance:
 
        """Return the balance of postings that match given criteria
 

	
 
        Given ``account`` and/or ``classification`` criteria, returns the total
 
        balance of postings *under* that account and/or classification. If you
 
        pass ``account_exact=True``, the postings must have exactly the
 
        ``account`` you specify instead.
 

	
 
        Given ``period``, ``fund``, or ``post_meta`` criteria, limits to
 
        reporting the balance of postings that match that reporting period,
 
        fund type, or metadata value, respectively.
 
        """
 
        if isinstance(account, str):
 
            account = (account,)
 
        acct_pred: Callable[[data.Account], bool]
 
        if account is None:
 
            acct_pred = lambda acct: True
 
        elif account_exact:
 
            # At this point, between this isinstance() above and the earlier
 
            # `account is None` check, we've collapsed the type of `account` to
 
            # `Collection[str]`. Unfortunately the logic is too involved for
 
            # mypy to follow, so ignore the type problem.
 
            acct_pred = lambda acct: acct in account  # type:ignore[operator]
 
        else:
 
            acct_pred = lambda acct: acct.is_under(*account) is not None  # type:ignore[misc]
 
        retval = MutableBalance()
 
        for key, balance in self.balances.items():
 
            if not acct_pred(key.account):
 
                pass
 
            elif not (classification is None
 
                      or key.classification.is_under(classification)):
 
                pass
 
            elif not period & key.period:
 
                pass
 
            elif not fund & key.fund:
 
                pass
 
            elif not (post_meta is None or post_meta == key.post_meta):
 
                pass
 
            else:
 
                retval += balance
 
        return retval
 

	
 
    def classifications(self,
 
                        account: str,
 
                        sort_period: Optional[int]=None,
 
    ) -> Sequence[data.Account]:
 
        """Return a sequence of seen account classifications
 

	
 
        Given an account name, returns a sequence of all the account
 
        classifications seen in the postings under that part of the account
 
        hierarchy. The classifications are sorted in descending order by the
 
        balance of postings under them for the ``sort_period`` time period.
 
        """
 
        if sort_period is None:
 
            if account in data.EQUITY_ACCOUNTS:
 
                sort_period = Period.PERIOD
 
            else:
 
                sort_period = Period.ANY
 
        class_bals: Mapping[data.Account, MutableBalance] \
 
            = collections.defaultdict(MutableBalance)
 
        for key, balance in self.balances.items():
 
            if not key.account.is_under(account):
 
                pass
 
            elif key.period & sort_period:
 
                class_bals[key.classification] += balance
 
            else:
 
                # Ensure the balance exists in the mapping
 
                class_bals[key.classification]
 
        norm_func = normalize_amount_func(f'{account}:RootsOK')
 
        def sortkey(acct: data.Account) -> Sortable:
 
            prefix, _, _ = acct.rpartition(':')
 
            balance = norm_func(class_bals[acct])
 
            try:
 
                max_bal = max(amount.number for amount in balance.values())
 
            except ValueError:
 
                max_bal = Decimal(0)
 
            return prefix, -max_bal
 
        return sorted(class_bals, key=sortkey)
 

	
 
    def iter_accounts(self, root: Optional[str]=None) -> Sequence[data.Account]:
 
        """Return a sequence of accounts open during the reporting period
 

	
 
        The sequence is sorted by account name.
 
        """
 
        start_date = self.period_range.start
 
        stop_date = self.period_range.stop
 
        return sorted(
 
            account
 
            for account in data.Account.iter_accounts(root)
 
            if account.meta.open_date < stop_date
 
            and (account.meta.close_date is None
 
                 or account.meta.close_date > start_date)
 
        )
 

	
 
    def meta_values(self) -> Set[str]:
 
        retval = {key.post_meta for key in self.balances}
 
        retval.discard(None)
 
        # discarding None ensures we return the desired type.
 
        return retval  # type:ignore[return-value]
 

	
 

	
 
class RelatedPostings(Sequence[data.Posting]):
 
    """Collect and query related postings
 

	
 
    This class provides common functionality for collecting related postings
 
    and running queries on them: iterating over them, tallying their balance,
 
    etc.
 

	
 
    This class doesn't know anything about how the postings are related. That's
 
    entirely up to the caller.
 

	
 
    A common pattern is to use this class with collections.defaultdict
 
    to organize postings based on some key. See the group_by_meta classmethod
 
    for an example.
 
    """
 
    __slots__ = ('_postings',)
 

	
 
    def __init__(self,
 
                 source: Iterable[data.Posting]=(),
 
                 *,
 
                 _can_own: bool=False,
 
    ) -> None:
 
        self._postings: Sequence[data.Posting]
 
        if _can_own and isinstance(source, Sequence):
 
            self._postings = source
 
        else:
 
            self._postings = list(source)
 

	
 
    @classmethod
 
    def _group_by(cls: Type[RelatedType],
 
                  postings: Iterable[data.Posting],
 
                  key: Callable[[data.Posting], T],
 
    ) -> Iterator[Tuple[T, RelatedType]]:
 
        mapping: Dict[T, List[data.Posting]] = collections.defaultdict(list)
 
        for post in postings:
 
            mapping[key(post)].append(post)
 
        for value, posts in mapping.items():
 
            yield value, cls(posts, _can_own=True)
 

	
 
    @classmethod
 
    def group_by_account(cls: Type[RelatedType],
 
                         postings: Iterable[data.Posting],
 
    ) -> Iterator[Tuple[data.Account, RelatedType]]:
 
        return cls._group_by(postings, operator.attrgetter('account'))
 

	
 
    @classmethod
 
    def group_by_meta(cls: Type[RelatedType],
 
                      postings: Iterable[data.Posting],
 
                      key: MetaKey,
 
                      default: Optional[MetaValue]=None,
 
    ) -> Iterator[Tuple[Optional[MetaValue], RelatedType]]:
 
        """Relate postings by metadata value
 

	
 
        This method takes an iterable of postings and returns a mapping.
 
        The keys of the mapping are the values of post.meta.get(key, default).
 
        The values are RelatedPostings instances that contain all the postings
 
        that had that same metadata value.
 
        """
 
        def key_func(post: data.Posting) -> Optional[MetaValue]:
 
            return post.meta.get(key, default)
 
        return cls._group_by(postings, key_func)
 

	
 
    @classmethod
 
    def group_by_first_meta_link(
 
            cls: Type[RelatedType],
 
            postings: Iterable[data.Posting],
 
            key: MetaKey,
 
    ) -> Iterator[Tuple[Optional[str], RelatedType]]:
 
        """Relate postings by the first link in metadata
 

	
 
        This method takes an iterable of postings and returns a mapping.
 
        The keys of the mapping are the values of
 
        post.meta.first_link(key, None).
 
        The values are RelatedPostings instances that contain all the postings
 
        that had that same first metadata link.
 
        """
 
        def key_func(post: data.Posting) -> Optional[MetaValue]:
 
            return post.meta.first_link(key, None)
 
        return cls._group_by(postings, key_func)
 

	
 
    def __repr__(self) -> str:
 
        return f'<{type(self).__name__} {self._postings!r}>'
 

	
 
    @overload
 
    def __getitem__(self: RelatedType, index: int) -> data.Posting: ...
 

	
 
    @overload
 
    def __getitem__(self: RelatedType, s: slice) -> RelatedType: ...
 

	
 
    def __getitem__(self: RelatedType,
 
                    index: Union[int, slice],
 
    ) -> Union[data.Posting, RelatedType]:
 
        if isinstance(index, slice):
 
            return type(self)(self._postings[index], _can_own=True)
 
        else:
 
            return self._postings[index]
 

	
 
    def __len__(self) -> int:
 
        return len(self._postings)
 

	
 
    def all_meta_links(self, key: MetaKey) -> Iterator[str]:
 
        return filters.iter_unique(
 
            link for post in self for link in post.meta.report_links(key)
 
        )
 

	
 
    @overload
 
    def first_meta_links(self, key: MetaKey, default: str='') -> Iterator[str]: ...
 

	
 
    @overload
 
    def first_meta_links(self, key: MetaKey, default: None) -> Iterator[Optional[str]]: ...
 

	
 
    def first_meta_links(self,
 
                         key: MetaKey,
 
                         default: Optional[str]='',
 
    ) -> Iterator[Optional[str]]:
 
        retval = filters.iter_unique(
 
            post.meta.first_link(key, default) for post in self
 
        )
 
        if default == '':
 
            retval = (s for s in retval if s)
 
        return retval
 

	
 
    def iter_with_balance(self) -> Iterator[Tuple[data.Posting, Balance]]:
 
        balance = MutableBalance()
 
        for post in self:
 
            balance += post.units
 
            yield post, balance
 

	
 
    def balance(self) -> Balance:
 
        return Balance(post.units for post in self)
 

	
 
    def balance_at_cost(self) -> Balance:
 
        return Balance(post.at_cost() for post in self)
 

	
 
    def balance_at_cost_by_date(self, date: datetime.date) -> Balance:
 
        for index, post in enumerate(self):
 
            if post.meta.date >= date:
 
                break
 
        else:
 
            index += 1
 
        return Balance(post.at_cost() for post in self._postings[:index])
 

	
 
    def meta_values(self,
 
                    key: MetaKey,
 
                    default: Optional[MetaValue]=None,
 
    ) -> Set[Optional[MetaValue]]:
 
        return {post.meta.get(key, default) for post in self}
 

	
 

	
 
class PeriodPostings(RelatedPostings):
 
    """Postings filtered and balanced over a date range
 

	
 
    Create a subclass with ``PeriodPostings.with_start_date(date)``.
 
    Note that there is no explicit stop date. The expectation is that the
 
    caller has already filtered out posts past the stop date from the input.
 

	
 
    Instances of that subclass will have three Balance attributes:
 

	
 
    * ``start_bal`` is the balance at cost of postings to your start date
 
    * ``period_bal`` is the balance at cost of postings from your start date
 
    * ``stop_bal`` is the balance at cost of all postings
 

	
 
    Use this subclass when your report includes a lot of balances over time to
 
    help you get the math right.
 
    """
 
    __slots__ = (
 
        'begin_bal',
 
        'end_bal',
 
        'period_bal',
 
        'start_bal',
 
        'stop_bal',
 
    )
 
    START_DATE = datetime.date(datetime.MINYEAR, 1, 1)
 

	
 
    def __init__(self,
 
                 source: Iterable[data.Posting]=(),
 
                 *,
 
                 _can_own: bool=False,
 
    ) -> None:
 
        start_posts: List[data.Posting] = []
 
        period_posts: List[data.Posting] = []
 
        for post in source:
 
            if post.meta.date < self.START_DATE:
 
                start_posts.append(post)
 
            else:
 
                period_posts.append(post)
 
        super().__init__(period_posts, _can_own=True)
 
        self.start_bal = RelatedPostings(start_posts, _can_own=True).balance_at_cost()
 
        self.period_bal = self.balance_at_cost()
 
        self.stop_bal = self.start_bal + self.period_bal
 
        # Convenience aliases
 
        self.begin_bal = self.start_bal
 
        self.end_bal = self.stop_bal
 

	
 
    @classmethod
 
    def with_start_date(cls: Type[RelatedType], start_date: datetime.date) -> Type[RelatedType]:
 
        name = f'BalancePostings{start_date.strftime("%Y%m%d")}'
 
        return type(name, (cls,), {'START_DATE': start_date})
 

	
 

	
 
class BaseSpreadsheet(Generic[RT, ST], metaclass=abc.ABCMeta):
 
    """Abstract base class to help write spreadsheets
 

	
 
    This class provides the very core logic to write an arbitrary set of data
 
    rows to arbitrary output. It calls hooks when it starts writing the
 
    spreadsheet, starts a new "section" of rows, ends a section, and ends the
 
    spreadsheet.
 

	
 
    RT is the type of the input data rows. ST is the type of the section
 
    identifier that you create from each row. If you don't want to use the
 
    section logic at all, set ST to None and define section_key to return None.
 
    """
 

	
 
    @abc.abstractmethod
 
    def section_key(self, row: RT) -> ST:
 
        """Return the section a row belongs to
 

	
 
        Given a data row, this method should return some identifier for the
 
        "section" the row belongs to. The write method uses this to
 
        determine when to call start_section and end_section.
 

	
 
        If your spreadsheet doesn't need sections, define this to return None.
 
        """
 
        ...
 

	
 
    @abc.abstractmethod
 
    def write_row(self, row: RT) -> None:
 
        """Write a data row to the output spreadsheet
 

	
 
        This method is called once for each data row in the input.
 
        """
 
        ...
 

	
 
    # The next four methods are all called by the write method when the name
 
    # says. You may override them to output headers or sums, record
 
    # state, etc. The default implementations are all noops.
 

	
 
    def start_spreadsheet(self) -> None:
 
        pass
 

	
 
    def start_section(self, key: ST) -> None:
 
        pass
 

	
 
    def end_section(self, key: ST) -> None:
 
        pass
 

	
 
    def end_spreadsheet(self) -> None:
 
        pass
 

	
 
    def write(self, rows: Iterable[RT]) -> None:
 
        prev_section: Optional[ST] = None
 
        self.start_spreadsheet()
 
        for row in rows:
 
            section = self.section_key(row)
 
            if section != prev_section:
 
                if prev_section is not None:
 
                    self.end_section(prev_section)
 
                self.start_section(section)
 
                prev_section = section
 
            self.write_row(row)
 
        try:
 
            should_end = section is not None
 
        except NameError:
 
            should_end = False
 
        if should_end:
 
            self.end_section(section)
 
        self.end_spreadsheet()
 

	
 

	
 
class Border(enum.IntFlag):
 
    TOP = 1
 
    RIGHT = 2
 
    BOTTOM = 4
 
    LEFT = 8
 
    # in CSS order, clockwise from top
 

	
 

	
 
class BaseODS(BaseSpreadsheet[RT, ST], metaclass=abc.ABCMeta):
 
    """Abstract base class to help write OpenDocument spreadsheets
 

	
 
    This class provides the very core logic to write an arbitrary set of data
 
    rows to an OpenDocument spreadsheet. It provides helper methods for
 
    building sheets, rows, and cells.
 

	
 
    See also the BaseSpreadsheet base class for additional documentation about
 
    methods you must and can define, the definition of RT and ST, etc.
 
    """
 
    # Defined in the XSL spec, "Definitions of Units of Measure"
 
    MEASUREMENT_UNITS = frozenset([
 
        'cm',
 
        'em',
 
        'in',
 
        'mm',
 
        'pc',
 
        'pt',
 
        'px',
 
    ])
 
    MEASUREMENT_RE = re.compile(
 
        r'([-+]?(?:\d+\.?|\.\d+|\d+\.\d+))({})'.format('|'.join(MEASUREMENT_UNITS)),
 
        re.ASCII,
 
    )
 

	
 
    def __init__(self, rt_wrapper: Optional[rtutil.RT]=None) -> None:
 
        self.rt_wrapper = rt_wrapper
 
        self.locale = babel.core.Locale.default('LC_MONETARY')
 
        self.currency_fmt_key = 'accounting'
 
        self._name_counter = itertools.count(1)
 
        self._style_cache: MutableMapping[str, odf.style.Style] = {}
 
        self.document = odf.opendocument.OpenDocumentSpreadsheet()
 
        self.init_settings()
 
        self.init_styles()
 
        self.set_properties()
 
        self.sheet = self.use_sheet("Report")
 

	
 
    ### Low-level document tree manipulation
 
    # The *intent* is that you only need to use these if you're adding new
 
    # methods to manipulate document settings or styles.
 

	
 
    def copy_element(self, elem: odf.element.Element) -> odf.element.Element:
 
        retval = odf.element.Element(
 
            qname=elem.qname,
 
            qattributes=copy.copy(elem.attributes),
 
        )
 
        try:
 
            orig_name = retval.getAttribute('name')
 
        except ValueError:
 
            orig_name = None
 
        if orig_name is not None:
 
            retval.setAttribute('name', f'{orig_name}{next(self._name_counter)}')
 
        for child in elem.childNodes:
 
            # Order is important: need to check the deepest subclasses first.
 
            if isinstance(child, odf.element.CDATASection):
 
                retval.addCDATA(child.data)
 
            elif isinstance(child, odf.element.Text):
 
                retval.addText(child.data)
 
            else:
 
                retval.addElement(self.copy_element(child))
 
        return retval
 

	
 
    def ensure_child(self,
 
                     parent: odf.element.Element,
 
                     child_type: ElementType,
 
                     **kwargs: Any,
 
    ) -> odf.element.Element:
 
        new_child = child_type(**kwargs)
 
        found_child = self.find_child(parent, new_child)
 
        if found_child is None:
 
            parent.addElement(new_child)
 
            return parent.lastChild
 
        else:
 
            return found_child
 

	
 
    def ensure_config_map_entry(self,
 
                                root: odf.element.Element,
 
                                map_name: str,
 
                                entry_name: str,
 
    ) -> odf.element.Element:
 
        """Return a ``ConfigItemMapEntry`` under ``root``
 

	
 
        This method ensures there's a ``ConfigItemMapNamed`` named ``map_name``
 
        under ``root``, and a ``ConfigItemMapEntry`` named ``entry_name`` under
 
        that. Return the ``ConfigItemMapEntry`` element.
 
        """
 
        config_map = self.ensure_child(root, odf.config.ConfigItemMapNamed, name=map_name)
 
        return self.ensure_child(config_map, odf.config.ConfigItemMapEntry, name=entry_name)
 

	
 
    def find_child(self,
 
                   parent: odf.element.Element,
 
                   child: odf.element.Element,
 
    ) -> Optional[odf.element.Element]:
 
        attrs = {k: v for k, v in self.iter_attributes(child)}
 
        if not attrs:
 
            return None
 
        for elem in parent.childNodes:
 
            if (elem.qname == child.qname
 
                and all(elem.getAttribute(k) == v for k, v in attrs.items())):
 
                return elem
 
        return None
 

	
 
    def iter_attributes(self, elem: odf.element.Element) -> Iterator[Tuple[str, str]]:
 
        for (_, key), value in self.iter_qattributes(elem):
 
            yield key.lower().replace('-', ''), value
 

	
 
    def iter_qattributes(self, elem: odf.element.Element) -> Iterator[Tuple[Tuple[str, str], str]]:
 
        if elem.attributes:
 
            yield from elem.attributes.items()
 

	
 
    def replace_child(self,
 
                     parent: odf.element.Element,
 
                     child_type: ElementType,
 
                     **kwargs: Any,
 
    ) -> odf.element.Element:
 
        new_child = child_type(**kwargs)
 
        found_child = self.find_child(parent, new_child)
 
        parent.insertBefore(new_child, found_child)
 
        if found_child is not None:
 
            parent.removeChild(found_child)
 
        return new_child
 

	
 
    def set_config(self,
 
                   root: odf.element.Element,
 
                   name: str,
 
                   value: Union[bool, int, str],
 
                   config_type: Optional[str]=None,
 
    ) -> None:
 
        """Ensure ``root`` has a ``ConfigItem`` with the given name, type, and value"""
 
        value_s = str(value)
 
        if isinstance(value, bool):
 
            value_s = str(value).lower()
 
            default_type = 'boolean'
 
        elif isinstance(value, str):
 
            default_type = 'string'
 
        if config_type is None:
 
            try:
 
                config_type = default_type
 
            except NameError:
 
                raise ValueError(
 
                    f"need config_type for {type(value).__name__} value",
 
                ) from None
 
        item = self.replace_child(
 
            root, odf.config.ConfigItem, name=name, type=config_type,
 
        )
 
        item.addText(value_s)
 

	
 
    ### Styles
 

	
 
    def bgcolor_style(self, color: str) -> odf.style.Style:
 
        key =f'BGColor{color.lstrip("#")}'
 
        try:
 
            retval = self._style_cache[key]
 
        except KeyError:
 
            props = odf.style.TableCellProperties(backgroundcolor=color)
 
            retval = odf.style.Style(name=key, family='table-cell')
 
            retval.addElement(props)
 
            self.document.styles.addElement(retval)
 
            self._style_cache[key] = retval
 
        return retval
 

	
 
    def border_style(self,
 
                     edges: int,
 
                     width: str='1px',
 
                     style: str='solid',
 
                     color: str='#000000',
 
    ) -> odf.style.Style:
 
        flags = [edge for edge in Border if edges & edge]
 
        if not flags:
 
            raise ValueError(f"no valid edges in {edges!r}")
 
        border_attr = f'{width} {style} {color}'
 
        key = f'{",".join(f.name for f in flags)} {border_attr}'
 
        try:
 
            retval = self._style_cache[key]
 
        except KeyError:
 
            props = odf.style.TableCellProperties()
 
            for flag in flags:
 
                props.setAttribute(f'border{flag.name.lower()}', border_attr)
 
            retval = odf.style.Style(
 
                name=f'Border{next(self._name_counter)}',
 
                family='table-cell',
 
            )
 
            retval.addElement(props)
 
            self.document.styles.addElement(retval)
 
            self._style_cache[key] = retval
 
        return retval
 

	
 
    def column_style(self, width: Union[float, str], **attrs: Any) -> odf.style.Style:
 
        if not isinstance(width, str) or (width and not width[-1].isalpha()):
 
            width = f'{width}in'
 
        match = self.MEASUREMENT_RE.fullmatch(width)
 
        if match is None:
 
            raise ValueError(f"invalid width {width!r}")
 
        width_float = float(match.group(1))
 
        if width_float <= 0:
 
            # Per the OpenDocument spec, column-width is a positiveLength.
 
            raise ValueError(f"width {width!r} must be positive")
 
        width = '{:.3g}{}'.format(width_float, match.group(2))
 
        retval = self.ensure_child(
 
            self.document.automaticstyles,
 
            odf.style.Style,
 
            name=f'col_{width.replace(".", "_")}'
 
        )
 
        retval.setAttribute('family', 'table-column')
 
        if retval.firstChild is None:
 
            retval.addElement(odf.style.TableColumnProperties(
 
                columnwidth=width, **attrs
 
            ))
 
        return retval
 

	
 
    def _build_currency_style(
 
            self,
 
            root: odf.element.Element,
 
            locale: babel.core.Locale,
 
            code: str,
 
            amount: DecimalCompat=0,
 
            properties: Optional[odf.style.TextProperties]=None,
 
            *,
 
            fmt_key: Optional[str]=None,
 
            volatile: bool=False,
 
    ) -> odf.element.Element:
 
        if fmt_key is None:
 
            fmt_key = self.currency_fmt_key
 
        pattern = locale.currency_formats[fmt_key]
 
        fmt = get_commodity_format(locale, code, amount, fmt_key)
 
        style = self.replace_child(
 
            root,
 
            odf.number.CurrencyStyle,
 
            name=f'{code}{next(self._name_counter)}',
 
        )
 
        style.setAttribute('volatile', 'true' if volatile else 'false')
 
        if properties is not None:
 
            style.addElement(properties)
 
        for part in re.split(r"(¤+|[#0,.]+|'[^']+')", fmt):
 
            if not part:
 
                pass
 
            elif not part.strip('#0,.'):
 
                style.addElement(odf.number.Number(
 
                    decimalplaces=str(pattern.frac_prec[0]),
 
                    grouping='true' if pattern.grouping[0] else 'false',
 
                    minintegerdigits=str(pattern.int_prec[0]),
 
                ))
 
            elif part == '¤':
 
                style.addElement(odf.number.CurrencySymbol(
 
                    country=locale.territory,
 
                    language=locale.language,
 
                    text=babel.numbers.get_currency_symbol(code, locale),
 
                ))
 
            elif part == '¤¤':
 
                style.addElement(odf.number.Text(text=code))
 
            else:
 
                style.addElement(odf.number.Text(text=part.strip("'")))
 
        return style
 

	
 
    def currency_style(
 
            self,
 
            code: str,
 
            locale: Optional[babel.core.Locale]=None,
 
            negative_properties: Optional[odf.style.TextProperties]=None,
 
            positive_properties: Optional[odf.style.TextProperties]=None,
 
            root: odf.element.Element=None,
 
    ) -> odf.style.Style:
 
        """Create and return a spreadsheet style to format currency data
 

	
 
        Given a currency code and a locale, this method will create all the
 
        styles necessary to format the currency according to the locale's
 
        rules, including rendering of decimal points and negative values.
 

	
 
        You may optionally pass in TextProperties to use for negative and
 
        positive amounts, respectively. If you don't, negative values will
 
        automatically be rendered in red (text color #f00).
 

	
 
        Results are cached. If you repeatedly call this method with the same
 
        arguments, you'll keep getting the same style returned, which will
 
        only be added to the document once.
 
        """
 
        if locale is None:
 
            locale = self.locale
 
        if negative_properties is None:
 
            negative_properties = odf.style.TextProperties(color='#ff0000')
 
        if root is None:
 
            root = self.document.styles
 
        cache_parts = [str(id(root)), code, str(locale)]
 
        for key, value in self.iter_attributes(negative_properties):
 
            cache_parts.append(f'{key}={value}')
 
        if positive_properties is not None:
 
            cache_parts.append('')
 
            for key, value in self.iter_attributes(positive_properties):
 
                cache_parts.append(f'{key}={value}')
 
        cache_key = '\0'.join(cache_parts)
 
        try:
 
            style = self._style_cache[cache_key]
 
        except KeyError:
 
            pos_style = self._build_currency_style(
 
                root, locale, code, 0, positive_properties, volatile=True,
 
            )
 
            curr_style = self._build_currency_style(
 
                root, locale, code, -1, negative_properties,
 
            )
 
            curr_style.addElement(odf.style.Map(
 
                condition='value()>=0', applystylename=pos_style,
 
            ))
 
            style = self.ensure_child(
 
                self.document.styles,
 
                odf.style.Style,
 
                name=f'{curr_style.getAttribute("name")}Cell',
 
                family='table-cell',
 
                datastylename=curr_style,
 
            )
 
            self._style_cache[cache_key] = style
 
        return style
 

	
 
    def _merge_style_iter_names(
 
            self,
 
            styles: Sequence[Union[str, odf.style.Style, None]],
 
    ) -> Iterator[str]:
 
        for source in styles:
 
            if source is None:
 
                continue
 
            elif not isinstance(source, str):
 
                source = source.getAttribute('name')
 
            if source.startswith('Merge_'):
 
                orig_names = iter(source.split('_'))
 
                next(orig_names)
 
                yield from orig_names
 
            else:
 
                yield source
 

	
 
    def _merge_styles(self,
 
                      new_style: odf.style.Style,
 
                      sources: Iterable[odf.style.Style],
 
    ) -> None:
 
        for elem in sources:
 
            for key, new_value in self.iter_attributes(elem):
 
                old_value = new_style.getAttribute(key)
 
                if (key == 'name'
 
                    or key == 'displayname'
 
                    or old_value == new_value):
 
                    pass
 
                elif old_value is None:
 
                    new_style.setAttribute(key, new_value)
 
                else:
 
                    raise ValueError(f"cannot merge styles with conflicting {key}")
 
            for child in elem.childNodes:
 
                new_style.addElement(self.copy_element(child))
 

	
 
    def merge_styles(self,
 
                     *styles: Union[str, odf.style.Style, None],
 
    ) -> Optional[odf.style.Style]:
 
        """Create a new style from multiple existing styles
 

	
 
        Given any number of existing styles, create a new style that combines
 
        all of those styles' attributes and properties, add it to the document
 
        styles, and return it.
 

	
 
        Styles can be specified by name, or by passing in their Style element.
 
        For convenience, you can also pass in None as an argument; None will
 
        simply be skipped.
 

	
 
        Results are cached. If you repeatedly call this method with the same
 
        arguments, you'll keep getting the same style returned, which will
 
        only be added to the document once.
 

	
 
        If you pass in zero real style arguments, returns None.
 
        If you pass in one style argument, returns that style unchanged.
 
        If you pass in a style that doesn't already exist in the document,
 
        or if you pass in styles that can't be merged (because they have
 
        conflicting attributes), raises ValueError.
 
        """
 
        name_map: Dict[str, odf.style.Style] = {}
 
        for name in self._merge_style_iter_names(styles):
 
            source = odf.style.Style(name=name)
 
            found = self.find_child(self.document.styles, source)
 
            if found is None:
 
                raise ValueError(f"no style named {name!r}")
 
            name_map[name] = found
 
        if not name_map:
 
            retval = None
 
        elif len(name_map) == 1:
 
            _, retval = name_map.popitem()
 
        else:
 
            new_name = f'Merge_{"_".join(sorted(name_map))}'
 
            retval = self.ensure_child(
 
                self.document.styles, odf.style.Style, name=new_name,
 
            )
 
            if retval.firstChild is None:
 
                self._merge_styles(retval, name_map.values())
 
        return retval
 

	
 
    ### Sheets
 

	
 
    def lock_first_column(self, sheet: Optional[odf.table.Table]=None) -> None:
 
        """Lock the first column of cells under the given sheet
 

	
 
        This method sets all the appropriate settings to "lock" the first column
 
        of cells in a sheet, so it stays in view even as the viewer scrolls
 
        across the sheet. If a sheet is not given, works on ``self.sheet``.
 
        """
 
        if sheet is None:
 
            sheet = self.sheet
 
        config_map = self.ensure_config_map_entry(
 
            self.view, 'Tables', sheet.getAttribute('name'),
 
        )
 
        self.set_config(config_map, 'PositionRight', 1, 'int')
 
        self.set_config(config_map, 'HorizontalSplitMode', 2, 'short')
 
        self.set_config(config_map, 'HorizontalSplitPosition', 1, 'short')
 

	
 
    def lock_first_row(self, sheet: Optional[odf.table.Table]=None) -> None:
 
        """Lock the first row of cells under the given sheet
 

	
 
        This method sets all the appropriate settings to "lock" the first row
 
        of cells in a sheet, so it stays in view even as the viewer scrolls
 
        through rows. If a sheet is not given, works on ``self.sheet``.
 
        """
 
        if sheet is None:
 
            sheet = self.sheet
 
        config_map = self.ensure_config_map_entry(
 
            self.view, 'Tables', sheet.getAttribute('name'),
 
        )
 
        self.set_config(config_map, 'PositionBottom', 1, 'int')
 
        self.set_config(config_map, 'VerticalSplitMode', 2, 'short')
 
        self.set_config(config_map, 'VerticalSplitPosition', 1, 'short')
 

	
 
    def set_open_sheet(self, sheet: Union[str, odf.table.Table, None]=None) -> None:
 
        """Set which sheet is open in the document
 

	
 
        When the user first opens the spreadsheet, their view will be on this
 
        sheet. You can provide a sheet name string or sheet object. With no
 
        argument, defaults to ``self.sheet``.
 
        """
 
        if sheet is None:
 
            sheet = self.sheet
 
        if not isinstance(sheet, str):
 
            sheet = sheet.getAttribute('name')
 
            if not isinstance(sheet, str):
 
                raise ValueError("sheet argument has no name for setting")
 
        self.set_config(self.view, 'ActiveTable', sheet, 'string')
 

	
 
    def use_sheet(self, name: str) -> odf.table.Table:
 
        """Switch the active sheet ``self.sheet`` to the one with the given name
 

	
 
        If there is no sheet with the given name, create it and append it to
 
        the spreadsheet first.
 

	
 
        If the current active sheet is empty when this method is called, it
 
        will be removed from the spreadsheet.
 
        """
 
        try:
 
            empty_sheet = not self.sheet.hasChildNodes()
 
        except AttributeError:
 
            empty_sheet = False
 
        if empty_sheet:
 
            self.document.spreadsheet.removeChild(self.sheet)
 
        self.sheet = self.ensure_child(
 
            self.document.spreadsheet, odf.table.Table, name=name,
 
        )
 
        return self.sheet
 

	
 
    ### Initialization hooks
 

	
 
    def init_settings(self) -> None:
 
        """Hook called to initialize settings
 

	
 
        This method is called by __init__ to populate
 
        ``self.document.settings``. This implementation creates the barest
 
        skeleton structure necessary to support other methods, in particular
 
        ``lock_first_row``.
 
        """
 
        view_settings = self.ensure_child(
 
            self.document.settings, odf.config.ConfigItemSet, name='ooo:view-settings',
 
        )
 
        views = self.ensure_child(
 
            view_settings, odf.config.ConfigItemMapIndexed, name='Views',
 
        )
 
        self.view = self.ensure_child(views, odf.config.ConfigItemMapEntry)
 
        self.set_config(self.view, 'ViewId', 'view1')
 

	
 
    def init_styles(self) -> None:
 
        """Hook called to initialize settings
 

	
 
        This method is called by __init__ to populate
 
        ``self.document.styles``. This implementation creates basic building
 
        block cell styles often used in financial reports.
 
        """
 
        styles = self.document.styles
 
        self.style_bold = self.ensure_child(
 
            styles, odf.style.Style, name='Bold', family='table-cell',
 
        )
 
        self.ensure_child(
 
            self.style_bold, odf.style.TextProperties, fontweight='bold',
 
        )
 

	
 
        date_style = self.replace_child(styles, odf.number.DateStyle, name='ISODate')
 
        date_style.addElement(odf.number.Year(style='long'))
 
        date_style.addElement(odf.number.Text(text='-'))
 
        date_style.addElement(odf.number.Month(style='long'))
 
        date_style.addElement(odf.number.Text(text='-'))
 
        date_style.addElement(odf.number.Day(style='long'))
 
        self.style_date = self.ensure_child(
 
            styles,
 
            odf.style.Style,
 
            name=f'{date_style.getAttribute("name")}Cell',
 
            family='table-cell',
 
            datastylename=date_style,
 
        )
 

	
 
        self.style_starttext: odf.style.Style
 
        self.style_centertext: odf.style.Style
 
        self.style_endtext: odf.style.Style
 
        for textalign in ['start', 'center', 'end']:
 
            aligned_style = self.replace_child(
 
                styles, odf.style.Style, name=f'{textalign.title()}Text',
 
            )
 
            aligned_style.setAttribute('family', 'table-cell')
 
            aligned_style.addElement(odf.style.ParagraphProperties(textalign=textalign))
 
            setattr(self, f'style_{textalign}text', aligned_style)
 

	
 
        self.style_total = self.border_style(Border.TOP, '1pt')
 
        self.style_endtotal = self.border_style(Border.TOP | Border.BOTTOM, '1pt')
 
        self.style_bottomline = self.merge_styles(
 
            self.style_total,
 
            self.border_style(Border.BOTTOM, '2pt', 'double'),
 
        )
 

	
 
    ### Properties
 

	
 
    def set_common_properties(self,
 
                              repo: Optional[git.Repo]=None,
 
                              command: Optional[Sequence[str]]=sys.argv,
 
    ) -> None:
 
        if repo is None:
 
            git_shahex = '<none>'
 
            git_dirty = True
 
        else:
 
            git_shahex = repo.head.commit.hexsha
 
            git_dirty = repo.is_dirty()
 
        self.set_custom_property('GitSHA', git_shahex)
 
        self.set_custom_property('GitDirty', git_dirty, 'boolean')
 
        if command is not None:
 
            command_s = ' '.join(shlex.quote(s) for s in command)
 
            self.set_custom_property('ReportCommand', command_s)
 

	
 
    def set_custom_property(self,
 
                            name: str,
 
                            value: Any,
 
                            valuetype: Optional[str]=None,
 
    ) -> odf.meta.UserDefined:
 
        if valuetype is None:
 
            if isinstance(value, bool):
 
                valuetype = 'boolean'
 
            elif isinstance(value, (datetime.date, datetime.datetime)):
 
                valuetype = 'date'
 
            elif isinstance(value, (int, float, Decimal)):
 
                valuetype = 'float'
 
        if not isinstance(value, str):
 
            if valuetype == 'boolean':
 
                value = 'true' if value else 'false'
 
            elif valuetype == 'date':
 
                value = value.isoformat()
 
            else:
 
                value = str(value)
 
        retval = self.ensure_child(self.document.meta, odf.meta.UserDefined, name=name)
 
        if valuetype is None:
 
            try:
 
                retval.removeAttribute('valuetype')
 
            except KeyError:
 
                pass
 
        else:
 
            retval.setAttribute('valuetype', valuetype)
 
        retval.childNodes.clear()
 
        retval.addText(value)
 
        return retval
 

	
 
    def set_properties(self, *,
 
                       created: Optional[datetime.datetime]=None,
 
                       generator: str='conservancy_beancount',
 
    ) -> None:
 
        if created is None:
 
            created = datetime.datetime.now()
 
        created_elem = self.ensure_child(self.document.meta, odf.meta.CreationDate)
 
        created_elem.childNodes.clear()
 
        created_elem.addText(created.isoformat())
 
        generator_elem = self.ensure_child(self.document.meta, odf.meta.Generator)
 
        generator_elem.childNodes.clear()
 
        generator_elem.addText(f'{generator}/{cliutil.VERSION} {TOOLSVERSION}')
 

	
 
    ### Rows and cells
 

	
 
    def add_row(self, *cells: odf.table.TableCell, **attrs: Any) -> odf.table.TableRow:
 
        row = odf.table.TableRow(**attrs)
 
        for cell in cells:
 
            row.addElement(cell)
 
        self.sheet.addElement(row)
 
        return row
 

	
 
    def row_count(self, sheet: Optional[odf.table.Table]=None) -> int:
 
        if sheet is None:
 
            sheet = self.sheet
 
        TableRow = odf.table.TableRow
 
        return sum(1 for cell in sheet.childNodes if cell.isInstanceOf(TableRow))
 

	
 
    def balance_cell(self, balance: Balance, **attrs: Any) -> odf.table.TableCell:
 
        balance = balance.clean_copy() or balance
 
        balance_currency_count = len(balance)
 
        if balance_currency_count == 0:
 
            return self.float_cell(0, **attrs)
 
        elif balance_currency_count == 1:
 
            amount = next(iter(balance.values()))
 
            attrs['stylename'] = self.merge_styles(
 
                attrs.get('stylename'), self.currency_style(amount.currency),
 
            )
 
            return self.currency_cell(amount, **attrs)
 
        else:
 
            lines = [babel.numbers.format_currency(number, currency, get_commodity_format(
 
                self.locale, currency, None, self.currency_fmt_key,
 
            )) for number, currency in balance.values()]
 
            attrs['stylename'] = self.merge_styles(
 
                attrs.get('stylename'), self.style_endtext,
 
            )
 
            return self.multiline_cell(lines, **attrs)
 

	
 
    def currency_cell(self, amount: data.Amount, **attrs: Any) -> odf.table.TableCell:
 
    def currency_cell(self, amount: bc_amount._Amount, **attrs: Any) -> odf.table.TableCell:
 
        if 'stylename' not in attrs:
 
            attrs['stylename'] = self.currency_style(amount.currency)
 
        number, currency = amount
 
        cell = odf.table.TableCell(valuetype='currency', value=number, **attrs)
 
        cell.addElement(odf.text.P(text=babel.numbers.format_currency(
 
            number, currency, locale=self.locale, format_type=self.currency_fmt_key,
 
        )))
 
        return cell
 

	
 
    def date_cell(self, date: datetime.date, **attrs: Any) -> odf.table.TableCell:
 
        attrs.setdefault('stylename', self.style_date)
 
        cell = odf.table.TableCell(valuetype='date', datevalue=date, **attrs)
 
        cell.addElement(odf.text.P(text=date.isoformat()))
 
        return cell
 

	
 
    def float_cell(self, value: Union[int, float, Decimal], **attrs: Any) -> odf.table.TableCell:
 
        cell = odf.table.TableCell(valuetype='float', value=value, **attrs)
 
        cell.addElement(odf.text.P(text=str(value)))
 
        return cell
 

	
 
    def _meta_link_pairs(self, links: Iterable[Optional[str]]) -> Iterator[Tuple[str, str]]:
 
        for href in links:
 
            if href is None:
 
                continue
 
            elif self.rt_wrapper is not None:
 
                rt_ids = self.rt_wrapper.parse(href)
 
                rt_href = rt_ids and self.rt_wrapper.url(*rt_ids)
 
            else:
 
                rt_ids = None
 
                rt_href = None
 
            if rt_ids is None or rt_href is None:
 
                # '..' pops the ODS filename off the link path. In other words,
 
                # make the link relative to the directory the ODS is in.
 
                href_path = Path('..', href)
 
                href = urlparse.quote(str(href_path))
 
                text = href_path.name
 
            else:
 
                rt_path = urlparse.urlparse(rt_href).path
 
                if rt_path.endswith('/Ticket/Display.html'):
 
                    text = rtutil.RT.unparse(*rt_ids)
 
                else:
 
                    text = urlparse.unquote(Path(rt_path).name)
 
                href = rt_href
 
            yield (href, text)
 

	
 
    def meta_links_cell(self, links: Iterable[Optional[str]], **attrs: Any) -> odf.table.TableCell:
 
        return self.multilink_cell(self._meta_link_pairs(links), **attrs)
 

	
 
    def multiline_cell(self, lines: Iterable[Any], **attrs: Any) -> odf.table.TableCell:
 
        cell = odf.table.TableCell(valuetype='string', **attrs)
 
        for line in lines:
 
            cell.addElement(odf.text.P(text=str(line)))
 
        return cell
 

	
 
    def multilink_cell(self, links: Iterable[LinkType], **attrs: Any) -> odf.table.TableCell:
 
        cell = odf.table.TableCell(valuetype='string', **attrs)
 
        for link in links:
 
            if isinstance(link, tuple):
 
                href, text = link
 
            else:
 
                href = link
 
                text = None
 
            cell.addElement(odf.text.P())
 
            cell.lastChild.addElement(odf.text.A(
 
                type='simple', href=href, text=text or href,
 
            ))
 
        return cell
 

	
 
    def string_cell(self, text: str, **attrs: Any) -> odf.table.TableCell:
 
        cell = odf.table.TableCell(valuetype='string', **attrs)
 
        cell.addElement(odf.text.P(text=text))
 
        return cell
 

	
 
    def write_row(self, row: RT) -> None:
 
        """Write a single row of input data to the spreadsheet
 

	
 
        This default implementation adds a single row to the spreadsheet,
 
        with one cell per element of the row. The type of each element
 
        determines what kind of cell is created.
 

	
 
        This implementation will help get you started, but you'll probably
 
        want to override it to specify styles.
 
        """
 
        out_row = odf.table.TableRow()
 
        for cell_source in row:
 
            if isinstance(cell_source, (int, float, Decimal)):
 
                cell = self.float_cell(cell_source)
 
            else:
 
                cell = self.string_cell(cell_source)
 
            out_row.addElement(cell)
 
        self.sheet.addElement(out_row)
 

	
 
    def save_file(self, out_file: BinaryIO) -> None:
 
        self.document.write(out_file)
 

	
 
    def save_path(self, path: Path, mode: str='w') -> None:
 
        with path.open(f'{mode}b') as out_file:
 
            out_file = cast(BinaryIO, out_file)
 
            self.save_file(out_file)
 

	
 

	
 
def account_balances(
 
        groups: Mapping[data.Account, PeriodPostings],
 
        order: Optional[Sequence[str]]=None,
 
) -> Iterator[Tuple[str, Balance]]:
 
    """Iterate account balances over a date range
 

	
 
    1. ``subclass = PeriodPostings.with_start_date(start_date)``
 
    2. ``groups = dict(subclass.group_by_account(postings))``
 
    3. ``for acct, bal in account_balances(groups, [optional ordering]): ...``
 

	
 
    This function returns an iterator of 2-tuples ``(account, balance)``
 
    that you can use to generate a report in the style of ``ledger balance``.
 
    The accounts are accounts in ``groups`` that appeared under one of the
 
    account name strings in ``order``. ``balance`` is the corresponding
 
    balance over the time period (``groups[key].period_bal``). Accounts are
 
    iterated in the order provided by ``sort_and_filter_accounts()``.
 

	
 
    The first 2-tuple is ``(OPENING_BALANCE_NAME, balance)`` with the balance of
 
    all these accounts as of ``start_date``.
 
    The final 2-tuple is ``(ENDING_BALANCE_NAME, balance)`` with the final
 
    balance of all these accounts as of ``start_date``.
 
    The iterator will always yield these special 2-tuples, even when there are
 
    no accounts in the input or to report.
 
    """
 
    if order is None:
 
        order = ['Equity', 'Income', 'Expenses']
 
    acct_seq = [account for _, account in sort_and_filter_accounts(groups, order)]
 
    yield (OPENING_BALANCE_NAME, sum(
 
        (groups[key].start_bal for key in acct_seq),
 
        MutableBalance(),
 
    ))
 
    for key in acct_seq:
 
        postings = groups[key]
 
        try:
 
            in_date_range = postings[-1].meta.date >= postings.START_DATE
 
        except IndexError:
 
            in_date_range = False
 
        if in_date_range:
 
            yield (key, groups[key].period_bal)
 
    yield (ENDING_BALANCE_NAME, sum(
 
        (groups[key].stop_bal for key in acct_seq),
 
        MutableBalance(),
 
    ))
 

	
 
def get_commodity_format(locale: babel.core.Locale,
 
                         code: str,
 
                         amount: Optional[DecimalCompat]=None,
 
                         format_type: str='accounting',
 
) -> str:
 
    """Return a format string for a commodity
 

	
 
    Typical use looks like::
 

	
 
      number, code = post.units
 
      fmt = get_commodity_format(locale, code)
 
      units_s = babel.numbers.format_currency(number, code, fmt)
 

	
 
    When the commodity code refers to a real currency, you get the same format
 
    string provided by Babel.
 

	
 
    For other commodities like stock, you get a format code built from the
 
    locale's currency unit pattern.
 

	
 
    If ``amount`` is defined, the format string will be specifically for that
 
    number, whether positive or negative. Otherwise, the format string may
 
    define both positive and negative formats.
 
    """
 
    fmt: str = locale.currency_formats[format_type].pattern
 
    if amount is not None:
 
        fmt, _, neg_fmt = fmt.partition(';')
 
        if amount < 0 and neg_fmt:
 
            fmt = neg_fmt
 
    symbol = babel.numbers.get_currency_symbol(code, locale)
 
    if symbol != code:
 
        return fmt
 
    else:
 
        long_fmt: str = babel.numbers.get_currency_unit_pattern(code, locale=locale)
 
        return re.sub(
 
            r'[#0,.\s¤]+',
 
            lambda match: long_fmt.format(
 
                match.group(0).replace('¤', '').strip(), '¤¤',
 
            ),
 
            fmt,
 
        )
 

	
 
def normalize_amount_func(account_name: str) -> Callable[[T], T]:
 
    """Get a function to normalize amounts for reporting
 

	
 
    Given an account name, return a function that can be used on "amounts"
 
    under that account (including numbers, Amount objects, and Balance objects)
 
    to normalize them for reporting. Right now that means make flipping the
 
    sign for accounts where "normal" postings are negative.
 
    """
 
    if account_name.startswith(('Assets:', 'Expenses:')):
 
        # We can't just return operator.pos because Beancount's Amount class
 
        # doesn't implement __pos__.
 
        return lambda amt: amt
 
    elif account_name.startswith(('Equity:', 'Income:', 'Liabilities:')):
 
        return operator.neg
 
    else:
 
        raise ValueError(f"unrecognized account name {account_name!r}")
 

	
 
def sort_and_filter_accounts(
 
        accounts: Iterable[data.Account],
 
        order: Sequence[str],
 
) -> Iterator[Tuple[int, data.Account]]:
 
    """Reorganize accounts based on an ordered set of names
 

	
 
    This function takes a iterable of Account objects, and a sequence of
 
    account names. Usually the account names are higher parts of the account
 
    hierarchy like Income, Equity, or Assets:Receivable.
 

	
 
    It returns an iterator of 2-tuples, ``(index, account)`` where ``index`` is
 
    an index into the ordering sequence, and ``account`` is one of the input
 
    Account objects that's under the account name ``order[index]``. Tuples are
 
    sorted, so ``index`` increases monotonically, and Account objects using the
 
    same index are yielded sorted by name.
 

	
 
    For example, if your order is
 
    ``['Liabilities:Payable', 'Assets:Receivable']``, the return value will
 
    first yield zero or more results with index 0 and an account under
 
    Liabilities:Payable, then zero or more results with index 1 and an account
 
    under Accounts:Receivable.
 

	
 
    Input Accounts that are not under any of the account names in ``order`` do
 
    not appear in the output iterator. That's the filtering part.
 

	
 
    Note that if none of the input Accounts are under one of the ordering
 
    sequence accounts, its index will never appear in the results. This is why
 
    the 2-tuples include an index rather than the original account name string,
 
    to make it easier for callers to know when this happens and do something
 
    with unused ordering accounts.
 
    """
 
    index_map = {s: ii for ii, s in enumerate(order)}
 
    retval: Mapping[int, List[data.Account]] = collections.defaultdict(list)
 
    for account in accounts:
 
        acct_key = account.is_under(*order)
 
        if acct_key is not None:
 
            retval[index_map[acct_key]].append(account)
 
    return (
 
        (key, account)
 
        for key in sorted(retval)
 
        for account in sorted(retval[key])
 
    )
conservancy_beancount/reports/query.py
Show inline comments
 
"""query.py - Report arbitrary queries with advanced loading and formatting"""
 
# Copyright © 2021  Brett Smith
 
# License: AGPLv3-or-later WITH Beancount-Plugin-Additional-Permission-1.0
 
#
 
# Full copyright and licensing details can be found at toplevel file
 
# LICENSE.txt in the repository.
 

	
 
import argparse
 
import contextlib
 
import datetime
 
import enum
 
import itertools
 
import logging
 
import re
 
import sys
 

	
 
from typing import (
 
    cast,
 
    AbstractSet,
 
    Any,
 
    Callable,
 
    Dict,
 
    Iterable,
 
    Iterator,
 
    List,
 
    Mapping,
 
    NamedTuple,
 
    Optional,
 
    Sequence,
 
    TextIO,
 
    Tuple,
 
    Type,
 
    Union,
 
)
 
from ..beancount_types import (
 
    MetaKey,
 
    MetaValue,
 
    Posting,
 
    Transaction,
 
)
 

	
 
from decimal import Decimal
 
from pathlib import Path
 
from beancount.core.amount import _Amount as BeancountAmount
 
from beancount.core.inventory import Inventory
 
from beancount.core.position import _Position as Position
 

	
 
import beancount.query.numberify as bc_query_numberify
 
import beancount.query.query_compile as bc_query_compile
 
import beancount.query.query_env as bc_query_env
 
import beancount.query.query_execute as bc_query_execute
 
import beancount.query.query_parser as bc_query_parser
 
import beancount.query.query_render as bc_query_render
 
import beancount.query.shell as bc_query_shell
 
import odf.table  # type:ignore[import]
 

	
 
from . import core
 
from . import rewrite
 
from .. import books
 
from .. import cliutil
 
from .. import config as configmod
 
from .. import data
 
from .. import rtutil
 

	
 
BUILTIN_FIELDS: AbstractSet[str] = frozenset(itertools.chain(
 
    bc_query_env.TargetsEnvironment.columns,  # type:ignore[has-type]
 
    bc_query_env.TargetsEnvironment.functions,  # type:ignore[has-type]
 
))
 
PROGNAME = 'query-report'
 
logger = logging.getLogger('conservancy_beancount.reports.query')
 

	
 
CellFunc = Callable[[Any], odf.table.TableCell]
 
EnvironmentFunctions = Dict[
 
    # The real key type is something like:
 
    #   Union[str, Tuple[str, Type, ...]]
 
    # but two issues with that. One, you can't use Ellipses in a Tuple like
 
    # that, so there's no short way to declare this. Second, Beancount doesn't
 
    # declare it anyway, and mypy infers it as Sequence[object]. So just use
 
    # that.
 
    Sequence[object],
 
    Type[bc_query_compile.EvalFunction],
 
]
 
RowTypes = Sequence[Tuple[str, Type]]
 
Rows = Sequence[NamedTuple]
 
Store = List[Any]
 
QueryExpression = Union[
 
    bc_query_parser.Column,
 
    bc_query_parser.Constant,
 
    bc_query_parser.Function,
 
    bc_query_parser.UnaryOp,
 
]
 
QueryStatement = Union[
 
    bc_query_parser.Balances,
 
    bc_query_parser.Journal,
 
    bc_query_parser.Select,
 
]
 

	
 
class BooksLoader:
 
    """Closure to load books with a zero-argument callable
 

	
 
    This matches the load interface that BQLShell expects.
 
    """
 
    def __init__(
 
            self,
 
            books_loader: Optional[books.Loader],
 
            start_date: Optional[datetime.date]=None,
 
            stop_date: Optional[datetime.date]=None,
 
            rewrite_rules: Sequence[rewrite.RewriteRuleset]=(),
 
    ) -> None:
 
        self.books_loader = books_loader
 
        self.start_date = start_date
 
        self.stop_date = stop_date
 
        self.rewrite_rules = rewrite_rules
 

	
 
    def __call__(self) -> books.LoadResult:
 
        logger.debug("BooksLoader called")
 
        result = books.Loader.dispatch(self.books_loader, self.start_date, self.stop_date)
 
        logger.debug("books loaded from Beancount")
 
        if self.rewrite_rules:
 
            for index, entry in enumerate(result.entries):
 
                # entry might not be a Transaction; we catch that later.
 
                # The type ignores are because the underlying Beancount type isn't
 
                # type-checkable.
 
                postings = data.Posting.from_txn(entry)  # type:ignore[arg-type]
 
                for ruleset in self.rewrite_rules:
 
                    postings = ruleset.rewrite(postings)
 
                try:
 
                    result.entries[index] = entry._replace(postings=list(postings))  # type:ignore[call-arg]
 
                except AttributeError:
 
                    pass
 
            logger.debug("rewrite rules applied")
 
        return result
 

	
 

	
 
class QueryODS(core.BaseODS[NamedTuple, None]):
 
    META_FNAMES = frozenset([
 
        'any_meta',
 
        'entry_meta',
 
        'meta',
 
        'meta_docs',
 
        'str_meta',
 
    ])
 

	
 
    def is_empty(self) -> bool:
 
        return not self.sheet.childNodes
 

	
 
    def section_key(self, row: NamedTuple) -> None:
 
        return None
 

	
 
    def _generic_cell(self, value: Any) -> odf.table.TableCell:
 
        if isinstance(value, Iterable) and not isinstance(value, (str, tuple)):
 
            return self.multiline_cell(value)
 
        else:
 
            return self.string_cell('' if value is None else str(value))
 

	
 
    def _inventory_cell(self, value: Inventory) -> odf.table.TableCell:
 
        return self.balance_cell(core.Balance(pos.units for pos in value))
 

	
 
    def _link_string_cell(self, value: str) -> odf.table.TableCell:
 
        return self.meta_links_cell(value.split())
 

	
 
    def _metadata_cell(self, value: MetaValue) -> odf.table.TableCell:
 
        return self._cell_type(type(value))(value)
 

	
 
    def _position_cell(self, value: Position) -> odf.table.TableCell:
 
        return self.currency_cell(value.units)
 

	
 
    def _cell_type(self, row_type: Type) -> CellFunc:
 
        if issubclass(row_type, BeancountAmount):
 
        """Return a function to create a cell, for non-metadata row types."""
 
        if issubclass(row_type, Inventory):
 
            return self._inventory_cell
 
        elif issubclass(row_type, Position):
 
            return self._position_cell
 
        elif issubclass(row_type, BeancountAmount):
 
            return self.currency_cell
 
        elif issubclass(row_type, (int, float, Decimal)):
 
            return self.float_cell
 
        elif issubclass(row_type, datetime.date):
 
            return self.date_cell
 
        elif issubclass(row_type, str):
 
            return self.string_cell
 
        else:
 
            return self._generic_cell
 

	
 
    def _generic_cell(self, value: Any) -> odf.table.TableCell:
 
        return self.string_cell('' if value is None else str(value))
 

	
 
    def _link_cell(self, value: MetaValue) -> odf.table.TableCell:
 
        if isinstance(value, str):
 
            return self.meta_links_cell(value.split())
 
    def _link_cell_type(self, row_type: Type) -> CellFunc:
 
        """Return a function to create a cell from metadata with documentation links."""
 
        if issubclass(row_type, str):
 
            return self._link_string_cell
 
        elif issubclass(row_type, tuple):
 
            return self._generic_cell
 
        elif issubclass(row_type, Iterable):
 
            return self.meta_links_cell
 
        else:
 
            return self._generic_cell(value)
 

	
 
    def _metadata_cell(self, value: MetaValue) -> odf.table.TableCell:
 
        return self._cell_type(type(value))(value)
 
            return self._generic_cell
 

	
 
    def _cell_types(self, row_types: RowTypes) -> Iterator[CellFunc]:
 
        for name, row_type in row_types:
 
            if row_type is object:
 
                if name.replace('_', '-') in data.LINK_METADATA:
 
                    yield self._link_cell
 
                else:
 
                    yield self._metadata_cell
 
            else:
 
    def _meta_target(self, target: QueryExpression) -> Optional[MetaKey]:
 
        """Return the metadata key looked up by this target, if any
 

	
 
        This function takes a parsed target (i.e., what we're SELECTing) and
 
        recurses it to see whether it's looking up any metadata. If so, it
 
        returns the key of that metadata. Otherwise it returns None.
 
        """
 
        if isinstance(target, bc_query_parser.UnaryOp):
 
            return self._meta_target(target.operand)
 
        elif not isinstance(target, bc_query_parser.Function):
 
            return None
 
        try:
 
            operand = target.operands[0]
 
        except IndexError:
 
            return None
 
        if (target.fname in self.META_FNAMES
 
            and isinstance(operand, bc_query_parser.Constant)):
 
            return operand.value  # type:ignore[no-any-return]
 
        else:
 
            for operand in target.operands:
 
                retval = self._meta_target(operand)
 
                if retval is not None:
 
                    break
 
            return retval
 

	
 
    def _cell_types(self, statement: QueryStatement, row_types: RowTypes) -> Iterator[CellFunc]:
 
        """Return functions to create table cells from result rows
 

	
 
        Given a parsed query and the types of return rows, yields a function
 
        to create a cell for each column in the row, in order. The returned
 
        functions vary in order to provide the best available formatting for
 
        different data types.
 
        """
 
        if (isinstance(statement, bc_query_parser.Select)
 
            and isinstance(statement.targets, Sequence)):
 
            targets = [t.expression for t in statement.targets]
 
        else:
 
            # Synthesize something that makes clear we're not loading metadata.
 
            targets = [bc_query_parser.Column(name) for name, _ in row_types]
 
        for target, (_, row_type) in zip(targets, row_types):
 
            meta_key = self._meta_target(target)
 
            if meta_key is None:
 
                yield self._cell_type(row_type)
 
            elif meta_key in data.LINK_METADATA:
 
                yield self._link_cell_type(row_type)
 
            else:
 
                yield self._metadata_cell
 

	
 
    def write_query(self, row_types: RowTypes, rows: Rows) -> None:
 
    def write_query(self, statement: QueryStatement, row_types: RowTypes, rows: Rows) -> None:
 
        if self.is_empty():
 
            self.sheet.setAttribute('name', "Query 1")
 
        else:
 
            self.use_sheet(f"Query {len(self.document.spreadsheet.childNodes) + 1}")
 
        for name, row_type in row_types:
 
            if row_type is object or issubclass(row_type, str):
 
                col_width = 2.0
 
            elif issubclass(row_type, BeancountAmount):
 
            if issubclass(row_type, datetime.date):
 
                col_width = 1.0
 
            elif issubclass(row_type, (BeancountAmount, Inventory, Position)):
 
                col_width = 1.5
 
            else:
 
                col_width = 1.0
 
                col_width = 2.0
 
            col_style = self.column_style(col_width)
 
            self.sheet.addElement(odf.table.TableColumn(stylename=col_style))
 
        self.add_row(*(
 
            self.string_cell(data.Metadata.human_name(name.replace('_', '-')),
 
                             stylename=self.style_bold)
 
            self.string_cell(data.Metadata.human_name(name), stylename=self.style_bold)
 
            for name, _ in row_types
 
        ))
 
        self.lock_first_row()
 
        cell_funcs = list(self._cell_types(row_types))
 
        cell_funcs = list(self._cell_types(statement, row_types))
 
        for row in rows:
 
            self.add_row(*(
 
                cell_func(value)
 
                for cell_func, value in zip(cell_funcs, row)
 
            ))
 

	
 

	
 
# This class mostly supports type checking. Beancount code dynamically sets the
 
# ``store`` attribute, in bc_query_execute.execute_query().
 
class Context(bc_query_execute.RowContext):
 
    store: Store
 

	
 

	
 
class MetaDocs(bc_query_env.AnyMeta):
 
    """Return a list of document links from metadata."""
 
    def __init__(self, operands: List[str]) -> None:
 
        super(bc_query_env.AnyMeta, self).__init__(operands, list)
 

	
 
    def __call__(self, context: Context) -> List[str]:
 
        raw_value = super().__call__(context)
 
        if isinstance(raw_value, str):
 
            return raw_value.split()
 
        else:
 
            return []
 

	
 

	
 
class StrMeta(bc_query_env.AnyMeta):
 
    """Looks up metadata like AnyMeta, then always returns a string."""
 
    def __init__(self, operands: List[str]) -> None:
 
        super(bc_query_env.AnyMeta, self).__init__(operands, str)
 

	
 
    def __call__(self, context: Context) -> str:
 
        raw_value = super().__call__(context)
 
        if raw_value is None:
 
            return ''
 
        else:
 
            return str(raw_value)
 

	
 

	
 
class AggregateSet(bc_query_compile.EvalAggregator):
 
    __intypes__ = [object]
 

	
 
    def __init__(self, operands: List[str]) -> None:
 
       super().__init__(operands, set)
 

	
 
    def allocate(self, allocator: bc_query_execute.Allocator) -> None:
 
        self.handle = allocator.allocate()
 

	
 
    def initialize(self, store: Store) -> None:
 
        store[self.handle] = self.dtype()
 

	
 
    def update(self, store: Store, context: Context) -> None:
 
        value, = self.eval_args(context)
 
        if isinstance(value, Sequence) and not isinstance(value, str):
 
        if isinstance(value, Sequence) and not isinstance(value, (str, tuple)):
 
            store[self.handle].update(value)
 
        else:
 
            store[self.handle].add(value)
 

	
 
    def __call__(self, context: Context) -> set:
 
        return context.store[self.handle]  # type:ignore[no-any-return]
 

	
 

	
 
class FilterPostingsEnvironment(bc_query_env.FilterPostingsEnvironment):
 
    functions: EnvironmentFunctions = bc_query_env.FilterPostingsEnvironment.functions.copy()  # type:ignore[assignment]
 
    functions['meta_docs'] = MetaDocs
 
    functions['str_meta'] = StrMeta
 

	
 

	
 
class TargetsEnvironment(bc_query_env.TargetsEnvironment):
 
    functions = FilterPostingsEnvironment.functions.copy()
 
    functions.update(bc_query_env.AGGREGATOR_FUNCTIONS)
 
    functions['set'] = AggregateSet
 

	
 

	
 
class BQLShell(bc_query_shell.BQLShell):
 
    def __init__(
 
            self,
 
            is_interactive: bool,
 
            loadfun: Callable[[], books.LoadResult],
 
            outfile: TextIO,
 
            default_format: str='text',
 
            do_numberify: bool=False,
 
            rt_wrapper: Optional[rtutil.RT]=None,
 
    ) -> None:
 
        super().__init__(is_interactive, loadfun, outfile, default_format, do_numberify)
 
        self.env_postings = FilterPostingsEnvironment()
 
        self.env_targets = TargetsEnvironment()
 
        self.ods = QueryODS(rt_wrapper)
 

	
 
    def on_Select(self, statement: QueryStatement) -> None:
 
        output_format: str = self.vars['format']
 
        try:
 
            render_func = getattr(self, f'_render_{output_format}')
 
        except AttributeError:
 
            logger.error("unknown output format %r", output_format)
 
            return
 

	
 
        try:
 
            logger.debug("compiling query")
 
            compiled_query = bc_query_compile.compile(
 
                statement, self.env_targets, self.env_postings, self.env_entries,
 
            )
 
            logger.debug("executing query")
 
            row_types, rows = bc_query_execute.execute_query(
 
                compiled_query, self.entries, self.options_map,
 
            )
 
            if self.vars['numberify']:
 
                logger.debug("numberifying query")
 
                row_types, rows = bc_query_numberify.numberify_results(
 
                    row_types, rows, self.options_map['dcontext'].build(),
 
                )
 
        except Exception as error:
 
            logger.error(str(error), exc_info=logger.isEnabledFor(logging.DEBUG))
 
            return
 

	
 
        if not rows and output_format != 'ods':
 
            print("(empty)", file=self.outfile)
 
        else:
 
            logger.debug("rendering query as %s", output_format)
 
            render_func(row_types, rows)
 
            render_func(statement, row_types, rows)
 

	
 
    def _render_csv(self, row_types: RowTypes, rows: Rows) -> None:
 
    def _render_csv(self, statement: QueryStatement, row_types: RowTypes, rows: Rows) -> None:
 
        bc_query_render.render_csv(
 
            row_types,
 
            rows,
 
            self.options_map['dcontext'],
 
            self.outfile,
 
            self.vars['expand'],
 
        )
 

	
 
    def _render_ods(self, row_types: RowTypes, rows: Rows) -> None:
 
        self.ods.write_query(row_types, rows)
 
        logger.info("results saved in sheet %s", self.ods.sheet.getAttribute('name'))
 
    def _render_ods(self, statement: QueryStatement, row_types: RowTypes, rows: Rows) -> None:
 
        self.ods.write_query(statement, row_types, rows)
 
        logger.info(
 
            "%s rows of results saved in sheet %s",
 
            len(rows),
 
            self.ods.sheet.getAttribute('name'),
 
        )
 

	
 
    def _render_text(self, row_types: RowTypes, rows: Rows) -> None:
 
    def _render_text(self, statement: QueryStatement, row_types: RowTypes, rows: Rows) -> None:
 
        with contextlib.ExitStack() as stack:
 
            if self.is_interactive:
 
                output = stack.enter_context(self.get_pager())
 
            else:
 
                output = self.outfile
 
            bc_query_render.render_text(
 
                row_types,
 
                rows,
 
                self.options_map['dcontext'],
 
                output,
 
                self.vars['expand'],
 
                self.vars['boxed'],
 
                self.vars['spaced'],
 
            )
 

	
 

	
 
class ReportFormat(enum.Enum):
 
    TEXT = 'text'
 
    TXT = TEXT
 
    CSV = 'csv'
 
    ODS = 'ods'
 

	
 

	
 
def parse_arguments(arglist: Optional[Sequence[str]]=None) -> argparse.Namespace:
 
    parser = argparse.ArgumentParser(prog=PROGNAME)
 
    cliutil.add_version_argument(parser)
 
    cliutil.add_loglevel_argument(parser)
 
    parser.add_argument(
 
        '--begin', '--start', '-b',
 
        dest='start_date',
 
        metavar='YEAR',
 
        type=cliutil.year_or_date_arg,
 
        help="""Begin loading entries from this fiscal year. You can specify a
 
full date, and %(prog)s will use the fiscal year for that date.
 
""")
 
    parser.add_argument(
 
        '--end', '--stop', '-e',
 
        dest='stop_date',
 
        metavar='YEAR',
 
        type=cliutil.year_or_date_arg,
 
        help="""End loading entries at this fiscal year. You can specify a
 
full date, and %(prog)s will use the fiscal year for that date.
 
""")
 
    cliutil.add_rewrite_rules_argument(parser)
 
    format_arg = cliutil.EnumArgument(ReportFormat)
 
    parser.add_argument(
 
        '--report-type', '--format', '-t', '-f',
 
        metavar='TYPE',
 
        type=format_arg.enum_type,
 
        help="""Format of report to generate. Choices are
 
{format_arg.choices_str()}. Default is guessed from your output filename
 
extension. If that fails, default is 'text' for interactive output, and 'ods'
 
otherwise.
 
""")
 
    parser.add_argument(
 
        '--numberify', '-m',
 
        action='store_true',
 
        help="""Separate currency amounts into numeric columns by currency
 
""")
 
    parser.add_argument(
 
        '--output-file', '-O', '-o',
 
        metavar='PATH',
 
        type=Path,
 
        help="""Write the report to this file, or stdout when PATH is `-`.
 
The default is stdout for text and CSV reports, and a generated filename for
 
ODS reports.
 
""")
 
    parser.add_argument(
 
        'query',
 
        nargs=argparse.ZERO_OR_MORE,
 
        default=[],
 
        help="""Query to run non-interactively. If none is provided, and
 
standard input is not a terminal, reads the query from stdin instead.
 
""")
 

	
 
    args = parser.parse_args(arglist)
 
    return args
 
    return parser.parse_args(arglist)
 

	
 
def main(arglist: Optional[Sequence[str]]=None,
 
         stdout: TextIO=sys.stdout,
 
         stderr: TextIO=sys.stderr,
 
         config: Optional[configmod.Config]=None,
 
) -> int:
 
    args = parse_arguments(arglist)
 
    cliutil.set_loglevel(logger, args.loglevel)
 
    if config is None:
 
        config = configmod.Config()
 
        config.load_file()
 

	
 
    query = ' '.join(args.query).strip()
 
    if not query and not sys.stdin.isatty():
 
        query = sys.stdin.read().strip()
 
    if args.report_type is None:
 
        try:
 
            args.report_type = ReportFormat[args.output_file.suffix[1:].upper()]
 
        except (AttributeError, KeyError):
 
            args.report_type = ReportFormat.ODS if query else ReportFormat.TEXT
 

	
 
    load_func = BooksLoader(
 
        config.books_loader(),
 
        args.start_date,
 
        args.stop_date,
 
        [rewrite.RewriteRuleset.from_yaml(path) for path in args.rewrite_rules],
 
    )
 
    shell = BQLShell(
 
        not query,
 
        load_func,
 
        stdout,
 
        args.report_type.value,
 
        args.numberify,
 
        config.rt_wrapper(),
 
    )
 
    shell.on_Reload()
 
    if query:
 
        shell.onecmd(query)
 
    else:
 
        shell.cmdloop()
 

	
 
    if not shell.ods.is_empty():
 
        shell.ods.set_common_properties(config.books_repo())
 
        shell.ods.set_custom_property('BeanQuery', query or '<interactive>')
 
        if args.output_file is None:
 
            out_dir_path = config.repository_path() or Path()
 
            args.output_file = out_dir_path / 'QueryResults_{}.ods'.format(
 
                datetime.datetime.now().isoformat(timespec='seconds'),
 
            )
 
            logger.info("Writing spreadsheet to %s", args.output_file)
 
        ods_file = cliutil.bytes_output(args.output_file, stdout)
 
        shell.ods.save_file(ods_file)
 

	
 
    return cliutil.ExitCode.OK
 

	
 
entry_point = cliutil.make_entry_point(__name__, PROGNAME)
 

	
 
if __name__ == '__main__':
 
    exit(entry_point())
tests/test_reports_query.py
Show inline comments
 
"""test_reports_query.py - Unit tests for query report"""
 
# Copyright © 2021  Brett Smith
 
# License: AGPLv3-or-later WITH Beancount-Plugin-Additional-Permission-1.0
 
#
 
# Full copyright and licensing details can be found at toplevel file
 
# LICENSE.txt in the repository.
 

	
 
import argparse
 
import collections
 
import copy
 
import csv
 
import datetime
 
import io
 
import itertools
 
import re
 

	
 
import odf.table
 
import odf.text
 
import pytest
 

	
 
from . import testutil
 

	
 
from beancount.core import data as bc_data
 
from beancount.query import query_parser as bc_query_parser
 
from conservancy_beancount.books import FiscalYear
 
from conservancy_beancount.reports import query as qmod
 
from conservancy_beancount import rtutil
 

	
 
from decimal import Decimal
 

	
 
class MockRewriteRuleset:
 
    def __init__(self, multiplier=2):
 
        self.multiplier = multiplier
 

	
 
    def rewrite(self, posts):
 
        for post in posts:
 
            number, currency = post.units
 
            number *= self.multiplier
 
            yield post._replace(units=testutil.Amount(number, currency))
 

	
 

	
 
@pytest.fixture(scope='module')
 
def qparser():
 
    return bc_query_parser.Parser()
 

	
 
@pytest.fixture(scope='module')
 
def rt():
 
    return rtutil.RT(testutil.RTClient())
 

	
 
def pipe_main(arglist, config, stdout_type=io.StringIO):
 
    stdout = stdout_type()
 
    stderr = io.StringIO()
 
    returncode = qmod.main(arglist, stdout, stderr, config)
 
    return returncode, stdout, stderr
 

	
 
def test_books_loader_empty():
 
    result = qmod.BooksLoader(None)()
 
    assert not result.entries
 
    assert len(result.errors) == 1
 

	
 
def test_books_loader_plain():
 
    books_path = testutil.test_path(f'books/books/2018.beancount')
 
    loader = testutil.TestBooksLoader(books_path)
 
    result = qmod.BooksLoader(loader)()
 
    assert not result.errors
 
    assert result.entries
 
    min_date = datetime.date(2018, 3, 1)
 
    assert all(ent.date >= min_date for ent in result.entries)
 

	
 
def test_books_loader_rewrites():
 
    rewrites = [MockRewriteRuleset()]
 
    books_path = testutil.test_path(f'books/books/2018.beancount')
 
    loader = testutil.TestBooksLoader(books_path)
 
    result = qmod.BooksLoader(loader, None, None, rewrites)()
 
    assert not result.errors
 
    assert result.entries
 
    numbers = frozenset(
 
        abs(post.units.number)
 
        for entry in result.entries
 
        for post in getattr(entry, 'postings', ())
 
    )
 
    assert numbers
 
    assert all(abs(number) >= 40 for number in numbers)
 

	
 
@pytest.mark.parametrize('arglist,fy', testutil.combine_values(
 
    [['--report-type', 'text'], ['--format=text'], ['-f', 'txt']],
 
    range(2018, 2021),
 
))
 
def test_text_query(arglist, fy):
 
    books_path = testutil.test_path(f'books/books/{fy}.beancount')
 
    config = testutil.TestConfig(books_path=books_path)
 
    arglist += ['select', 'date,', 'narration,', 'account,', 'position']
 
    returncode, stdout, stderr = pipe_main(arglist, config)
 
    assert returncode == 0
 
    stdout.seek(0)
 
    lines = iter(stdout)
 
    next(lines); next(lines)  # Skip header
 
    for count, line in enumerate(lines, 1):
 
        assert re.match(rf'^{fy}-\d\d-\d\d\s+{fy} ', line)
 
    assert count >= 2
 

	
 
@pytest.mark.parametrize('arglist,fy', testutil.combine_values(
 
    [['--format=csv'], ['-f', 'csv'], ['-t', 'csv']],
 
    range(2018, 2021),
 
))
 
def test_csv_query(arglist, fy):
 
    books_path = testutil.test_path(f'books/books/{fy}.beancount')
 
    config = testutil.TestConfig(books_path=books_path)
 
    arglist += ['select', 'date,', 'narration,', 'account,', 'position']
 
    returncode, stdout, stderr = pipe_main(arglist, config)
 
    assert returncode == 0
 
    stdout.seek(0)
 
    for count, row in enumerate(csv.DictReader(stdout), 1):
 
        assert re.fullmatch(rf'{fy}-\d\d-\d\d', row['date'])
 
        assert row['narration'].startswith(f'{fy} ')
 
    assert count >= 2
 

	
 
@pytest.mark.parametrize('end_index', range(3))
 
def test_rewrite_query(end_index):
 
    books_path = testutil.test_path(f'books/books/2018.beancount')
 
    config = testutil.TestConfig(books_path=books_path)
 
    accounts = ['Assets', 'Income']
 
    expected = frozenset(accounts[:end_index])
 
    rewrite_paths = [
 
        testutil.test_path(f'userconfig/Rewrite{s}.yml')
 
        for s in expected
 
    ]
 
    arglist = [f'--rewrite-rules={path}' for path in rewrite_paths]
 
    arglist.append('--format=txt')
 
    arglist.append('select any_meta("root") as root')
 
    returncode, stdout, stderr = pipe_main(arglist, config)
 
    assert returncode == 0
 
    stdout.seek(0)
 
    actual = frozenset(line.rstrip('\n') for line in stdout)
 
    assert expected.issubset(actual)
 
    assert frozenset(accounts).difference(expected).isdisjoint(actual)
 

	
 
def test_ods_amount_formatting():
 
def test_ods_amount_formatting(qparser):
 
    statement = qparser.parse('SELECT UNITS(position)')
 
    row_types = [('amount', bc_data.Amount)]
 
    row_source = [(testutil.Amount(12),), (testutil.Amount(1480, 'JPY'),)]
 
    ods = qmod.QueryODS()
 
    ods.write_query(row_types, row_source)
 
    ods.write_query(statement, row_types, row_source)
 
    actual = testutil.ODSCell.from_sheet(ods.document.spreadsheet.firstChild)
 
    assert next(actual)[0].text == 'Amount'
 
    assert next(actual)[0].text == '$12.00'
 
    assert next(actual)[0].text == 'Â¥1,480'
 
    assert next(actual, None) is None
 

	
 
def test_ods_datetime_formatting():
 
def test_ods_datetime_formatting(qparser):
 
    statement = qparser.parse('SELECT date')
 
    row_types = [('date', datetime.date)]
 
    row_source = [(testutil.PAST_DATE,), (testutil.FUTURE_DATE,)]
 
    ods = qmod.QueryODS()
 
    ods.write_query(row_types, row_source)
 
    ods.write_query(statement, row_types, row_source)
 
    actual = testutil.ODSCell.from_sheet(ods.document.spreadsheet.firstChild)
 
    assert next(actual)[0].text == 'Date'
 
    assert next(actual)[0].text == testutil.PAST_DATE.isoformat()
 
    assert next(actual)[0].text == testutil.FUTURE_DATE.isoformat()
 
    assert next(actual, None) is None
 

	
 
@pytest.mark.parametrize('meta_key,header_text', [
 
    ('check', 'Check'),
 
    ('purchase-order', 'Purchase Order'),
 
    ('rt-id', 'Ticket'),
 
@pytest.mark.parametrize('meta_key,meta_func', [
 
    ('check', 'ANY_META'),
 
    ('purchase-order', 'META'),
 
    ('rt-id', 'META_DOCS'),
 
])
 
def test_ods_link_formatting(rt, meta_key, header_text):
 
    row_types = [(meta_key.replace('-', '_'), object)]
 
    row_source = [('rt:1/5',), ('rt:3 Checks/9.pdf',)]
 
def test_ods_link_formatting(qparser, rt, meta_key, meta_func):
 
    meta_func_returns_list = meta_func == 'META_DOCS'
 
    statement = qparser.parse(f'SELECT {meta_func}({meta_key!r}) AS docs')
 
    row_types = [('docs', list if meta_func_returns_list else str)]
 
    row_source = [
 
        (s.split() if meta_func_returns_list else s,)
 
        for s in ['rt:1/5', 'rt:3 Checks/9.pdf']
 
    ]
 
    ods = qmod.QueryODS(rt)
 
    ods.write_query(row_types, row_source)
 
    ods.write_query(statement, row_types, row_source)
 
    rows = iter(ods.document.spreadsheet.firstChild.getElementsByType(odf.table.TableRow))
 
    assert next(rows).text == header_text
 
    assert next(rows).text == 'Docs'
 
    actual = iter(
 
        [link.text for link in row.getElementsByType(odf.text.A)]
 
        for row in rows
 
    )
 
    assert next(actual) == ['photo.jpg']
 
    assert next(actual) == ['rt:3', '9.pdf']
 
    assert next(actual, None) is None
 

	
 
def test_ods_meta_formatting():
 
    row_types = [('metadata', object)]
 
def test_ods_meta_formatting(qparser):
 
    statement = qparser.parse('SELECT ANY_META("entity") AS entity')
 
    row_types = [('entity', object)]
 
    row_source = [(testutil.Amount(14),), (None,), ('foo bar',)]
 
    ods = qmod.QueryODS()
 
    ods.write_query(row_types, row_source)
 
    ods.write_query(statement, row_types, row_source)
 
    actual = testutil.ODSCell.from_sheet(ods.document.spreadsheet.firstChild)
 
    assert next(actual)[0].text == 'Metadata'
 
    assert next(actual)[0].text == 'Entity'
 
    assert next(actual)[0].text == '$14.00'
 
    assert next(actual)[0].text == ''
 
    assert next(actual)[0].text == 'foo bar'
 
    assert next(actual, None) is None
 

	
 
def test_ods_multicolumn_write(rt):
 
    row_types = [('date', datetime.date), ('rt-id', object), ('desc', str)]
 
def test_ods_multicolumn_write(qparser, rt):
 
    statement = qparser.parse(
 
        'SELECT MIN(date) AS date, SET(META_DOCS("rt-id")) AS tix, STR_META("entity") AS entity',
 
    )
 
    row_types = [('date', datetime.date), ('tix', set), ('entity', str)]
 
    row_source = [
 
        (testutil.PAST_DATE, 'rt:1', 'aaa'),
 
        (testutil.FY_START_DATE, 'rt:2', 'bbb'),
 
        (testutil.FUTURE_DATE, 'rt:3', 'ccc'),
 
        (testutil.PAST_DATE, {'rt:1'}, 'AA'),
 
        (testutil.FY_START_DATE, {'rt:2'}, 'BB'),
 
        (testutil.FUTURE_DATE, {'rt:3', 'rt:4'}, 'CC'),
 
    ]
 
    ods = qmod.QueryODS(rt)
 
    ods.write_query(row_types, row_source)
 
    ods.write_query(statement, list(row_types), list(row_source))
 
    actual = iter(
 
        cell.text
 
        for row in testutil.ODSCell.from_sheet(ods.document.spreadsheet.firstChild)
 
        for cell in row
 
    )
 
    assert next(actual) == 'Date'
 
    assert next(actual) == 'Ticket'
 
    assert next(actual) == 'Desc'
 
    for expected, _ in row_types:
 
        assert next(actual) == expected.title()
 
    assert next(actual) == testutil.PAST_DATE.isoformat()
 
    assert next(actual) == 'rt:1'
 
    assert next(actual) == 'aaa'
 
    assert next(actual) == 'AA'
 
    assert next(actual) == testutil.FY_START_DATE.isoformat()
 
    assert next(actual) == 'rt:2'
 
    assert next(actual) == 'bbb'
 
    assert next(actual) == 'BB'
 
    assert next(actual) == testutil.FUTURE_DATE.isoformat()
 
    assert next(actual) == 'rt:3'
 
    assert next(actual) == 'ccc'
 
    assert frozenset(next(actual).split('\0')) == row_source[-1][1]
 
    assert next(actual) == 'CC'
 
    assert next(actual, None) is None
 

	
 
def test_ods_is_empty():
 
def test_ods_is_empty(qparser):
 
    statement = qparser.parse('SELECT * WHERE date < 1900-01-01')
 
    ods = qmod.QueryODS()
 
    assert ods.is_empty()
 
    ods.write_query([], [])
 
    ods.write_query(statement, [], [])
 
    assert not ods.is_empty()
 

	
 
@pytest.mark.parametrize('fy,account,amt_prefix', [
 
    (2018, 'Assets', '($'),
 
    (2019, 'Income', '$'),
 
])
 
def test_ods_output(fy, account, amt_prefix):
 
    books_path = testutil.test_path(f'books/books/{fy}.beancount')
 
    config = testutil.TestConfig(books_path=books_path)
 
    arglist = [
 
        '-O', '-',
 
        '-f', 'ods',
 
        f'SELECT date, narration, UNITS(position) WHERE account ~ "^{account}:"',
 
    ]
 
    returncode, stdout, stderr = pipe_main(arglist, config, io.BytesIO)
 
    assert returncode == 0
 
    stdout.seek(0)
 
    ods_doc = odf.opendocument.load(stdout)
 
    with stdout:
 
        stdout.seek(0)
 
        ods_doc = odf.opendocument.load(stdout)
 
    rows = iter(ods_doc.spreadsheet.firstChild.getElementsByType(odf.table.TableRow))
 
    next(rows)  # Skip header row
 
    amt_pattern = rf'^{re.escape(amt_prefix)}\d'
 
    for count, row in enumerate(rows, 1):
 
        date, narration, amount = row.childNodes
 
        assert re.fullmatch(rf'{fy}-\d{{2}}-\d{{2}}', date.text)
 
        assert narration.text.startswith(f'{fy} ')
 
        assert re.match(amt_pattern, amount.text)
 
    assert count
 

	
 
def test_ods_aggregate_output():
 
    books_path = testutil.test_path(f'books/books/2020.beancount')
 
    config = testutil.TestConfig(books_path=books_path)
 
    arglist = [
 
        '-O', '-',
 
        '-f', 'ods',
 
        'SELECT account, SET(narration), SUM(UNITS(position))',
 
        'WHERE date >= 2020-04-01 AND date <= 2020-04-02',
 
        'GROUP BY account ORDER BY account ASC',
 
    ]
 
    returncode, stdout, stderr = pipe_main(arglist, config, io.BytesIO)
 
    assert returncode == 0
 
    with stdout:
 
        stdout.seek(0)
 
        ods_doc = odf.opendocument.load(stdout)
 
    rows = iter(ods_doc.spreadsheet.firstChild.getElementsByType(odf.table.TableRow))
 
    next(rows)  # Skip header row
 
    actual = {}
 
    for row in rows:
 
        acct, descs, balance = row.childNodes
 
        actual[acct.text] = (frozenset(descs.text.split('\0')), balance.text)
 
    in_desc = {'2020 donation'}
 
    ex_desc = {'2020 bank maintenance fee'}
 
    assert actual['Income:Donations'] == (in_desc, '$20.20')
 
    assert actual['Expenses:BankingFees'] == (ex_desc, '$1.00')
 
    assert actual['Assets:Checking'] == (in_desc | ex_desc, '($21.20)')
0 comments (0 inline, 0 general)