Changeset - b880115774ac
[Not reviewed]
0 1 0
Brett Smith - 3 years ago 2021-03-15 17:40:09
brettcsmith@brettcsmith.org
query: Refactor DBColumn.

Avoid an issubclass check on every call, and make it easier for subclasses
to override part of the call implementation.
1 file changed with 15 insertions and 5 deletions:
0 comments (0 inline, 0 general)
conservancy_beancount/reports/query.py
Show inline comments
...
 
@@ -124,121 +124,131 @@ EnvironmentColumns = Dict[
 
    Sequence[object],
 
    Type[bc_query_compile.EvalColumn],
 
]
 
EnvironmentFunctions = Dict[Sequence[object], Type[bc_query_compile.EvalFunction]]
 
RowTypes = Sequence[Tuple[str, Type]]
 
Rows = Sequence[NamedTuple]
 
RTResult = Optional[Mapping[Any, Any]]
 
Store = List[Any]
 
QueryExpression = Union[
 
    bc_query_parser.Column,
 
    bc_query_parser.Constant,
 
    bc_query_parser.Function,
 
    bc_query_parser.UnaryOp,
 
]
 
QueryStatement = Union[
 
    bc_query_parser.Balances,
 
    bc_query_parser.Journal,
 
    bc_query_parser.Select,
 
]
 

	
 
# This class annotates the types that Beancount's RowContexts have when they're
 
# passed to EvalFunction.__call__(). These types get set across
 
# create_row_context and execute_query.
 
class PostingContext:
 
    posting: Posting
 
    entry: Transaction
 
    balance: Inventory
 
    options_map: OptionsMap
 
    account_types: Mapping
 
    open_close_map: Mapping
 
    commodity_map: Mapping
 
    price_map: Mapping
 
    # Dynamically set by execute_query
 
    store: Store
 

	
 

	
 
def ContextMeta(context: PostingContext) -> data.PostingMeta:
 
    """Build a read-only PostingMeta object from the query context"""
 
    # We use sys.maxsize as the index because using a constant is fast, and
 
    # that helps keep the object read-only: if it ever tries to manipulate
 
    # the transaction, it'll get an IndexError.
 
    return data.PostingMeta(context.entry, sys.maxsize, context.posting).detached()
 

	
 

	
 
class DBColumn(bc_query_compile.EvalColumn):
 
    _db_cursor: ClassVar[sqlite3.Cursor]
 
    _db_query: ClassVar[str]
 
    _dtype: ClassVar[Type] = set
 
    _return: ClassVar[Callable[['DBColumn'], object]]
 
    __intypes__ = [Posting]
 

	
 
    @classmethod
 
    def with_db(cls, connection: sqlite3.Connection) -> Type['DBColumn']:
 
        return type(cls.__name__, (cls,), {'_db_cursor': connection.cursor()})
 

	
 
    def __init_subclass__(cls) -> None:
 
        if issubclass(cls._dtype, set):
 
            cls._return = cls._return_set
 
        else:
 
            cls._return = cls._return_scalar
 

	
 
    def __init__(self, colname: Optional[str]=None) -> None:
 
        if not hasattr(self, '_db_cursor'):
 
            if colname is None:
 
                colname = type(self).__name__.lower().replace('db', 'db_', 1)
 
            raise RuntimeError(f"no entity database loaded - {colname} not available")
 
        super().__init__(self._dtype)
 

	
 
    def _entity(self, meta: data.PostingMeta) -> str:
 
        entity = meta.get('entity')
 
        return entity if isinstance(entity, str) else '\0'
 

	
 
    def _return_scalar(self) -> object:
 
        row = self._db_cursor.fetchone()
 
        return self._dtype() if row is None else self._dtype(row[0])
 

	
 
    def _return_set(self) -> object:
 
        return self._dtype(value for value, in self._db_cursor)
 

	
 
    def __call__(self, context: PostingContext) -> object:
 
        entity = self._entity(ContextMeta(context))
 
        self._db_cursor.execute(self._db_query, (entity,))
 
        if issubclass(self._dtype, set):
 
            return self._dtype(value for value, in self._db_cursor)
 
        else:
 
            row = self._db_cursor.fetchone()
 
            return self._dtype() if row is None else self._dtype(row[0])
 
        return self._return()
 

	
 

	
 
class DBEmail(DBColumn):
 
    """Look up an entity's email addresses from the database"""
 
    _db_query = """
 
SELECT email.email_address
 
FROM donor
 
JOIN donor_email_address_mapping map ON donor.id = map.donor_id
 
JOIN email_address email ON map.email_address_id = email.id
 
WHERE donor.ledger_entity_id = ?
 
"""
 

	
 

	
 
class DBId(DBColumn):
 
    """Look up an entity's numeric id from the database"""
 
    _db_query = "SELECT id FROM donor WHERE ledger_entity_id = ?"
 
    _dtype = int
 

	
 

	
 
class DBPostal(DBColumn):
 
    """Look up an entity's postal addresses from the database"""
 
    _db_query = """
 
SELECT postal.formatted_address
 
FROM donor
 
JOIN donor_postal_address_mapping map ON donor.id = map.donor_id
 
JOIN postal_address postal ON map.postal_address_id = postal.id
 
WHERE donor.ledger_entity_id = ?
 
"""
 

	
 

	
 
class MetaDocs(bc_query_env.AnyMeta):
 
    """Return a list of document links from metadata."""
 
    def __init__(self, operands: List[bc_query_compile.EvalNode]) -> None:
 
        super(bc_query_env.AnyMeta, self).__init__(operands, set)
 
        # The second argument is our return type.
 
        # It should match the annotated return type of __call__.
 

	
 
    def __call__(self, context: PostingContext) -> Set[str]:
 
        raw_value = super().__call__(context)
 
        seq = raw_value.split() if isinstance(raw_value, str) else ''
 
        return set(seq)
 

	
 

	
 
class RTField(NamedTuple):
 
    key: str
 
    parse: Optional[Callable[[str], object]]
 
    unset_value: Optional[str] = None
 

	
0 comments (0 inline, 0 general)