Files @ 5c6066661900
Branch filter:

Location: NPO-Accounting/conservancy_beancount/tests/test_config.py

Brett Smith
books: Start FiscalYear class.
"""Test Config class"""
# Copyright © 2020  Brett Smith
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.

import contextlib
import decimal
import operator
import os
import re

from pathlib import Path

import pytest

from . import testutil

from conservancy_beancount import config as config_mod

RT_AUTH_METHODS = frozenset(['basic', 'gssapi', 'rt'])

RT_ENV_KEYS = (
    'RTSERVER',
    'RTUSER',
    'RTPASSWD',
    'RTAUTH',
)

RT_ENV_CREDS = (
    'https://example.org/envrt',
    'envuser',
    'env  password',
    'gssapi',
)

RT_FILE_CREDS = (
    'https://example.org/filert',
    'fileuser',
    'file  password',
    'basic',
)

RT_GENERIC_CREDS = config_mod.RTCredentials(
    'https://example.org/genericrt',
    'genericuser',
    'generic password',
    None,
)

@pytest.fixture
def rt_environ():
    return dict(zip(RT_ENV_KEYS, RT_ENV_CREDS))

def _update_environ(updates):
    for key, value in updates.items():
        if value is None:
            os.environ.pop(key, None)
        else:
            os.environ[key] = str(value)

@contextlib.contextmanager
def update_environ(**kwargs):
    revert = {key: os.environ.get(key) for key in kwargs}
    _update_environ(kwargs)
    try:
        yield
    finally:
        _update_environ(revert)

@contextlib.contextmanager
def update_umask(mask):
    old_mask = os.umask(mask)
    try:
        yield old_mask
    finally:
        os.umask(old_mask)

def test_repository_from_environment():
    config = config_mod.Config()
    assert config.repository_path() == testutil.test_path('repository')

def test_no_repository():
    with update_environ(CONSERVANCY_REPOSITORY=None):
        config = config_mod.Config()
        assert config.repository_path() is None

def test_no_rt_credentials():
    with update_environ(HOME=testutil.TESTS_DIR):
        config = config_mod.Config()
        rt_credentials = config.rt_credentials()
    assert rt_credentials.server is None
    assert rt_credentials.user is None
    assert rt_credentials.passwd is None
    assert rt_credentials.auth == 'rt'

def test_rt_credentials_from_file():
    config = config_mod.Config()
    rt_credentials = config.rt_credentials()
    assert rt_credentials == RT_FILE_CREDS

def test_rt_credentials_from_environment(rt_environ):
    with update_environ(**rt_environ):
        config = config_mod.Config()
        rt_credentials = config.rt_credentials()
    assert rt_credentials == RT_ENV_CREDS

@pytest.mark.parametrize('index,drop_key', enumerate(RT_ENV_KEYS))
def test_rt_credentials_from_file_and_environment_mixed(rt_environ, index, drop_key):
    del rt_environ[drop_key]
    with update_environ(**rt_environ):
        config = config_mod.Config()
        rt_credentials = config.rt_credentials()
    expected = list(RT_ENV_CREDS)
    expected[index] = RT_FILE_CREDS[index]
    assert rt_credentials == tuple(expected)

def test_rt_credentials_from_all_sources_mixed(tmp_path):
    server = 'https://example.org/mixedrt'
    with (tmp_path / '.rtrc').open('w') as rtrc_file:
        print('user basemix', 'passwd mixed up', file=rtrc_file, sep='\n')
    with update_environ(HOME=tmp_path, RTSERVER=server, RTUSER='mixedup'):
        config = config_mod.Config()
        rt_credentials = config.rt_credentials()
    assert rt_credentials == (server, 'mixedup', 'mixed up', 'rt')

def check_rt_client_url(credentials, client):
    pattern = '^{}/?$'.format(re.escape(credentials[0].rstrip('/') + '/REST/1.0'))
    assert re.match(pattern, client.url)

@pytest.mark.parametrize('authmethod', RT_AUTH_METHODS)
def test_rt_client(authmethod):
    rt_credentials = RT_GENERIC_CREDS._replace(auth=authmethod)
    config = config_mod.Config()
    rt_client = config.rt_client(rt_credentials, testutil.RTClient)
    check_rt_client_url(RT_GENERIC_CREDS, rt_client)
    assert rt_client.auth_method == ('HTTPBasicAuth' if authmethod == 'basic' else 'login')
    assert rt_client.last_login == (
        RT_GENERIC_CREDS.user,
        RT_GENERIC_CREDS.passwd,
        True,
    )

def test_default_rt_client(rt_environ):
    with update_environ(**rt_environ):
        config = config_mod.Config()
        rt_client = config.rt_client(client=testutil.RTClient)
    check_rt_client_url(RT_ENV_CREDS, rt_client)
    assert rt_client.last_login[:-1] == RT_ENV_CREDS[1:3]
    assert rt_client.last_login[-1]

@pytest.mark.parametrize('authmethod', RT_AUTH_METHODS)
def test_rt_client_login_failure(authmethod):
    rt_credentials = RT_GENERIC_CREDS._replace(
        auth=authmethod,
        passwd='bad{}'.format(authmethod),
    )
    config = config_mod.Config()
    assert config.rt_client(rt_credentials, testutil.RTClient) is None

def test_no_rt_client_without_server():
    rt_credentials = RT_GENERIC_CREDS._replace(server=None, auth='rt')
    config = config_mod.Config()
    assert config.rt_client(rt_credentials, testutil.RTClient) is None

def test_rt_wrapper():
    config = config_mod.Config()
    rt = config.rt_wrapper(RT_GENERIC_CREDS._replace(auth='rt'), testutil.RTClient)
    assert rt.exists(1)

def test_rt_wrapper_default_creds():
    config = config_mod.Config()
    rt = config.rt_wrapper(None, testutil.RTClient)
    assert rt.rt.url.startswith(RT_FILE_CREDS[0])

def test_rt_wrapper_default_creds_from_environ(rt_environ):
    with update_environ(**rt_environ):
        config = config_mod.Config()
        rt = config.rt_wrapper(None, testutil.RTClient)
    assert rt.rt.url.startswith(RT_ENV_CREDS[0])

def test_rt_wrapper_no_creds():
    with update_environ(HOME=testutil.TESTS_DIR):
        config = config_mod.Config()
        assert config.rt_wrapper(None, testutil.RTClient) is None

def test_rt_wrapper_bad_creds():
    rt_credentials = RT_GENERIC_CREDS._replace(passwd='badpass', auth='rt')
    config = config_mod.Config()
    assert config.rt_wrapper(rt_credentials, testutil.RTClient) is None

def test_rt_wrapper_caches():
    rt_credentials = RT_GENERIC_CREDS._replace(auth='rt')
    config = config_mod.Config()
    rt1 = config.rt_wrapper(rt_credentials, testutil.RTClient)
    rt2 = config.rt_wrapper(rt_credentials, testutil.RTClient)
    assert rt1 is rt2

def test_rt_wrapper_caches_by_creds():
    config = config_mod.Config()
    rt1 = config.rt_wrapper(RT_GENERIC_CREDS._replace(auth='rt'), testutil.RTClient)
    rt2 = config.rt_wrapper(None, testutil.RTClient)
    assert rt1 is not rt2

def test_rt_wrapper_cache_responds_to_external_credential_changes(rt_environ):
    config = config_mod.Config()
    rt1 = config.rt_wrapper(None, testutil.RTClient)
    with update_environ(**rt_environ):
        rt2 = config.rt_wrapper(None, testutil.RTClient)
    assert rt1 is not rt2

def test_rt_wrapper_has_cache(tmp_path):
    with update_environ(XDG_CACHE_HOME=tmp_path), update_umask(0o002):
        config = config_mod.Config()
        rt = config.rt_wrapper(None, testutil.RTClient)
        rt.exists(1)
    expected = 'conservancy_beancount/{}@*.sqlite3'.format(RT_FILE_CREDS[1])
    actual = None
    for actual in tmp_path.glob(expected):
        assert not actual.stat().st_mode & 0o177
    assert actual is not None, "did not find any generated cache file"

def test_rt_wrapper_without_cache(tmp_path):
    tmp_path.chmod(0)
    with update_environ(XDG_CACHE_HOME=tmp_path):
        config = config_mod.Config()
        rt = config.rt_wrapper(None, testutil.RTClient)
    tmp_path.chmod(0o600)
    assert not any(tmp_path.iterdir())

def test_cache_mkdir(tmp_path):
    expected = tmp_path / 'TESTcache'
    with update_environ(XDG_CACHE_HOME=tmp_path):
        config = config_mod.Config()
        cache_path = config.cache_dir_path(expected.name)
    assert cache_path == tmp_path / 'TESTcache'
    assert cache_path.is_dir()

def test_cache_mkdir_parent(tmp_path):
    xdg_cache_dir = tmp_path / 'xdgcache'
    expected = xdg_cache_dir / 'conservancy_beancount'
    with update_environ(XDG_CACHE_HOME=xdg_cache_dir):
        config = config_mod.Config()
        cache_path = config.cache_dir_path(expected.name)
    assert cache_path == expected
    assert cache_path.is_dir()

def test_cache_mkdir_from_home(tmp_path):
    expected = tmp_path / '.cache' / 'TESTcache'
    with update_environ(HOME=tmp_path, XDG_CACHE_HOME=None):
        config = config_mod.Config()
        cache_path = config.cache_dir_path(expected.name)
    assert cache_path == expected
    assert cache_path.is_dir()

def test_cache_mkdir_exists_ok(tmp_path):
    expected = tmp_path / 'TESTcache'
    expected.mkdir()
    with update_environ(XDG_CACHE_HOME=tmp_path):
        config = config_mod.Config()
        cache_path = config.cache_dir_path(expected.name)
    assert cache_path == expected

def test_cache_path_conflict(tmp_path):
    extant_path = tmp_path / 'TESTcache'
    extant_path.touch()
    with update_environ(XDG_CACHE_HOME=tmp_path):
        config = config_mod.Config()
        cache_path = config.cache_dir_path(extant_path.name)
    assert cache_path is None
    assert extant_path.is_file()

def test_cache_path_parent_conflict(tmp_path):
    (tmp_path / '.cache').touch()
    with update_environ(HOME=tmp_path, XDG_CACHE_HOME=None):
        config = config_mod.Config()
        assert config.cache_dir_path('TESTcache') is None

def test_relative_xdg_cache_home_ignored(tmp_path):
    with update_environ(HOME=tmp_path,
                        XDG_CACHE_HOME='nonexistent/test/cache/directory/tree'):
        config = config_mod.Config()
        cache_dir_path = config.cache_dir_path('TESTcache')
    assert cache_dir_path == tmp_path / '.cache/TESTcache'

def test_payment_threshold():
    threshold = config_mod.Config().payment_threshold()
    assert threshold == 0
    assert isinstance(threshold, (int, decimal.Decimal))

@pytest.mark.parametrize('config_path', [
    None,
    '',
    'nonexistent/relative/path',
])
def test_config_file_path(config_path):
    expected = Path('~/.config/conservancy_beancount/config.ini').expanduser()
    with update_environ(XDG_CONFIG_HOME=config_path):
        config = config_mod.Config()
        assert config.config_file_path() == expected

def test_config_file_path_respects_xdg_config_home():
    with update_environ(XDG_CONFIG_HOME='/etc'):
        config = config_mod.Config()
        assert config.config_file_path() == Path('/etc/conservancy_beancount/config.ini')

def test_config_file_path_with_subdir():
    expected = testutil.test_path('userconfig/conftest/config.ini')
    config = config_mod.Config()
    assert config.config_file_path('conftest') == expected

@pytest.mark.parametrize('path', [
    None,
    testutil.test_path('userconfig/conservancy_beancount/config.ini'),
])
def test_load_file(path):
    config = config_mod.Config()
    config.load_file(path)
    assert config.books_path() == Path('/test/conservancy_beancount')

@pytest.mark.parametrize('path_func', [
    lambda path: None,
    operator.methodcaller('touch', 0o200),
])
def test_load_file_error(tmp_path, path_func):
    config_path = tmp_path / 'nonexistent.ini'
    path_func(config_path)
    config = config_mod.Config()
    with pytest.raises(OSError):
        config.load_file(config_path)

def test_no_books_path():
    config = config_mod.Config()
    assert config.books_path() is None

@pytest.mark.parametrize('value,month,day', [
    ('2', 2, 1),
    ('3 ', 3, 1),
    ('  4', 4, 1),
    (' 5 ', 5, 1),
    ('6 1', 6, 1),
    ('  06  03  ', 6, 3),
    ('6-05', 6, 5),
    ('06 - 10', 6, 10),
    ('6/15', 6, 15),
    ('06  /  20', 6, 20),
    ('10.25', 10, 25),
    (' 10 . 30 ', 10, 30),
])
def test_fiscal_year_begin(value, month, day):
    config = config_mod.Config()
    config.load_string(f'[Beancount]\nfiscal year begin = {value}\n')
    assert config.fiscal_year_begin() == (month, day)

@pytest.mark.parametrize('value', [
    'text',
    '1900',
    '13',
    '010',
    '2 30',
    '4-31',
])
def test_bad_fiscal_year_begin(value):
    config = config_mod.Config()
    config.load_string(f'[Beancount]\nfiscal year begin = {value}\n')
    with pytest.raises(ValueError):
        config.fiscal_year_begin()

def test_default_fiscal_year_begin():
    config = config_mod.Config()
    assert config.fiscal_year_begin() == (3, 1)