diff --git a/conservancy_beancount/plugin/core.py b/conservancy_beancount/plugin/core.py
index 7817859ae4f90bea6cb219271c74a3babd5706c8..7b427d0cb846cac4faea997c3a8455140a39492f 100644
--- a/conservancy_beancount/plugin/core.py
+++ b/conservancy_beancount/plugin/core.py
@@ -21,6 +21,7 @@ import re
from .. import config as configmod
from .. import data
from .. import errors as errormod
+from .. import ranges
from typing import (
Any,
@@ -49,10 +50,10 @@ from ..beancount_types import (
# I expect these will become configurable in the future, which is why I'm
# keeping them outside of a class, but for now constants will do.
-DEFAULT_START_DATE: datetime.date = datetime.date(2020, 3, 1)
+DEFAULT_START_DATE = datetime.date(2020, 3, 1)
# The default stop date leaves a little room after so it's easy to test
# dates past the far end of the range.
-DEFAULT_STOP_DATE: datetime.date = datetime.date(datetime.MAXYEAR, 1, 1)
+DEFAULT_STOP_DATE = datetime.date(datetime.MAXYEAR, 1, 1)
### TYPE DEFINITIONS
@@ -74,38 +75,6 @@ class Hook(Generic[Entry], metaclass=abc.ABCMeta):
### HELPER CLASSES
-class LessComparable(metaclass=abc.ABCMeta):
- @abc.abstractmethod
- def __le__(self, other: Any) -> bool: ...
-
- @abc.abstractmethod
- def __lt__(self, other: Any) -> bool: ...
-
-
-CT = TypeVar('CT', bound=LessComparable)
-class _GenericRange(Generic[CT]):
- """Convenience class to check whether a value is within a range.
-
- `foo in generic_range` is equivalent to `start <= foo < stop`.
- Since we have multiple user-configurable ranges, having the check
- encapsulated in an object helps implement the check consistently, and
- makes it easier for subclasses to override.
- """
-
- def __init__(self, start: CT, stop: CT) -> None:
- self.start = start
- self.stop = stop
-
- def __repr__(self) -> str:
- return "{clsname}({self.start!r}, {self.stop!r})".format(
- clsname=type(self).__name__,
- self=self,
- )
-
- def __contains__(self, item: CT) -> bool:
- return self.start <= item < self.stop
-
-
class MetadataEnum:
"""Map acceptable metadata values to their normalized forms.
@@ -178,7 +147,7 @@ class MetadataEnum:
class TransactionHook(Hook[Transaction]):
DIRECTIVE = Transaction
SKIP_FLAGS: Container[str] = frozenset()
- TXN_DATE_RANGE: _GenericRange = _GenericRange(DEFAULT_START_DATE, DEFAULT_STOP_DATE)
+ TXN_DATE_RANGE = ranges.DateRange(DEFAULT_START_DATE, DEFAULT_STOP_DATE)
def _run_on_txn(self, txn: Transaction) -> bool:
"""Check whether we should run on a given transaction
diff --git a/conservancy_beancount/ranges.py b/conservancy_beancount/ranges.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa2348b7fe8fdf6af4e94accdaaf7b9af41ef271
--- /dev/null
+++ b/conservancy_beancount/ranges.py
@@ -0,0 +1,59 @@
+"""ranges.py - Higher-typed range classes"""
+# Copyright © 2020 Brett Smith
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+import datetime
+
+from decimal import Decimal
+
+from typing import (
+ Generic,
+ TypeVar,
+ Union,
+)
+
+RangeT = TypeVar(
+ 'RangeT',
+ # This is a relatively arbitrary set of types. Feel free to add to it if
+ # you need; the types just need to support enough comparisons to implement
+ # _GenericRange.__contains__.
+ datetime.date,
+ datetime.datetime,
+ datetime.time,
+ Union[int, Decimal],
+)
+
+class _GenericRange(Generic[RangeT]):
+ """range for higher-level types
+
+ This class knows how to check membership for higher-level types just like
+ Python's built-in range. It does not know how to iterate or step.
+ """
+ def __init__(self, start: RangeT, stop: RangeT) -> None:
+ self.start: RangeT = start
+ self.stop: RangeT = stop
+
+ def __repr__(self) -> str:
+ return "{clsname}({self.start!r}, {self.stop!r})".format(
+ clsname=type(self).__name__,
+ self=self,
+ )
+
+ def __contains__(self, item: RangeT) -> bool:
+ return self.start <= item < self.stop
+
+
+DateRange = _GenericRange[datetime.date]
+DecimalCompatRange = _GenericRange[Union[int, Decimal]]
diff --git a/tests/test_ranges.py b/tests/test_ranges.py
new file mode 100644
index 0000000000000000000000000000000000000000..c76902ea0997a19e2d2d02f397fc1667932de5f8
--- /dev/null
+++ b/tests/test_ranges.py
@@ -0,0 +1,64 @@
+"""test_ranges.py - Unit tests for range classes"""
+# Copyright © 2020 Brett Smith
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+import datetime
+
+import pytest
+
+from conservancy_beancount import ranges
+
+ONE_DAY = datetime.timedelta(days=1)
+
+@pytest.mark.parametrize('start,stop', [
+ # One month
+ (datetime.date(2018, 3, 1), datetime.date(2018, 4, 1)),
+ # Three months
+ (datetime.date(2018, 6, 1), datetime.date(2018, 9, 1)),
+ # Six months, spanning year
+ (datetime.date(2018, 9, 1), datetime.date(2019, 3, 1)),
+ # Nine months
+ (datetime.date(2018, 2, 1), datetime.date(2018, 12, 1)),
+ # Twelve months on Jan 1
+ (datetime.date(2018, 1, 1), datetime.date(2019, 1, 1)),
+ # Twelve months spanning year
+ (datetime.date(2018, 3, 1), datetime.date(2019, 3, 1)),
+ # Eighteen months spanning year
+ (datetime.date(2018, 3, 1), datetime.date(2019, 9, 1)),
+ # Wild
+ (datetime.date(2018, 1, 1), datetime.date(2020, 4, 15)),
+])
+def test_date_range(start, stop):
+ date_range = ranges.DateRange(start, stop)
+ assert (start - ONE_DAY) not in date_range
+ assert start in date_range
+ assert (start + ONE_DAY) in date_range
+ assert (stop - ONE_DAY) in date_range
+ assert stop not in date_range
+ assert (stop + ONE_DAY) not in date_range
+
+def test_date_range_one_day():
+ start = datetime.date(2018, 7, 1)
+ date_range = ranges.DateRange(start, start + ONE_DAY)
+ assert (start - ONE_DAY) not in date_range
+ assert start in date_range
+ assert (start + ONE_DAY) not in date_range
+
+def test_date_range_empty():
+ date = datetime.date(2018, 8, 10)
+ date_range = ranges.DateRange(date, date)
+ assert (date - ONE_DAY) not in date_range
+ assert date not in date_range
+ assert (date + ONE_DAY) not in date_range