diff --git a/conservancy_beancount/reports/query.py b/conservancy_beancount/reports/query.py index c8708292e44b3d43b3f49534179af1c3d42f2b0c..fb33ba129a4ab0b15a9974479c58698a5a280989 100644 --- a/conservancy_beancount/reports/query.py +++ b/conservancy_beancount/reports/query.py @@ -6,6 +6,7 @@ # LICENSE.txt in the repository. import argparse +import contextlib import datetime import enum import itertools @@ -21,10 +22,12 @@ from typing import ( Iterable, Iterator, Mapping, + NamedTuple, Optional, Sequence, TextIO, Tuple, + Type, Union, ) from ..beancount_types import ( @@ -36,9 +39,13 @@ from ..beancount_types import ( from decimal import Decimal from pathlib import Path +import beancount.query.numberify as bc_query_numberify +import beancount.query.query_compile as bc_query_compile import beancount.query.query_env as bc_query_env +import beancount.query.query_execute as bc_query_execute import beancount.query.query_parser as bc_query_parser -import beancount.query.shell as bc_query +import beancount.query.query_render as bc_query_render +import beancount.query.shell as bc_query_shell from . import core from . import rewrite @@ -55,6 +62,9 @@ PROGNAME = 'query-report' QUERY_PARSER = bc_query_parser.Parser() logger = logging.getLogger('conservancy_beancount.reports.query') +RowTypes = Sequence[Tuple[str, Type]] +Rows = Sequence[NamedTuple] + class BooksLoader: """Closure to load books with a zero-argument callable @@ -88,8 +98,63 @@ class BooksLoader: return result -class BQLShell(bc_query.BQLShell): - pass +class BQLShell(bc_query_shell.BQLShell): + def on_Select(self, statement: str) -> None: + output_format: str = self.vars['format'] + try: + render_func = getattr(self, f'_render_{output_format}') + except AttributeError: + logger.error("unknown output format %r", output_format) + return + + try: + logger.debug("compiling query") + compiled_query = bc_query_compile.compile( + statement, self.env_targets, self.env_postings, self.env_entries, + ) + logger.debug("executing query") + row_types, rows = bc_query_execute.execute_query( + compiled_query, self.entries, self.options_map, + ) + if self.vars['numberify'] and output_format != 'ods': + logger.debug("numberifying query") + row_types, rows = bc_query_numberify.numberify_results( + row_types, rows, self.options_map['dcontext'].build(), + ) + except Exception as error: + logger.error(str(error), exc_info=logger.isEnabledFor(logging.DEBUG)) + return + + if not rows and output_format != 'ods': + print("(empty)", file=self.outfile) + else: + logger.debug("rendering query as %s", output_format) + render_func(row_types, rows) + + def _render_csv(self, row_types: RowTypes, rows: Rows) -> None: + bc_query_render.render_csv( + row_types, + rows, + self.options_map['dcontext'], + self.outfile, + self.vars['expand'], + ) + + def _render_text(self, row_types: RowTypes, rows: Rows) -> None: + with contextlib.ExitStack() as stack: + if self.is_interactive: + output = stack.enter_context(self.get_pager()) + else: + output = self.outfile + bc_query_render.render_text( + row_types, + rows, + self.options_map['dcontext'], + output, + self.vars['expand'], + self.vars['boxed'], + self.vars['spaced'], + ) class JoinOperator(enum.Enum):