Changeset - 04be991e1958
[Not reviewed]
0 1 0
Brett Smith - 4 years ago 2020-07-29 20:58:57
brettcsmith@brettcsmith.org
reports: Add BaseODS.set_custom_property() method.
1 file changed with 31 insertions and 0 deletions:
0 comments (0 inline, 0 general)
conservancy_beancount/reports/core.py
Show inline comments
...
 
@@ -671,631 +671,662 @@ class BaseODS(BaseSpreadsheet[RT, ST], metaclass=abc.ABCMeta):
 
            except NameError:
 
                raise ValueError(
 
                    f"need config_type for {type(value).__name__} value",
 
                ) from None
 
        item = self.replace_child(
 
            root, odf.config.ConfigItem, name=name, type=config_type,
 
        )
 
        item.addText(value_s)
 

	
 
    ### Styles
 

	
 
    def border_style(self,
 
                     edges: int,
 
                     width: str='1px',
 
                     style: str='solid',
 
                     color: str='#000000',
 
    ) -> odf.style.Style:
 
        flags = [edge for edge in Border if edges & edge]
 
        if not flags:
 
            raise ValueError(f"no valid edges in {edges!r}")
 
        border_attr = f'{width} {style} {color}'
 
        key = f'{",".join(f.name for f in flags)} {border_attr}'
 
        try:
 
            retval = self._style_cache[key]
 
        except KeyError:
 
            props = odf.style.TableCellProperties()
 
            for flag in flags:
 
                props.setAttribute(f'border{flag.name.lower()}', border_attr)
 
            retval = odf.style.Style(
 
                name=f'Border{next(self._name_counter)}',
 
                family='table-cell',
 
            )
 
            retval.addElement(props)
 
            self.document.styles.addElement(retval)
 
            self._style_cache[key] = retval
 
        return retval
 

	
 
    def column_style(self, width: Union[float, str], **attrs: Any) -> odf.style.Style:
 
        if not isinstance(width, str) or (width and not width[-1].isalpha()):
 
            width = f'{width}in'
 
        match = self.MEASUREMENT_RE.fullmatch(width)
 
        if match is None:
 
            raise ValueError(f"invalid width {width!r}")
 
        width_float = float(match.group(1))
 
        if width_float <= 0:
 
            # Per the OpenDocument spec, column-width is a positiveLength.
 
            raise ValueError(f"width {width!r} must be positive")
 
        width = '{:.3g}{}'.format(width_float, match.group(2))
 
        retval = self.ensure_child(
 
            self.document.automaticstyles,
 
            odf.style.Style,
 
            name=f'col_{width.replace(".", "_")}'
 
        )
 
        retval.setAttribute('family', 'table-column')
 
        if retval.firstChild is None:
 
            retval.addElement(odf.style.TableColumnProperties(
 
                columnwidth=width, **attrs
 
            ))
 
        return retval
 

	
 
    def _build_currency_style(
 
            self,
 
            root: odf.element.Element,
 
            locale: babel.core.Locale,
 
            code: str,
 
            fmt_index: int,
 
            properties: Optional[odf.style.TextProperties]=None,
 
            *,
 
            fmt_key: Optional[str]=None,
 
            volatile: bool=False,
 
            minintegerdigits: int=1,
 
    ) -> odf.element.Element:
 
        if fmt_key is None:
 
            fmt_key = self.currency_fmt_key
 
        pattern = locale.currency_formats[fmt_key]
 
        fmts = pattern.pattern.split(';')
 
        try:
 
            fmt = fmts[fmt_index]
 
        except IndexError:
 
            fmt = fmts[0]
 
            grouping = pattern.grouping[0]
 
        else:
 
            grouping = pattern.grouping[fmt_index]
 
        zero_s = babel.numbers.format_currency(0, code, '##0.0', locale)
 
        try:
 
            decimal_index = zero_s.rindex('.') + 1
 
        except ValueError:
 
            decimalplaces = 0
 
        else:
 
            decimalplaces = len(zero_s) - decimal_index
 
        style = self.replace_child(
 
            root,
 
            odf.number.CurrencyStyle,
 
            name=f'{code}{next(self._name_counter)}',
 
        )
 
        style.setAttribute('volatile', 'true' if volatile else 'false')
 
        if properties is not None:
 
            style.addElement(properties)
 
        for part in re.split(r"(¤+|[#0,.]+|'[^']+')", fmt):
 
            if not part:
 
                pass
 
            elif not part.strip('#0,.'):
 
                style.addElement(odf.number.Number(
 
                    decimalplaces=str(decimalplaces),
 
                    grouping='true' if grouping else 'false',
 
                    minintegerdigits=str(minintegerdigits),
 
                ))
 
            elif part == '¤':
 
                style.addElement(odf.number.CurrencySymbol(
 
                    country=locale.territory,
 
                    language=locale.language,
 
                    text=babel.numbers.get_currency_symbol(code, locale),
 
                ))
 
            elif part == '¤¤':
 
                style.addElement(odf.number.Text(text=code))
 
            else:
 
                style.addElement(odf.number.Text(text=part.strip("'")))
 
        return style
 

	
 
    def currency_style(
 
            self,
 
            code: str,
 
            locale: Optional[babel.core.Locale]=None,
 
            negative_properties: Optional[odf.style.TextProperties]=None,
 
            positive_properties: Optional[odf.style.TextProperties]=None,
 
            root: odf.element.Element=None,
 
    ) -> odf.style.Style:
 
        """Create and return a spreadsheet style to format currency data
 

	
 
        Given a currency code and a locale, this method will create all the
 
        styles necessary to format the currency according to the locale's
 
        rules, including rendering of decimal points and negative values.
 

	
 
        You may optionally pass in TextProperties to use for negative and
 
        positive amounts, respectively. If you don't, negative values will
 
        automatically be rendered in red (text color #f00).
 

	
 
        Results are cached. If you repeatedly call this method with the same
 
        arguments, you'll keep getting the same style returned, which will
 
        only be added to the document once.
 
        """
 
        if locale is None:
 
            locale = self.locale
 
        if negative_properties is None:
 
            negative_properties = odf.style.TextProperties(color='#ff0000')
 
        if root is None:
 
            root = self.document.styles
 
        cache_parts = [str(id(root)), code, str(locale)]
 
        for key, value in self.iter_attributes(negative_properties):
 
            cache_parts.append(f'{key}={value}')
 
        if positive_properties is not None:
 
            cache_parts.append('')
 
            for key, value in self.iter_attributes(positive_properties):
 
                cache_parts.append(f'{key}={value}')
 
        cache_key = '\0'.join(cache_parts)
 
        try:
 
            style = self._style_cache[cache_key]
 
        except KeyError:
 
            pos_style = self._build_currency_style(
 
                root, locale, code, 0, positive_properties, volatile=True,
 
            )
 
            curr_style = self._build_currency_style(
 
                root, locale, code, 1, negative_properties,
 
            )
 
            curr_style.addElement(odf.style.Map(
 
                condition='value()>=0', applystylename=pos_style,
 
            ))
 
            style = self.ensure_child(
 
                self.document.styles,
 
                odf.style.Style,
 
                name=f'{curr_style.getAttribute("name")}Cell',
 
                family='table-cell',
 
                datastylename=curr_style,
 
            )
 
            self._style_cache[cache_key] = style
 
        return style
 

	
 
    def _merge_style_iter_names(
 
            self,
 
            styles: Sequence[Union[str, odf.style.Style, None]],
 
    ) -> Iterator[str]:
 
        for source in styles:
 
            if source is None:
 
                continue
 
            elif not isinstance(source, str):
 
                source = source.getAttribute('name')
 
            if source.startswith('Merge_'):
 
                orig_names = iter(source.split('_'))
 
                next(orig_names)
 
                yield from orig_names
 
            else:
 
                yield source
 

	
 
    def _merge_styles(self,
 
                      new_style: odf.style.Style,
 
                      sources: Iterable[odf.style.Style],
 
    ) -> None:
 
        for elem in sources:
 
            for key, new_value in self.iter_attributes(elem):
 
                old_value = new_style.getAttribute(key)
 
                if (key == 'name'
 
                    or key == 'displayname'
 
                    or old_value == new_value):
 
                    pass
 
                elif old_value is None:
 
                    new_style.setAttribute(key, new_value)
 
                else:
 
                    raise ValueError(f"cannot merge styles with conflicting {key}")
 
            for child in elem.childNodes:
 
                new_style.addElement(self.copy_element(child))
 

	
 
    def merge_styles(self,
 
                     *styles: Union[str, odf.style.Style, None],
 
    ) -> Optional[odf.style.Style]:
 
        """Create a new style from multiple existing styles
 

	
 
        Given any number of existing styles, create a new style that combines
 
        all of those styles' attributes and properties, add it to the document
 
        styles, and return it.
 

	
 
        Styles can be specified by name, or by passing in their Style element.
 
        For convenience, you can also pass in None as an argument; None will
 
        simply be skipped.
 

	
 
        Results are cached. If you repeatedly call this method with the same
 
        arguments, you'll keep getting the same style returned, which will
 
        only be added to the document once.
 

	
 
        If you pass in zero real style arguments, returns None.
 
        If you pass in one style argument, returns that style unchanged.
 
        If you pass in a style that doesn't already exist in the document,
 
        or if you pass in styles that can't be merged (because they have
 
        conflicting attributes), raises ValueError.
 
        """
 
        name_map: Dict[str, odf.style.Style] = {}
 
        for name in self._merge_style_iter_names(styles):
 
            source = odf.style.Style(name=name)
 
            found = self.find_child(self.document.styles, source)
 
            if found is None:
 
                raise ValueError(f"no style named {name!r}")
 
            name_map[name] = found
 
        if not name_map:
 
            retval = None
 
        elif len(name_map) == 1:
 
            _, retval = name_map.popitem()
 
        else:
 
            new_name = f'Merge_{"_".join(sorted(name_map))}'
 
            retval = self.ensure_child(
 
                self.document.styles, odf.style.Style, name=new_name,
 
            )
 
            if retval.firstChild is None:
 
                self._merge_styles(retval, name_map.values())
 
        return retval
 

	
 
    ### Sheets
 

	
 
    def lock_first_column(self, sheet: Optional[odf.table.Table]=None) -> None:
 
        """Lock the first column of cells under the given sheet
 

	
 
        This method sets all the appropriate settings to "lock" the first column
 
        of cells in a sheet, so it stays in view even as the viewer scrolls
 
        across the sheet. If a sheet is not given, works on ``self.sheet``.
 
        """
 
        if sheet is None:
 
            sheet = self.sheet
 
        config_map = self.ensure_config_map_entry(
 
            self.view, 'Tables', sheet.getAttribute('name'),
 
        )
 
        self.set_config(config_map, 'PositionRight', 1, 'int')
 
        self.set_config(config_map, 'HorizontalSplitMode', 2, 'short')
 
        self.set_config(config_map, 'HorizontalSplitPosition', 1, 'short')
 

	
 
    def lock_first_row(self, sheet: Optional[odf.table.Table]=None) -> None:
 
        """Lock the first row of cells under the given sheet
 

	
 
        This method sets all the appropriate settings to "lock" the first row
 
        of cells in a sheet, so it stays in view even as the viewer scrolls
 
        through rows. If a sheet is not given, works on ``self.sheet``.
 
        """
 
        if sheet is None:
 
            sheet = self.sheet
 
        config_map = self.ensure_config_map_entry(
 
            self.view, 'Tables', sheet.getAttribute('name'),
 
        )
 
        self.set_config(config_map, 'PositionBottom', 1, 'int')
 
        self.set_config(config_map, 'VerticalSplitMode', 2, 'short')
 
        self.set_config(config_map, 'VerticalSplitPosition', 1, 'short')
 

	
 
    def set_open_sheet(self, sheet: Union[str, odf.table.Table, None]=None) -> None:
 
        """Set which sheet is open in the document
 

	
 
        When the user first opens the spreadsheet, their view will be on this
 
        sheet. You can provide a sheet name string or sheet object. With no
 
        argument, defaults to ``self.sheet``.
 
        """
 
        if sheet is None:
 
            sheet = self.sheet
 
        if not isinstance(sheet, str):
 
            sheet = sheet.getAttribute('name')
 
            if not isinstance(sheet, str):
 
                raise ValueError("sheet argument has no name for setting")
 
        self.set_config(self.view, 'ActiveTable', sheet, 'string')
 

	
 
    def use_sheet(self, name: str) -> odf.table.Table:
 
        """Switch the active sheet ``self.sheet`` to the one with the given name
 

	
 
        If there is no sheet with the given name, create it and append it to
 
        the spreadsheet first.
 

	
 
        If the current active sheet is empty when this method is called, it
 
        will be removed from the spreadsheet.
 
        """
 
        try:
 
            empty_sheet = not self.sheet.hasChildNodes()
 
        except AttributeError:
 
            empty_sheet = False
 
        if empty_sheet:
 
            self.document.spreadsheet.removeChild(self.sheet)
 
        self.sheet = self.ensure_child(
 
            self.document.spreadsheet, odf.table.Table, name=name,
 
        )
 
        return self.sheet
 

	
 
    ### Initialization hooks
 

	
 
    def init_settings(self) -> None:
 
        """Hook called to initialize settings
 

	
 
        This method is called by __init__ to populate
 
        ``self.document.settings``. This implementation creates the barest
 
        skeleton structure necessary to support other methods, in particular
 
        ``lock_first_row``.
 
        """
 
        view_settings = self.ensure_child(
 
            self.document.settings, odf.config.ConfigItemSet, name='ooo:view-settings',
 
        )
 
        views = self.ensure_child(
 
            view_settings, odf.config.ConfigItemMapIndexed, name='Views',
 
        )
 
        self.view = self.ensure_child(views, odf.config.ConfigItemMapEntry)
 
        self.set_config(self.view, 'ViewId', 'view1')
 

	
 
    def init_styles(self) -> None:
 
        """Hook called to initialize settings
 

	
 
        This method is called by __init__ to populate
 
        ``self.document.styles``. This implementation creates basic building
 
        block cell styles often used in financial reports.
 
        """
 
        styles = self.document.styles
 
        self.style_bold = self.ensure_child(
 
            styles, odf.style.Style, name='Bold', family='table-cell',
 
        )
 
        self.ensure_child(
 
            self.style_bold, odf.style.TextProperties, fontweight='bold',
 
        )
 

	
 
        date_style = self.replace_child(styles, odf.number.DateStyle, name='ISODate')
 
        date_style.addElement(odf.number.Year(style='long'))
 
        date_style.addElement(odf.number.Text(text='-'))
 
        date_style.addElement(odf.number.Month(style='long'))
 
        date_style.addElement(odf.number.Text(text='-'))
 
        date_style.addElement(odf.number.Day(style='long'))
 
        self.style_date = self.ensure_child(
 
            styles,
 
            odf.style.Style,
 
            name=f'{date_style.getAttribute("name")}Cell',
 
            family='table-cell',
 
            datastylename=date_style,
 
        )
 

	
 
        self.style_starttext: odf.style.Style
 
        self.style_centertext: odf.style.Style
 
        self.style_endtext: odf.style.Style
 
        for textalign in ['start', 'center', 'end']:
 
            aligned_style = self.replace_child(
 
                styles, odf.style.Style, name=f'{textalign.title()}Text',
 
            )
 
            aligned_style.setAttribute('family', 'table-cell')
 
            aligned_style.addElement(odf.style.ParagraphProperties(textalign=textalign))
 
            setattr(self, f'style_{textalign}text', aligned_style)
 

	
 
    ### Properties
 

	
 
    def set_custom_property(self,
 
                            name: str,
 
                            value: Any,
 
                            valuetype: Optional[str]=None,
 
    ) -> odf.meta.UserDefined:
 
        if valuetype is None:
 
            if isinstance(value, bool):
 
                valuetype = 'boolean'
 
            elif isinstance(value, (datetime.date, datetime.datetime)):
 
                valuetype = 'date'
 
            elif isinstance(value, (int, float, Decimal)):
 
                valuetype = 'float'
 
        if not isinstance(value, str):
 
            if valuetype == 'boolean':
 
                value = 'true' if value else 'false'
 
            elif valuetype == 'date':
 
                value = value.isoformat()
 
            else:
 
                value = str(value)
 
        retval = self.ensure_child(self.document.meta, odf.meta.UserDefined, name=name)
 
        if valuetype is None:
 
            try:
 
                retval.removeAttribute('valuetype')
 
            except KeyError:
 
                pass
 
        else:
 
            retval.setAttribute('valuetype', valuetype)
 
        retval.childNodes.clear()
 
        retval.addText(value)
 
        return retval
 

	
 
    def set_properties(self, *,
 
                       created: Optional[datetime.datetime]=None,
 
                       generator: str='conservancy_beancount',
 
    ) -> None:
 
        if created is None:
 
            created = datetime.datetime.now()
 
        created_elem = self.ensure_child(self.document.meta, odf.meta.CreationDate)
 
        created_elem.childNodes.clear()
 
        created_elem.addText(created.isoformat())
 
        generator_elem = self.ensure_child(self.document.meta, odf.meta.Generator)
 
        generator_elem.childNodes.clear()
 
        generator_elem.addText(f'{generator}/{VERSION} {TOOLSVERSION}')
 

	
 
    ### Rows and cells
 

	
 
    def add_row(self, *cells: odf.table.TableCell, **attrs: Any) -> odf.table.TableRow:
 
        row = odf.table.TableRow(**attrs)
 
        for cell in cells:
 
            row.addElement(cell)
 
        self.sheet.addElement(row)
 
        return row
 

	
 
    def balance_cell(self, balance: Balance, **attrs: Any) -> odf.table.TableCell:
 
        balance = balance.clean_copy() or balance
 
        balance_currency_count = len(balance)
 
        if balance_currency_count == 0:
 
            return self.float_cell(0, **attrs)
 
        elif balance_currency_count == 1:
 
            amount = next(iter(balance.values()))
 
            attrs['stylename'] = self.merge_styles(
 
                attrs.get('stylename'), self.currency_style(amount.currency),
 
            )
 
            return self.currency_cell(amount, **attrs)
 
        else:
 
            lines = [babel.numbers.format_currency(
 
                number, currency, locale=self.locale, format_type=self.currency_fmt_key,
 
            ) for number, currency in balance.values()]
 
            attrs['stylename'] = self.merge_styles(
 
                attrs.get('stylename'), self.style_endtext,
 
            )
 
            return self.multiline_cell(lines, **attrs)
 

	
 
    def currency_cell(self, amount: data.Amount, **attrs: Any) -> odf.table.TableCell:
 
        if 'stylename' not in attrs:
 
            attrs['stylename'] = self.currency_style(amount.currency)
 
        number, currency = amount
 
        cell = odf.table.TableCell(valuetype='currency', value=number, **attrs)
 
        cell.addElement(odf.text.P(text=babel.numbers.format_currency(
 
            number, currency, locale=self.locale, format_type=self.currency_fmt_key,
 
        )))
 
        return cell
 

	
 
    def date_cell(self, date: datetime.date, **attrs: Any) -> odf.table.TableCell:
 
        attrs.setdefault('stylename', self.style_date)
 
        cell = odf.table.TableCell(valuetype='date', datevalue=date, **attrs)
 
        cell.addElement(odf.text.P(text=date.isoformat()))
 
        return cell
 

	
 
    def float_cell(self, value: Union[int, float, Decimal], **attrs: Any) -> odf.table.TableCell:
 
        cell = odf.table.TableCell(valuetype='float', value=value, **attrs)
 
        cell.addElement(odf.text.P(text=str(value)))
 
        return cell
 

	
 
    def _meta_link_pairs(self, links: Iterable[Optional[str]]) -> Iterator[Tuple[str, str]]:
 
        for href in links:
 
            if href is None:
 
                continue
 
            elif self.rt_wrapper is not None:
 
                rt_ids = self.rt_wrapper.parse(href)
 
                rt_href = rt_ids and self.rt_wrapper.url(*rt_ids)
 
            else:
 
                rt_ids = None
 
                rt_href = None
 
            if rt_ids is None or rt_href is None:
 
                # '..' pops the ODS filename off the link path. In other words,
 
                # make the link relative to the directory the ODS is in.
 
                href_path = Path('..', href)
 
                href = str(href_path)
 
                text = href_path.name
 
            else:
 
                rt_path = urlparse.urlparse(rt_href).path
 
                if rt_path.endswith('/Ticket/Display.html'):
 
                    text = rtutil.RT.unparse(*rt_ids)
 
                else:
 
                    text = urlparse.unquote(Path(rt_path).name)
 
                href = rt_href
 
            yield (href, text)
 

	
 
    def meta_links_cell(self, links: Iterable[Optional[str]], **attrs: Any) -> odf.table.TableCell:
 
        return self.multilink_cell(self._meta_link_pairs(links), **attrs)
 

	
 
    def multiline_cell(self, lines: Iterable[Any], **attrs: Any) -> odf.table.TableCell:
 
        cell = odf.table.TableCell(valuetype='string', **attrs)
 
        for line in lines:
 
            cell.addElement(odf.text.P(text=str(line)))
 
        return cell
 

	
 
    def multilink_cell(self, links: Iterable[LinkType], **attrs: Any) -> odf.table.TableCell:
 
        cell = odf.table.TableCell(valuetype='string', **attrs)
 
        for link in links:
 
            if isinstance(link, tuple):
 
                href, text = link
 
            else:
 
                href = link
 
                text = None
 
            cell.addElement(odf.text.P())
 
            cell.lastChild.addElement(odf.text.A(
 
                type='simple', href=href, text=text or href,
 
            ))
 
        return cell
 

	
 
    def string_cell(self, text: str, **attrs: Any) -> odf.table.TableCell:
 
        cell = odf.table.TableCell(valuetype='string', **attrs)
 
        cell.addElement(odf.text.P(text=text))
 
        return cell
 

	
 
    def write_row(self, row: RT) -> None:
 
        """Write a single row of input data to the spreadsheet
 

	
 
        This default implementation adds a single row to the spreadsheet,
 
        with one cell per element of the row. The type of each element
 
        determines what kind of cell is created.
 

	
 
        This implementation will help get you started, but you'll probably
 
        want to override it to specify styles.
 
        """
 
        out_row = odf.table.TableRow()
 
        for cell_source in row:
 
            if isinstance(cell_source, (int, float, Decimal)):
 
                cell = self.float_cell(cell_source)
 
            else:
 
                cell = self.string_cell(cell_source)
 
            out_row.addElement(cell)
 
        self.sheet.addElement(out_row)
 

	
 
    def save_file(self, out_file: BinaryIO) -> None:
 
        self.document.write(out_file)
 

	
 
    def save_path(self, path: Path, mode: str='w') -> None:
 
        with path.open(f'{mode}b') as out_file:
 
            out_file = cast(BinaryIO, out_file)
 
            self.save_file(out_file)
 

	
 

	
 
def account_balances(
 
        groups: Mapping[data.Account, PeriodPostings],
 
        order: Optional[Sequence[str]]=None,
 
) -> Iterator[Tuple[str, Balance]]:
 
    """Iterate account balances over a date range
 

	
 
    1. ``subclass = PeriodPostings.with_start_date(start_date)``
 
    2. ``groups = dict(subclass.group_by_account(postings))``
 
    3. ``for acct, bal in account_balances(groups, [optional ordering]): ...``
 

	
 
    This function returns an iterator of 2-tuples ``(account, balance)``
 
    that you can use to generate a report in the style of ``ledger balance``.
 
    The accounts are accounts in ``groups`` that appeared under one of the
 
    account name strings in ``order``. ``balance`` is the corresponding
 
    balance over the time period (``groups[key].period_bal``). Accounts are
 
    iterated in the order provided by ``sort_and_filter_accounts()``.
 

	
 
    The first 2-tuple is ``(OPENING_BALANCE_NAME, balance)`` with the balance of
 
    all these accounts as of ``start_date``.
 
    The final 2-tuple is ``(ENDING_BALANCE_NAME, balance)`` with the final
 
    balance of all these accounts as of ``start_date``.
 
    The iterator will always yield these special 2-tuples, even when there are
 
    no accounts in the input or to report.
 
    """
 
    if order is None:
 
        order = ['Equity', 'Income', 'Expenses']
 
    acct_seq = [account for _, account in sort_and_filter_accounts(groups, order)]
 
    yield (OPENING_BALANCE_NAME, sum(
 
        (groups[key].start_bal for key in acct_seq),
 
        MutableBalance(),
 
    ))
 
    for key in acct_seq:
 
        postings = groups[key]
 
        try:
 
            in_date_range = postings[-1].meta.date >= postings.START_DATE
 
        except IndexError:
 
            in_date_range = False
 
        if in_date_range:
 
            yield (key, groups[key].period_bal)
 
    yield (ENDING_BALANCE_NAME, sum(
 
        (groups[key].stop_bal for key in acct_seq),
 
        MutableBalance(),
 
    ))
 

	
 
def normalize_amount_func(account_name: str) -> Callable[[T], T]:
 
    """Get a function to normalize amounts for reporting
 

	
 
    Given an account name, return a function that can be used on "amounts"
 
    under that account (including numbers, Amount objects, and Balance objects)
 
    to normalize them for reporting. Right now that means make flipping the
 
    sign for accounts where "normal" postings are negative.
 
    """
 
    if account_name.startswith(('Assets:', 'Expenses:')):
 
        # We can't just return operator.pos because Beancount's Amount class
 
        # doesn't implement __pos__.
 
        return lambda amt: amt
 
    elif account_name.startswith(('Equity:', 'Income:', 'Liabilities:')):
 
        return operator.neg
 
    else:
 
        raise ValueError(f"unrecognized account name {account_name!r}")
 

	
 
def sort_and_filter_accounts(
 
        accounts: Iterable[data.Account],
 
        order: Sequence[str],
 
) -> Iterator[Tuple[int, data.Account]]:
 
    """Reorganize accounts based on an ordered set of names
 

	
 
    This function takes a iterable of Account objects, and a sequence of
 
    account names. Usually the account names are higher parts of the account
 
    hierarchy like Income, Equity, or Assets:Receivable.
 

	
 
    It returns an iterator of 2-tuples, ``(index, account)`` where ``index`` is
 
    an index into the ordering sequence, and ``account`` is one of the input
 
    Account objects that's under the account name ``order[index]``. Tuples are
 
    sorted, so ``index`` increases monotonically, and Account objects using the
 
    same index are yielded sorted by name.
 

	
 
    For example, if your order is
 
    ``['Liabilities:Payable', 'Assets:Receivable']``, the return value will
 
    first yield zero or more results with index 0 and an account under
 
    Liabilities:Payable, then zero or more results with index 1 and an account
 
    under Accounts:Receivable.
 

	
 
    Input Accounts that are not under any of the account names in ``order`` do
 
    not appear in the output iterator. That's the filtering part.
 

	
 
    Note that if none of the input Accounts are under one of the ordering
 
    sequence accounts, its index will never appear in the results. This is why
 
    the 2-tuples include an index rather than the original account name string,
 
    to make it easier for callers to know when this happens and do something
 
    with unused ordering accounts.
 
    """
 
    index_map = {s: ii for ii, s in enumerate(order)}
 
    retval: Mapping[int, List[data.Account]] = collections.defaultdict(list)
 
    for account in accounts:
 
        acct_key = account.is_under(*order)
 
        if acct_key is not None:
 
            retval[index_map[acct_key]].append(account)
 
    return (
 
        (key, account)
 
        for key in sorted(retval)
 
        for account in sorted(retval[key])
 
    )
0 comments (0 inline, 0 general)