diracx-testing 0.0.1a23__py3-none-any.whl → 0.0.1a24__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- diracx/testing/__init__.py +42 -746
- diracx/testing/entrypoints.py +67 -0
- diracx/testing/mock_osdb.py +9 -9
- diracx/testing/utils.py +694 -0
- {diracx_testing-0.0.1a23.dist-info → diracx_testing-0.0.1a24.dist-info}/METADATA +2 -2
- diracx_testing-0.0.1a24.dist-info/RECORD +11 -0
- {diracx_testing-0.0.1a23.dist-info → diracx_testing-0.0.1a24.dist-info}/WHEEL +1 -1
- diracx_testing-0.0.1a23.dist-info/RECORD +0 -9
- {diracx_testing-0.0.1a23.dist-info → diracx_testing-0.0.1a24.dist-info}/top_level.txt +0 -0
diracx/testing/utils.py
ADDED
@@ -0,0 +1,694 @@
|
|
1
|
+
"""Utilities for testing DiracX."""
|
2
|
+
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
# TODO: this needs a lot of documentation, in particular what will matter for users
|
6
|
+
# are the enabled_dependencies markers
|
7
|
+
import asyncio
|
8
|
+
import contextlib
|
9
|
+
import os
|
10
|
+
import re
|
11
|
+
import ssl
|
12
|
+
import subprocess
|
13
|
+
from datetime import datetime, timedelta, timezone
|
14
|
+
from functools import partial
|
15
|
+
from html.parser import HTMLParser
|
16
|
+
from pathlib import Path
|
17
|
+
from typing import TYPE_CHECKING, Generator
|
18
|
+
from urllib.parse import parse_qs, urljoin, urlparse
|
19
|
+
from uuid import uuid4
|
20
|
+
|
21
|
+
import pytest
|
22
|
+
import requests
|
23
|
+
|
24
|
+
if TYPE_CHECKING:
|
25
|
+
from diracx.core.settings import DevelopmentSettings
|
26
|
+
from diracx.routers.jobs.sandboxes import SandboxStoreSettings
|
27
|
+
from diracx.routers.utils.users import AuthorizedUserInfo, AuthSettings
|
28
|
+
|
29
|
+
|
30
|
+
# to get a string like this run:
|
31
|
+
# openssl rand -hex 32
|
32
|
+
ALGORITHM = "HS256"
|
33
|
+
ISSUER = "http://lhcbdirac.cern.ch/"
|
34
|
+
AUDIENCE = "dirac"
|
35
|
+
ACCESS_TOKEN_EXPIRE_MINUTES = 30
|
36
|
+
|
37
|
+
|
38
|
+
def pytest_addoption(parser):
|
39
|
+
parser.addoption(
|
40
|
+
"--regenerate-client",
|
41
|
+
action="store_true",
|
42
|
+
default=False,
|
43
|
+
help="Regenerate the AutoREST client",
|
44
|
+
)
|
45
|
+
parser.addoption(
|
46
|
+
"--demo-dir",
|
47
|
+
type=Path,
|
48
|
+
default=None,
|
49
|
+
help="Path to a diracx-charts directory with the demo running",
|
50
|
+
)
|
51
|
+
|
52
|
+
|
53
|
+
def pytest_collection_modifyitems(config, items):
|
54
|
+
"""Disable the test_regenerate_client if not explicitly asked for."""
|
55
|
+
if config.getoption("--regenerate-client"):
|
56
|
+
# --regenerate-client given in cli: allow client re-generation
|
57
|
+
return
|
58
|
+
skip_regen = pytest.mark.skip(reason="need --regenerate-client option to run")
|
59
|
+
for item in items:
|
60
|
+
if item.name == "test_regenerate_client":
|
61
|
+
item.add_marker(skip_regen)
|
62
|
+
|
63
|
+
|
64
|
+
@pytest.fixture(scope="session")
|
65
|
+
def private_key_pem() -> str:
|
66
|
+
from cryptography.hazmat.primitives import serialization
|
67
|
+
from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey
|
68
|
+
|
69
|
+
private_key = Ed25519PrivateKey.generate()
|
70
|
+
return private_key.private_bytes(
|
71
|
+
encoding=serialization.Encoding.PEM,
|
72
|
+
format=serialization.PrivateFormat.PKCS8,
|
73
|
+
encryption_algorithm=serialization.NoEncryption(),
|
74
|
+
).decode()
|
75
|
+
|
76
|
+
|
77
|
+
@pytest.fixture(scope="session")
|
78
|
+
def fernet_key() -> str:
|
79
|
+
from cryptography.fernet import Fernet
|
80
|
+
|
81
|
+
return Fernet.generate_key().decode()
|
82
|
+
|
83
|
+
|
84
|
+
@pytest.fixture(scope="session")
|
85
|
+
def test_dev_settings() -> Generator[DevelopmentSettings, None, None]:
|
86
|
+
from diracx.core.settings import DevelopmentSettings
|
87
|
+
|
88
|
+
yield DevelopmentSettings()
|
89
|
+
|
90
|
+
|
91
|
+
@pytest.fixture(scope="session")
|
92
|
+
def test_auth_settings(
|
93
|
+
private_key_pem, fernet_key
|
94
|
+
) -> Generator[AuthSettings, None, None]:
|
95
|
+
from diracx.routers.utils.users import AuthSettings
|
96
|
+
|
97
|
+
yield AuthSettings(
|
98
|
+
token_algorithm="EdDSA",
|
99
|
+
token_key=private_key_pem,
|
100
|
+
state_key=fernet_key,
|
101
|
+
allowed_redirects=[
|
102
|
+
"http://diracx.test.invalid:8000/api/docs/oauth2-redirect",
|
103
|
+
],
|
104
|
+
)
|
105
|
+
|
106
|
+
|
107
|
+
@pytest.fixture(scope="session")
|
108
|
+
def aio_moto(worker_id):
|
109
|
+
"""Start the moto server in a separate thread and return the base URL.
|
110
|
+
|
111
|
+
The mocking provided by moto doesn't play nicely with aiobotocore so we use
|
112
|
+
the server directly. See https://github.com/aio-libs/aiobotocore/issues/755
|
113
|
+
"""
|
114
|
+
from moto.server import ThreadedMotoServer
|
115
|
+
|
116
|
+
port = 27132
|
117
|
+
if worker_id != "master":
|
118
|
+
port += int(worker_id.replace("gw", "")) + 1
|
119
|
+
server = ThreadedMotoServer(port=port)
|
120
|
+
server.start()
|
121
|
+
yield {
|
122
|
+
"endpoint_url": f"http://localhost:{port}",
|
123
|
+
"aws_access_key_id": "testing",
|
124
|
+
"aws_secret_access_key": "testing",
|
125
|
+
}
|
126
|
+
server.stop()
|
127
|
+
|
128
|
+
|
129
|
+
@pytest.fixture(scope="session")
|
130
|
+
def test_sandbox_settings(aio_moto) -> SandboxStoreSettings:
|
131
|
+
from diracx.routers.jobs.sandboxes import SandboxStoreSettings
|
132
|
+
|
133
|
+
yield SandboxStoreSettings(
|
134
|
+
bucket_name="sandboxes",
|
135
|
+
s3_client_kwargs=aio_moto,
|
136
|
+
auto_create_bucket=True,
|
137
|
+
)
|
138
|
+
|
139
|
+
|
140
|
+
class UnavailableDependency:
|
141
|
+
def __init__(self, key):
|
142
|
+
self.key = key
|
143
|
+
|
144
|
+
def __call__(self):
|
145
|
+
raise NotImplementedError(
|
146
|
+
f"{self.key} has not been made available to this test!"
|
147
|
+
)
|
148
|
+
|
149
|
+
|
150
|
+
class ClientFactory:
|
151
|
+
|
152
|
+
def __init__(
|
153
|
+
self,
|
154
|
+
tmp_path_factory,
|
155
|
+
with_config_repo,
|
156
|
+
test_auth_settings,
|
157
|
+
test_sandbox_settings,
|
158
|
+
test_dev_settings,
|
159
|
+
):
|
160
|
+
from diracx.core.config import ConfigSource
|
161
|
+
from diracx.core.extensions import select_from_extension
|
162
|
+
from diracx.core.settings import ServiceSettingsBase
|
163
|
+
from diracx.db.os.utils import BaseOSDB
|
164
|
+
from diracx.db.sql.utils import BaseSQLDB
|
165
|
+
from diracx.routers import create_app_inner
|
166
|
+
from diracx.routers.access_policies import BaseAccessPolicy
|
167
|
+
|
168
|
+
from .mock_osdb import fake_available_osdb_implementations
|
169
|
+
|
170
|
+
class AlwaysAllowAccessPolicy(BaseAccessPolicy):
|
171
|
+
"""Dummy access policy."""
|
172
|
+
|
173
|
+
@staticmethod
|
174
|
+
async def policy(
|
175
|
+
policy_name: str, user_info: AuthorizedUserInfo, /, **kwargs
|
176
|
+
):
|
177
|
+
pass
|
178
|
+
|
179
|
+
@staticmethod
|
180
|
+
def enrich_tokens(access_payload: dict, refresh_payload: dict):
|
181
|
+
|
182
|
+
return {"PolicySpecific": "OpenAccessForTest"}, {}
|
183
|
+
|
184
|
+
enabled_systems = {
|
185
|
+
e.name for e in select_from_extension(group="diracx.services")
|
186
|
+
}
|
187
|
+
database_urls = {
|
188
|
+
e.name: "sqlite+aiosqlite:///:memory:"
|
189
|
+
for e in select_from_extension(group="diracx.db.sql")
|
190
|
+
}
|
191
|
+
# TODO: Monkeypatch this in a less stupid way
|
192
|
+
# TODO: Only use this if opensearch isn't available
|
193
|
+
os_database_conn_kwargs = {
|
194
|
+
e.name: {"sqlalchemy_dsn": "sqlite+aiosqlite:///:memory:"}
|
195
|
+
for e in select_from_extension(group="diracx.db.os")
|
196
|
+
}
|
197
|
+
BaseOSDB.available_implementations = partial(
|
198
|
+
fake_available_osdb_implementations,
|
199
|
+
real_available_implementations=BaseOSDB.available_implementations,
|
200
|
+
)
|
201
|
+
|
202
|
+
self._cache_dir = tmp_path_factory.mktemp("empty-dbs")
|
203
|
+
|
204
|
+
self.test_auth_settings = test_auth_settings
|
205
|
+
self.test_dev_settings = test_dev_settings
|
206
|
+
|
207
|
+
all_access_policies = {
|
208
|
+
e.name: [AlwaysAllowAccessPolicy]
|
209
|
+
+ BaseAccessPolicy.available_implementations(e.name)
|
210
|
+
for e in select_from_extension(group="diracx.access_policies")
|
211
|
+
}
|
212
|
+
|
213
|
+
self.app = create_app_inner(
|
214
|
+
enabled_systems=enabled_systems,
|
215
|
+
all_service_settings=[
|
216
|
+
test_auth_settings,
|
217
|
+
test_sandbox_settings,
|
218
|
+
test_dev_settings,
|
219
|
+
],
|
220
|
+
database_urls=database_urls,
|
221
|
+
os_database_conn_kwargs=os_database_conn_kwargs,
|
222
|
+
config_source=ConfigSource.create_from_url(
|
223
|
+
backend_url=f"git+file://{with_config_repo}"
|
224
|
+
),
|
225
|
+
all_access_policies=all_access_policies,
|
226
|
+
)
|
227
|
+
|
228
|
+
self.all_dependency_overrides = self.app.dependency_overrides.copy()
|
229
|
+
self.app.dependency_overrides = {}
|
230
|
+
for obj in self.all_dependency_overrides:
|
231
|
+
assert issubclass(
|
232
|
+
obj.__self__,
|
233
|
+
(
|
234
|
+
ServiceSettingsBase,
|
235
|
+
BaseSQLDB,
|
236
|
+
BaseOSDB,
|
237
|
+
ConfigSource,
|
238
|
+
BaseAccessPolicy,
|
239
|
+
),
|
240
|
+
), obj
|
241
|
+
|
242
|
+
self.all_lifetime_functions = self.app.lifetime_functions[:]
|
243
|
+
self.app.lifetime_functions = []
|
244
|
+
for obj in self.all_lifetime_functions:
|
245
|
+
assert isinstance(
|
246
|
+
obj.__self__, (ServiceSettingsBase, BaseSQLDB, BaseOSDB, ConfigSource)
|
247
|
+
), obj
|
248
|
+
|
249
|
+
@contextlib.contextmanager
|
250
|
+
def configure(self, enabled_dependencies):
|
251
|
+
|
252
|
+
assert (
|
253
|
+
self.app.dependency_overrides == {} and self.app.lifetime_functions == []
|
254
|
+
), "configure cannot be nested"
|
255
|
+
for k, v in self.all_dependency_overrides.items():
|
256
|
+
|
257
|
+
class_name = k.__self__.__name__
|
258
|
+
|
259
|
+
if class_name in enabled_dependencies:
|
260
|
+
self.app.dependency_overrides[k] = v
|
261
|
+
else:
|
262
|
+
self.app.dependency_overrides[k] = UnavailableDependency(class_name)
|
263
|
+
|
264
|
+
for obj in self.all_lifetime_functions:
|
265
|
+
# TODO: We should use the name of the entry point instead of the class name
|
266
|
+
if obj.__self__.__class__.__name__ in enabled_dependencies:
|
267
|
+
self.app.lifetime_functions.append(obj)
|
268
|
+
|
269
|
+
# Add create_db_schemas to the end of the lifetime_functions so that the
|
270
|
+
# other lifetime_functions (i.e. those which run db.engine_context) have
|
271
|
+
# already been ran
|
272
|
+
self.app.lifetime_functions.append(self.create_db_schemas)
|
273
|
+
|
274
|
+
try:
|
275
|
+
yield
|
276
|
+
finally:
|
277
|
+
self.app.dependency_overrides = {}
|
278
|
+
self.app.lifetime_functions = []
|
279
|
+
|
280
|
+
@contextlib.asynccontextmanager
|
281
|
+
async def create_db_schemas(self):
|
282
|
+
"""Create DB schema's based on the DBs available in app.dependency_overrides."""
|
283
|
+
import aiosqlite
|
284
|
+
import sqlalchemy
|
285
|
+
from sqlalchemy.util.concurrency import greenlet_spawn
|
286
|
+
|
287
|
+
from diracx.db.sql.utils import BaseSQLDB
|
288
|
+
|
289
|
+
for k, v in self.app.dependency_overrides.items():
|
290
|
+
# Ignore dependency overrides which aren't BaseSQLDB.transaction
|
291
|
+
if (
|
292
|
+
isinstance(v, UnavailableDependency)
|
293
|
+
or k.__func__ != BaseSQLDB.transaction.__func__
|
294
|
+
):
|
295
|
+
continue
|
296
|
+
# The first argument of the overridden BaseSQLDB.transaction is the DB object
|
297
|
+
db = v.args[0]
|
298
|
+
assert isinstance(db, BaseSQLDB), (k, db)
|
299
|
+
|
300
|
+
# set PRAGMA foreign_keys=ON if sqlite
|
301
|
+
if db.engine.url.drivername.startswith("sqlite"):
|
302
|
+
|
303
|
+
def set_sqlite_pragma(dbapi_connection, connection_record):
|
304
|
+
cursor = dbapi_connection.cursor()
|
305
|
+
cursor.execute("PRAGMA foreign_keys=ON")
|
306
|
+
cursor.close()
|
307
|
+
|
308
|
+
sqlalchemy.event.listen(
|
309
|
+
db.engine.sync_engine, "connect", set_sqlite_pragma
|
310
|
+
)
|
311
|
+
|
312
|
+
# We maintain a cache of the populated DBs in empty_db_dir so that
|
313
|
+
# we don't have to recreate them for every test. This speeds up the
|
314
|
+
# tests by a considerable amount.
|
315
|
+
ref_db = self._cache_dir / f"{k.__self__.__name__}.db"
|
316
|
+
if ref_db.exists():
|
317
|
+
async with aiosqlite.connect(ref_db) as ref_conn:
|
318
|
+
conn = await db.engine.raw_connection()
|
319
|
+
await ref_conn.backup(conn.driver_connection)
|
320
|
+
await greenlet_spawn(conn.close)
|
321
|
+
else:
|
322
|
+
async with db.engine.begin() as conn:
|
323
|
+
await conn.run_sync(db.metadata.create_all)
|
324
|
+
|
325
|
+
async with aiosqlite.connect(ref_db) as ref_conn:
|
326
|
+
conn = await db.engine.raw_connection()
|
327
|
+
await conn.driver_connection.backup(ref_conn)
|
328
|
+
await greenlet_spawn(conn.close)
|
329
|
+
|
330
|
+
yield
|
331
|
+
|
332
|
+
@contextlib.contextmanager
|
333
|
+
def unauthenticated(self):
|
334
|
+
from fastapi.testclient import TestClient
|
335
|
+
|
336
|
+
with TestClient(self.app) as client:
|
337
|
+
yield client
|
338
|
+
|
339
|
+
@contextlib.contextmanager
|
340
|
+
def normal_user(self):
|
341
|
+
from diracx.core.properties import NORMAL_USER
|
342
|
+
from diracx.routers.auth.token import create_token
|
343
|
+
|
344
|
+
with self.unauthenticated() as client:
|
345
|
+
payload = {
|
346
|
+
"sub": "testingVO:yellow-sub",
|
347
|
+
"exp": datetime.now(tz=timezone.utc)
|
348
|
+
+ timedelta(self.test_auth_settings.access_token_expire_minutes),
|
349
|
+
"iss": ISSUER,
|
350
|
+
"dirac_properties": [NORMAL_USER],
|
351
|
+
"jti": str(uuid4()),
|
352
|
+
"preferred_username": "preferred_username",
|
353
|
+
"dirac_group": "test_group",
|
354
|
+
"vo": "lhcb",
|
355
|
+
}
|
356
|
+
token = create_token(payload, self.test_auth_settings)
|
357
|
+
|
358
|
+
client.headers["Authorization"] = f"Bearer {token}"
|
359
|
+
client.dirac_token_payload = payload
|
360
|
+
yield client
|
361
|
+
|
362
|
+
@contextlib.contextmanager
|
363
|
+
def admin_user(self):
|
364
|
+
from diracx.core.properties import JOB_ADMINISTRATOR
|
365
|
+
from diracx.routers.auth.token import create_token
|
366
|
+
|
367
|
+
with self.unauthenticated() as client:
|
368
|
+
payload = {
|
369
|
+
"sub": "testingVO:yellow-sub",
|
370
|
+
"iss": ISSUER,
|
371
|
+
"dirac_properties": [JOB_ADMINISTRATOR],
|
372
|
+
"jti": str(uuid4()),
|
373
|
+
"preferred_username": "preferred_username",
|
374
|
+
"dirac_group": "test_group",
|
375
|
+
"vo": "lhcb",
|
376
|
+
}
|
377
|
+
token = create_token(payload, self.test_auth_settings)
|
378
|
+
client.headers["Authorization"] = f"Bearer {token}"
|
379
|
+
client.dirac_token_payload = payload
|
380
|
+
yield client
|
381
|
+
|
382
|
+
|
383
|
+
@pytest.fixture(scope="session")
|
384
|
+
def session_client_factory(
|
385
|
+
test_auth_settings,
|
386
|
+
test_sandbox_settings,
|
387
|
+
with_config_repo,
|
388
|
+
tmp_path_factory,
|
389
|
+
test_dev_settings,
|
390
|
+
):
|
391
|
+
"""TODO.
|
392
|
+
----
|
393
|
+
|
394
|
+
"""
|
395
|
+
yield ClientFactory(
|
396
|
+
tmp_path_factory,
|
397
|
+
with_config_repo,
|
398
|
+
test_auth_settings,
|
399
|
+
test_sandbox_settings,
|
400
|
+
test_dev_settings,
|
401
|
+
)
|
402
|
+
|
403
|
+
|
404
|
+
@pytest.fixture
|
405
|
+
def client_factory(session_client_factory, request):
|
406
|
+
marker = request.node.get_closest_marker("enabled_dependencies")
|
407
|
+
if marker is None:
|
408
|
+
raise RuntimeError("This test requires the enabled_dependencies marker")
|
409
|
+
(enabled_dependencies,) = marker.args
|
410
|
+
with session_client_factory.configure(enabled_dependencies=enabled_dependencies):
|
411
|
+
yield session_client_factory
|
412
|
+
|
413
|
+
|
414
|
+
@pytest.fixture(scope="session")
|
415
|
+
def with_config_repo(tmp_path_factory):
|
416
|
+
from git import Repo
|
417
|
+
|
418
|
+
from diracx.core.config import Config
|
419
|
+
|
420
|
+
tmp_path = tmp_path_factory.mktemp("cs-repo")
|
421
|
+
|
422
|
+
repo = Repo.init(tmp_path, initial_branch="master")
|
423
|
+
cs_file = tmp_path / "default.yml"
|
424
|
+
example_cs = Config.model_validate(
|
425
|
+
{
|
426
|
+
"DIRAC": {},
|
427
|
+
"Registry": {
|
428
|
+
"lhcb": {
|
429
|
+
"DefaultGroup": "lhcb_user",
|
430
|
+
"DefaultProxyLifeTime": 432000,
|
431
|
+
"DefaultStorageQuota": 2000,
|
432
|
+
"IdP": {
|
433
|
+
"URL": "https://idp-server.invalid",
|
434
|
+
"ClientID": "test-idp",
|
435
|
+
},
|
436
|
+
"Users": {
|
437
|
+
"b824d4dc-1f9d-4ee8-8df5-c0ae55d46041": {
|
438
|
+
"PreferedUsername": "chaen",
|
439
|
+
"Email": None,
|
440
|
+
},
|
441
|
+
"c935e5ed-2g0e-5ff9-9eg6-d1bf66e57152": {
|
442
|
+
"PreferedUsername": "albdr",
|
443
|
+
"Email": None,
|
444
|
+
},
|
445
|
+
},
|
446
|
+
"Groups": {
|
447
|
+
"lhcb_user": {
|
448
|
+
"Properties": ["NormalUser", "PrivateLimitedDelegation"],
|
449
|
+
"Users": [
|
450
|
+
"b824d4dc-1f9d-4ee8-8df5-c0ae55d46041",
|
451
|
+
"c935e5ed-2g0e-5ff9-9eg6-d1bf66e57152",
|
452
|
+
],
|
453
|
+
},
|
454
|
+
"lhcb_prmgr": {
|
455
|
+
"Properties": ["NormalUser", "ProductionManagement"],
|
456
|
+
"Users": ["b824d4dc-1f9d-4ee8-8df5-c0ae55d46041"],
|
457
|
+
},
|
458
|
+
"lhcb_tokenmgr": {
|
459
|
+
"Properties": ["NormalUser", "ProxyManagement"],
|
460
|
+
"Users": ["c935e5ed-2g0e-5ff9-9eg6-d1bf66e57152"],
|
461
|
+
},
|
462
|
+
},
|
463
|
+
}
|
464
|
+
},
|
465
|
+
"Operations": {"Defaults": {}},
|
466
|
+
"Systems": {
|
467
|
+
"WorkloadManagement": {
|
468
|
+
"Production": {
|
469
|
+
"Databases": {
|
470
|
+
"JobDB": {
|
471
|
+
"DBName": "xyz",
|
472
|
+
"Host": "xyz",
|
473
|
+
"Port": 9999,
|
474
|
+
"MaxRescheduling": 3,
|
475
|
+
},
|
476
|
+
"JobLoggingDB": {
|
477
|
+
"DBName": "xyz",
|
478
|
+
"Host": "xyz",
|
479
|
+
"Port": 9999,
|
480
|
+
},
|
481
|
+
"PilotAgentsDB": {
|
482
|
+
"DBName": "xyz",
|
483
|
+
"Host": "xyz",
|
484
|
+
"Port": 9999,
|
485
|
+
},
|
486
|
+
"SandboxMetadataDB": {
|
487
|
+
"DBName": "xyz",
|
488
|
+
"Host": "xyz",
|
489
|
+
"Port": 9999,
|
490
|
+
},
|
491
|
+
"TaskQueueDB": {
|
492
|
+
"DBName": "xyz",
|
493
|
+
"Host": "xyz",
|
494
|
+
"Port": 9999,
|
495
|
+
},
|
496
|
+
"ElasticJobParametersDB": {
|
497
|
+
"DBName": "xyz",
|
498
|
+
"Host": "xyz",
|
499
|
+
"Port": 9999,
|
500
|
+
},
|
501
|
+
"VirtualMachineDB": {
|
502
|
+
"DBName": "xyz",
|
503
|
+
"Host": "xyz",
|
504
|
+
"Port": 9999,
|
505
|
+
},
|
506
|
+
},
|
507
|
+
},
|
508
|
+
},
|
509
|
+
},
|
510
|
+
}
|
511
|
+
)
|
512
|
+
cs_file.write_text(example_cs.model_dump_json())
|
513
|
+
repo.index.add([cs_file]) # add it to the index
|
514
|
+
repo.index.commit("Added a new file")
|
515
|
+
yield tmp_path
|
516
|
+
|
517
|
+
|
518
|
+
@pytest.fixture(scope="session")
|
519
|
+
def demo_dir(request) -> Path:
|
520
|
+
demo_dir = request.config.getoption("--demo-dir")
|
521
|
+
if demo_dir is None:
|
522
|
+
pytest.skip("Requires a running instance of the DiracX demo")
|
523
|
+
demo_dir = (demo_dir / ".demo").resolve()
|
524
|
+
yield demo_dir
|
525
|
+
|
526
|
+
|
527
|
+
@pytest.fixture(scope="session")
|
528
|
+
def demo_urls(demo_dir):
|
529
|
+
import yaml
|
530
|
+
|
531
|
+
helm_values = yaml.safe_load((demo_dir / "values.yaml").read_text())
|
532
|
+
yield helm_values["developer"]["urls"]
|
533
|
+
|
534
|
+
|
535
|
+
@pytest.fixture(scope="session")
|
536
|
+
def demo_kubectl_env(demo_dir):
|
537
|
+
"""Get the dictionary of environment variables for kubectl to control the demo."""
|
538
|
+
kube_conf = demo_dir / "kube.conf"
|
539
|
+
if not kube_conf.exists():
|
540
|
+
raise RuntimeError(f"Could not find {kube_conf}, is the demo running?")
|
541
|
+
|
542
|
+
env = {
|
543
|
+
**os.environ,
|
544
|
+
"KUBECONFIG": str(kube_conf),
|
545
|
+
"PATH": f"{demo_dir}:{os.environ['PATH']}",
|
546
|
+
}
|
547
|
+
|
548
|
+
# Check that we can run kubectl
|
549
|
+
pods_result = subprocess.check_output(
|
550
|
+
["kubectl", "get", "pods"], env=env, text=True
|
551
|
+
)
|
552
|
+
assert "diracx" in pods_result
|
553
|
+
|
554
|
+
yield env
|
555
|
+
|
556
|
+
|
557
|
+
@pytest.fixture
|
558
|
+
def cli_env(monkeypatch, tmp_path, demo_urls, demo_dir):
|
559
|
+
"""Set up the environment for the CLI."""
|
560
|
+
import httpx
|
561
|
+
|
562
|
+
from diracx.core.preferences import get_diracx_preferences
|
563
|
+
|
564
|
+
diracx_url = demo_urls["diracx"]
|
565
|
+
ca_path = demo_dir / "demo-ca.pem"
|
566
|
+
if not ca_path.exists():
|
567
|
+
raise RuntimeError(f"Could not find {ca_path}, is the demo running?")
|
568
|
+
|
569
|
+
# Ensure the demo is working
|
570
|
+
|
571
|
+
r = httpx.get(
|
572
|
+
f"{diracx_url}/api/openapi.json",
|
573
|
+
verify=ssl.create_default_context(cafile=ca_path),
|
574
|
+
)
|
575
|
+
r.raise_for_status()
|
576
|
+
assert r.json()["info"]["title"] == "Dirac"
|
577
|
+
|
578
|
+
env = {
|
579
|
+
"DIRACX_URL": diracx_url,
|
580
|
+
"DIRACX_CA_PATH": str(ca_path),
|
581
|
+
"HOME": str(tmp_path),
|
582
|
+
}
|
583
|
+
for key, value in env.items():
|
584
|
+
monkeypatch.setenv(key, value)
|
585
|
+
yield env
|
586
|
+
|
587
|
+
# The DiracX preferences are cached however when testing this cache is invalid
|
588
|
+
get_diracx_preferences.cache_clear()
|
589
|
+
|
590
|
+
|
591
|
+
@pytest.fixture
|
592
|
+
async def with_cli_login(monkeypatch, capfd, cli_env, tmp_path):
|
593
|
+
try:
|
594
|
+
credentials = await test_login(monkeypatch, capfd, cli_env)
|
595
|
+
except Exception as e:
|
596
|
+
pytest.skip(f"Login failed, fix test_login to re-enable this test: {e!r}")
|
597
|
+
|
598
|
+
credentials_path = tmp_path / "credentials.json"
|
599
|
+
credentials_path.write_text(credentials)
|
600
|
+
monkeypatch.setenv("DIRACX_CREDENTIALS_PATH", str(credentials_path))
|
601
|
+
yield
|
602
|
+
|
603
|
+
|
604
|
+
async def test_login(monkeypatch, capfd, cli_env):
|
605
|
+
from diracx import cli
|
606
|
+
|
607
|
+
poll_attempts = 0
|
608
|
+
|
609
|
+
def fake_sleep(*args, **kwargs):
|
610
|
+
nonlocal poll_attempts
|
611
|
+
|
612
|
+
# Keep track of the number of times this is called
|
613
|
+
poll_attempts += 1
|
614
|
+
|
615
|
+
# After polling 5 times, do the actual login
|
616
|
+
if poll_attempts == 5:
|
617
|
+
# The login URL should have been printed to stdout
|
618
|
+
captured = capfd.readouterr()
|
619
|
+
match = re.search(rf"{cli_env['DIRACX_URL']}[^\n]+", captured.out)
|
620
|
+
assert match, captured
|
621
|
+
|
622
|
+
do_device_flow_with_dex(match.group(), cli_env["DIRACX_CA_PATH"])
|
623
|
+
|
624
|
+
# Ensure we don't poll forever
|
625
|
+
assert poll_attempts <= 100
|
626
|
+
|
627
|
+
# Reduce the sleep duration to zero to speed up the test
|
628
|
+
return unpatched_sleep(0)
|
629
|
+
|
630
|
+
# We monkeypatch asyncio.sleep to provide a hook to run the actions that
|
631
|
+
# would normally be done by a user. This includes capturing the login URL
|
632
|
+
# and doing the actual device flow with dex.
|
633
|
+
unpatched_sleep = asyncio.sleep
|
634
|
+
|
635
|
+
expected_credentials_path = Path(
|
636
|
+
cli_env["HOME"], ".cache", "diracx", "credentials.json"
|
637
|
+
)
|
638
|
+
# Ensure the credentials file does not exist before logging in
|
639
|
+
assert not expected_credentials_path.exists()
|
640
|
+
|
641
|
+
# Run the login command
|
642
|
+
with monkeypatch.context() as m:
|
643
|
+
m.setattr("asyncio.sleep", fake_sleep)
|
644
|
+
await cli.auth.login(vo="diracAdmin", group=None, property=None)
|
645
|
+
captured = capfd.readouterr()
|
646
|
+
assert "Login successful!" in captured.out
|
647
|
+
assert captured.err == ""
|
648
|
+
|
649
|
+
# Ensure the credentials file exists after logging in
|
650
|
+
assert expected_credentials_path.exists()
|
651
|
+
|
652
|
+
# Return the credentials so this test can also be used by the
|
653
|
+
# "with_cli_login" fixture
|
654
|
+
return expected_credentials_path.read_text()
|
655
|
+
|
656
|
+
|
657
|
+
def do_device_flow_with_dex(url: str, ca_path: str) -> None:
|
658
|
+
"""Do the device flow with dex."""
|
659
|
+
|
660
|
+
class DexLoginFormParser(HTMLParser):
|
661
|
+
def handle_starttag(self, tag, attrs):
|
662
|
+
nonlocal action_url
|
663
|
+
if "form" in str(tag):
|
664
|
+
assert action_url is None
|
665
|
+
action_url = urljoin(login_page_url, dict(attrs)["action"])
|
666
|
+
|
667
|
+
# Get the login page
|
668
|
+
r = requests.get(url, verify=ca_path)
|
669
|
+
r.raise_for_status()
|
670
|
+
login_page_url = r.url # This is not the same as URL as we redirect to dex
|
671
|
+
login_page_body = r.text
|
672
|
+
|
673
|
+
# Search the page for the login form so we know where to post the credentials
|
674
|
+
action_url = None
|
675
|
+
DexLoginFormParser().feed(login_page_body)
|
676
|
+
assert action_url is not None, login_page_body
|
677
|
+
|
678
|
+
# Do the actual login
|
679
|
+
r = requests.post(
|
680
|
+
action_url,
|
681
|
+
data={"login": "admin@example.com", "password": "password"},
|
682
|
+
verify=ca_path,
|
683
|
+
)
|
684
|
+
r.raise_for_status()
|
685
|
+
approval_url = r.url # This is not the same as URL as we redirect to dex
|
686
|
+
# Do the actual approval
|
687
|
+
r = requests.post(
|
688
|
+
approval_url,
|
689
|
+
{"approval": "approve", "req": parse_qs(urlparse(r.url).query)["req"][0]},
|
690
|
+
verify=ca_path,
|
691
|
+
)
|
692
|
+
|
693
|
+
# This should have redirected to the DiracX page that shows the login is complete
|
694
|
+
assert "Please close the window" in r.text
|