diff --git a/conservancy_beancount/data.py b/conservancy_beancount/data.py index b93a2c7e854b4beba00ff94d7062980142f599bb..7c016e00617578516f39be45c4ad4a5e0475cccc 100644 --- a/conservancy_beancount/data.py +++ b/conservancy_beancount/data.py @@ -25,6 +25,7 @@ import operator from beancount.core import account as bc_account from beancount.core import amount as bc_amount +from beancount.core import convert as bc_convert from typing import ( cast, @@ -257,18 +258,28 @@ class Posting(BasePosting): def balance_of(txn: Transaction, *preds: Callable[[Account], Optional[bool]], -) -> decimal.Decimal: +) -> Amount: """Return the balance of specified postings in a transaction. Given a transaction and a series of account predicates, balance_of returns the balance of the amounts of all postings with accounts that match any of the predicates. + + balance_of uses the "weight" of each posting, so the return value will + use the currency of the postings' cost when available. """ - return sum( - (post.units.number for post in iter_postings(txn) - if any(pred(post.account) for pred in preds)), - decimal.Decimal(0), - ) + match_posts = [post for post in iter_postings(txn) + if any(pred(post.account) for pred in preds)] + number = decimal.Decimal(0) + if not match_posts: + currency = '' + else: + weights: Sequence[Amount] = [ + bc_convert.get_weight(post) for post in match_posts # type:ignore[no-untyped-call] + ] + number = sum((wt.number for wt in weights), number) + currency = weights[0].currency + return Amount._make((number, currency)) def iter_postings(txn: Transaction) -> Iterator[Posting]: """Yield an enhanced Posting object for every posting in the transaction"""