diff --git a/conservancy_beancount/reports/accrual.py b/conservancy_beancount/reports/accrual.py index 67db8e82fbaf06bcd33433bda67e2ba494acf679..6f1e713459011d77e151ad374062c3513f3a3fdf 100644 --- a/conservancy_beancount/reports/accrual.py +++ b/conservancy_beancount/reports/accrual.py @@ -125,6 +125,7 @@ STANDARD_PATH = Path('-') CompoundAmount = TypeVar('CompoundAmount', data.Amount, core.Balance) PostGroups = Mapping[Optional[MetaValue], 'AccrualPostings'] RTObject = Mapping[str, str] +T = TypeVar('T') logger = logging.getLogger('conservancy_beancount.reports.accrual') @@ -134,19 +135,14 @@ class Sentinel: class Account(NamedTuple): name: str - norm_func: Callable[[CompoundAmount], CompoundAmount] aging_thresholds: Sequence[int] class AccrualAccount(enum.Enum): # Note the aging report uses the same order accounts are defined here. # See AgingODS.start_spreadsheet(). - RECEIVABLE = Account( - 'Assets:Receivable', lambda bal: bal, [365, 120, 90, 60], - ) - PAYABLE = Account( - 'Liabilities:Payable', operator.neg, [365, 90, 60, 30], - ) + RECEIVABLE = Account('Assets:Receivable', [365, 120, 90, 60]) + PAYABLE = Account('Liabilities:Payable', [365, 90, 60, 30]) @classmethod def account_names(cls) -> Iterator[str]: @@ -167,6 +163,10 @@ class AccrualAccount(enum.Enum): return account raise ValueError("unrecognized account set in related postings") + @property + def normalize_amount(self) -> Callable[[T], T]: + return core.normalize_amount_func(self.value.name) + class AccrualPostings(core.RelatedPostings): def _meta_getter(key: MetaKey) -> Callable[[data.Posting], MetaValue]: # type:ignore[misc] @@ -221,8 +221,7 @@ class AccrualPostings(core.RelatedPostings): self.paid_entities = self.accrued_entities else: self.accrual_type = AccrualAccount.classify(self) - accrual_acct: Account = self.accrual_type.value - norm_func = accrual_acct.norm_func + norm_func = self.accrual_type.normalize_amount self.end_balance = norm_func(self.balance_at_cost()) self.accrued_entities = self._collect_entities( lambda post: norm_func(post.units).number > 0, @@ -453,7 +452,7 @@ class AgingODS(core.BaseODS[AccrualPostings, Optional[data.Account]]): return raw_balance = row.balance() if row.accrual_type is not None: - raw_balance = row.accrual_type.value.norm_func(raw_balance) + raw_balance = row.accrual_type.normalize_amount(raw_balance) if raw_balance == row.end_balance: amount_cell = odf.table.TableCell() else: diff --git a/conservancy_beancount/reports/core.py b/conservancy_beancount/reports/core.py index aab80c13a4f5fd06afc83b69f0f55570d3dc34c8..01e04cd6eea99fe3f3246ff8d394d08632a302b9 100644 --- a/conservancy_beancount/reports/core.py +++ b/conservancy_beancount/reports/core.py @@ -73,6 +73,7 @@ LinkType = Union[str, Tuple[str, Optional[str]]] RelatedType = TypeVar('RelatedType', bound='RelatedPostings') RT = TypeVar('RT', bound=Sequence) ST = TypeVar('ST') +T = TypeVar('T') class Balance(Mapping[str, data.Amount]): """A collection of amounts mapped by currency @@ -898,3 +899,21 @@ class BaseODS(BaseSpreadsheet[RT, ST], metaclass=abc.ABCMeta): with path.open(f'{mode}b') as out_file: out_file = cast(BinaryIO, out_file) self.save_file(out_file) + + +def normalize_amount_func(account_name: str) -> Callable[[T], T]: + """Get a function to normalize amounts for reporting + + Given an account name, return a function that can be used on "amounts" + under that account (including numbers, Amount objects, and Balance objects) + to normalize them for reporting. Right now that means make flipping the + sign for accounts where "normal" postings are negative. + """ + if account_name.startswith(('Assets:', 'Expenses:')): + # We can't just return operator.pos because Beancount's Amount class + # doesn't implement __pos__. + return lambda amt: amt + elif account_name.startswith(('Equity:', 'Income:', 'Liabilities:')): + return operator.neg + else: + raise ValueError(f"unrecognized account name {account_name!r}") diff --git a/tests/test_reports_core.py b/tests/test_reports_core.py new file mode 100644 index 0000000000000000000000000000000000000000..a0b8a03a3604aab4ec266fc9e2cecd4388603bda --- /dev/null +++ b/tests/test_reports_core.py @@ -0,0 +1,66 @@ +"""test_reports_core - Unit tests for basic reports functions""" +# Copyright © 2020 Brett Smith +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +import pytest + +from decimal import Decimal + +from . import testutil + +from conservancy_beancount.reports import core + +AMOUNTS = [ + 2, + Decimal('4.40'), + testutil.Amount('6.60', 'CHF'), + core.Balance([testutil.Amount('8.80')]), +] + +@pytest.mark.parametrize('acct_name', [ + 'Assets:Checking', + 'Assets:Receivable:Accounts', + 'Expenses:Other', + 'Expenses:FilingFees', +]) +def test_normalize_amount_func_pos(acct_name): + actual = core.normalize_amount_func(acct_name) + for amount in AMOUNTS: + assert actual(amount) == amount + +@pytest.mark.parametrize('acct_name', [ + 'Equity:Funds:Restricted', + 'Equity:Realized:CurrencyConversion', + 'Income:Donations', + 'Income:Other', + 'Liabilities:CreditCard', + 'Liabilities:Payable:Accounts', +]) +def test_normalize_amount_func_neg(acct_name): + actual = core.normalize_amount_func(acct_name) + for amount in AMOUNTS: + assert actual(amount) == -amount + +@pytest.mark.parametrize('acct_name', [ + '', + 'Assets', + 'Equity', + 'Expenses', + 'Income', + 'Liabilities', +]) +def test_normalize_amount_func_bad_acct_name(acct_name): + with pytest.raises(ValueError): + core.normalize_amount_func(acct_name)