diff --git a/conservancy_beancount/config.py b/conservancy_beancount/config.py index 93531d61a263918d1cd0ba5e6ec02a70324bda3b..9509b9faa66693bf16033a034887a10fa4f67ac8 100644 --- a/conservancy_beancount/config.py +++ b/conservancy_beancount/config.py @@ -15,11 +15,16 @@ # along with this program. If not, see . import os +import urllib.parse as urlparse + +import requests.auth +import rt from pathlib import Path from typing import ( NamedTuple, Optional, + Type, ) class RTCredentials(NamedTuple): @@ -73,3 +78,24 @@ class Config: RTCredentials(auth='rt'), ) return RTCredentials._make(v0 or v1 or v2 for v0, v1, v2 in all_creds) + + def rt_client(self, + credentials: RTCredentials=None, + client: Type[rt.Rt]=rt.Rt, + ) -> Optional[rt.Rt]: + if credentials is None: + credentials = self.rt_credentials() + if credentials.server is None: + return None + urlparts = urlparse.urlparse(credentials.server) + rest_path = urlparts.path.rstrip('/') + '/REST/1.0/' + url = urlparse.urlunparse(urlparts._replace(path=rest_path)) + if credentials.auth == 'basic': + auth = requests.auth.HTTPBasicAuth(credentials.user, credentials.passwd) + retval = client(url, http_auth=auth) + else: + retval = client(url, credentials.user, credentials.passwd) + if retval.login(): + return retval + else: + return None diff --git a/setup.py b/setup.py index 2eb80c5325772482ba8a582a8055e8d48aa7c358..c473325b952b3c274357359d02854c4d6cb8ffd2 100755 --- a/setup.py +++ b/setup.py @@ -12,6 +12,7 @@ setup( install_requires=[ 'beancount>=2.2', + 'rt>=2.0', ], setup_requires=[ 'pytest-mypy', diff --git a/tests/test_config.py b/tests/test_config.py index 465468b5403dea9025287f6916fe8c606e242681..4618ddcfd5bad4b3c7f4aa7d20e5d10dc778436e 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -16,6 +16,7 @@ import contextlib import os +import re import pytest @@ -23,6 +24,8 @@ from . import testutil from conservancy_beancount import config as config_mod +RT_AUTH_METHODS = frozenset(['basic', 'gssapi', 'rt']) + RT_ENV_KEYS = ( 'RTSERVER', 'RTUSER', @@ -44,6 +47,13 @@ RT_FILE_CREDS = ( 'basic', ) +RT_GENERIC_CREDS = config_mod.RTCredentials( + 'https://example.org/genericrt', + 'genericuser', + 'generic password', + None, +) + @pytest.fixture def rt_environ(): return dict(zip(RT_ENV_KEYS, RT_ENV_CREDS)) @@ -109,3 +119,42 @@ def test_rt_credentials_from_all_sources_mixed(tmp_path): config = config_mod.Config() rt_credentials = config.rt_credentials() assert rt_credentials == (server, 'mixedup', 'mixed up', 'rt') + +def check_rt_client_url(credentials, client): + pattern = '^{}/?$'.format(re.escape(credentials[0].rstrip('/') + '/REST/1.0')) + assert re.match(pattern, client.url) + +@pytest.mark.parametrize('authmethod', RT_AUTH_METHODS) +def test_rt_client(authmethod): + rt_credentials = RT_GENERIC_CREDS._replace(auth=authmethod) + config = config_mod.Config() + rt_client = config.rt_client(rt_credentials, testutil.RTClient) + check_rt_client_url(RT_GENERIC_CREDS, rt_client) + assert rt_client.auth_method == ('HTTPBasicAuth' if authmethod == 'basic' else 'login') + assert rt_client.last_login == ( + RT_GENERIC_CREDS.user, + RT_GENERIC_CREDS.passwd, + True, + ) + +def test_default_rt_client(rt_environ): + with update_environ(**rt_environ): + config = config_mod.Config() + rt_client = config.rt_client(client=testutil.RTClient) + check_rt_client_url(RT_ENV_CREDS, rt_client) + assert rt_client.last_login[:-1] == RT_ENV_CREDS[1:3] + assert rt_client.last_login[-1] + +@pytest.mark.parametrize('authmethod', RT_AUTH_METHODS) +def test_rt_client_login_failure(authmethod): + rt_credentials = RT_GENERIC_CREDS._replace( + auth=authmethod, + passwd='bad{}'.format(authmethod), + ) + config = config_mod.Config() + assert config.rt_client(rt_credentials, testutil.RTClient) is None + +def test_no_rt_client_without_server(): + rt_credentials = RT_GENERIC_CREDS._replace(server=None, auth='rt') + config = config_mod.Config() + assert config.rt_client(rt_credentials, testutil.RTClient) is None diff --git a/tests/testutil.py b/tests/testutil.py index c6cec5c7eb30e2db244c6f15bcf6df6bb4e5c3b4..bb30ffd69af77c206da44de6c718fbb251eb244f 100644 --- a/tests/testutil.py +++ b/tests/testutil.py @@ -112,3 +112,36 @@ class TestConfig: def repository_path(self): return self.repo_path + + +class RTClient: + def __init__(self, + url, + default_login=None, + default_password=None, + proxy=None, + default_queue='General', + skip_login=False, + verify_cert=True, + http_auth=None, + ): + self.url = url + if http_auth is None: + self.user = default_login + self.password = default_password + self.auth_method = 'login' + self.login_result = skip_login or None + else: + self.user = http_auth.username + self.password = http_auth.password + self.auth_method = type(http_auth).__name__ + self.login_result = True + self.last_login = None + + def login(self, login=None, password=None): + if login is None and password is None: + login = self.user + password = self.password + self.login_result = bool(login and password and not password.startswith('bad')) + self.last_login = (login, password, self.login_result) + return self.login_result