diff --git a/conservancy_beancount/cliutil.py b/conservancy_beancount/cliutil.py index ceb3c4391dc110c79dc3a7b1b81f5b3662c42ffb..0c9bdfbb39bd227a733e93e78c05a4e85d6e3798 100644 --- a/conservancy_beancount/cliutil.py +++ b/conservancy_beancount/cliutil.py @@ -17,6 +17,7 @@ You should have received a copy of the GNU Affero General Public License along with this program. If not, see .""" import argparse +import datetime import enum import io import logging @@ -231,6 +232,9 @@ def add_version_argument(parser: argparse.ArgumentParser) -> argparse.Action: help="Show program version and license information", ) +def date_arg(arg: str) -> datetime.date: + return datetime.datetime.strptime(arg, '%Y-%m-%d').date() + def make_entry_point(mod_name: str, prog_name: str=sys.argv[0]) -> Callable[[], int]: """Create an entry_point function for a tool diff --git a/tests/test_cliutil.py b/tests/test_cliutil.py index 8623461e99dd233870451202adf2c26a1f8f4aec..3b7fec9d7d23d2b9d178e7904f885f4888e9b8fa 100644 --- a/tests/test_cliutil.py +++ b/tests/test_cliutil.py @@ -15,6 +15,7 @@ # along with this program. If not, see . import argparse +import datetime import errno import io import inspect @@ -74,6 +75,26 @@ def test_bytes_output_stream(path): actual = cliutil.bytes_output(path, stream) assert actual is stream +@pytest.mark.parametrize('year,month,day', [ + (2000, 1, 1), + (2016, 2, 29), + (2020, 12, 31), +]) +def test_date_arg_valid(year, month, day): + arg = f'{year}-{month}-{day}' + expected = datetime.date(year, month, day) + assert cliutil.date_arg(arg) == expected + +@pytest.mark.parametrize('arg', [ + '2000', + '20-02-12', + '2019-02-29', + 'two thousand', +]) +def test_date_arg_invalid(arg): + with pytest.raises(ValueError): + cliutil.date_arg(arg) + @pytest.mark.parametrize('func_name', [ 'bytes_output', 'text_output',