Changeset - 8d3d7e7ce4e2
[Not reviewed]
0 3 0
Brett Smith - 4 years ago 2020-06-09 13:04:27
brettcsmith@brettcsmith.org
data: Add part slicing methods to Account.
3 files changed with 150 insertions and 1 deletions:
0 comments (0 inline, 0 general)
conservancy_beancount/data.py
Show inline comments
...
 
@@ -22,24 +22,25 @@ throughout Conservancy tools.
 
import collections
 
import datetime
 
import decimal
 
import functools
 

	
 
from beancount.core import account as bc_account
 
from beancount.core import amount as bc_amount
 
from beancount.core import convert as bc_convert
 
from beancount.core import position as bc_position
 

	
 
from typing import (
 
    cast,
 
    overload,
 
    Callable,
 
    Hashable,
 
    Iterable,
 
    Iterator,
 
    MutableMapping,
 
    Optional,
 
    Sequence,
 
    TypeVar,
 
    Union,
 
)
 

	
 
from .beancount_types import (
...
 
@@ -112,24 +113,82 @@ class Account(str):
 

	
 
          Account('Expenses:Tax').is_under('Exp') # returns None
 
        """
 
        for prefix in acct_seq:
 
            if self.startswith(prefix) and (
 
                prefix.endswith(self.SEP)
 
                or self == prefix
 
                or self[len(prefix)] == self.SEP
 
            ):
 
                return prefix
 
        return None
 

	
 
    def _find_part_slice(self, index: int) -> slice:
 
        if index < 0:
 
            raise ValueError(f"bad part index {index!r}")
 
        start = 0
 
        for _ in range(index):
 
            try:
 
                start = self.index(self.SEP, start) + 1
 
            except ValueError:
 
                raise IndexError("part index {index!r} out of range") from None
 
        try:
 
            stop = self.index(self.SEP, start + 1)
 
        except ValueError:
 
            stop = len(self)
 
        return slice(start, stop)
 

	
 
    def count_parts(self) -> int:
 
        return self.count(self.SEP) + 1
 

	
 
    @overload
 
    def slice_parts(self, start: None=None, stop: None=None) -> Sequence[str]: ...
 

	
 
    @overload
 
    def slice_parts(self, start: slice, stop: None=None) -> Sequence[str]: ...
 

	
 
    @overload
 
    def slice_parts(self, start: int, stop: int) -> Sequence[str]: ...
 

	
 
    @overload
 
    def slice_parts(self, start: int, stop: None=None) -> str: ...
 

	
 
    def slice_parts(self,
 
                    start: Optional[Union[int, slice]]=None,
 
                    stop: Optional[int]=None,
 
    ) -> Sequence[str]:
 
        """Slice the account parts like they were a list
 

	
 
        Given a single index, return that part of the account name as a string.
 
        Otherwise, return a list of part names sliced according to the arguments.
 
        """
 
        if start is None:
 
            part_slice = slice(None)
 
        elif isinstance(start, slice):
 
            part_slice = start
 
        elif stop is None:
 
            return self[self._find_part_slice(start)]
 
        else:
 
            part_slice = slice(start, stop)
 
        return self.split(self.SEP)[part_slice]
 

	
 
    def root_part(self, count: int=1) -> str:
 
        """Return the first part(s) of the account name as a string"""
 
        try:
 
            stop = self._find_part_slice(count - 1).stop
 
        except IndexError:
 
            return self
 
        else:
 
            return self[:stop]
 

	
 

	
 
class Amount(bc_amount.Amount):
 
    """Beancount amount after processing
 

	
 
    Beancount's native Amount class declares number to be Optional[Decimal],
 
    because the number is None when Beancount first parses a posting that does
 
    not have an amount, because the user wants it to be automatically balanced.
 

	
 
    As part of the loading process, Beancount replaces those None numbers
 
    with the calculated amount, so it will always be a Decimal. This class
 
    overrides the type declaration accordingly, so the type checker knows
 
    that our code doesn't have to consider the possibility that number is
conservancy_beancount/reports/accrual.py
Show inline comments
...
 
@@ -376,25 +376,25 @@ class AgingODS(core.BaseODS[AccrualPostings, Optional[data.Account]]):
 
            self.add_row(*(
 
                self.string_cell(name, stylename=self.style_bold)
 
                for name in self.COLUMNS
 
            ))
 
            self.lock_first_row()
 

	
 
    def start_section(self, key: Optional[data.Account]) -> None:
 
        if key is None:
 
            return
 
        self.age_thresholds = list(AccrualAccount.by_account(key).value.aging_thresholds)
 
        self.age_balances = [core.MutableBalance() for _ in self.age_thresholds]
 
        accrual_date = self.date - datetime.timedelta(days=self.age_thresholds[-1])
 
        acct_parts = key.split(':')
 
        acct_parts = key.slice_parts()
 
        self.use_sheet(acct_parts[1])
 
        self.add_row()
 
        self.add_row(self.string_cell(
 
            f"{' '.join(acct_parts[2:])} {acct_parts[1]} Aging Report"
 
            f" Accrued by {accrual_date.isoformat()} Unpaid by {self.date.isoformat()}",
 
            stylename=self.merge_styles(self.style_bold, self.style_centertext),
 
            numbercolumnsspanned=self.COL_COUNT,
 
        ))
 
        self.add_row()
 

	
 
    def end_section(self, key: Optional[data.Account]) -> None:
 
        if key is None:
tests/test_data_account.py
Show inline comments
...
 
@@ -106,12 +106,102 @@ def test_is_credit_card(acct_name, expected):
 
    ('Expenses:Other', False),
 
    ('Equity:Funds:Restricted', True),
 
    ('Equity:Funds:Unrestricted', True),
 
    ('Equity:OpeningBalance', True),
 
    ('Equity:Retained:Costs', False),
 
    ('Income:Other', False),
 
    ('Liabilities:CreditCard', False),
 
    ('Liabilities:Payable:Accounts', False),
 
    ('Liabilities:UnearnedIncome:Donations', False),
 
])
 
def test_is_opening_equity(acct_name, expected):
 
    assert data.Account(acct_name).is_opening_equity() == expected
 

	
 
@pytest.mark.parametrize('acct_name', [
 
    'Assets:Cash',
 
    'Assets:Receivable:Accounts',
 
    'Expenses:Other',
 
    'Equity:Funds:Restricted',
 
    'Income:Other',
 
    'Liabilities:CreditCard',
 
    'Liabilities:Payable:Accounts',
 
])
 
def test_slice_parts_no_args(acct_name):
 
    account = data.Account(acct_name)
 
    assert account.slice_parts() == acct_name.split(':')
 

	
 
@pytest.mark.parametrize('acct_name', [
 
    'Assets:Cash',
 
    'Assets:Receivable:Accounts',
 
    'Expenses:Other',
 
    'Equity:Funds:Restricted',
 
    'Income:Other',
 
    'Liabilities:CreditCard',
 
    'Liabilities:Payable:Accounts',
 
])
 
def test_slice_parts_index(acct_name):
 
    account = data.Account(acct_name)
 
    parts = acct_name.split(':')
 
    for index, expected in enumerate(parts):
 
        assert account.slice_parts(index) == expected
 
    with pytest.raises(IndexError):
 
        account.slice_parts(index + 1)
 

	
 
@pytest.mark.parametrize('acct_name', [
 
    'Assets:Cash',
 
    'Assets:Receivable:Accounts',
 
    'Expenses:Other',
 
    'Equity:Funds:Restricted',
 
    'Income:Other',
 
    'Liabilities:CreditCard',
 
    'Liabilities:Payable:Accounts',
 
])
 
def test_slice_parts_range(acct_name):
 
    account = data.Account(acct_name)
 
    parts = acct_name.split(':')
 
    for start, stop in zip([0, 0, 1, 1], [2, 3, 2, 3]):
 
        assert account.slice_parts(start, stop) == parts[start:stop]
 

	
 
@pytest.mark.parametrize('acct_name', [
 
    'Assets:Cash',
 
    'Assets:Receivable:Accounts',
 
    'Expenses:Other',
 
    'Equity:Funds:Restricted',
 
    'Income:Other',
 
    'Liabilities:CreditCard',
 
    'Liabilities:Payable:Accounts',
 
])
 
def test_slice_parts_slice(acct_name):
 
    account = data.Account(acct_name)
 
    parts = acct_name.split(':')
 
    for start, stop in zip([0, 0, 1, 1], [2, 3, 2, 3]):
 
        sl = slice(start, stop)
 
        assert account.slice_parts(sl) == parts[start:stop]
 

	
 
@pytest.mark.parametrize('acct_name', [
 
    'Assets:Cash',
 
    'Assets:Receivable:Accounts',
 
    'Expenses:Other',
 
    'Equity:Funds:Restricted',
 
    'Income:Other',
 
    'Liabilities:CreditCard',
 
    'Liabilities:Payable:Accounts',
 
])
 
def test_count_parts(acct_name):
 
    account = data.Account(acct_name)
 
    assert account.count_parts() == acct_name.count(':') + 1
 

	
 
@pytest.mark.parametrize('acct_name', [
 
    'Assets:Cash',
 
    'Assets:Receivable:Accounts',
 
    'Expenses:Other',
 
    'Equity:Funds:Restricted',
 
    'Income:Other',
 
    'Liabilities:CreditCard',
 
    'Liabilities:Payable:Accounts',
 
])
 
def test_root_part(acct_name):
 
    account = data.Account(acct_name)
 
    parts = acct_name.split(':')
 
    assert account.root_part() == parts[0]
 
    assert account.root_part(1) == parts[0]
 
    assert account.root_part(2) == ':'.join(parts[:2])
0 comments (0 inline, 0 general)