diff --git a/conservancy_beancount/plugin/core.py b/conservancy_beancount/plugin/core.py index 7817859ae4f90bea6cb219271c74a3babd5706c8..7b427d0cb846cac4faea997c3a8455140a39492f 100644 --- a/conservancy_beancount/plugin/core.py +++ b/conservancy_beancount/plugin/core.py @@ -21,6 +21,7 @@ import re from .. import config as configmod from .. import data from .. import errors as errormod +from .. import ranges from typing import ( Any, @@ -49,10 +50,10 @@ from ..beancount_types import ( # I expect these will become configurable in the future, which is why I'm # keeping them outside of a class, but for now constants will do. -DEFAULT_START_DATE: datetime.date = datetime.date(2020, 3, 1) +DEFAULT_START_DATE = datetime.date(2020, 3, 1) # The default stop date leaves a little room after so it's easy to test # dates past the far end of the range. -DEFAULT_STOP_DATE: datetime.date = datetime.date(datetime.MAXYEAR, 1, 1) +DEFAULT_STOP_DATE = datetime.date(datetime.MAXYEAR, 1, 1) ### TYPE DEFINITIONS @@ -74,38 +75,6 @@ class Hook(Generic[Entry], metaclass=abc.ABCMeta): ### HELPER CLASSES -class LessComparable(metaclass=abc.ABCMeta): - @abc.abstractmethod - def __le__(self, other: Any) -> bool: ... - - @abc.abstractmethod - def __lt__(self, other: Any) -> bool: ... - - -CT = TypeVar('CT', bound=LessComparable) -class _GenericRange(Generic[CT]): - """Convenience class to check whether a value is within a range. - - `foo in generic_range` is equivalent to `start <= foo < stop`. - Since we have multiple user-configurable ranges, having the check - encapsulated in an object helps implement the check consistently, and - makes it easier for subclasses to override. - """ - - def __init__(self, start: CT, stop: CT) -> None: - self.start = start - self.stop = stop - - def __repr__(self) -> str: - return "{clsname}({self.start!r}, {self.stop!r})".format( - clsname=type(self).__name__, - self=self, - ) - - def __contains__(self, item: CT) -> bool: - return self.start <= item < self.stop - - class MetadataEnum: """Map acceptable metadata values to their normalized forms. @@ -178,7 +147,7 @@ class MetadataEnum: class TransactionHook(Hook[Transaction]): DIRECTIVE = Transaction SKIP_FLAGS: Container[str] = frozenset() - TXN_DATE_RANGE: _GenericRange = _GenericRange(DEFAULT_START_DATE, DEFAULT_STOP_DATE) + TXN_DATE_RANGE = ranges.DateRange(DEFAULT_START_DATE, DEFAULT_STOP_DATE) def _run_on_txn(self, txn: Transaction) -> bool: """Check whether we should run on a given transaction diff --git a/conservancy_beancount/ranges.py b/conservancy_beancount/ranges.py new file mode 100644 index 0000000000000000000000000000000000000000..aa2348b7fe8fdf6af4e94accdaaf7b9af41ef271 --- /dev/null +++ b/conservancy_beancount/ranges.py @@ -0,0 +1,59 @@ +"""ranges.py - Higher-typed range classes""" +# Copyright © 2020 Brett Smith +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +import datetime + +from decimal import Decimal + +from typing import ( + Generic, + TypeVar, + Union, +) + +RangeT = TypeVar( + 'RangeT', + # This is a relatively arbitrary set of types. Feel free to add to it if + # you need; the types just need to support enough comparisons to implement + # _GenericRange.__contains__. + datetime.date, + datetime.datetime, + datetime.time, + Union[int, Decimal], +) + +class _GenericRange(Generic[RangeT]): + """range for higher-level types + + This class knows how to check membership for higher-level types just like + Python's built-in range. It does not know how to iterate or step. + """ + def __init__(self, start: RangeT, stop: RangeT) -> None: + self.start: RangeT = start + self.stop: RangeT = stop + + def __repr__(self) -> str: + return "{clsname}({self.start!r}, {self.stop!r})".format( + clsname=type(self).__name__, + self=self, + ) + + def __contains__(self, item: RangeT) -> bool: + return self.start <= item < self.stop + + +DateRange = _GenericRange[datetime.date] +DecimalCompatRange = _GenericRange[Union[int, Decimal]] diff --git a/tests/test_ranges.py b/tests/test_ranges.py new file mode 100644 index 0000000000000000000000000000000000000000..c76902ea0997a19e2d2d02f397fc1667932de5f8 --- /dev/null +++ b/tests/test_ranges.py @@ -0,0 +1,64 @@ +"""test_ranges.py - Unit tests for range classes""" +# Copyright © 2020 Brett Smith +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +import datetime + +import pytest + +from conservancy_beancount import ranges + +ONE_DAY = datetime.timedelta(days=1) + +@pytest.mark.parametrize('start,stop', [ + # One month + (datetime.date(2018, 3, 1), datetime.date(2018, 4, 1)), + # Three months + (datetime.date(2018, 6, 1), datetime.date(2018, 9, 1)), + # Six months, spanning year + (datetime.date(2018, 9, 1), datetime.date(2019, 3, 1)), + # Nine months + (datetime.date(2018, 2, 1), datetime.date(2018, 12, 1)), + # Twelve months on Jan 1 + (datetime.date(2018, 1, 1), datetime.date(2019, 1, 1)), + # Twelve months spanning year + (datetime.date(2018, 3, 1), datetime.date(2019, 3, 1)), + # Eighteen months spanning year + (datetime.date(2018, 3, 1), datetime.date(2019, 9, 1)), + # Wild + (datetime.date(2018, 1, 1), datetime.date(2020, 4, 15)), +]) +def test_date_range(start, stop): + date_range = ranges.DateRange(start, stop) + assert (start - ONE_DAY) not in date_range + assert start in date_range + assert (start + ONE_DAY) in date_range + assert (stop - ONE_DAY) in date_range + assert stop not in date_range + assert (stop + ONE_DAY) not in date_range + +def test_date_range_one_day(): + start = datetime.date(2018, 7, 1) + date_range = ranges.DateRange(start, start + ONE_DAY) + assert (start - ONE_DAY) not in date_range + assert start in date_range + assert (start + ONE_DAY) not in date_range + +def test_date_range_empty(): + date = datetime.date(2018, 8, 10) + date_range = ranges.DateRange(date, date) + assert (date - ONE_DAY) not in date_range + assert date not in date_range + assert (date + ONE_DAY) not in date_range