Changeset - aa488effb0f5
[Not reviewed]
0 7 1
Brett Smith - 4 years ago 2020-05-16 14:29:06
brettcsmith@brettcsmith.org
books.Loader: New loading strategy based on load_file. RT#11034.

Building a string and loading it means Beancount can never cache any
load. It only caches top-level file loads because options in the
top-level file can change the semantics of included entries.

Instead use load_file as much as possible, and filter entries as
needed.
8 files changed with 60 insertions and 79 deletions:
0 comments (0 inline, 0 general)
conservancy_beancount/books.py
Show inline comments
 
"""books - Tools for loading the books"""
 
# Copyright © 2020  Brett Smith
 
#
 
# This program is free software: you can redistribute it and/or modify
 
# it under the terms of the GNU Affero General Public License as published by
 
# the Free Software Foundation, either version 3 of the License, or
 
# (at your option) any later version.
 
#
 
# This program is distributed in the hope that it will be useful,
 
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
# GNU Affero General Public License for more details.
 
#
 
# You should have received a copy of the GNU Affero General Public License
 
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
 

	
 
import contextlib
 
import datetime
 
import os
 

	
 
from pathlib import Path
 

	
 
from beancount import loader as bc_loader
 

	
 
from typing import (
 
    Any,
 
    Iterable,
 
    Iterator,
 
    Mapping,
 
    NamedTuple,
 
    Optional,
 
    Union,
 
)
 
from .beancount_types import (
 
    LoadResult,
 
)
 

	
 
PathLike = Union[str, Path]
 
Year = Union[int, datetime.date]
 

	
 
@contextlib.contextmanager
 
def workdir(path: PathLike) -> Iterator[Path]:
 
    old_dir = os.getcwd()
 
    os.chdir(path)
 
    try:
 
        yield Path(old_dir)
 
    finally:
 
        os.chdir(old_dir)
 

	
 
class FiscalYear(NamedTuple):
 
    month: int = 3
 
    day: int = 1
 

	
 
    def for_date(self, date: Optional[datetime.date]=None) -> int:
 
        if date is None:
 
            date = datetime.date.today()
 
        if (date.month, date.day) < self:
 
            return date.year - 1
 
        else:
 
            return date.year
 

	
 
    def range(self, from_fy: Year, to_fy: Optional[Year]=None) -> Iterable[int]:
 
        """Return a range of fiscal years
 

	
 
        Both arguments can be either a year (represented as an integer) or a
 
        date. Dates will be converted into a year by calling for_date() on
 
        them.
 

	
 
        If the first argument is negative or below 1000, it will be treated as
 
        an offset. You'll get a range of fiscal years between the second
 
        argument offset by this amount.
 

	
 
        If the second argument is omitted, it defaults to the current fiscal
 
        year.
 

	
 
        Note that unlike normal Python ranges, these ranges include the final
 
        fiscal year.
 

	
 
        Examples:
 

	
 
          range(2015)  # Iterate all fiscal years from 2015 to today, inclusive
 

	
 
          range(-1)  # Iterate the previous fiscal year and current fiscal year
 
        """
 
        if not isinstance(from_fy, int):
 
            from_fy = self.for_date(from_fy)
 
        if to_fy is None:
 
            to_fy = self.for_date()
 
        elif not isinstance(to_fy, int):
 
            to_fy = self.for_date(to_fy - datetime.timedelta(days=1))
 
        if from_fy < 1:
 
            from_fy += to_fy
 
        elif from_fy < 1000:
 
            from_fy, to_fy = to_fy, from_fy + to_fy
 
        return range(from_fy, to_fy + 1)
 

	
 

	
 
class Loader:
 
    """Load Beancount books organized by fiscal year"""
 

	
 
    def __init__(self,
 
                 books_root: Path,
 
                 fiscal_year: FiscalYear,
 
    ) -> None:
 
        """Set up a books loader
 

	
 
        Arguments:
 
        * books_root: A Path to a Beancount books checkout.
 
        * fiscal_year: A FiscalYear object, used to determine what books to
 
          load for a given date range.
 
        """
 
        self.books_root = books_root
 
        self.opening_root = books_root / 'books'
 
        self.fiscal_year = fiscal_year
 

	
 
    def _iter_fy_books(self, fy_range: Iterable[int]) -> Iterator[Path]:
 
        dir_path = self.opening_root
 
        for year in fy_range:
 
            path = dir_path / f'{year}.beancount'
 
            path = Path(self.books_root, 'books', f'{year}.beancount')
 
            if path.exists():
 
                yield path
 
                dir_path = self.books_root
 

	
 
    def fy_range_string(self,
 
    def load_fy_range(self,
 
                      from_fy: Year,
 
                      to_fy: Optional[Year]=None,
 
    ) -> str:
 
        """Return a string to load books for a range of fiscal years
 
    ) -> LoadResult:
 
        """Load books for a range of fiscal years
 

	
 
        This method generates a range of fiscal years by calling
 
        FiscalYear.range() with its first two arguments. It returns a string of
 
        Beancount directives to load the books from the first available fiscal
 
        year through the end of the range.
 

	
 
        Pass the string to Loader.load_string() to actually load data from it.
 
        """
 
        paths = self._iter_fy_books(self.fiscal_year.range(from_fy, to_fy))
 
        fy_range = self.fiscal_year.range(from_fy, to_fy)
 
        fy_paths = self._iter_fy_books(fy_range)
 
        try:
 
            with next(paths).open() as opening_books:
 
                lines = [opening_books.read()]
 
            entries, errors, options_map = bc_loader.load_file(next(fy_paths))
 
        except StopIteration:
 
            return ''
 
        for path in paths:
 
            lines.append(f'include "../{path.name}"')
 
        return '\n'.join(lines)
 

	
 
    def load_string(self, source: str) -> LoadResult:
 
        """Load a generated string of Beancount directives
 

	
 
        This method takes a string generated by another Loader method, like
 
        fy_range_string, and loads it through Beancount, setting up the
 
        environment as necessary to do that.
 
        """
 
        with workdir(self.opening_root):
 
            retval: LoadResult = bc_loader.load_string(source)
 
        return retval
 

	
 
    def load_fy_range(self,
 
                      from_fy: Year,
 
                      to_fy: Optional[Year]=None,
 
    ) -> LoadResult:
 
        """Load books for a range of fiscal years"""
 
        return self.load_string(self.fy_range_string(from_fy, to_fy))
 
            entries, errors, options_map = [], [], {}
 
        for load_path in fy_paths:
 
            new_entries, new_errors, new_options = bc_loader.load_file(load_path)
 
            # We only want transactions from the new fiscal year.
 
            # We don't want the opening balance, duplicate definitions, etc.
 
            fy_filename = str(load_path.parent.parent / load_path.name)
 
            entries.extend(
 
                entry for entry in new_entries
 
                if entry.meta.get('filename') == fy_filename
 
            )
 
            errors.extend(new_errors)
 
        return entries, errors, options_map
tests/books/books/2018.beancount
Show inline comments
 
option "title" "Books from 2018"
 
plugin "beancount.plugins.auto"
 
include "../definitions.beancount"
 
include "../2018.beancount"
tests/books/books/2019.beancount
Show inline comments
 
option "title" "Books from 2019"
 
plugin "beancount.plugins.auto"
 
include "../definitions.beancount"
 
include "../2019.beancount"
tests/books/books/2020.beancount
Show inline comments
 
option "title" "Books from 2020"
 
plugin "beancount.plugins.auto"
 
include "../definitions.beancount"
 
include "../2020.beancount"
tests/books/definitions.beancount
Show inline comments
 
new file 100644
 
2018-03-01 open Assets:Checking
 
2018-03-01 open Income:Donations
tests/test_books_loader.py
Show inline comments
 
"""test_books_loader - Unit tests for books Loader class"""
 
# Copyright © 2020  Brett Smith
 
#
 
# This program is free software: you can redistribute it and/or modify
 
# it under the terms of the GNU Affero General Public License as published by
 
# the Free Software Foundation, either version 3 of the License, or
 
# (at your option) any later version.
 
#
 
# This program is distributed in the hope that it will be useful,
 
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
# GNU Affero General Public License for more details.
 
#
 
# You should have received a copy of the GNU Affero General Public License
 
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
 

	
 
import hashlib
 
import re
 

	
 
from datetime import date
 
from pathlib import Path
 

	
 
import pytest
 

	
 
from . import testutil
 

	
 
from conservancy_beancount import books
 

	
 
books_path = testutil.test_path('books')
 

	
 
@pytest.fixture(scope='module')
 
def conservancy_loader():
 
    return books.Loader(books_path, books.FiscalYear(3))
 

	
 
def include_patterns(years, subdir='..'):
 
    for year in years:
 
        path = Path(subdir, f'{year}.beancount')
 
        yield rf'^include "{re.escape(str(path))}"$'
 

	
 
@pytest.mark.parametrize('range_start,range_stop,expect_years', [
 
    (2019, 2020, [2019, 2020]),
 
    (-1, 2020, [2019, 2020]),
 
    (10, 2019, [2019, 2020]),
 
    (-10, 2019, [2018, 2019]),
 
    (date(2019, 1, 1), date(2020, 6, 1), [2018, 2019, 2020]),
 
    (-1, date(2020, 2, 1), [2018, 2019]),
 
])
 
def test_fy_range_string(conservancy_loader, range_start, range_stop, expect_years):
 
    actual = conservancy_loader.fy_range_string(range_start, range_stop)
 
    testutil.check_lines_match(actual.splitlines(), [
 
        rf'^option "title" "Books from {expect_years[0]}"$',
 
        rf'^plugin "beancount\.plugins\.auto"$',
 
        *include_patterns(expect_years),
 
@pytest.mark.parametrize('from_fy,to_fy,expect_years', [
 
    (2019, 2019, range(2019, 2020)),
 
    (0, 2019, range(2019, 2020)),
 
    (2018, 2019, range(2018, 2020)),
 
    (1, 2018, range(2018, 2020)),
 
    (-1, 2019, range(2018, 2020)),
 
    (2019, 2020, range(2019, 2021)),
 
    (1, 2019, range(2019, 2021)),
 
    (-1, 2020, range(2019, 2021)),
 
    (2010, 2030, range(2018, 2021)),
 
    (20, 2010, range(2018, 2021)),
 
    (-20, 2030, range(2018, 2021)),
 
])
 

	
 
@pytest.mark.parametrize('year_offset', range(-3, 1))
 
def test_fy_range_string_with_offset(conservancy_loader, year_offset):
 
    base_year = 2020
 
    start_year = max(2018, base_year + year_offset)
 
    expect_years = range(start_year, base_year + 1)
 
    actual = conservancy_loader.fy_range_string(year_offset, base_year)
 
    testutil.check_lines_match(actual.splitlines(), include_patterns(expect_years))
 

	
 
def test_fy_range_string_empty_range(conservancy_loader):
 
    assert conservancy_loader.fy_range_string(2020, 2019) == ''
 

	
 
def test_load_fy_range(conservancy_loader):
 
    entries, errors, options_map = conservancy_loader.load_fy_range(2018, 2019)
 
def test_load_fy_range(conservancy_loader, from_fy, to_fy, expect_years):
 
    entries, errors, options_map = conservancy_loader.load_fy_range(from_fy, to_fy)
 
    assert not errors
 
    narrations = {getattr(entry, 'narration', None) for entry in entries}
 
    assert '2018 donation' in narrations
 
    assert '2019 donation' in narrations
 
    assert '2020 donation' not in narrations
 
    assert ('2018 donation' in narrations) == (2018 in expect_years)
 
    assert ('2019 donation' in narrations) == (2019 in expect_years)
 
    assert ('2020 donation' in narrations) == (2020 in expect_years)
 

	
 
def test_load_fy_range_does_not_duplicate_openings(conservancy_loader):
 
    entries, errors, options_map = conservancy_loader.load_fy_range(2010, 2030)
 
    openings = []
 
    open_accounts = set()
 
    for entry in entries:
 
        try:
 
            open_accounts.add(entry.account)
 
        except AttributeError:
 
            pass
 
        else:
 
            openings.append(entry)
 
    assert len(openings) == len(open_accounts)
 

	
 
def test_load_fy_range_empty(conservancy_loader):
 
    entries, errors, options_map = conservancy_loader.load_fy_range(2020, 2019)
 
    assert not errors
 
    assert not entries
 
    assert options_map.get('input_hash') == hashlib.md5().hexdigest()
 
    assert not options_map
tests/test_config.py
Show inline comments
...
 
@@ -17,388 +17,390 @@
 
import contextlib
 
import decimal
 
import operator
 
import os
 
import re
 

	
 
from pathlib import Path
 

	
 
import pytest
 

	
 
from . import testutil
 

	
 
from conservancy_beancount import config as config_mod
 

	
 
RT_AUTH_METHODS = frozenset(['basic', 'gssapi', 'rt'])
 

	
 
RT_ENV_KEYS = (
 
    'RTSERVER',
 
    'RTUSER',
 
    'RTPASSWD',
 
    'RTAUTH',
 
)
 

	
 
RT_ENV_CREDS = (
 
    'https://example.org/envrt',
 
    'envuser',
 
    'env  password',
 
    'gssapi',
 
)
 

	
 
RT_FILE_CREDS = (
 
    'https://example.org/filert',
 
    'fileuser',
 
    'file  password',
 
    'basic',
 
)
 

	
 
RT_GENERIC_CREDS = config_mod.RTCredentials(
 
    'https://example.org/genericrt',
 
    'genericuser',
 
    'generic password',
 
    None,
 
)
 

	
 
@pytest.fixture
 
def rt_environ():
 
    return dict(zip(RT_ENV_KEYS, RT_ENV_CREDS))
 

	
 
def _update_environ(updates):
 
    for key, value in updates.items():
 
        if value is None:
 
            os.environ.pop(key, None)
 
        else:
 
            os.environ[key] = str(value)
 

	
 
@contextlib.contextmanager
 
def update_environ(**kwargs):
 
    revert = {key: os.environ.get(key) for key in kwargs}
 
    _update_environ(kwargs)
 
    try:
 
        yield
 
    finally:
 
        _update_environ(revert)
 

	
 
@contextlib.contextmanager
 
def update_umask(mask):
 
    old_mask = os.umask(mask)
 
    try:
 
        yield old_mask
 
    finally:
 
        os.umask(old_mask)
 

	
 
def test_repository_from_environment():
 
    config = config_mod.Config()
 
    assert config.repository_path() == testutil.test_path('repository')
 

	
 
def test_no_repository():
 
    with update_environ(CONSERVANCY_REPOSITORY=None):
 
        config = config_mod.Config()
 
        assert config.repository_path() is None
 

	
 
def test_no_rt_credentials():
 
    with update_environ(HOME=testutil.TESTS_DIR):
 
        config = config_mod.Config()
 
        rt_credentials = config.rt_credentials()
 
    assert rt_credentials.server is None
 
    assert rt_credentials.user is None
 
    assert rt_credentials.passwd is None
 
    assert rt_credentials.auth == 'rt'
 

	
 
def test_rt_credentials_from_file():
 
    config = config_mod.Config()
 
    rt_credentials = config.rt_credentials()
 
    assert rt_credentials == RT_FILE_CREDS
 

	
 
def test_rt_credentials_from_environment(rt_environ):
 
    with update_environ(**rt_environ):
 
        config = config_mod.Config()
 
        rt_credentials = config.rt_credentials()
 
    assert rt_credentials == RT_ENV_CREDS
 

	
 
@pytest.mark.parametrize('index,drop_key', enumerate(RT_ENV_KEYS))
 
def test_rt_credentials_from_file_and_environment_mixed(rt_environ, index, drop_key):
 
    del rt_environ[drop_key]
 
    with update_environ(**rt_environ):
 
        config = config_mod.Config()
 
        rt_credentials = config.rt_credentials()
 
    expected = list(RT_ENV_CREDS)
 
    expected[index] = RT_FILE_CREDS[index]
 
    assert rt_credentials == tuple(expected)
 

	
 
def test_rt_credentials_from_all_sources_mixed(tmp_path):
 
    server = 'https://example.org/mixedrt'
 
    with (tmp_path / '.rtrc').open('w') as rtrc_file:
 
        print('user basemix', 'passwd mixed up', file=rtrc_file, sep='\n')
 
    with update_environ(HOME=tmp_path, RTSERVER=server, RTUSER='mixedup'):
 
        config = config_mod.Config()
 
        rt_credentials = config.rt_credentials()
 
    assert rt_credentials == (server, 'mixedup', 'mixed up', 'rt')
 

	
 
def check_rt_client_url(credentials, client):
 
    pattern = '^{}/?$'.format(re.escape(credentials[0].rstrip('/') + '/REST/1.0'))
 
    assert re.match(pattern, client.url)
 

	
 
@pytest.mark.parametrize('authmethod', RT_AUTH_METHODS)
 
def test_rt_client(authmethod):
 
    rt_credentials = RT_GENERIC_CREDS._replace(auth=authmethod)
 
    config = config_mod.Config()
 
    rt_client = config.rt_client(rt_credentials, testutil.RTClient)
 
    check_rt_client_url(RT_GENERIC_CREDS, rt_client)
 
    assert rt_client.auth_method == ('HTTPBasicAuth' if authmethod == 'basic' else 'login')
 
    assert rt_client.last_login == (
 
        RT_GENERIC_CREDS.user,
 
        RT_GENERIC_CREDS.passwd,
 
        True,
 
    )
 

	
 
def test_default_rt_client(rt_environ):
 
    with update_environ(**rt_environ):
 
        config = config_mod.Config()
 
        rt_client = config.rt_client(client=testutil.RTClient)
 
    check_rt_client_url(RT_ENV_CREDS, rt_client)
 
    assert rt_client.last_login[:-1] == RT_ENV_CREDS[1:3]
 
    assert rt_client.last_login[-1]
 

	
 
@pytest.mark.parametrize('authmethod', RT_AUTH_METHODS)
 
def test_rt_client_login_failure(authmethod):
 
    rt_credentials = RT_GENERIC_CREDS._replace(
 
        auth=authmethod,
 
        passwd='bad{}'.format(authmethod),
 
    )
 
    config = config_mod.Config()
 
    assert config.rt_client(rt_credentials, testutil.RTClient) is None
 

	
 
def test_no_rt_client_without_server():
 
    rt_credentials = RT_GENERIC_CREDS._replace(server=None, auth='rt')
 
    config = config_mod.Config()
 
    assert config.rt_client(rt_credentials, testutil.RTClient) is None
 

	
 
def test_rt_wrapper():
 
    config = config_mod.Config()
 
    rt = config.rt_wrapper(RT_GENERIC_CREDS._replace(auth='rt'), testutil.RTClient)
 
    assert rt.exists(1)
 

	
 
def test_rt_wrapper_default_creds():
 
    config = config_mod.Config()
 
    rt = config.rt_wrapper(None, testutil.RTClient)
 
    assert rt.rt.url.startswith(RT_FILE_CREDS[0])
 

	
 
def test_rt_wrapper_default_creds_from_environ(rt_environ):
 
    with update_environ(**rt_environ):
 
        config = config_mod.Config()
 
        rt = config.rt_wrapper(None, testutil.RTClient)
 
    assert rt.rt.url.startswith(RT_ENV_CREDS[0])
 

	
 
def test_rt_wrapper_no_creds():
 
    with update_environ(HOME=testutil.TESTS_DIR):
 
        config = config_mod.Config()
 
        assert config.rt_wrapper(None, testutil.RTClient) is None
 

	
 
def test_rt_wrapper_bad_creds():
 
    rt_credentials = RT_GENERIC_CREDS._replace(passwd='badpass', auth='rt')
 
    config = config_mod.Config()
 
    assert config.rt_wrapper(rt_credentials, testutil.RTClient) is None
 

	
 
def test_rt_wrapper_caches():
 
    rt_credentials = RT_GENERIC_CREDS._replace(auth='rt')
 
    config = config_mod.Config()
 
    rt1 = config.rt_wrapper(rt_credentials, testutil.RTClient)
 
    rt2 = config.rt_wrapper(rt_credentials, testutil.RTClient)
 
    assert rt1 is rt2
 

	
 
def test_rt_wrapper_caches_by_creds():
 
    config = config_mod.Config()
 
    rt1 = config.rt_wrapper(RT_GENERIC_CREDS._replace(auth='rt'), testutil.RTClient)
 
    rt2 = config.rt_wrapper(None, testutil.RTClient)
 
    assert rt1 is not rt2
 

	
 
def test_rt_wrapper_cache_responds_to_external_credential_changes(rt_environ):
 
    config = config_mod.Config()
 
    rt1 = config.rt_wrapper(None, testutil.RTClient)
 
    with update_environ(**rt_environ):
 
        rt2 = config.rt_wrapper(None, testutil.RTClient)
 
    assert rt1 is not rt2
 

	
 
def test_rt_wrapper_has_cache(tmp_path):
 
    with update_environ(XDG_CACHE_HOME=tmp_path), update_umask(0o002):
 
        config = config_mod.Config()
 
        rt = config.rt_wrapper(None, testutil.RTClient)
 
        rt.exists(1)
 
    expected = 'conservancy_beancount/{}@*.sqlite3'.format(RT_FILE_CREDS[1])
 
    actual = None
 
    for actual in tmp_path.glob(expected):
 
        assert not actual.stat().st_mode & 0o177
 
    assert actual is not None, "did not find any generated cache file"
 

	
 
def test_rt_wrapper_without_cache(tmp_path):
 
    tmp_path.chmod(0)
 
    with update_environ(XDG_CACHE_HOME=tmp_path):
 
        config = config_mod.Config()
 
        rt = config.rt_wrapper(None, testutil.RTClient)
 
    tmp_path.chmod(0o600)
 
    assert not any(tmp_path.iterdir())
 

	
 
def test_cache_mkdir(tmp_path):
 
    expected = tmp_path / 'TESTcache'
 
    with update_environ(XDG_CACHE_HOME=tmp_path):
 
        config = config_mod.Config()
 
        cache_path = config.cache_dir_path(expected.name)
 
    assert cache_path == tmp_path / 'TESTcache'
 
    assert cache_path.is_dir()
 

	
 
def test_cache_mkdir_parent(tmp_path):
 
    xdg_cache_dir = tmp_path / 'xdgcache'
 
    expected = xdg_cache_dir / 'conservancy_beancount'
 
    with update_environ(XDG_CACHE_HOME=xdg_cache_dir):
 
        config = config_mod.Config()
 
        cache_path = config.cache_dir_path(expected.name)
 
    assert cache_path == expected
 
    assert cache_path.is_dir()
 

	
 
def test_cache_mkdir_from_home(tmp_path):
 
    expected = tmp_path / '.cache' / 'TESTcache'
 
    with update_environ(HOME=tmp_path, XDG_CACHE_HOME=None):
 
        config = config_mod.Config()
 
        cache_path = config.cache_dir_path(expected.name)
 
    assert cache_path == expected
 
    assert cache_path.is_dir()
 

	
 
def test_cache_mkdir_exists_ok(tmp_path):
 
    expected = tmp_path / 'TESTcache'
 
    expected.mkdir()
 
    with update_environ(XDG_CACHE_HOME=tmp_path):
 
        config = config_mod.Config()
 
        cache_path = config.cache_dir_path(expected.name)
 
    assert cache_path == expected
 

	
 
def test_cache_path_conflict(tmp_path):
 
    extant_path = tmp_path / 'TESTcache'
 
    extant_path.touch()
 
    with update_environ(XDG_CACHE_HOME=tmp_path):
 
        config = config_mod.Config()
 
        cache_path = config.cache_dir_path(extant_path.name)
 
    assert cache_path is None
 
    assert extant_path.is_file()
 

	
 
def test_cache_path_parent_conflict(tmp_path):
 
    (tmp_path / '.cache').touch()
 
    with update_environ(HOME=tmp_path, XDG_CACHE_HOME=None):
 
        config = config_mod.Config()
 
        assert config.cache_dir_path('TESTcache') is None
 

	
 
def test_relative_xdg_cache_home_ignored(tmp_path):
 
    with update_environ(HOME=tmp_path,
 
                        XDG_CACHE_HOME='nonexistent/test/cache/directory/tree'):
 
        config = config_mod.Config()
 
        cache_dir_path = config.cache_dir_path('TESTcache')
 
    assert cache_dir_path == tmp_path / '.cache/TESTcache'
 

	
 
def test_default_payment_threshold():
 
    threshold = config_mod.Config().payment_threshold()
 
    assert isinstance(threshold, (int, decimal.Decimal))
 

	
 
@pytest.mark.parametrize('config_threshold', [
 
    '15',
 
    ' +15',
 
    '15. ',
 
    '15.0',
 
    '15.00',
 
])
 
def test_payment_threshold(config_threshold):
 
    config = config_mod.Config()
 
    config.load_string(f'[Beancount]\npayment threshold = {config_threshold}\n')
 
    assert config.payment_threshold() == decimal.Decimal(15)
 

	
 
@pytest.mark.parametrize('config_path', [
 
    None,
 
    '',
 
    'nonexistent/relative/path',
 
])
 
def test_config_file_path(config_path):
 
    expected = Path('~/.config/conservancy_beancount/config.ini').expanduser()
 
    with update_environ(XDG_CONFIG_HOME=config_path):
 
        config = config_mod.Config()
 
        assert config.config_file_path() == expected
 

	
 
def test_config_file_path_respects_xdg_config_home():
 
    with update_environ(XDG_CONFIG_HOME='/etc'):
 
        config = config_mod.Config()
 
        assert config.config_file_path() == Path('/etc/conservancy_beancount/config.ini')
 

	
 
def test_config_file_path_with_subdir():
 
    expected = testutil.test_path('userconfig/conftest/config.ini')
 
    config = config_mod.Config()
 
    assert config.config_file_path('conftest') == expected
 

	
 
@pytest.mark.parametrize('path', [
 
    None,
 
    testutil.test_path('userconfig/conservancy_beancount/config.ini'),
 
])
 
def test_load_file(path):
 
    config = config_mod.Config()
 
    config.load_file(path)
 
    assert config.books_path() == Path('/test/conservancy_beancount')
 

	
 
@pytest.mark.parametrize('path_func', [
 
    lambda path: None,
 
    operator.methodcaller('touch', 0o200),
 
])
 
def test_load_file_error(tmp_path, path_func):
 
    config_path = tmp_path / 'nonexistent.ini'
 
    path_func(config_path)
 
    config = config_mod.Config()
 
    with pytest.raises(OSError):
 
        config.load_file(config_path)
 

	
 
def test_no_books_path():
 
    config = config_mod.Config()
 
    assert config.books_path() is None
 

	
 
@pytest.mark.parametrize('value,month,day', [
 
    ('2', 2, 1),
 
    ('3 ', 3, 1),
 
    ('  4', 4, 1),
 
    (' 5 ', 5, 1),
 
    ('6 1', 6, 1),
 
    ('  06  03  ', 6, 3),
 
    ('6-05', 6, 5),
 
    ('06 - 10', 6, 10),
 
    ('6/15', 6, 15),
 
    ('06  /  20', 6, 20),
 
    ('10.25', 10, 25),
 
    (' 10 . 30 ', 10, 30),
 
])
 
def test_fiscal_year_begin(value, month, day):
 
    config = config_mod.Config()
 
    config.load_string(f'[Beancount]\nfiscal year begin = {value}\n')
 
    assert config.fiscal_year_begin() == (month, day)
 

	
 
@pytest.mark.parametrize('value', [
 
    'text',
 
    '1900',
 
    '13',
 
    '010',
 
    '2 30',
 
    '4-31',
 
])
 
def test_bad_fiscal_year_begin(value):
 
    config = config_mod.Config()
 
    config.load_string(f'[Beancount]\nfiscal year begin = {value}\n')
 
    with pytest.raises(ValueError):
 
        config.fiscal_year_begin()
 

	
 
def test_default_fiscal_year_begin():
 
    config = config_mod.Config()
 
    actual = config.fiscal_year_begin()
 
    assert actual.month == 3
 
    assert actual.day == 1
 

	
 
def test_books_loader():
 
    books_path = testutil.test_path('books')
 
    config = config_mod.Config()
 
    config.load_string(f'[Beancount]\nbooks dir = {books_path}\n')
 
    loader = config.books_loader()
 
    assert loader.fy_range_string(2020, 2020)
 
    entries, errors, _ = loader.load_fy_range(2020, 2020)
 
    assert entries
 
    assert not errors
 

	
 
def test_books_loader_without_books():
 
    assert config_mod.Config().books_loader() is None
tests/testutil.py
Show inline comments
 
"""Mock Beancount objects for testing"""
 
# Copyright © 2020  Brett Smith
 
#
 
# This program is free software: you can redistribute it and/or modify
 
# it under the terms of the GNU Affero General Public License as published by
 
# the Free Software Foundation, either version 3 of the License, or
 
# (at your option) any later version.
 
#
 
# This program is distributed in the hope that it will be useful,
 
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
# GNU Affero General Public License for more details.
 
#
 
# You should have received a copy of the GNU Affero General Public License
 
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
 

	
 
import datetime
 
import itertools
 
import re
 

	
 
import beancount.core.amount as bc_amount
 
import beancount.core.data as bc_data
 
import beancount.loader as bc_loader
 

	
 
from decimal import Decimal
 
from pathlib import Path
 

	
 
from conservancy_beancount import books, rtutil
 

	
 
EXTREME_FUTURE_DATE = datetime.date(datetime.MAXYEAR, 12, 30)
 
FUTURE_DATE = datetime.date.today() + datetime.timedelta(days=365 * 99)
 
FY_START_DATE = datetime.date(2020, 3, 1)
 
FY_MID_DATE = datetime.date(2020, 9, 1)
 
PAST_DATE = datetime.date(2000, 1, 1)
 
TESTS_DIR = Path(__file__).parent
 

	
 
def check_lines_match(lines, expect_patterns, source='output'):
 
    for pattern in expect_patterns:
 
        assert any(re.search(pattern, line) for line in lines), \
 
            f"{pattern!r} not found in {source}"
 

	
 
def check_post_meta(txn, *expected_meta, default=None):
 
    assert len(txn.postings) == len(expected_meta)
 
    for post, expected in zip(txn.postings, expected_meta):
 
        if not expected:
 
            assert not post.meta
 
        else:
 
            actual = None if post.meta is None else {
 
                key: post.meta.get(key, default) for key in expected
 
            }
 
            assert actual == expected
 

	
 
def combine_values(*value_seqs):
 
    stop = 0
 
    for seq in value_seqs:
 
        try:
 
            stop = max(stop, len(seq))
 
        except TypeError:
 
            pass
 
    return itertools.islice(
 
        zip(*(itertools.cycle(seq) for seq in value_seqs)),
 
        stop,
 
    )
 

	
 
def date_seq(date=FY_MID_DATE, step=1):
 
    while True:
 
        yield date
 
        date += datetime.timedelta(days=step)
 

	
 
def parse_date(s, fmt='%Y-%m-%d'):
 
    return datetime.datetime.strptime(s, fmt).date()
 

	
 
def test_path(s):
 
    if s is None:
 
        return s
 
    s = Path(s)
 
    if not s.is_absolute():
 
        s = TESTS_DIR / s
 
    return s
 

	
 
def Amount(number, currency='USD'):
 
    return bc_amount.Amount(Decimal(number), currency)
 

	
 
def Cost(number, currency='USD', date=FY_MID_DATE, label=None):
 
    return bc_data.Cost(Decimal(number), currency, date, label)
 

	
 
def Posting(account, number,
 
            currency='USD', cost=None, price=None, flag=None,
 
            type_=bc_data.Posting, **meta):
 
    if cost is not None:
 
        cost = Cost(*cost)
 
    if not meta:
 
        meta = None
 
    return type_(
 
        account,
 
        Amount(number, currency),
 
        cost,
 
        price,
 
        flag,
 
        meta,
 
    )
 

	
 
def Transaction(date=FY_MID_DATE, flag='*', payee=None,
 
                narration='', tags=None, links=None, postings=(),
 
                **meta):
 
    if isinstance(date, str):
 
        date = parse_date(date)
 
    meta.setdefault('filename', '<test>')
 
    meta.setdefault('lineno', 0)
 
    real_postings = []
 
    for post in postings:
 
        try:
 
            post.account
 
        except AttributeError:
 
            if isinstance(post[-1], dict):
 
                args = post[:-1]
 
                kwargs = post[-1]
 
            else:
 
                args = post
 
                kwargs = {}
 
            post = Posting(*args, **kwargs)
 
        real_postings.append(post)
 
    return bc_data.Transaction(
 
        meta,
 
        date,
 
        flag,
 
        payee,
 
        narration,
 
        set(tags or ''),
 
        set(links or ''),
 
        real_postings,
 
    )
 

	
 
LINK_METADATA_STRINGS = {
 
    'Invoices/304321.pdf',
 
    'rt:123/456',
 
    'rt://ticket/234',
 
}
 

	
 
NON_LINK_METADATA_STRINGS = {
 
    '',
 
    ' ',
 
    '     ',
 
}
 

	
 
NON_STRING_METADATA_VALUES = [
 
    Decimal(5),
 
    FY_MID_DATE,
 
    Amount(50),
 
    Amount(500, None),
 
]
 

	
 
OPENING_EQUITY_ACCOUNTS = itertools.cycle([
 
    'Equity:Funds:Unrestricted',
 
    'Equity:Funds:Restricted',
 
    'Equity:OpeningBalance',
 
])
 

	
 
def balance_map(source=None, **kwargs):
 
    # The source and/or kwargs should map currency name strings to
 
    # things you can pass to Decimal (a decimal string, an int, etc.)
 
    # This returns a dict that maps currency name strings to Amount instances.
 
    retval = {}
 
    if source is not None:
 
        retval.update((currency, Amount(number, currency))
 
                      for currency, number in source)
 
    if kwargs:
 
        retval.update(balance_map(kwargs.items()))
 
    return retval
 

	
 
def OpeningBalance(acct=None, **txn_meta):
 
    if acct is None:
 
        acct = next(OPENING_EQUITY_ACCOUNTS)
 
    return Transaction(**txn_meta, postings=[
 
        ('Assets:Receivable:Accounts', 100),
 
        ('Assets:Receivable:Loans', 200),
 
        ('Liabilities:Payable:Accounts', -15),
 
        ('Liabilities:Payable:Vacation', -25),
 
        (acct, -260),
 
    ])
 

	
 
class TestBooksLoader(books.Loader):
 
    def __init__(self, source):
 
        self.source = source
 

	
 
    def fy_range_string(self, from_fy=None, to_fy=None, plugins=None):
 
        return f'include "{self.source}"'
 

	
 
    load_string = staticmethod(bc_loader.load_string)
 
    def load_fy_range(self, from_fy, to_fy=None):
 
        return bc_loader.load_file(self.source)
 

	
 

	
 
class TestConfig:
 
    def __init__(self, *,
 
                 books_path=None,
 
                 payment_threshold=0,
 
                 repo_path=None,
 
                 rt_client=None,
 
    ):
 
        if books_path is None:
 
            self._books_loader = None
 
        else:
 
            self._books_loader = TestBooksLoader(books_path)
 
        self._payment_threshold = Decimal(payment_threshold)
 
        self.repo_path = test_path(repo_path)
 
        self._rt_client = rt_client
 
        if rt_client is None:
 
            self._rt_wrapper = None
 
        else:
 
            self._rt_wrapper = rtutil.RT(rt_client)
 

	
 
    def books_loader(self):
 
        return self._books_loader
 

	
 
    def config_file_path(self):
 
        return test_path('userconfig/conservancy_beancount/config.ini')
 

	
 
    def payment_threshold(self):
 
        return self._payment_threshold
 

	
 
    def repository_path(self):
 
        return self.repo_path
 

	
 
    def rt_client(self):
 
        return self._rt_client
 

	
 
    def rt_wrapper(self):
 
        return self._rt_wrapper
 

	
 

	
 
class _TicketBuilder:
 
    MESSAGE_ATTACHMENTS = [
 
        ('(Unnamed)', 'multipart/alternative', '0b'),
 
        ('(Unnamed)', 'text/plain', '1.2k'),
 
        ('(Unnamed)', 'text/html', '1.4k'),
 
    ]
 
    MISC_ATTACHMENTS = [
 
        ('Forwarded Message.eml', 'message/rfc822', '3.1k'),
 
        ('photo.jpg', 'image/jpeg', '65.2k'),
 
        ('ConservancyInvoice-301.pdf', 'application/pdf', '326k'),
 
        ('Company_invoice-2020030405_as-sent.pdf', 'application/pdf', '50k'),
 
        ('statement.txt', 'text/plain', '652b'),
 
        ('screenshot.png', 'image/png', '1.9m'),
 
    ]
 

	
 
    def __init__(self):
 
        self.id_seq = itertools.count(1)
 
        self.misc_attchs = itertools.cycle(self.MISC_ATTACHMENTS)
 

	
 
    def new_attch(self, attch):
 
        return (str(next(self.id_seq)), *attch)
 

	
 
    def new_msg_with_attachments(self, attachments_count=1):
 
        for attch in self.MESSAGE_ATTACHMENTS:
 
            yield self.new_attch(attch)
 
        for _ in range(attachments_count):
 
            yield self.new_attch(next(self.misc_attchs))
 

	
 
    def new_messages(self, messages_count, attachments_count=None):
 
        for n in range(messages_count):
 
            if attachments_count is None:
 
                att_count = messages_count - n
 
            else:
 
                att_count = attachments_count
 
            yield from self.new_msg_with_attachments(att_count)
 

	
 

	
 
class RTClient:
 
    _builder = _TicketBuilder()
 
    DEFAULT_URL = 'https://example.org/defaultrt/REST/1.0/'
 
    TICKET_DATA = {
 
        '1': list(_builder.new_messages(1, 3)),
 
        '2': list(_builder.new_messages(2, 1)),
 
        '3': list(_builder.new_messages(3, 0)),
 
    }
 
    del _builder
 

	
 
    def __init__(self,
 
                 url=DEFAULT_URL,
 
                 default_login=None,
 
                 default_password=None,
 
                 proxy=None,
 
                 default_queue='General',
 
                 skip_login=False,
 
                 verify_cert=True,
 
                 http_auth=None,
 
    ):
 
        self.url = url
 
        if http_auth is None:
 
            self.user = default_login
 
            self.password = default_password
 
            self.auth_method = 'login'
 
            self.login_result = skip_login or None
 
        else:
 
            self.user = http_auth.username
 
            self.password = http_auth.password
 
            self.auth_method = type(http_auth).__name__
 
            self.login_result = True
 
        self.last_login = None
 

	
 
    def login(self, login=None, password=None):
 
        if login is None and password is None:
 
            login = self.user
 
            password = self.password
 
        self.login_result = bool(login and password and not password.startswith('bad'))
 
        self.last_login = (login, password, self.login_result)
 
        return self.login_result
 

	
 
    def get_attachments(self, ticket_id):
 
        try:
 
            return list(self.TICKET_DATA[str(ticket_id)])
 
        except KeyError:
 
            return None
 

	
 
    def get_attachment(self, ticket_id, attachment_id):
 
        try:
 
            att_seq = iter(self.TICKET_DATA[str(ticket_id)])
 
        except KeyError:
 
            return None
 
        att_id = str(attachment_id)
 
        multipart_id = None
 
        for attch in att_seq:
 
            if attch[0] == att_id:
 
                break
 
            elif attch[2].startswith('multipart/'):
 
                multipart_id = attch[0]
 
        else:
 
            return None
 
        tx_id = multipart_id or att_id
 
        if attch[1] == '(Unnamed)':
 
            filename = ''
 
        else:
 
            filename = attch[1]
 
        return {
 
            'id': att_id,
 
            'ContentType': attch[2],
 
            'Filename': filename,
 
            'Transaction': tx_id,
 
        }
 

	
 
    def get_ticket(self, ticket_id):
 
        ticket_id_s = str(ticket_id)
 
        if ticket_id_s not in self.TICKET_DATA:
 
            return None
 
        return {
 
            'id': 'ticket/{}'.format(ticket_id_s),
 
            'numerical_id': ticket_id_s,
 
            'CF.{payment-method}': f'payment method {ticket_id_s}',
 
            'Requestors': [
 
                f'mx{ticket_id_s}@example.org',
 
                'requestor2@example.org',
 
            ],
 
        }
 

	
 
    def get_user(self, user_id):
 
        user_id_s = str(user_id)
 
        match = re.search(r'(\d+)@', user_id_s)
 
        if match is None:
 
            email = f'mx{user_id_s}@example.org'
 
            user_id_num = int(user_id_s)
 
        else:
 
            email = user_id_s
 
            user_id_num = int(match.group(1))
 
        return {
 
            'id': f'user/{user_id_num}',
 
            'EmailAddress': email,
 
            'Name': email,
 
            'RealName': f'Mx. {user_id_num}',
 
        }
0 comments (0 inline, 0 general)