diff --git a/conservancy_beancount/cliutil.py b/conservancy_beancount/cliutil.py index cc2f5deddd32f8827c7b9702f4324e5c9b62dca7..22167f977e0f513f49d0721b3ebcb1fbf6a09ea6 100644 --- a/conservancy_beancount/cliutil.py +++ b/conservancy_beancount/cliutil.py @@ -255,6 +255,17 @@ def add_version_argument(parser: argparse.ArgumentParser) -> argparse.Action: def date_arg(arg: str) -> datetime.date: return datetime.datetime.strptime(arg, '%Y-%m-%d').date() +def diff_year(date: datetime.date, diff: int) -> datetime.date: + new_year = date.year + diff + try: + return date.replace(year=new_year) + except ValueError: + # The original date is Feb 29, which doesn't exist in the new year. + if diff < 0: + return datetime.date(new_year, 2, 28) + else: + return datetime.date(new_year, 3, 1) + def year_or_date_arg(arg: str) -> Union[int, datetime.date]: """Get either a date or a year (int) from an argument string diff --git a/conservancy_beancount/reports/fund.py b/conservancy_beancount/reports/fund.py index c21e4358198e90db5bcdf9ed9c839d1d588e3b8a..90df1ffe74130047712a00e69803a17f245f897c 100644 --- a/conservancy_beancount/reports/fund.py +++ b/conservancy_beancount/reports/fund.py @@ -339,17 +339,6 @@ metadata to match. A single ticket number is a shortcut for args.report_type = ReportType.ODS return args -def diff_year(date: datetime.date, diff: int) -> datetime.date: - new_year = date.year + diff - try: - return date.replace(year=new_year) - except ValueError: - # The original date is Feb 29, which doesn't exist in the new year. - if diff < 0: - return datetime.date(new_year, 2, 28) - else: - return datetime.date(new_year, 3, 1) - def main(arglist: Optional[Sequence[str]]=None, stdout: TextIO=sys.stdout, stderr: TextIO=sys.stderr, @@ -365,9 +354,9 @@ def main(arglist: Optional[Sequence[str]]=None, if args.start_date is None: args.stop_date = datetime.date.today() else: - args.stop_date = diff_year(args.start_date, 1) + args.stop_date = cliutil.diff_year(args.start_date, 1) if args.start_date is None: - args.start_date = diff_year(args.stop_date, -1) + args.start_date = cliutil.diff_year(args.stop_date, -1) returncode = 0 books_loader = config.books_loader() diff --git a/conservancy_beancount/reports/ledger.py b/conservancy_beancount/reports/ledger.py index 6297e0efd0bc3ac3deaaf58f5db2c4ad4750b115..05df5507ec7227f1a7854a922291d511da0f926b 100644 --- a/conservancy_beancount/reports/ledger.py +++ b/conservancy_beancount/reports/ledger.py @@ -747,17 +747,6 @@ metadata to match. A single ticket number is a shortcut for args.accounts = list(LedgerODS.ACCOUNT_COLUMNS) return args -def diff_year(date: datetime.date, diff: int) -> datetime.date: - new_year = date.year + diff - try: - return date.replace(year=new_year) - except ValueError: - # The original date is Feb 29, which doesn't exist in the new year. - if diff < 0: - return datetime.date(new_year, 2, 28) - else: - return datetime.date(new_year, 3, 1) - def main(arglist: Optional[Sequence[str]]=None, stdout: TextIO=sys.stdout, stderr: TextIO=sys.stderr, @@ -771,11 +760,11 @@ def main(arglist: Optional[Sequence[str]]=None, today = datetime.date.today() if args.start_date is None: - args.start_date = diff_year(today, -1) + args.start_date = cliutil.diff_year(today, -1) if args.stop_date is None: args.stop_date = today + datetime.timedelta(days=30) elif args.stop_date is None: - args.stop_date = diff_year(args.start_date, 1) + args.stop_date = cliutil.diff_year(args.start_date, 1) returncode = 0 books_loader = config.books_loader() diff --git a/tests/test_cliutil.py b/tests/test_cliutil.py index f26c4884b1b12c1be9fb228fa79c7a22e7bcb17c..5a558d88fb0745fe0815e64ca0df50252c63931e 100644 --- a/tests/test_cliutil.py +++ b/tests/test_cliutil.py @@ -218,3 +218,21 @@ def test_version_argument(argparser, capsys, arg): stdout, _ = capsys.readouterr() lines = iter(stdout.splitlines()) assert re.match(r'^test_cliutil version \d+\.\d+\.\d+', next(lines, "")) + +@pytest.mark.parametrize('date,diff,expected', [ + (datetime.date(2010, 2, 28), 0, datetime.date(2010, 2, 28)), + (datetime.date(2010, 2, 28), 1, datetime.date(2011, 2, 28)), + (datetime.date(2010, 2, 28), 2, datetime.date(2012, 2, 28)), + (datetime.date(2010, 2, 28), -1, datetime.date(2009, 2, 28)), + (datetime.date(2010, 2, 28), -2, datetime.date(2008, 2, 28)), + (datetime.date(2012, 2, 29), 2, datetime.date(2014, 3, 1)), + (datetime.date(2012, 2, 29), 4, datetime.date(2016, 2, 29)), + (datetime.date(2012, 2, 29), -2, datetime.date(2010, 2, 28)), + (datetime.date(2012, 2, 29), -4, datetime.date(2008, 2, 29)), + (datetime.date(2010, 3, 1), 1, datetime.date(2011, 3, 1)), + (datetime.date(2010, 3, 1), 2, datetime.date(2012, 3, 1)), + (datetime.date(2010, 3, 1), -1, datetime.date(2009, 3, 1)), + (datetime.date(2010, 3, 1), -2, datetime.date(2008, 3, 1)), +]) +def test_diff_year(date, diff, expected): + assert cliutil.diff_year(date, diff) == expected