diracx-testing 0.0.1a23__py3-none-any.whl → 0.0.1a24__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,694 @@
1
+ """Utilities for testing DiracX."""
2
+
3
+ from __future__ import annotations
4
+
5
+ # TODO: this needs a lot of documentation, in particular what will matter for users
6
+ # are the enabled_dependencies markers
7
+ import asyncio
8
+ import contextlib
9
+ import os
10
+ import re
11
+ import ssl
12
+ import subprocess
13
+ from datetime import datetime, timedelta, timezone
14
+ from functools import partial
15
+ from html.parser import HTMLParser
16
+ from pathlib import Path
17
+ from typing import TYPE_CHECKING, Generator
18
+ from urllib.parse import parse_qs, urljoin, urlparse
19
+ from uuid import uuid4
20
+
21
+ import pytest
22
+ import requests
23
+
24
+ if TYPE_CHECKING:
25
+ from diracx.core.settings import DevelopmentSettings
26
+ from diracx.routers.jobs.sandboxes import SandboxStoreSettings
27
+ from diracx.routers.utils.users import AuthorizedUserInfo, AuthSettings
28
+
29
+
30
+ # to get a string like this run:
31
+ # openssl rand -hex 32
32
+ ALGORITHM = "HS256"
33
+ ISSUER = "http://lhcbdirac.cern.ch/"
34
+ AUDIENCE = "dirac"
35
+ ACCESS_TOKEN_EXPIRE_MINUTES = 30
36
+
37
+
38
+ def pytest_addoption(parser):
39
+ parser.addoption(
40
+ "--regenerate-client",
41
+ action="store_true",
42
+ default=False,
43
+ help="Regenerate the AutoREST client",
44
+ )
45
+ parser.addoption(
46
+ "--demo-dir",
47
+ type=Path,
48
+ default=None,
49
+ help="Path to a diracx-charts directory with the demo running",
50
+ )
51
+
52
+
53
+ def pytest_collection_modifyitems(config, items):
54
+ """Disable the test_regenerate_client if not explicitly asked for."""
55
+ if config.getoption("--regenerate-client"):
56
+ # --regenerate-client given in cli: allow client re-generation
57
+ return
58
+ skip_regen = pytest.mark.skip(reason="need --regenerate-client option to run")
59
+ for item in items:
60
+ if item.name == "test_regenerate_client":
61
+ item.add_marker(skip_regen)
62
+
63
+
64
+ @pytest.fixture(scope="session")
65
+ def private_key_pem() -> str:
66
+ from cryptography.hazmat.primitives import serialization
67
+ from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey
68
+
69
+ private_key = Ed25519PrivateKey.generate()
70
+ return private_key.private_bytes(
71
+ encoding=serialization.Encoding.PEM,
72
+ format=serialization.PrivateFormat.PKCS8,
73
+ encryption_algorithm=serialization.NoEncryption(),
74
+ ).decode()
75
+
76
+
77
+ @pytest.fixture(scope="session")
78
+ def fernet_key() -> str:
79
+ from cryptography.fernet import Fernet
80
+
81
+ return Fernet.generate_key().decode()
82
+
83
+
84
+ @pytest.fixture(scope="session")
85
+ def test_dev_settings() -> Generator[DevelopmentSettings, None, None]:
86
+ from diracx.core.settings import DevelopmentSettings
87
+
88
+ yield DevelopmentSettings()
89
+
90
+
91
+ @pytest.fixture(scope="session")
92
+ def test_auth_settings(
93
+ private_key_pem, fernet_key
94
+ ) -> Generator[AuthSettings, None, None]:
95
+ from diracx.routers.utils.users import AuthSettings
96
+
97
+ yield AuthSettings(
98
+ token_algorithm="EdDSA",
99
+ token_key=private_key_pem,
100
+ state_key=fernet_key,
101
+ allowed_redirects=[
102
+ "http://diracx.test.invalid:8000/api/docs/oauth2-redirect",
103
+ ],
104
+ )
105
+
106
+
107
+ @pytest.fixture(scope="session")
108
+ def aio_moto(worker_id):
109
+ """Start the moto server in a separate thread and return the base URL.
110
+
111
+ The mocking provided by moto doesn't play nicely with aiobotocore so we use
112
+ the server directly. See https://github.com/aio-libs/aiobotocore/issues/755
113
+ """
114
+ from moto.server import ThreadedMotoServer
115
+
116
+ port = 27132
117
+ if worker_id != "master":
118
+ port += int(worker_id.replace("gw", "")) + 1
119
+ server = ThreadedMotoServer(port=port)
120
+ server.start()
121
+ yield {
122
+ "endpoint_url": f"http://localhost:{port}",
123
+ "aws_access_key_id": "testing",
124
+ "aws_secret_access_key": "testing",
125
+ }
126
+ server.stop()
127
+
128
+
129
+ @pytest.fixture(scope="session")
130
+ def test_sandbox_settings(aio_moto) -> SandboxStoreSettings:
131
+ from diracx.routers.jobs.sandboxes import SandboxStoreSettings
132
+
133
+ yield SandboxStoreSettings(
134
+ bucket_name="sandboxes",
135
+ s3_client_kwargs=aio_moto,
136
+ auto_create_bucket=True,
137
+ )
138
+
139
+
140
+ class UnavailableDependency:
141
+ def __init__(self, key):
142
+ self.key = key
143
+
144
+ def __call__(self):
145
+ raise NotImplementedError(
146
+ f"{self.key} has not been made available to this test!"
147
+ )
148
+
149
+
150
+ class ClientFactory:
151
+
152
+ def __init__(
153
+ self,
154
+ tmp_path_factory,
155
+ with_config_repo,
156
+ test_auth_settings,
157
+ test_sandbox_settings,
158
+ test_dev_settings,
159
+ ):
160
+ from diracx.core.config import ConfigSource
161
+ from diracx.core.extensions import select_from_extension
162
+ from diracx.core.settings import ServiceSettingsBase
163
+ from diracx.db.os.utils import BaseOSDB
164
+ from diracx.db.sql.utils import BaseSQLDB
165
+ from diracx.routers import create_app_inner
166
+ from diracx.routers.access_policies import BaseAccessPolicy
167
+
168
+ from .mock_osdb import fake_available_osdb_implementations
169
+
170
+ class AlwaysAllowAccessPolicy(BaseAccessPolicy):
171
+ """Dummy access policy."""
172
+
173
+ @staticmethod
174
+ async def policy(
175
+ policy_name: str, user_info: AuthorizedUserInfo, /, **kwargs
176
+ ):
177
+ pass
178
+
179
+ @staticmethod
180
+ def enrich_tokens(access_payload: dict, refresh_payload: dict):
181
+
182
+ return {"PolicySpecific": "OpenAccessForTest"}, {}
183
+
184
+ enabled_systems = {
185
+ e.name for e in select_from_extension(group="diracx.services")
186
+ }
187
+ database_urls = {
188
+ e.name: "sqlite+aiosqlite:///:memory:"
189
+ for e in select_from_extension(group="diracx.db.sql")
190
+ }
191
+ # TODO: Monkeypatch this in a less stupid way
192
+ # TODO: Only use this if opensearch isn't available
193
+ os_database_conn_kwargs = {
194
+ e.name: {"sqlalchemy_dsn": "sqlite+aiosqlite:///:memory:"}
195
+ for e in select_from_extension(group="diracx.db.os")
196
+ }
197
+ BaseOSDB.available_implementations = partial(
198
+ fake_available_osdb_implementations,
199
+ real_available_implementations=BaseOSDB.available_implementations,
200
+ )
201
+
202
+ self._cache_dir = tmp_path_factory.mktemp("empty-dbs")
203
+
204
+ self.test_auth_settings = test_auth_settings
205
+ self.test_dev_settings = test_dev_settings
206
+
207
+ all_access_policies = {
208
+ e.name: [AlwaysAllowAccessPolicy]
209
+ + BaseAccessPolicy.available_implementations(e.name)
210
+ for e in select_from_extension(group="diracx.access_policies")
211
+ }
212
+
213
+ self.app = create_app_inner(
214
+ enabled_systems=enabled_systems,
215
+ all_service_settings=[
216
+ test_auth_settings,
217
+ test_sandbox_settings,
218
+ test_dev_settings,
219
+ ],
220
+ database_urls=database_urls,
221
+ os_database_conn_kwargs=os_database_conn_kwargs,
222
+ config_source=ConfigSource.create_from_url(
223
+ backend_url=f"git+file://{with_config_repo}"
224
+ ),
225
+ all_access_policies=all_access_policies,
226
+ )
227
+
228
+ self.all_dependency_overrides = self.app.dependency_overrides.copy()
229
+ self.app.dependency_overrides = {}
230
+ for obj in self.all_dependency_overrides:
231
+ assert issubclass(
232
+ obj.__self__,
233
+ (
234
+ ServiceSettingsBase,
235
+ BaseSQLDB,
236
+ BaseOSDB,
237
+ ConfigSource,
238
+ BaseAccessPolicy,
239
+ ),
240
+ ), obj
241
+
242
+ self.all_lifetime_functions = self.app.lifetime_functions[:]
243
+ self.app.lifetime_functions = []
244
+ for obj in self.all_lifetime_functions:
245
+ assert isinstance(
246
+ obj.__self__, (ServiceSettingsBase, BaseSQLDB, BaseOSDB, ConfigSource)
247
+ ), obj
248
+
249
+ @contextlib.contextmanager
250
+ def configure(self, enabled_dependencies):
251
+
252
+ assert (
253
+ self.app.dependency_overrides == {} and self.app.lifetime_functions == []
254
+ ), "configure cannot be nested"
255
+ for k, v in self.all_dependency_overrides.items():
256
+
257
+ class_name = k.__self__.__name__
258
+
259
+ if class_name in enabled_dependencies:
260
+ self.app.dependency_overrides[k] = v
261
+ else:
262
+ self.app.dependency_overrides[k] = UnavailableDependency(class_name)
263
+
264
+ for obj in self.all_lifetime_functions:
265
+ # TODO: We should use the name of the entry point instead of the class name
266
+ if obj.__self__.__class__.__name__ in enabled_dependencies:
267
+ self.app.lifetime_functions.append(obj)
268
+
269
+ # Add create_db_schemas to the end of the lifetime_functions so that the
270
+ # other lifetime_functions (i.e. those which run db.engine_context) have
271
+ # already been ran
272
+ self.app.lifetime_functions.append(self.create_db_schemas)
273
+
274
+ try:
275
+ yield
276
+ finally:
277
+ self.app.dependency_overrides = {}
278
+ self.app.lifetime_functions = []
279
+
280
+ @contextlib.asynccontextmanager
281
+ async def create_db_schemas(self):
282
+ """Create DB schema's based on the DBs available in app.dependency_overrides."""
283
+ import aiosqlite
284
+ import sqlalchemy
285
+ from sqlalchemy.util.concurrency import greenlet_spawn
286
+
287
+ from diracx.db.sql.utils import BaseSQLDB
288
+
289
+ for k, v in self.app.dependency_overrides.items():
290
+ # Ignore dependency overrides which aren't BaseSQLDB.transaction
291
+ if (
292
+ isinstance(v, UnavailableDependency)
293
+ or k.__func__ != BaseSQLDB.transaction.__func__
294
+ ):
295
+ continue
296
+ # The first argument of the overridden BaseSQLDB.transaction is the DB object
297
+ db = v.args[0]
298
+ assert isinstance(db, BaseSQLDB), (k, db)
299
+
300
+ # set PRAGMA foreign_keys=ON if sqlite
301
+ if db.engine.url.drivername.startswith("sqlite"):
302
+
303
+ def set_sqlite_pragma(dbapi_connection, connection_record):
304
+ cursor = dbapi_connection.cursor()
305
+ cursor.execute("PRAGMA foreign_keys=ON")
306
+ cursor.close()
307
+
308
+ sqlalchemy.event.listen(
309
+ db.engine.sync_engine, "connect", set_sqlite_pragma
310
+ )
311
+
312
+ # We maintain a cache of the populated DBs in empty_db_dir so that
313
+ # we don't have to recreate them for every test. This speeds up the
314
+ # tests by a considerable amount.
315
+ ref_db = self._cache_dir / f"{k.__self__.__name__}.db"
316
+ if ref_db.exists():
317
+ async with aiosqlite.connect(ref_db) as ref_conn:
318
+ conn = await db.engine.raw_connection()
319
+ await ref_conn.backup(conn.driver_connection)
320
+ await greenlet_spawn(conn.close)
321
+ else:
322
+ async with db.engine.begin() as conn:
323
+ await conn.run_sync(db.metadata.create_all)
324
+
325
+ async with aiosqlite.connect(ref_db) as ref_conn:
326
+ conn = await db.engine.raw_connection()
327
+ await conn.driver_connection.backup(ref_conn)
328
+ await greenlet_spawn(conn.close)
329
+
330
+ yield
331
+
332
+ @contextlib.contextmanager
333
+ def unauthenticated(self):
334
+ from fastapi.testclient import TestClient
335
+
336
+ with TestClient(self.app) as client:
337
+ yield client
338
+
339
+ @contextlib.contextmanager
340
+ def normal_user(self):
341
+ from diracx.core.properties import NORMAL_USER
342
+ from diracx.routers.auth.token import create_token
343
+
344
+ with self.unauthenticated() as client:
345
+ payload = {
346
+ "sub": "testingVO:yellow-sub",
347
+ "exp": datetime.now(tz=timezone.utc)
348
+ + timedelta(self.test_auth_settings.access_token_expire_minutes),
349
+ "iss": ISSUER,
350
+ "dirac_properties": [NORMAL_USER],
351
+ "jti": str(uuid4()),
352
+ "preferred_username": "preferred_username",
353
+ "dirac_group": "test_group",
354
+ "vo": "lhcb",
355
+ }
356
+ token = create_token(payload, self.test_auth_settings)
357
+
358
+ client.headers["Authorization"] = f"Bearer {token}"
359
+ client.dirac_token_payload = payload
360
+ yield client
361
+
362
+ @contextlib.contextmanager
363
+ def admin_user(self):
364
+ from diracx.core.properties import JOB_ADMINISTRATOR
365
+ from diracx.routers.auth.token import create_token
366
+
367
+ with self.unauthenticated() as client:
368
+ payload = {
369
+ "sub": "testingVO:yellow-sub",
370
+ "iss": ISSUER,
371
+ "dirac_properties": [JOB_ADMINISTRATOR],
372
+ "jti": str(uuid4()),
373
+ "preferred_username": "preferred_username",
374
+ "dirac_group": "test_group",
375
+ "vo": "lhcb",
376
+ }
377
+ token = create_token(payload, self.test_auth_settings)
378
+ client.headers["Authorization"] = f"Bearer {token}"
379
+ client.dirac_token_payload = payload
380
+ yield client
381
+
382
+
383
+ @pytest.fixture(scope="session")
384
+ def session_client_factory(
385
+ test_auth_settings,
386
+ test_sandbox_settings,
387
+ with_config_repo,
388
+ tmp_path_factory,
389
+ test_dev_settings,
390
+ ):
391
+ """TODO.
392
+ ----
393
+
394
+ """
395
+ yield ClientFactory(
396
+ tmp_path_factory,
397
+ with_config_repo,
398
+ test_auth_settings,
399
+ test_sandbox_settings,
400
+ test_dev_settings,
401
+ )
402
+
403
+
404
+ @pytest.fixture
405
+ def client_factory(session_client_factory, request):
406
+ marker = request.node.get_closest_marker("enabled_dependencies")
407
+ if marker is None:
408
+ raise RuntimeError("This test requires the enabled_dependencies marker")
409
+ (enabled_dependencies,) = marker.args
410
+ with session_client_factory.configure(enabled_dependencies=enabled_dependencies):
411
+ yield session_client_factory
412
+
413
+
414
+ @pytest.fixture(scope="session")
415
+ def with_config_repo(tmp_path_factory):
416
+ from git import Repo
417
+
418
+ from diracx.core.config import Config
419
+
420
+ tmp_path = tmp_path_factory.mktemp("cs-repo")
421
+
422
+ repo = Repo.init(tmp_path, initial_branch="master")
423
+ cs_file = tmp_path / "default.yml"
424
+ example_cs = Config.model_validate(
425
+ {
426
+ "DIRAC": {},
427
+ "Registry": {
428
+ "lhcb": {
429
+ "DefaultGroup": "lhcb_user",
430
+ "DefaultProxyLifeTime": 432000,
431
+ "DefaultStorageQuota": 2000,
432
+ "IdP": {
433
+ "URL": "https://idp-server.invalid",
434
+ "ClientID": "test-idp",
435
+ },
436
+ "Users": {
437
+ "b824d4dc-1f9d-4ee8-8df5-c0ae55d46041": {
438
+ "PreferedUsername": "chaen",
439
+ "Email": None,
440
+ },
441
+ "c935e5ed-2g0e-5ff9-9eg6-d1bf66e57152": {
442
+ "PreferedUsername": "albdr",
443
+ "Email": None,
444
+ },
445
+ },
446
+ "Groups": {
447
+ "lhcb_user": {
448
+ "Properties": ["NormalUser", "PrivateLimitedDelegation"],
449
+ "Users": [
450
+ "b824d4dc-1f9d-4ee8-8df5-c0ae55d46041",
451
+ "c935e5ed-2g0e-5ff9-9eg6-d1bf66e57152",
452
+ ],
453
+ },
454
+ "lhcb_prmgr": {
455
+ "Properties": ["NormalUser", "ProductionManagement"],
456
+ "Users": ["b824d4dc-1f9d-4ee8-8df5-c0ae55d46041"],
457
+ },
458
+ "lhcb_tokenmgr": {
459
+ "Properties": ["NormalUser", "ProxyManagement"],
460
+ "Users": ["c935e5ed-2g0e-5ff9-9eg6-d1bf66e57152"],
461
+ },
462
+ },
463
+ }
464
+ },
465
+ "Operations": {"Defaults": {}},
466
+ "Systems": {
467
+ "WorkloadManagement": {
468
+ "Production": {
469
+ "Databases": {
470
+ "JobDB": {
471
+ "DBName": "xyz",
472
+ "Host": "xyz",
473
+ "Port": 9999,
474
+ "MaxRescheduling": 3,
475
+ },
476
+ "JobLoggingDB": {
477
+ "DBName": "xyz",
478
+ "Host": "xyz",
479
+ "Port": 9999,
480
+ },
481
+ "PilotAgentsDB": {
482
+ "DBName": "xyz",
483
+ "Host": "xyz",
484
+ "Port": 9999,
485
+ },
486
+ "SandboxMetadataDB": {
487
+ "DBName": "xyz",
488
+ "Host": "xyz",
489
+ "Port": 9999,
490
+ },
491
+ "TaskQueueDB": {
492
+ "DBName": "xyz",
493
+ "Host": "xyz",
494
+ "Port": 9999,
495
+ },
496
+ "ElasticJobParametersDB": {
497
+ "DBName": "xyz",
498
+ "Host": "xyz",
499
+ "Port": 9999,
500
+ },
501
+ "VirtualMachineDB": {
502
+ "DBName": "xyz",
503
+ "Host": "xyz",
504
+ "Port": 9999,
505
+ },
506
+ },
507
+ },
508
+ },
509
+ },
510
+ }
511
+ )
512
+ cs_file.write_text(example_cs.model_dump_json())
513
+ repo.index.add([cs_file]) # add it to the index
514
+ repo.index.commit("Added a new file")
515
+ yield tmp_path
516
+
517
+
518
+ @pytest.fixture(scope="session")
519
+ def demo_dir(request) -> Path:
520
+ demo_dir = request.config.getoption("--demo-dir")
521
+ if demo_dir is None:
522
+ pytest.skip("Requires a running instance of the DiracX demo")
523
+ demo_dir = (demo_dir / ".demo").resolve()
524
+ yield demo_dir
525
+
526
+
527
+ @pytest.fixture(scope="session")
528
+ def demo_urls(demo_dir):
529
+ import yaml
530
+
531
+ helm_values = yaml.safe_load((demo_dir / "values.yaml").read_text())
532
+ yield helm_values["developer"]["urls"]
533
+
534
+
535
+ @pytest.fixture(scope="session")
536
+ def demo_kubectl_env(demo_dir):
537
+ """Get the dictionary of environment variables for kubectl to control the demo."""
538
+ kube_conf = demo_dir / "kube.conf"
539
+ if not kube_conf.exists():
540
+ raise RuntimeError(f"Could not find {kube_conf}, is the demo running?")
541
+
542
+ env = {
543
+ **os.environ,
544
+ "KUBECONFIG": str(kube_conf),
545
+ "PATH": f"{demo_dir}:{os.environ['PATH']}",
546
+ }
547
+
548
+ # Check that we can run kubectl
549
+ pods_result = subprocess.check_output(
550
+ ["kubectl", "get", "pods"], env=env, text=True
551
+ )
552
+ assert "diracx" in pods_result
553
+
554
+ yield env
555
+
556
+
557
+ @pytest.fixture
558
+ def cli_env(monkeypatch, tmp_path, demo_urls, demo_dir):
559
+ """Set up the environment for the CLI."""
560
+ import httpx
561
+
562
+ from diracx.core.preferences import get_diracx_preferences
563
+
564
+ diracx_url = demo_urls["diracx"]
565
+ ca_path = demo_dir / "demo-ca.pem"
566
+ if not ca_path.exists():
567
+ raise RuntimeError(f"Could not find {ca_path}, is the demo running?")
568
+
569
+ # Ensure the demo is working
570
+
571
+ r = httpx.get(
572
+ f"{diracx_url}/api/openapi.json",
573
+ verify=ssl.create_default_context(cafile=ca_path),
574
+ )
575
+ r.raise_for_status()
576
+ assert r.json()["info"]["title"] == "Dirac"
577
+
578
+ env = {
579
+ "DIRACX_URL": diracx_url,
580
+ "DIRACX_CA_PATH": str(ca_path),
581
+ "HOME": str(tmp_path),
582
+ }
583
+ for key, value in env.items():
584
+ monkeypatch.setenv(key, value)
585
+ yield env
586
+
587
+ # The DiracX preferences are cached however when testing this cache is invalid
588
+ get_diracx_preferences.cache_clear()
589
+
590
+
591
+ @pytest.fixture
592
+ async def with_cli_login(monkeypatch, capfd, cli_env, tmp_path):
593
+ try:
594
+ credentials = await test_login(monkeypatch, capfd, cli_env)
595
+ except Exception as e:
596
+ pytest.skip(f"Login failed, fix test_login to re-enable this test: {e!r}")
597
+
598
+ credentials_path = tmp_path / "credentials.json"
599
+ credentials_path.write_text(credentials)
600
+ monkeypatch.setenv("DIRACX_CREDENTIALS_PATH", str(credentials_path))
601
+ yield
602
+
603
+
604
+ async def test_login(monkeypatch, capfd, cli_env):
605
+ from diracx import cli
606
+
607
+ poll_attempts = 0
608
+
609
+ def fake_sleep(*args, **kwargs):
610
+ nonlocal poll_attempts
611
+
612
+ # Keep track of the number of times this is called
613
+ poll_attempts += 1
614
+
615
+ # After polling 5 times, do the actual login
616
+ if poll_attempts == 5:
617
+ # The login URL should have been printed to stdout
618
+ captured = capfd.readouterr()
619
+ match = re.search(rf"{cli_env['DIRACX_URL']}[^\n]+", captured.out)
620
+ assert match, captured
621
+
622
+ do_device_flow_with_dex(match.group(), cli_env["DIRACX_CA_PATH"])
623
+
624
+ # Ensure we don't poll forever
625
+ assert poll_attempts <= 100
626
+
627
+ # Reduce the sleep duration to zero to speed up the test
628
+ return unpatched_sleep(0)
629
+
630
+ # We monkeypatch asyncio.sleep to provide a hook to run the actions that
631
+ # would normally be done by a user. This includes capturing the login URL
632
+ # and doing the actual device flow with dex.
633
+ unpatched_sleep = asyncio.sleep
634
+
635
+ expected_credentials_path = Path(
636
+ cli_env["HOME"], ".cache", "diracx", "credentials.json"
637
+ )
638
+ # Ensure the credentials file does not exist before logging in
639
+ assert not expected_credentials_path.exists()
640
+
641
+ # Run the login command
642
+ with monkeypatch.context() as m:
643
+ m.setattr("asyncio.sleep", fake_sleep)
644
+ await cli.auth.login(vo="diracAdmin", group=None, property=None)
645
+ captured = capfd.readouterr()
646
+ assert "Login successful!" in captured.out
647
+ assert captured.err == ""
648
+
649
+ # Ensure the credentials file exists after logging in
650
+ assert expected_credentials_path.exists()
651
+
652
+ # Return the credentials so this test can also be used by the
653
+ # "with_cli_login" fixture
654
+ return expected_credentials_path.read_text()
655
+
656
+
657
+ def do_device_flow_with_dex(url: str, ca_path: str) -> None:
658
+ """Do the device flow with dex."""
659
+
660
+ class DexLoginFormParser(HTMLParser):
661
+ def handle_starttag(self, tag, attrs):
662
+ nonlocal action_url
663
+ if "form" in str(tag):
664
+ assert action_url is None
665
+ action_url = urljoin(login_page_url, dict(attrs)["action"])
666
+
667
+ # Get the login page
668
+ r = requests.get(url, verify=ca_path)
669
+ r.raise_for_status()
670
+ login_page_url = r.url # This is not the same as URL as we redirect to dex
671
+ login_page_body = r.text
672
+
673
+ # Search the page for the login form so we know where to post the credentials
674
+ action_url = None
675
+ DexLoginFormParser().feed(login_page_body)
676
+ assert action_url is not None, login_page_body
677
+
678
+ # Do the actual login
679
+ r = requests.post(
680
+ action_url,
681
+ data={"login": "admin@example.com", "password": "password"},
682
+ verify=ca_path,
683
+ )
684
+ r.raise_for_status()
685
+ approval_url = r.url # This is not the same as URL as we redirect to dex
686
+ # Do the actual approval
687
+ r = requests.post(
688
+ approval_url,
689
+ {"approval": "approve", "req": parse_qs(urlparse(r.url).query)["req"][0]},
690
+ verify=ca_path,
691
+ )
692
+
693
+ # This should have redirected to the DiracX page that shows the login is complete
694
+ assert "Please close the window" in r.text
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.2
2
2
  Name: diracx-testing
3
- Version: 0.0.1a23
3
+ Version: 0.0.1a24
4
4
  Summary: TODO
5
5
  License: GPL-3.0-only
6
6
  Classifier: Intended Audience :: Science/Research