diracx-testing 0.0.1a23__py3-none-any.whl → 0.0.1a25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diracx/testing/__init__.py +42 -746
- diracx/testing/entrypoints.py +67 -0
- diracx/testing/mock_osdb.py +19 -14
- diracx/testing/utils.py +704 -0
- {diracx_testing-0.0.1a23.dist-info → diracx_testing-0.0.1a25.dist-info}/METADATA +2 -2
- diracx_testing-0.0.1a25.dist-info/RECORD +11 -0
- {diracx_testing-0.0.1a23.dist-info → diracx_testing-0.0.1a25.dist-info}/WHEEL +1 -1
- diracx_testing-0.0.1a23.dist-info/RECORD +0 -9
- {diracx_testing-0.0.1a23.dist-info → diracx_testing-0.0.1a25.dist-info}/top_level.txt +0 -0
diracx/testing/utils.py
ADDED
@@ -0,0 +1,704 @@
|
|
1
|
+
"""Utilities for testing DiracX."""
|
2
|
+
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
# TODO: this needs a lot of documentation, in particular what will matter for users
|
6
|
+
# are the enabled_dependencies markers
|
7
|
+
import asyncio
|
8
|
+
import contextlib
|
9
|
+
import os
|
10
|
+
import re
|
11
|
+
import ssl
|
12
|
+
import subprocess
|
13
|
+
from datetime import datetime, timedelta, timezone
|
14
|
+
from functools import partial
|
15
|
+
from html.parser import HTMLParser
|
16
|
+
from pathlib import Path
|
17
|
+
from typing import TYPE_CHECKING, Generator
|
18
|
+
from urllib.parse import parse_qs, urljoin, urlparse
|
19
|
+
from uuid import uuid4
|
20
|
+
|
21
|
+
import pytest
|
22
|
+
import requests
|
23
|
+
|
24
|
+
if TYPE_CHECKING:
|
25
|
+
from diracx.core.settings import DevelopmentSettings
|
26
|
+
from diracx.routers.jobs.sandboxes import SandboxStoreSettings
|
27
|
+
from diracx.routers.utils.users import AuthorizedUserInfo, AuthSettings
|
28
|
+
|
29
|
+
|
30
|
+
# to get a string like this run:
|
31
|
+
# openssl rand -hex 32
|
32
|
+
ALGORITHM = "HS256"
|
33
|
+
ISSUER = "http://lhcbdirac.cern.ch/"
|
34
|
+
AUDIENCE = "dirac"
|
35
|
+
ACCESS_TOKEN_EXPIRE_MINUTES = 30
|
36
|
+
|
37
|
+
|
38
|
+
def pytest_addoption(parser):
|
39
|
+
parser.addoption(
|
40
|
+
"--regenerate-client",
|
41
|
+
action="store_true",
|
42
|
+
default=False,
|
43
|
+
help="Regenerate the AutoREST client",
|
44
|
+
)
|
45
|
+
parser.addoption(
|
46
|
+
"--demo-dir",
|
47
|
+
type=Path,
|
48
|
+
default=None,
|
49
|
+
help="Path to a diracx-charts directory with the demo running",
|
50
|
+
)
|
51
|
+
|
52
|
+
|
53
|
+
def pytest_collection_modifyitems(config, items):
|
54
|
+
"""Disable the test_regenerate_client if not explicitly asked for."""
|
55
|
+
if config.getoption("--regenerate-client"):
|
56
|
+
# --regenerate-client given in cli: allow client re-generation
|
57
|
+
return
|
58
|
+
skip_regen = pytest.mark.skip(reason="need --regenerate-client option to run")
|
59
|
+
for item in items:
|
60
|
+
if item.name == "test_regenerate_client":
|
61
|
+
item.add_marker(skip_regen)
|
62
|
+
|
63
|
+
|
64
|
+
@pytest.fixture(scope="session")
|
65
|
+
def private_key_pem() -> str:
|
66
|
+
from cryptography.hazmat.primitives import serialization
|
67
|
+
from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey
|
68
|
+
|
69
|
+
private_key = Ed25519PrivateKey.generate()
|
70
|
+
return private_key.private_bytes(
|
71
|
+
encoding=serialization.Encoding.PEM,
|
72
|
+
format=serialization.PrivateFormat.PKCS8,
|
73
|
+
encryption_algorithm=serialization.NoEncryption(),
|
74
|
+
).decode()
|
75
|
+
|
76
|
+
|
77
|
+
@pytest.fixture(scope="session")
|
78
|
+
def fernet_key() -> str:
|
79
|
+
from cryptography.fernet import Fernet
|
80
|
+
|
81
|
+
return Fernet.generate_key().decode()
|
82
|
+
|
83
|
+
|
84
|
+
@pytest.fixture(scope="session")
|
85
|
+
def test_dev_settings() -> Generator[DevelopmentSettings, None, None]:
|
86
|
+
from diracx.core.settings import DevelopmentSettings
|
87
|
+
|
88
|
+
yield DevelopmentSettings()
|
89
|
+
|
90
|
+
|
91
|
+
@pytest.fixture(scope="session")
|
92
|
+
def test_auth_settings(
|
93
|
+
private_key_pem, fernet_key
|
94
|
+
) -> Generator[AuthSettings, None, None]:
|
95
|
+
from diracx.routers.utils.users import AuthSettings
|
96
|
+
|
97
|
+
yield AuthSettings(
|
98
|
+
token_algorithm="EdDSA",
|
99
|
+
token_key=private_key_pem,
|
100
|
+
state_key=fernet_key,
|
101
|
+
allowed_redirects=[
|
102
|
+
"http://diracx.test.invalid:8000/api/docs/oauth2-redirect",
|
103
|
+
],
|
104
|
+
)
|
105
|
+
|
106
|
+
|
107
|
+
@pytest.fixture(scope="session")
|
108
|
+
def aio_moto(worker_id):
|
109
|
+
"""Start the moto server in a separate thread and return the base URL.
|
110
|
+
|
111
|
+
The mocking provided by moto doesn't play nicely with aiobotocore so we use
|
112
|
+
the server directly. See https://github.com/aio-libs/aiobotocore/issues/755
|
113
|
+
"""
|
114
|
+
from moto.server import ThreadedMotoServer
|
115
|
+
|
116
|
+
port = 27132
|
117
|
+
if worker_id != "master":
|
118
|
+
port += int(worker_id.replace("gw", "")) + 1
|
119
|
+
server = ThreadedMotoServer(port=port)
|
120
|
+
server.start()
|
121
|
+
yield {
|
122
|
+
"endpoint_url": f"http://localhost:{port}",
|
123
|
+
"aws_access_key_id": "testing",
|
124
|
+
"aws_secret_access_key": "testing",
|
125
|
+
}
|
126
|
+
server.stop()
|
127
|
+
|
128
|
+
|
129
|
+
@pytest.fixture(scope="session")
|
130
|
+
def test_sandbox_settings(aio_moto) -> SandboxStoreSettings:
|
131
|
+
from diracx.routers.jobs.sandboxes import SandboxStoreSettings
|
132
|
+
|
133
|
+
yield SandboxStoreSettings(
|
134
|
+
bucket_name="sandboxes",
|
135
|
+
s3_client_kwargs=aio_moto,
|
136
|
+
auto_create_bucket=True,
|
137
|
+
)
|
138
|
+
|
139
|
+
|
140
|
+
class UnavailableDependency:
|
141
|
+
def __init__(self, key):
|
142
|
+
self.key = key
|
143
|
+
|
144
|
+
def __call__(self):
|
145
|
+
raise NotImplementedError(
|
146
|
+
f"{self.key} has not been made available to this test!"
|
147
|
+
)
|
148
|
+
|
149
|
+
|
150
|
+
class ClientFactory:
|
151
|
+
|
152
|
+
def __init__(
|
153
|
+
self,
|
154
|
+
tmp_path_factory,
|
155
|
+
with_config_repo,
|
156
|
+
test_auth_settings,
|
157
|
+
test_sandbox_settings,
|
158
|
+
test_dev_settings,
|
159
|
+
):
|
160
|
+
from diracx.core.config import ConfigSource
|
161
|
+
from diracx.core.extensions import select_from_extension
|
162
|
+
from diracx.core.settings import ServiceSettingsBase
|
163
|
+
from diracx.db.os.utils import BaseOSDB
|
164
|
+
from diracx.db.sql.utils import BaseSQLDB
|
165
|
+
from diracx.routers import create_app_inner
|
166
|
+
from diracx.routers.access_policies import BaseAccessPolicy
|
167
|
+
|
168
|
+
from .mock_osdb import fake_available_osdb_implementations
|
169
|
+
|
170
|
+
class AlwaysAllowAccessPolicy(BaseAccessPolicy):
|
171
|
+
"""Dummy access policy."""
|
172
|
+
|
173
|
+
@staticmethod
|
174
|
+
async def policy(
|
175
|
+
policy_name: str, user_info: AuthorizedUserInfo, /, **kwargs
|
176
|
+
):
|
177
|
+
pass
|
178
|
+
|
179
|
+
@staticmethod
|
180
|
+
def enrich_tokens(access_payload: dict, refresh_payload: dict):
|
181
|
+
|
182
|
+
return {"PolicySpecific": "OpenAccessForTest"}, {}
|
183
|
+
|
184
|
+
enabled_systems = {
|
185
|
+
e.name for e in select_from_extension(group="diracx.services")
|
186
|
+
}
|
187
|
+
database_urls = {
|
188
|
+
e.name: "sqlite+aiosqlite:///:memory:"
|
189
|
+
for e in select_from_extension(group="diracx.db.sql")
|
190
|
+
}
|
191
|
+
# TODO: Monkeypatch this in a less stupid way
|
192
|
+
# TODO: Only use this if opensearch isn't available
|
193
|
+
os_database_conn_kwargs = {
|
194
|
+
e.name: {"sqlalchemy_dsn": "sqlite+aiosqlite:///:memory:"}
|
195
|
+
for e in select_from_extension(group="diracx.db.os")
|
196
|
+
}
|
197
|
+
BaseOSDB.available_implementations = partial(
|
198
|
+
fake_available_osdb_implementations,
|
199
|
+
real_available_implementations=BaseOSDB.available_implementations,
|
200
|
+
)
|
201
|
+
|
202
|
+
self._cache_dir = tmp_path_factory.mktemp("empty-dbs")
|
203
|
+
|
204
|
+
self.test_auth_settings = test_auth_settings
|
205
|
+
self.test_dev_settings = test_dev_settings
|
206
|
+
|
207
|
+
all_access_policies = {
|
208
|
+
e.name: [AlwaysAllowAccessPolicy]
|
209
|
+
+ BaseAccessPolicy.available_implementations(e.name)
|
210
|
+
for e in select_from_extension(group="diracx.access_policies")
|
211
|
+
}
|
212
|
+
|
213
|
+
self.app = create_app_inner(
|
214
|
+
enabled_systems=enabled_systems,
|
215
|
+
all_service_settings=[
|
216
|
+
test_auth_settings,
|
217
|
+
test_sandbox_settings,
|
218
|
+
test_dev_settings,
|
219
|
+
],
|
220
|
+
database_urls=database_urls,
|
221
|
+
os_database_conn_kwargs=os_database_conn_kwargs,
|
222
|
+
config_source=ConfigSource.create_from_url(
|
223
|
+
backend_url=f"git+file://{with_config_repo}"
|
224
|
+
),
|
225
|
+
all_access_policies=all_access_policies,
|
226
|
+
)
|
227
|
+
|
228
|
+
self.all_dependency_overrides = self.app.dependency_overrides.copy()
|
229
|
+
self.app.dependency_overrides = {}
|
230
|
+
for obj in self.all_dependency_overrides:
|
231
|
+
assert issubclass(
|
232
|
+
obj.__self__,
|
233
|
+
(
|
234
|
+
ServiceSettingsBase,
|
235
|
+
BaseSQLDB,
|
236
|
+
BaseOSDB,
|
237
|
+
ConfigSource,
|
238
|
+
BaseAccessPolicy,
|
239
|
+
),
|
240
|
+
), obj
|
241
|
+
|
242
|
+
self.all_lifetime_functions = self.app.lifetime_functions[:]
|
243
|
+
self.app.lifetime_functions = []
|
244
|
+
for obj in self.all_lifetime_functions:
|
245
|
+
assert isinstance(
|
246
|
+
obj.__self__, (ServiceSettingsBase, BaseSQLDB, BaseOSDB, ConfigSource)
|
247
|
+
), obj
|
248
|
+
|
249
|
+
@contextlib.contextmanager
|
250
|
+
def configure(self, enabled_dependencies):
|
251
|
+
|
252
|
+
assert (
|
253
|
+
self.app.dependency_overrides == {} and self.app.lifetime_functions == []
|
254
|
+
), "configure cannot be nested"
|
255
|
+
|
256
|
+
for k, v in self.all_dependency_overrides.items():
|
257
|
+
|
258
|
+
class_name = k.__self__.__name__
|
259
|
+
|
260
|
+
if class_name in enabled_dependencies:
|
261
|
+
self.app.dependency_overrides[k] = v
|
262
|
+
else:
|
263
|
+
self.app.dependency_overrides[k] = UnavailableDependency(class_name)
|
264
|
+
|
265
|
+
for obj in self.all_lifetime_functions:
|
266
|
+
# TODO: We should use the name of the entry point instead of the class name
|
267
|
+
if obj.__self__.__class__.__name__ in enabled_dependencies:
|
268
|
+
self.app.lifetime_functions.append(obj)
|
269
|
+
|
270
|
+
# Add create_db_schemas to the end of the lifetime_functions so that the
|
271
|
+
# other lifetime_functions (i.e. those which run db.engine_context) have
|
272
|
+
# already been ran
|
273
|
+
self.app.lifetime_functions.append(self.create_db_schemas)
|
274
|
+
|
275
|
+
try:
|
276
|
+
yield
|
277
|
+
finally:
|
278
|
+
self.app.dependency_overrides = {}
|
279
|
+
self.app.lifetime_functions = []
|
280
|
+
|
281
|
+
@contextlib.asynccontextmanager
|
282
|
+
async def create_db_schemas(self):
|
283
|
+
"""Create DB schema's based on the DBs available in app.dependency_overrides."""
|
284
|
+
import aiosqlite
|
285
|
+
import sqlalchemy
|
286
|
+
from sqlalchemy.util.concurrency import greenlet_spawn
|
287
|
+
|
288
|
+
from diracx.db.os.utils import BaseOSDB
|
289
|
+
from diracx.db.sql.utils import BaseSQLDB
|
290
|
+
from diracx.testing.mock_osdb import MockOSDBMixin
|
291
|
+
|
292
|
+
for k, v in self.app.dependency_overrides.items():
|
293
|
+
# Ignore dependency overrides which aren't BaseSQLDB.transaction or BaseOSDB.session
|
294
|
+
if isinstance(v, UnavailableDependency) or k.__func__ not in (
|
295
|
+
BaseSQLDB.transaction.__func__,
|
296
|
+
BaseOSDB.session.__func__,
|
297
|
+
):
|
298
|
+
|
299
|
+
continue
|
300
|
+
|
301
|
+
# The first argument of the overridden BaseSQLDB.transaction is the DB object
|
302
|
+
db = v.args[0]
|
303
|
+
# We expect the OS DB to be mocked with sqlite, so use the
|
304
|
+
# internal DB
|
305
|
+
if isinstance(db, MockOSDBMixin):
|
306
|
+
db = db._sql_db
|
307
|
+
|
308
|
+
assert isinstance(db, BaseSQLDB), (k, db)
|
309
|
+
|
310
|
+
# set PRAGMA foreign_keys=ON if sqlite
|
311
|
+
if db.engine.url.drivername.startswith("sqlite"):
|
312
|
+
|
313
|
+
def set_sqlite_pragma(dbapi_connection, connection_record):
|
314
|
+
cursor = dbapi_connection.cursor()
|
315
|
+
cursor.execute("PRAGMA foreign_keys=ON")
|
316
|
+
cursor.close()
|
317
|
+
|
318
|
+
sqlalchemy.event.listen(
|
319
|
+
db.engine.sync_engine, "connect", set_sqlite_pragma
|
320
|
+
)
|
321
|
+
|
322
|
+
# We maintain a cache of the populated DBs in empty_db_dir so that
|
323
|
+
# we don't have to recreate them for every test. This speeds up the
|
324
|
+
# tests by a considerable amount.
|
325
|
+
ref_db = self._cache_dir / f"{k.__self__.__name__}.db"
|
326
|
+
if ref_db.exists():
|
327
|
+
async with aiosqlite.connect(ref_db) as ref_conn:
|
328
|
+
conn = await db.engine.raw_connection()
|
329
|
+
await ref_conn.backup(conn.driver_connection)
|
330
|
+
await greenlet_spawn(conn.close)
|
331
|
+
else:
|
332
|
+
async with db.engine.begin() as conn:
|
333
|
+
await conn.run_sync(db.metadata.create_all)
|
334
|
+
|
335
|
+
async with aiosqlite.connect(ref_db) as ref_conn:
|
336
|
+
conn = await db.engine.raw_connection()
|
337
|
+
await conn.driver_connection.backup(ref_conn)
|
338
|
+
await greenlet_spawn(conn.close)
|
339
|
+
|
340
|
+
yield
|
341
|
+
|
342
|
+
@contextlib.contextmanager
|
343
|
+
def unauthenticated(self):
|
344
|
+
from fastapi.testclient import TestClient
|
345
|
+
|
346
|
+
with TestClient(self.app) as client:
|
347
|
+
yield client
|
348
|
+
|
349
|
+
@contextlib.contextmanager
|
350
|
+
def normal_user(self):
|
351
|
+
from diracx.core.properties import NORMAL_USER
|
352
|
+
from diracx.routers.auth.token import create_token
|
353
|
+
|
354
|
+
with self.unauthenticated() as client:
|
355
|
+
payload = {
|
356
|
+
"sub": "testingVO:yellow-sub",
|
357
|
+
"exp": datetime.now(tz=timezone.utc)
|
358
|
+
+ timedelta(self.test_auth_settings.access_token_expire_minutes),
|
359
|
+
"iss": ISSUER,
|
360
|
+
"dirac_properties": [NORMAL_USER],
|
361
|
+
"jti": str(uuid4()),
|
362
|
+
"preferred_username": "preferred_username",
|
363
|
+
"dirac_group": "test_group",
|
364
|
+
"vo": "lhcb",
|
365
|
+
}
|
366
|
+
token = create_token(payload, self.test_auth_settings)
|
367
|
+
|
368
|
+
client.headers["Authorization"] = f"Bearer {token}"
|
369
|
+
client.dirac_token_payload = payload
|
370
|
+
yield client
|
371
|
+
|
372
|
+
@contextlib.contextmanager
|
373
|
+
def admin_user(self):
|
374
|
+
from diracx.core.properties import JOB_ADMINISTRATOR
|
375
|
+
from diracx.routers.auth.token import create_token
|
376
|
+
|
377
|
+
with self.unauthenticated() as client:
|
378
|
+
payload = {
|
379
|
+
"sub": "testingVO:yellow-sub",
|
380
|
+
"iss": ISSUER,
|
381
|
+
"dirac_properties": [JOB_ADMINISTRATOR],
|
382
|
+
"jti": str(uuid4()),
|
383
|
+
"preferred_username": "preferred_username",
|
384
|
+
"dirac_group": "test_group",
|
385
|
+
"vo": "lhcb",
|
386
|
+
}
|
387
|
+
token = create_token(payload, self.test_auth_settings)
|
388
|
+
client.headers["Authorization"] = f"Bearer {token}"
|
389
|
+
client.dirac_token_payload = payload
|
390
|
+
yield client
|
391
|
+
|
392
|
+
|
393
|
+
@pytest.fixture(scope="session")
|
394
|
+
def session_client_factory(
|
395
|
+
test_auth_settings,
|
396
|
+
test_sandbox_settings,
|
397
|
+
with_config_repo,
|
398
|
+
tmp_path_factory,
|
399
|
+
test_dev_settings,
|
400
|
+
):
|
401
|
+
"""TODO.
|
402
|
+
----
|
403
|
+
|
404
|
+
"""
|
405
|
+
yield ClientFactory(
|
406
|
+
tmp_path_factory,
|
407
|
+
with_config_repo,
|
408
|
+
test_auth_settings,
|
409
|
+
test_sandbox_settings,
|
410
|
+
test_dev_settings,
|
411
|
+
)
|
412
|
+
|
413
|
+
|
414
|
+
@pytest.fixture
|
415
|
+
def client_factory(session_client_factory, request):
|
416
|
+
marker = request.node.get_closest_marker("enabled_dependencies")
|
417
|
+
if marker is None:
|
418
|
+
raise RuntimeError("This test requires the enabled_dependencies marker")
|
419
|
+
(enabled_dependencies,) = marker.args
|
420
|
+
with session_client_factory.configure(enabled_dependencies=enabled_dependencies):
|
421
|
+
yield session_client_factory
|
422
|
+
|
423
|
+
|
424
|
+
@pytest.fixture(scope="session")
|
425
|
+
def with_config_repo(tmp_path_factory):
|
426
|
+
from git import Repo
|
427
|
+
|
428
|
+
from diracx.core.config import Config
|
429
|
+
|
430
|
+
tmp_path = tmp_path_factory.mktemp("cs-repo")
|
431
|
+
|
432
|
+
repo = Repo.init(tmp_path, initial_branch="master")
|
433
|
+
cs_file = tmp_path / "default.yml"
|
434
|
+
example_cs = Config.model_validate(
|
435
|
+
{
|
436
|
+
"DIRAC": {},
|
437
|
+
"Registry": {
|
438
|
+
"lhcb": {
|
439
|
+
"DefaultGroup": "lhcb_user",
|
440
|
+
"DefaultProxyLifeTime": 432000,
|
441
|
+
"DefaultStorageQuota": 2000,
|
442
|
+
"IdP": {
|
443
|
+
"URL": "https://idp-server.invalid",
|
444
|
+
"ClientID": "test-idp",
|
445
|
+
},
|
446
|
+
"Users": {
|
447
|
+
"b824d4dc-1f9d-4ee8-8df5-c0ae55d46041": {
|
448
|
+
"PreferedUsername": "chaen",
|
449
|
+
"Email": None,
|
450
|
+
},
|
451
|
+
"c935e5ed-2g0e-5ff9-9eg6-d1bf66e57152": {
|
452
|
+
"PreferedUsername": "albdr",
|
453
|
+
"Email": None,
|
454
|
+
},
|
455
|
+
},
|
456
|
+
"Groups": {
|
457
|
+
"lhcb_user": {
|
458
|
+
"Properties": ["NormalUser", "PrivateLimitedDelegation"],
|
459
|
+
"Users": [
|
460
|
+
"b824d4dc-1f9d-4ee8-8df5-c0ae55d46041",
|
461
|
+
"c935e5ed-2g0e-5ff9-9eg6-d1bf66e57152",
|
462
|
+
],
|
463
|
+
},
|
464
|
+
"lhcb_prmgr": {
|
465
|
+
"Properties": ["NormalUser", "ProductionManagement"],
|
466
|
+
"Users": ["b824d4dc-1f9d-4ee8-8df5-c0ae55d46041"],
|
467
|
+
},
|
468
|
+
"lhcb_tokenmgr": {
|
469
|
+
"Properties": ["NormalUser", "ProxyManagement"],
|
470
|
+
"Users": ["c935e5ed-2g0e-5ff9-9eg6-d1bf66e57152"],
|
471
|
+
},
|
472
|
+
},
|
473
|
+
}
|
474
|
+
},
|
475
|
+
"Operations": {"Defaults": {}},
|
476
|
+
"Systems": {
|
477
|
+
"WorkloadManagement": {
|
478
|
+
"Production": {
|
479
|
+
"Databases": {
|
480
|
+
"JobDB": {
|
481
|
+
"DBName": "xyz",
|
482
|
+
"Host": "xyz",
|
483
|
+
"Port": 9999,
|
484
|
+
"MaxRescheduling": 3,
|
485
|
+
},
|
486
|
+
"JobLoggingDB": {
|
487
|
+
"DBName": "xyz",
|
488
|
+
"Host": "xyz",
|
489
|
+
"Port": 9999,
|
490
|
+
},
|
491
|
+
"PilotAgentsDB": {
|
492
|
+
"DBName": "xyz",
|
493
|
+
"Host": "xyz",
|
494
|
+
"Port": 9999,
|
495
|
+
},
|
496
|
+
"SandboxMetadataDB": {
|
497
|
+
"DBName": "xyz",
|
498
|
+
"Host": "xyz",
|
499
|
+
"Port": 9999,
|
500
|
+
},
|
501
|
+
"TaskQueueDB": {
|
502
|
+
"DBName": "xyz",
|
503
|
+
"Host": "xyz",
|
504
|
+
"Port": 9999,
|
505
|
+
},
|
506
|
+
"ElasticJobParametersDB": {
|
507
|
+
"DBName": "xyz",
|
508
|
+
"Host": "xyz",
|
509
|
+
"Port": 9999,
|
510
|
+
},
|
511
|
+
"VirtualMachineDB": {
|
512
|
+
"DBName": "xyz",
|
513
|
+
"Host": "xyz",
|
514
|
+
"Port": 9999,
|
515
|
+
},
|
516
|
+
},
|
517
|
+
},
|
518
|
+
},
|
519
|
+
},
|
520
|
+
}
|
521
|
+
)
|
522
|
+
cs_file.write_text(example_cs.model_dump_json())
|
523
|
+
repo.index.add([cs_file]) # add it to the index
|
524
|
+
repo.index.commit("Added a new file")
|
525
|
+
yield tmp_path
|
526
|
+
|
527
|
+
|
528
|
+
@pytest.fixture(scope="session")
|
529
|
+
def demo_dir(request) -> Path:
|
530
|
+
demo_dir = request.config.getoption("--demo-dir")
|
531
|
+
if demo_dir is None:
|
532
|
+
pytest.skip("Requires a running instance of the DiracX demo")
|
533
|
+
demo_dir = (demo_dir / ".demo").resolve()
|
534
|
+
yield demo_dir
|
535
|
+
|
536
|
+
|
537
|
+
@pytest.fixture(scope="session")
|
538
|
+
def demo_urls(demo_dir):
|
539
|
+
import yaml
|
540
|
+
|
541
|
+
helm_values = yaml.safe_load((demo_dir / "values.yaml").read_text())
|
542
|
+
yield helm_values["developer"]["urls"]
|
543
|
+
|
544
|
+
|
545
|
+
@pytest.fixture(scope="session")
|
546
|
+
def demo_kubectl_env(demo_dir):
|
547
|
+
"""Get the dictionary of environment variables for kubectl to control the demo."""
|
548
|
+
kube_conf = demo_dir / "kube.conf"
|
549
|
+
if not kube_conf.exists():
|
550
|
+
raise RuntimeError(f"Could not find {kube_conf}, is the demo running?")
|
551
|
+
|
552
|
+
env = {
|
553
|
+
**os.environ,
|
554
|
+
"KUBECONFIG": str(kube_conf),
|
555
|
+
"PATH": f"{demo_dir}:{os.environ['PATH']}",
|
556
|
+
}
|
557
|
+
|
558
|
+
# Check that we can run kubectl
|
559
|
+
pods_result = subprocess.check_output(
|
560
|
+
["kubectl", "get", "pods"], env=env, text=True
|
561
|
+
)
|
562
|
+
assert "diracx" in pods_result
|
563
|
+
|
564
|
+
yield env
|
565
|
+
|
566
|
+
|
567
|
+
@pytest.fixture
|
568
|
+
def cli_env(monkeypatch, tmp_path, demo_urls, demo_dir):
|
569
|
+
"""Set up the environment for the CLI."""
|
570
|
+
import httpx
|
571
|
+
|
572
|
+
from diracx.core.preferences import get_diracx_preferences
|
573
|
+
|
574
|
+
diracx_url = demo_urls["diracx"]
|
575
|
+
ca_path = demo_dir / "demo-ca.pem"
|
576
|
+
if not ca_path.exists():
|
577
|
+
raise RuntimeError(f"Could not find {ca_path}, is the demo running?")
|
578
|
+
|
579
|
+
# Ensure the demo is working
|
580
|
+
|
581
|
+
r = httpx.get(
|
582
|
+
f"{diracx_url}/api/openapi.json",
|
583
|
+
verify=ssl.create_default_context(cafile=ca_path),
|
584
|
+
)
|
585
|
+
r.raise_for_status()
|
586
|
+
assert r.json()["info"]["title"] == "Dirac"
|
587
|
+
|
588
|
+
env = {
|
589
|
+
"DIRACX_URL": diracx_url,
|
590
|
+
"DIRACX_CA_PATH": str(ca_path),
|
591
|
+
"HOME": str(tmp_path),
|
592
|
+
}
|
593
|
+
for key, value in env.items():
|
594
|
+
monkeypatch.setenv(key, value)
|
595
|
+
yield env
|
596
|
+
|
597
|
+
# The DiracX preferences are cached however when testing this cache is invalid
|
598
|
+
get_diracx_preferences.cache_clear()
|
599
|
+
|
600
|
+
|
601
|
+
@pytest.fixture
|
602
|
+
async def with_cli_login(monkeypatch, capfd, cli_env, tmp_path):
|
603
|
+
try:
|
604
|
+
credentials = await test_login(monkeypatch, capfd, cli_env)
|
605
|
+
except Exception as e:
|
606
|
+
pytest.skip(f"Login failed, fix test_login to re-enable this test: {e!r}")
|
607
|
+
|
608
|
+
credentials_path = tmp_path / "credentials.json"
|
609
|
+
credentials_path.write_text(credentials)
|
610
|
+
monkeypatch.setenv("DIRACX_CREDENTIALS_PATH", str(credentials_path))
|
611
|
+
yield
|
612
|
+
|
613
|
+
|
614
|
+
async def test_login(monkeypatch, capfd, cli_env):
|
615
|
+
from diracx import cli
|
616
|
+
|
617
|
+
poll_attempts = 0
|
618
|
+
|
619
|
+
def fake_sleep(*args, **kwargs):
|
620
|
+
nonlocal poll_attempts
|
621
|
+
|
622
|
+
# Keep track of the number of times this is called
|
623
|
+
poll_attempts += 1
|
624
|
+
|
625
|
+
# After polling 5 times, do the actual login
|
626
|
+
if poll_attempts == 5:
|
627
|
+
# The login URL should have been printed to stdout
|
628
|
+
captured = capfd.readouterr()
|
629
|
+
match = re.search(rf"{cli_env['DIRACX_URL']}[^\n]+", captured.out)
|
630
|
+
assert match, captured
|
631
|
+
|
632
|
+
do_device_flow_with_dex(match.group(), cli_env["DIRACX_CA_PATH"])
|
633
|
+
|
634
|
+
# Ensure we don't poll forever
|
635
|
+
assert poll_attempts <= 100
|
636
|
+
|
637
|
+
# Reduce the sleep duration to zero to speed up the test
|
638
|
+
return unpatched_sleep(0)
|
639
|
+
|
640
|
+
# We monkeypatch asyncio.sleep to provide a hook to run the actions that
|
641
|
+
# would normally be done by a user. This includes capturing the login URL
|
642
|
+
# and doing the actual device flow with dex.
|
643
|
+
unpatched_sleep = asyncio.sleep
|
644
|
+
|
645
|
+
expected_credentials_path = Path(
|
646
|
+
cli_env["HOME"], ".cache", "diracx", "credentials.json"
|
647
|
+
)
|
648
|
+
# Ensure the credentials file does not exist before logging in
|
649
|
+
assert not expected_credentials_path.exists()
|
650
|
+
|
651
|
+
# Run the login command
|
652
|
+
with monkeypatch.context() as m:
|
653
|
+
m.setattr("asyncio.sleep", fake_sleep)
|
654
|
+
await cli.auth.login(vo="diracAdmin", group=None, property=None)
|
655
|
+
captured = capfd.readouterr()
|
656
|
+
assert "Login successful!" in captured.out
|
657
|
+
assert captured.err == ""
|
658
|
+
|
659
|
+
# Ensure the credentials file exists after logging in
|
660
|
+
assert expected_credentials_path.exists()
|
661
|
+
|
662
|
+
# Return the credentials so this test can also be used by the
|
663
|
+
# "with_cli_login" fixture
|
664
|
+
return expected_credentials_path.read_text()
|
665
|
+
|
666
|
+
|
667
|
+
def do_device_flow_with_dex(url: str, ca_path: str) -> None:
|
668
|
+
"""Do the device flow with dex."""
|
669
|
+
|
670
|
+
class DexLoginFormParser(HTMLParser):
|
671
|
+
def handle_starttag(self, tag, attrs):
|
672
|
+
nonlocal action_url
|
673
|
+
if "form" in str(tag):
|
674
|
+
assert action_url is None
|
675
|
+
action_url = urljoin(login_page_url, dict(attrs)["action"])
|
676
|
+
|
677
|
+
# Get the login page
|
678
|
+
r = requests.get(url, verify=ca_path)
|
679
|
+
r.raise_for_status()
|
680
|
+
login_page_url = r.url # This is not the same as URL as we redirect to dex
|
681
|
+
login_page_body = r.text
|
682
|
+
|
683
|
+
# Search the page for the login form so we know where to post the credentials
|
684
|
+
action_url = None
|
685
|
+
DexLoginFormParser().feed(login_page_body)
|
686
|
+
assert action_url is not None, login_page_body
|
687
|
+
|
688
|
+
# Do the actual login
|
689
|
+
r = requests.post(
|
690
|
+
action_url,
|
691
|
+
data={"login": "admin@example.com", "password": "password"},
|
692
|
+
verify=ca_path,
|
693
|
+
)
|
694
|
+
r.raise_for_status()
|
695
|
+
approval_url = r.url # This is not the same as URL as we redirect to dex
|
696
|
+
# Do the actual approval
|
697
|
+
r = requests.post(
|
698
|
+
approval_url,
|
699
|
+
{"approval": "approve", "req": parse_qs(urlparse(r.url).query)["req"][0]},
|
700
|
+
verify=ca_path,
|
701
|
+
)
|
702
|
+
|
703
|
+
# This should have redirected to the DiracX page that shows the login is complete
|
704
|
+
assert "Please close the window" in r.text
|