diracx-testing 0.0.1a23__py3-none-any.whl → 0.0.1a25__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,704 @@
1
+ """Utilities for testing DiracX."""
2
+
3
+ from __future__ import annotations
4
+
5
+ # TODO: this needs a lot of documentation, in particular what will matter for users
6
+ # are the enabled_dependencies markers
7
+ import asyncio
8
+ import contextlib
9
+ import os
10
+ import re
11
+ import ssl
12
+ import subprocess
13
+ from datetime import datetime, timedelta, timezone
14
+ from functools import partial
15
+ from html.parser import HTMLParser
16
+ from pathlib import Path
17
+ from typing import TYPE_CHECKING, Generator
18
+ from urllib.parse import parse_qs, urljoin, urlparse
19
+ from uuid import uuid4
20
+
21
+ import pytest
22
+ import requests
23
+
24
+ if TYPE_CHECKING:
25
+ from diracx.core.settings import DevelopmentSettings
26
+ from diracx.routers.jobs.sandboxes import SandboxStoreSettings
27
+ from diracx.routers.utils.users import AuthorizedUserInfo, AuthSettings
28
+
29
+
30
+ # to get a string like this run:
31
+ # openssl rand -hex 32
32
+ ALGORITHM = "HS256"
33
+ ISSUER = "http://lhcbdirac.cern.ch/"
34
+ AUDIENCE = "dirac"
35
+ ACCESS_TOKEN_EXPIRE_MINUTES = 30
36
+
37
+
38
+ def pytest_addoption(parser):
39
+ parser.addoption(
40
+ "--regenerate-client",
41
+ action="store_true",
42
+ default=False,
43
+ help="Regenerate the AutoREST client",
44
+ )
45
+ parser.addoption(
46
+ "--demo-dir",
47
+ type=Path,
48
+ default=None,
49
+ help="Path to a diracx-charts directory with the demo running",
50
+ )
51
+
52
+
53
+ def pytest_collection_modifyitems(config, items):
54
+ """Disable the test_regenerate_client if not explicitly asked for."""
55
+ if config.getoption("--regenerate-client"):
56
+ # --regenerate-client given in cli: allow client re-generation
57
+ return
58
+ skip_regen = pytest.mark.skip(reason="need --regenerate-client option to run")
59
+ for item in items:
60
+ if item.name == "test_regenerate_client":
61
+ item.add_marker(skip_regen)
62
+
63
+
64
+ @pytest.fixture(scope="session")
65
+ def private_key_pem() -> str:
66
+ from cryptography.hazmat.primitives import serialization
67
+ from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey
68
+
69
+ private_key = Ed25519PrivateKey.generate()
70
+ return private_key.private_bytes(
71
+ encoding=serialization.Encoding.PEM,
72
+ format=serialization.PrivateFormat.PKCS8,
73
+ encryption_algorithm=serialization.NoEncryption(),
74
+ ).decode()
75
+
76
+
77
+ @pytest.fixture(scope="session")
78
+ def fernet_key() -> str:
79
+ from cryptography.fernet import Fernet
80
+
81
+ return Fernet.generate_key().decode()
82
+
83
+
84
+ @pytest.fixture(scope="session")
85
+ def test_dev_settings() -> Generator[DevelopmentSettings, None, None]:
86
+ from diracx.core.settings import DevelopmentSettings
87
+
88
+ yield DevelopmentSettings()
89
+
90
+
91
+ @pytest.fixture(scope="session")
92
+ def test_auth_settings(
93
+ private_key_pem, fernet_key
94
+ ) -> Generator[AuthSettings, None, None]:
95
+ from diracx.routers.utils.users import AuthSettings
96
+
97
+ yield AuthSettings(
98
+ token_algorithm="EdDSA",
99
+ token_key=private_key_pem,
100
+ state_key=fernet_key,
101
+ allowed_redirects=[
102
+ "http://diracx.test.invalid:8000/api/docs/oauth2-redirect",
103
+ ],
104
+ )
105
+
106
+
107
+ @pytest.fixture(scope="session")
108
+ def aio_moto(worker_id):
109
+ """Start the moto server in a separate thread and return the base URL.
110
+
111
+ The mocking provided by moto doesn't play nicely with aiobotocore so we use
112
+ the server directly. See https://github.com/aio-libs/aiobotocore/issues/755
113
+ """
114
+ from moto.server import ThreadedMotoServer
115
+
116
+ port = 27132
117
+ if worker_id != "master":
118
+ port += int(worker_id.replace("gw", "")) + 1
119
+ server = ThreadedMotoServer(port=port)
120
+ server.start()
121
+ yield {
122
+ "endpoint_url": f"http://localhost:{port}",
123
+ "aws_access_key_id": "testing",
124
+ "aws_secret_access_key": "testing",
125
+ }
126
+ server.stop()
127
+
128
+
129
+ @pytest.fixture(scope="session")
130
+ def test_sandbox_settings(aio_moto) -> SandboxStoreSettings:
131
+ from diracx.routers.jobs.sandboxes import SandboxStoreSettings
132
+
133
+ yield SandboxStoreSettings(
134
+ bucket_name="sandboxes",
135
+ s3_client_kwargs=aio_moto,
136
+ auto_create_bucket=True,
137
+ )
138
+
139
+
140
+ class UnavailableDependency:
141
+ def __init__(self, key):
142
+ self.key = key
143
+
144
+ def __call__(self):
145
+ raise NotImplementedError(
146
+ f"{self.key} has not been made available to this test!"
147
+ )
148
+
149
+
150
+ class ClientFactory:
151
+
152
+ def __init__(
153
+ self,
154
+ tmp_path_factory,
155
+ with_config_repo,
156
+ test_auth_settings,
157
+ test_sandbox_settings,
158
+ test_dev_settings,
159
+ ):
160
+ from diracx.core.config import ConfigSource
161
+ from diracx.core.extensions import select_from_extension
162
+ from diracx.core.settings import ServiceSettingsBase
163
+ from diracx.db.os.utils import BaseOSDB
164
+ from diracx.db.sql.utils import BaseSQLDB
165
+ from diracx.routers import create_app_inner
166
+ from diracx.routers.access_policies import BaseAccessPolicy
167
+
168
+ from .mock_osdb import fake_available_osdb_implementations
169
+
170
+ class AlwaysAllowAccessPolicy(BaseAccessPolicy):
171
+ """Dummy access policy."""
172
+
173
+ @staticmethod
174
+ async def policy(
175
+ policy_name: str, user_info: AuthorizedUserInfo, /, **kwargs
176
+ ):
177
+ pass
178
+
179
+ @staticmethod
180
+ def enrich_tokens(access_payload: dict, refresh_payload: dict):
181
+
182
+ return {"PolicySpecific": "OpenAccessForTest"}, {}
183
+
184
+ enabled_systems = {
185
+ e.name for e in select_from_extension(group="diracx.services")
186
+ }
187
+ database_urls = {
188
+ e.name: "sqlite+aiosqlite:///:memory:"
189
+ for e in select_from_extension(group="diracx.db.sql")
190
+ }
191
+ # TODO: Monkeypatch this in a less stupid way
192
+ # TODO: Only use this if opensearch isn't available
193
+ os_database_conn_kwargs = {
194
+ e.name: {"sqlalchemy_dsn": "sqlite+aiosqlite:///:memory:"}
195
+ for e in select_from_extension(group="diracx.db.os")
196
+ }
197
+ BaseOSDB.available_implementations = partial(
198
+ fake_available_osdb_implementations,
199
+ real_available_implementations=BaseOSDB.available_implementations,
200
+ )
201
+
202
+ self._cache_dir = tmp_path_factory.mktemp("empty-dbs")
203
+
204
+ self.test_auth_settings = test_auth_settings
205
+ self.test_dev_settings = test_dev_settings
206
+
207
+ all_access_policies = {
208
+ e.name: [AlwaysAllowAccessPolicy]
209
+ + BaseAccessPolicy.available_implementations(e.name)
210
+ for e in select_from_extension(group="diracx.access_policies")
211
+ }
212
+
213
+ self.app = create_app_inner(
214
+ enabled_systems=enabled_systems,
215
+ all_service_settings=[
216
+ test_auth_settings,
217
+ test_sandbox_settings,
218
+ test_dev_settings,
219
+ ],
220
+ database_urls=database_urls,
221
+ os_database_conn_kwargs=os_database_conn_kwargs,
222
+ config_source=ConfigSource.create_from_url(
223
+ backend_url=f"git+file://{with_config_repo}"
224
+ ),
225
+ all_access_policies=all_access_policies,
226
+ )
227
+
228
+ self.all_dependency_overrides = self.app.dependency_overrides.copy()
229
+ self.app.dependency_overrides = {}
230
+ for obj in self.all_dependency_overrides:
231
+ assert issubclass(
232
+ obj.__self__,
233
+ (
234
+ ServiceSettingsBase,
235
+ BaseSQLDB,
236
+ BaseOSDB,
237
+ ConfigSource,
238
+ BaseAccessPolicy,
239
+ ),
240
+ ), obj
241
+
242
+ self.all_lifetime_functions = self.app.lifetime_functions[:]
243
+ self.app.lifetime_functions = []
244
+ for obj in self.all_lifetime_functions:
245
+ assert isinstance(
246
+ obj.__self__, (ServiceSettingsBase, BaseSQLDB, BaseOSDB, ConfigSource)
247
+ ), obj
248
+
249
+ @contextlib.contextmanager
250
+ def configure(self, enabled_dependencies):
251
+
252
+ assert (
253
+ self.app.dependency_overrides == {} and self.app.lifetime_functions == []
254
+ ), "configure cannot be nested"
255
+
256
+ for k, v in self.all_dependency_overrides.items():
257
+
258
+ class_name = k.__self__.__name__
259
+
260
+ if class_name in enabled_dependencies:
261
+ self.app.dependency_overrides[k] = v
262
+ else:
263
+ self.app.dependency_overrides[k] = UnavailableDependency(class_name)
264
+
265
+ for obj in self.all_lifetime_functions:
266
+ # TODO: We should use the name of the entry point instead of the class name
267
+ if obj.__self__.__class__.__name__ in enabled_dependencies:
268
+ self.app.lifetime_functions.append(obj)
269
+
270
+ # Add create_db_schemas to the end of the lifetime_functions so that the
271
+ # other lifetime_functions (i.e. those which run db.engine_context) have
272
+ # already been ran
273
+ self.app.lifetime_functions.append(self.create_db_schemas)
274
+
275
+ try:
276
+ yield
277
+ finally:
278
+ self.app.dependency_overrides = {}
279
+ self.app.lifetime_functions = []
280
+
281
+ @contextlib.asynccontextmanager
282
+ async def create_db_schemas(self):
283
+ """Create DB schema's based on the DBs available in app.dependency_overrides."""
284
+ import aiosqlite
285
+ import sqlalchemy
286
+ from sqlalchemy.util.concurrency import greenlet_spawn
287
+
288
+ from diracx.db.os.utils import BaseOSDB
289
+ from diracx.db.sql.utils import BaseSQLDB
290
+ from diracx.testing.mock_osdb import MockOSDBMixin
291
+
292
+ for k, v in self.app.dependency_overrides.items():
293
+ # Ignore dependency overrides which aren't BaseSQLDB.transaction or BaseOSDB.session
294
+ if isinstance(v, UnavailableDependency) or k.__func__ not in (
295
+ BaseSQLDB.transaction.__func__,
296
+ BaseOSDB.session.__func__,
297
+ ):
298
+
299
+ continue
300
+
301
+ # The first argument of the overridden BaseSQLDB.transaction is the DB object
302
+ db = v.args[0]
303
+ # We expect the OS DB to be mocked with sqlite, so use the
304
+ # internal DB
305
+ if isinstance(db, MockOSDBMixin):
306
+ db = db._sql_db
307
+
308
+ assert isinstance(db, BaseSQLDB), (k, db)
309
+
310
+ # set PRAGMA foreign_keys=ON if sqlite
311
+ if db.engine.url.drivername.startswith("sqlite"):
312
+
313
+ def set_sqlite_pragma(dbapi_connection, connection_record):
314
+ cursor = dbapi_connection.cursor()
315
+ cursor.execute("PRAGMA foreign_keys=ON")
316
+ cursor.close()
317
+
318
+ sqlalchemy.event.listen(
319
+ db.engine.sync_engine, "connect", set_sqlite_pragma
320
+ )
321
+
322
+ # We maintain a cache of the populated DBs in empty_db_dir so that
323
+ # we don't have to recreate them for every test. This speeds up the
324
+ # tests by a considerable amount.
325
+ ref_db = self._cache_dir / f"{k.__self__.__name__}.db"
326
+ if ref_db.exists():
327
+ async with aiosqlite.connect(ref_db) as ref_conn:
328
+ conn = await db.engine.raw_connection()
329
+ await ref_conn.backup(conn.driver_connection)
330
+ await greenlet_spawn(conn.close)
331
+ else:
332
+ async with db.engine.begin() as conn:
333
+ await conn.run_sync(db.metadata.create_all)
334
+
335
+ async with aiosqlite.connect(ref_db) as ref_conn:
336
+ conn = await db.engine.raw_connection()
337
+ await conn.driver_connection.backup(ref_conn)
338
+ await greenlet_spawn(conn.close)
339
+
340
+ yield
341
+
342
+ @contextlib.contextmanager
343
+ def unauthenticated(self):
344
+ from fastapi.testclient import TestClient
345
+
346
+ with TestClient(self.app) as client:
347
+ yield client
348
+
349
+ @contextlib.contextmanager
350
+ def normal_user(self):
351
+ from diracx.core.properties import NORMAL_USER
352
+ from diracx.routers.auth.token import create_token
353
+
354
+ with self.unauthenticated() as client:
355
+ payload = {
356
+ "sub": "testingVO:yellow-sub",
357
+ "exp": datetime.now(tz=timezone.utc)
358
+ + timedelta(self.test_auth_settings.access_token_expire_minutes),
359
+ "iss": ISSUER,
360
+ "dirac_properties": [NORMAL_USER],
361
+ "jti": str(uuid4()),
362
+ "preferred_username": "preferred_username",
363
+ "dirac_group": "test_group",
364
+ "vo": "lhcb",
365
+ }
366
+ token = create_token(payload, self.test_auth_settings)
367
+
368
+ client.headers["Authorization"] = f"Bearer {token}"
369
+ client.dirac_token_payload = payload
370
+ yield client
371
+
372
+ @contextlib.contextmanager
373
+ def admin_user(self):
374
+ from diracx.core.properties import JOB_ADMINISTRATOR
375
+ from diracx.routers.auth.token import create_token
376
+
377
+ with self.unauthenticated() as client:
378
+ payload = {
379
+ "sub": "testingVO:yellow-sub",
380
+ "iss": ISSUER,
381
+ "dirac_properties": [JOB_ADMINISTRATOR],
382
+ "jti": str(uuid4()),
383
+ "preferred_username": "preferred_username",
384
+ "dirac_group": "test_group",
385
+ "vo": "lhcb",
386
+ }
387
+ token = create_token(payload, self.test_auth_settings)
388
+ client.headers["Authorization"] = f"Bearer {token}"
389
+ client.dirac_token_payload = payload
390
+ yield client
391
+
392
+
393
+ @pytest.fixture(scope="session")
394
+ def session_client_factory(
395
+ test_auth_settings,
396
+ test_sandbox_settings,
397
+ with_config_repo,
398
+ tmp_path_factory,
399
+ test_dev_settings,
400
+ ):
401
+ """TODO.
402
+ ----
403
+
404
+ """
405
+ yield ClientFactory(
406
+ tmp_path_factory,
407
+ with_config_repo,
408
+ test_auth_settings,
409
+ test_sandbox_settings,
410
+ test_dev_settings,
411
+ )
412
+
413
+
414
+ @pytest.fixture
415
+ def client_factory(session_client_factory, request):
416
+ marker = request.node.get_closest_marker("enabled_dependencies")
417
+ if marker is None:
418
+ raise RuntimeError("This test requires the enabled_dependencies marker")
419
+ (enabled_dependencies,) = marker.args
420
+ with session_client_factory.configure(enabled_dependencies=enabled_dependencies):
421
+ yield session_client_factory
422
+
423
+
424
+ @pytest.fixture(scope="session")
425
+ def with_config_repo(tmp_path_factory):
426
+ from git import Repo
427
+
428
+ from diracx.core.config import Config
429
+
430
+ tmp_path = tmp_path_factory.mktemp("cs-repo")
431
+
432
+ repo = Repo.init(tmp_path, initial_branch="master")
433
+ cs_file = tmp_path / "default.yml"
434
+ example_cs = Config.model_validate(
435
+ {
436
+ "DIRAC": {},
437
+ "Registry": {
438
+ "lhcb": {
439
+ "DefaultGroup": "lhcb_user",
440
+ "DefaultProxyLifeTime": 432000,
441
+ "DefaultStorageQuota": 2000,
442
+ "IdP": {
443
+ "URL": "https://idp-server.invalid",
444
+ "ClientID": "test-idp",
445
+ },
446
+ "Users": {
447
+ "b824d4dc-1f9d-4ee8-8df5-c0ae55d46041": {
448
+ "PreferedUsername": "chaen",
449
+ "Email": None,
450
+ },
451
+ "c935e5ed-2g0e-5ff9-9eg6-d1bf66e57152": {
452
+ "PreferedUsername": "albdr",
453
+ "Email": None,
454
+ },
455
+ },
456
+ "Groups": {
457
+ "lhcb_user": {
458
+ "Properties": ["NormalUser", "PrivateLimitedDelegation"],
459
+ "Users": [
460
+ "b824d4dc-1f9d-4ee8-8df5-c0ae55d46041",
461
+ "c935e5ed-2g0e-5ff9-9eg6-d1bf66e57152",
462
+ ],
463
+ },
464
+ "lhcb_prmgr": {
465
+ "Properties": ["NormalUser", "ProductionManagement"],
466
+ "Users": ["b824d4dc-1f9d-4ee8-8df5-c0ae55d46041"],
467
+ },
468
+ "lhcb_tokenmgr": {
469
+ "Properties": ["NormalUser", "ProxyManagement"],
470
+ "Users": ["c935e5ed-2g0e-5ff9-9eg6-d1bf66e57152"],
471
+ },
472
+ },
473
+ }
474
+ },
475
+ "Operations": {"Defaults": {}},
476
+ "Systems": {
477
+ "WorkloadManagement": {
478
+ "Production": {
479
+ "Databases": {
480
+ "JobDB": {
481
+ "DBName": "xyz",
482
+ "Host": "xyz",
483
+ "Port": 9999,
484
+ "MaxRescheduling": 3,
485
+ },
486
+ "JobLoggingDB": {
487
+ "DBName": "xyz",
488
+ "Host": "xyz",
489
+ "Port": 9999,
490
+ },
491
+ "PilotAgentsDB": {
492
+ "DBName": "xyz",
493
+ "Host": "xyz",
494
+ "Port": 9999,
495
+ },
496
+ "SandboxMetadataDB": {
497
+ "DBName": "xyz",
498
+ "Host": "xyz",
499
+ "Port": 9999,
500
+ },
501
+ "TaskQueueDB": {
502
+ "DBName": "xyz",
503
+ "Host": "xyz",
504
+ "Port": 9999,
505
+ },
506
+ "ElasticJobParametersDB": {
507
+ "DBName": "xyz",
508
+ "Host": "xyz",
509
+ "Port": 9999,
510
+ },
511
+ "VirtualMachineDB": {
512
+ "DBName": "xyz",
513
+ "Host": "xyz",
514
+ "Port": 9999,
515
+ },
516
+ },
517
+ },
518
+ },
519
+ },
520
+ }
521
+ )
522
+ cs_file.write_text(example_cs.model_dump_json())
523
+ repo.index.add([cs_file]) # add it to the index
524
+ repo.index.commit("Added a new file")
525
+ yield tmp_path
526
+
527
+
528
+ @pytest.fixture(scope="session")
529
+ def demo_dir(request) -> Path:
530
+ demo_dir = request.config.getoption("--demo-dir")
531
+ if demo_dir is None:
532
+ pytest.skip("Requires a running instance of the DiracX demo")
533
+ demo_dir = (demo_dir / ".demo").resolve()
534
+ yield demo_dir
535
+
536
+
537
+ @pytest.fixture(scope="session")
538
+ def demo_urls(demo_dir):
539
+ import yaml
540
+
541
+ helm_values = yaml.safe_load((demo_dir / "values.yaml").read_text())
542
+ yield helm_values["developer"]["urls"]
543
+
544
+
545
+ @pytest.fixture(scope="session")
546
+ def demo_kubectl_env(demo_dir):
547
+ """Get the dictionary of environment variables for kubectl to control the demo."""
548
+ kube_conf = demo_dir / "kube.conf"
549
+ if not kube_conf.exists():
550
+ raise RuntimeError(f"Could not find {kube_conf}, is the demo running?")
551
+
552
+ env = {
553
+ **os.environ,
554
+ "KUBECONFIG": str(kube_conf),
555
+ "PATH": f"{demo_dir}:{os.environ['PATH']}",
556
+ }
557
+
558
+ # Check that we can run kubectl
559
+ pods_result = subprocess.check_output(
560
+ ["kubectl", "get", "pods"], env=env, text=True
561
+ )
562
+ assert "diracx" in pods_result
563
+
564
+ yield env
565
+
566
+
567
+ @pytest.fixture
568
+ def cli_env(monkeypatch, tmp_path, demo_urls, demo_dir):
569
+ """Set up the environment for the CLI."""
570
+ import httpx
571
+
572
+ from diracx.core.preferences import get_diracx_preferences
573
+
574
+ diracx_url = demo_urls["diracx"]
575
+ ca_path = demo_dir / "demo-ca.pem"
576
+ if not ca_path.exists():
577
+ raise RuntimeError(f"Could not find {ca_path}, is the demo running?")
578
+
579
+ # Ensure the demo is working
580
+
581
+ r = httpx.get(
582
+ f"{diracx_url}/api/openapi.json",
583
+ verify=ssl.create_default_context(cafile=ca_path),
584
+ )
585
+ r.raise_for_status()
586
+ assert r.json()["info"]["title"] == "Dirac"
587
+
588
+ env = {
589
+ "DIRACX_URL": diracx_url,
590
+ "DIRACX_CA_PATH": str(ca_path),
591
+ "HOME": str(tmp_path),
592
+ }
593
+ for key, value in env.items():
594
+ monkeypatch.setenv(key, value)
595
+ yield env
596
+
597
+ # The DiracX preferences are cached however when testing this cache is invalid
598
+ get_diracx_preferences.cache_clear()
599
+
600
+
601
+ @pytest.fixture
602
+ async def with_cli_login(monkeypatch, capfd, cli_env, tmp_path):
603
+ try:
604
+ credentials = await test_login(monkeypatch, capfd, cli_env)
605
+ except Exception as e:
606
+ pytest.skip(f"Login failed, fix test_login to re-enable this test: {e!r}")
607
+
608
+ credentials_path = tmp_path / "credentials.json"
609
+ credentials_path.write_text(credentials)
610
+ monkeypatch.setenv("DIRACX_CREDENTIALS_PATH", str(credentials_path))
611
+ yield
612
+
613
+
614
+ async def test_login(monkeypatch, capfd, cli_env):
615
+ from diracx import cli
616
+
617
+ poll_attempts = 0
618
+
619
+ def fake_sleep(*args, **kwargs):
620
+ nonlocal poll_attempts
621
+
622
+ # Keep track of the number of times this is called
623
+ poll_attempts += 1
624
+
625
+ # After polling 5 times, do the actual login
626
+ if poll_attempts == 5:
627
+ # The login URL should have been printed to stdout
628
+ captured = capfd.readouterr()
629
+ match = re.search(rf"{cli_env['DIRACX_URL']}[^\n]+", captured.out)
630
+ assert match, captured
631
+
632
+ do_device_flow_with_dex(match.group(), cli_env["DIRACX_CA_PATH"])
633
+
634
+ # Ensure we don't poll forever
635
+ assert poll_attempts <= 100
636
+
637
+ # Reduce the sleep duration to zero to speed up the test
638
+ return unpatched_sleep(0)
639
+
640
+ # We monkeypatch asyncio.sleep to provide a hook to run the actions that
641
+ # would normally be done by a user. This includes capturing the login URL
642
+ # and doing the actual device flow with dex.
643
+ unpatched_sleep = asyncio.sleep
644
+
645
+ expected_credentials_path = Path(
646
+ cli_env["HOME"], ".cache", "diracx", "credentials.json"
647
+ )
648
+ # Ensure the credentials file does not exist before logging in
649
+ assert not expected_credentials_path.exists()
650
+
651
+ # Run the login command
652
+ with monkeypatch.context() as m:
653
+ m.setattr("asyncio.sleep", fake_sleep)
654
+ await cli.auth.login(vo="diracAdmin", group=None, property=None)
655
+ captured = capfd.readouterr()
656
+ assert "Login successful!" in captured.out
657
+ assert captured.err == ""
658
+
659
+ # Ensure the credentials file exists after logging in
660
+ assert expected_credentials_path.exists()
661
+
662
+ # Return the credentials so this test can also be used by the
663
+ # "with_cli_login" fixture
664
+ return expected_credentials_path.read_text()
665
+
666
+
667
+ def do_device_flow_with_dex(url: str, ca_path: str) -> None:
668
+ """Do the device flow with dex."""
669
+
670
+ class DexLoginFormParser(HTMLParser):
671
+ def handle_starttag(self, tag, attrs):
672
+ nonlocal action_url
673
+ if "form" in str(tag):
674
+ assert action_url is None
675
+ action_url = urljoin(login_page_url, dict(attrs)["action"])
676
+
677
+ # Get the login page
678
+ r = requests.get(url, verify=ca_path)
679
+ r.raise_for_status()
680
+ login_page_url = r.url # This is not the same as URL as we redirect to dex
681
+ login_page_body = r.text
682
+
683
+ # Search the page for the login form so we know where to post the credentials
684
+ action_url = None
685
+ DexLoginFormParser().feed(login_page_body)
686
+ assert action_url is not None, login_page_body
687
+
688
+ # Do the actual login
689
+ r = requests.post(
690
+ action_url,
691
+ data={"login": "admin@example.com", "password": "password"},
692
+ verify=ca_path,
693
+ )
694
+ r.raise_for_status()
695
+ approval_url = r.url # This is not the same as URL as we redirect to dex
696
+ # Do the actual approval
697
+ r = requests.post(
698
+ approval_url,
699
+ {"approval": "approve", "req": parse_qs(urlparse(r.url).query)["req"][0]},
700
+ verify=ca_path,
701
+ )
702
+
703
+ # This should have redirected to the DiracX page that shows the login is complete
704
+ assert "Please close the window" in r.text