diracx-testing 0.0.1a23__py3-none-any.whl → 0.0.1a25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,704 @@
1
+ """Utilities for testing DiracX."""
2
+
3
+ from __future__ import annotations
4
+
5
+ # TODO: this needs a lot of documentation, in particular what will matter for users
6
+ # are the enabled_dependencies markers
7
+ import asyncio
8
+ import contextlib
9
+ import os
10
+ import re
11
+ import ssl
12
+ import subprocess
13
+ from datetime import datetime, timedelta, timezone
14
+ from functools import partial
15
+ from html.parser import HTMLParser
16
+ from pathlib import Path
17
+ from typing import TYPE_CHECKING, Generator
18
+ from urllib.parse import parse_qs, urljoin, urlparse
19
+ from uuid import uuid4
20
+
21
+ import pytest
22
+ import requests
23
+
24
+ if TYPE_CHECKING:
25
+ from diracx.core.settings import DevelopmentSettings
26
+ from diracx.routers.jobs.sandboxes import SandboxStoreSettings
27
+ from diracx.routers.utils.users import AuthorizedUserInfo, AuthSettings
28
+
29
+
30
+ # to get a string like this run:
31
+ # openssl rand -hex 32
32
+ ALGORITHM = "HS256"
33
+ ISSUER = "http://lhcbdirac.cern.ch/"
34
+ AUDIENCE = "dirac"
35
+ ACCESS_TOKEN_EXPIRE_MINUTES = 30
36
+
37
+
38
+ def pytest_addoption(parser):
39
+ parser.addoption(
40
+ "--regenerate-client",
41
+ action="store_true",
42
+ default=False,
43
+ help="Regenerate the AutoREST client",
44
+ )
45
+ parser.addoption(
46
+ "--demo-dir",
47
+ type=Path,
48
+ default=None,
49
+ help="Path to a diracx-charts directory with the demo running",
50
+ )
51
+
52
+
53
+ def pytest_collection_modifyitems(config, items):
54
+ """Disable the test_regenerate_client if not explicitly asked for."""
55
+ if config.getoption("--regenerate-client"):
56
+ # --regenerate-client given in cli: allow client re-generation
57
+ return
58
+ skip_regen = pytest.mark.skip(reason="need --regenerate-client option to run")
59
+ for item in items:
60
+ if item.name == "test_regenerate_client":
61
+ item.add_marker(skip_regen)
62
+
63
+
64
+ @pytest.fixture(scope="session")
65
+ def private_key_pem() -> str:
66
+ from cryptography.hazmat.primitives import serialization
67
+ from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey
68
+
69
+ private_key = Ed25519PrivateKey.generate()
70
+ return private_key.private_bytes(
71
+ encoding=serialization.Encoding.PEM,
72
+ format=serialization.PrivateFormat.PKCS8,
73
+ encryption_algorithm=serialization.NoEncryption(),
74
+ ).decode()
75
+
76
+
77
+ @pytest.fixture(scope="session")
78
+ def fernet_key() -> str:
79
+ from cryptography.fernet import Fernet
80
+
81
+ return Fernet.generate_key().decode()
82
+
83
+
84
+ @pytest.fixture(scope="session")
85
+ def test_dev_settings() -> Generator[DevelopmentSettings, None, None]:
86
+ from diracx.core.settings import DevelopmentSettings
87
+
88
+ yield DevelopmentSettings()
89
+
90
+
91
+ @pytest.fixture(scope="session")
92
+ def test_auth_settings(
93
+ private_key_pem, fernet_key
94
+ ) -> Generator[AuthSettings, None, None]:
95
+ from diracx.routers.utils.users import AuthSettings
96
+
97
+ yield AuthSettings(
98
+ token_algorithm="EdDSA",
99
+ token_key=private_key_pem,
100
+ state_key=fernet_key,
101
+ allowed_redirects=[
102
+ "http://diracx.test.invalid:8000/api/docs/oauth2-redirect",
103
+ ],
104
+ )
105
+
106
+
107
+ @pytest.fixture(scope="session")
108
+ def aio_moto(worker_id):
109
+ """Start the moto server in a separate thread and return the base URL.
110
+
111
+ The mocking provided by moto doesn't play nicely with aiobotocore so we use
112
+ the server directly. See https://github.com/aio-libs/aiobotocore/issues/755
113
+ """
114
+ from moto.server import ThreadedMotoServer
115
+
116
+ port = 27132
117
+ if worker_id != "master":
118
+ port += int(worker_id.replace("gw", "")) + 1
119
+ server = ThreadedMotoServer(port=port)
120
+ server.start()
121
+ yield {
122
+ "endpoint_url": f"http://localhost:{port}",
123
+ "aws_access_key_id": "testing",
124
+ "aws_secret_access_key": "testing",
125
+ }
126
+ server.stop()
127
+
128
+
129
+ @pytest.fixture(scope="session")
130
+ def test_sandbox_settings(aio_moto) -> SandboxStoreSettings:
131
+ from diracx.routers.jobs.sandboxes import SandboxStoreSettings
132
+
133
+ yield SandboxStoreSettings(
134
+ bucket_name="sandboxes",
135
+ s3_client_kwargs=aio_moto,
136
+ auto_create_bucket=True,
137
+ )
138
+
139
+
140
+ class UnavailableDependency:
141
+ def __init__(self, key):
142
+ self.key = key
143
+
144
+ def __call__(self):
145
+ raise NotImplementedError(
146
+ f"{self.key} has not been made available to this test!"
147
+ )
148
+
149
+
150
+ class ClientFactory:
151
+
152
+ def __init__(
153
+ self,
154
+ tmp_path_factory,
155
+ with_config_repo,
156
+ test_auth_settings,
157
+ test_sandbox_settings,
158
+ test_dev_settings,
159
+ ):
160
+ from diracx.core.config import ConfigSource
161
+ from diracx.core.extensions import select_from_extension
162
+ from diracx.core.settings import ServiceSettingsBase
163
+ from diracx.db.os.utils import BaseOSDB
164
+ from diracx.db.sql.utils import BaseSQLDB
165
+ from diracx.routers import create_app_inner
166
+ from diracx.routers.access_policies import BaseAccessPolicy
167
+
168
+ from .mock_osdb import fake_available_osdb_implementations
169
+
170
+ class AlwaysAllowAccessPolicy(BaseAccessPolicy):
171
+ """Dummy access policy."""
172
+
173
+ @staticmethod
174
+ async def policy(
175
+ policy_name: str, user_info: AuthorizedUserInfo, /, **kwargs
176
+ ):
177
+ pass
178
+
179
+ @staticmethod
180
+ def enrich_tokens(access_payload: dict, refresh_payload: dict):
181
+
182
+ return {"PolicySpecific": "OpenAccessForTest"}, {}
183
+
184
+ enabled_systems = {
185
+ e.name for e in select_from_extension(group="diracx.services")
186
+ }
187
+ database_urls = {
188
+ e.name: "sqlite+aiosqlite:///:memory:"
189
+ for e in select_from_extension(group="diracx.db.sql")
190
+ }
191
+ # TODO: Monkeypatch this in a less stupid way
192
+ # TODO: Only use this if opensearch isn't available
193
+ os_database_conn_kwargs = {
194
+ e.name: {"sqlalchemy_dsn": "sqlite+aiosqlite:///:memory:"}
195
+ for e in select_from_extension(group="diracx.db.os")
196
+ }
197
+ BaseOSDB.available_implementations = partial(
198
+ fake_available_osdb_implementations,
199
+ real_available_implementations=BaseOSDB.available_implementations,
200
+ )
201
+
202
+ self._cache_dir = tmp_path_factory.mktemp("empty-dbs")
203
+
204
+ self.test_auth_settings = test_auth_settings
205
+ self.test_dev_settings = test_dev_settings
206
+
207
+ all_access_policies = {
208
+ e.name: [AlwaysAllowAccessPolicy]
209
+ + BaseAccessPolicy.available_implementations(e.name)
210
+ for e in select_from_extension(group="diracx.access_policies")
211
+ }
212
+
213
+ self.app = create_app_inner(
214
+ enabled_systems=enabled_systems,
215
+ all_service_settings=[
216
+ test_auth_settings,
217
+ test_sandbox_settings,
218
+ test_dev_settings,
219
+ ],
220
+ database_urls=database_urls,
221
+ os_database_conn_kwargs=os_database_conn_kwargs,
222
+ config_source=ConfigSource.create_from_url(
223
+ backend_url=f"git+file://{with_config_repo}"
224
+ ),
225
+ all_access_policies=all_access_policies,
226
+ )
227
+
228
+ self.all_dependency_overrides = self.app.dependency_overrides.copy()
229
+ self.app.dependency_overrides = {}
230
+ for obj in self.all_dependency_overrides:
231
+ assert issubclass(
232
+ obj.__self__,
233
+ (
234
+ ServiceSettingsBase,
235
+ BaseSQLDB,
236
+ BaseOSDB,
237
+ ConfigSource,
238
+ BaseAccessPolicy,
239
+ ),
240
+ ), obj
241
+
242
+ self.all_lifetime_functions = self.app.lifetime_functions[:]
243
+ self.app.lifetime_functions = []
244
+ for obj in self.all_lifetime_functions:
245
+ assert isinstance(
246
+ obj.__self__, (ServiceSettingsBase, BaseSQLDB, BaseOSDB, ConfigSource)
247
+ ), obj
248
+
249
+ @contextlib.contextmanager
250
+ def configure(self, enabled_dependencies):
251
+
252
+ assert (
253
+ self.app.dependency_overrides == {} and self.app.lifetime_functions == []
254
+ ), "configure cannot be nested"
255
+
256
+ for k, v in self.all_dependency_overrides.items():
257
+
258
+ class_name = k.__self__.__name__
259
+
260
+ if class_name in enabled_dependencies:
261
+ self.app.dependency_overrides[k] = v
262
+ else:
263
+ self.app.dependency_overrides[k] = UnavailableDependency(class_name)
264
+
265
+ for obj in self.all_lifetime_functions:
266
+ # TODO: We should use the name of the entry point instead of the class name
267
+ if obj.__self__.__class__.__name__ in enabled_dependencies:
268
+ self.app.lifetime_functions.append(obj)
269
+
270
+ # Add create_db_schemas to the end of the lifetime_functions so that the
271
+ # other lifetime_functions (i.e. those which run db.engine_context) have
272
+ # already been ran
273
+ self.app.lifetime_functions.append(self.create_db_schemas)
274
+
275
+ try:
276
+ yield
277
+ finally:
278
+ self.app.dependency_overrides = {}
279
+ self.app.lifetime_functions = []
280
+
281
+ @contextlib.asynccontextmanager
282
+ async def create_db_schemas(self):
283
+ """Create DB schema's based on the DBs available in app.dependency_overrides."""
284
+ import aiosqlite
285
+ import sqlalchemy
286
+ from sqlalchemy.util.concurrency import greenlet_spawn
287
+
288
+ from diracx.db.os.utils import BaseOSDB
289
+ from diracx.db.sql.utils import BaseSQLDB
290
+ from diracx.testing.mock_osdb import MockOSDBMixin
291
+
292
+ for k, v in self.app.dependency_overrides.items():
293
+ # Ignore dependency overrides which aren't BaseSQLDB.transaction or BaseOSDB.session
294
+ if isinstance(v, UnavailableDependency) or k.__func__ not in (
295
+ BaseSQLDB.transaction.__func__,
296
+ BaseOSDB.session.__func__,
297
+ ):
298
+
299
+ continue
300
+
301
+ # The first argument of the overridden BaseSQLDB.transaction is the DB object
302
+ db = v.args[0]
303
+ # We expect the OS DB to be mocked with sqlite, so use the
304
+ # internal DB
305
+ if isinstance(db, MockOSDBMixin):
306
+ db = db._sql_db
307
+
308
+ assert isinstance(db, BaseSQLDB), (k, db)
309
+
310
+ # set PRAGMA foreign_keys=ON if sqlite
311
+ if db.engine.url.drivername.startswith("sqlite"):
312
+
313
+ def set_sqlite_pragma(dbapi_connection, connection_record):
314
+ cursor = dbapi_connection.cursor()
315
+ cursor.execute("PRAGMA foreign_keys=ON")
316
+ cursor.close()
317
+
318
+ sqlalchemy.event.listen(
319
+ db.engine.sync_engine, "connect", set_sqlite_pragma
320
+ )
321
+
322
+ # We maintain a cache of the populated DBs in empty_db_dir so that
323
+ # we don't have to recreate them for every test. This speeds up the
324
+ # tests by a considerable amount.
325
+ ref_db = self._cache_dir / f"{k.__self__.__name__}.db"
326
+ if ref_db.exists():
327
+ async with aiosqlite.connect(ref_db) as ref_conn:
328
+ conn = await db.engine.raw_connection()
329
+ await ref_conn.backup(conn.driver_connection)
330
+ await greenlet_spawn(conn.close)
331
+ else:
332
+ async with db.engine.begin() as conn:
333
+ await conn.run_sync(db.metadata.create_all)
334
+
335
+ async with aiosqlite.connect(ref_db) as ref_conn:
336
+ conn = await db.engine.raw_connection()
337
+ await conn.driver_connection.backup(ref_conn)
338
+ await greenlet_spawn(conn.close)
339
+
340
+ yield
341
+
342
+ @contextlib.contextmanager
343
+ def unauthenticated(self):
344
+ from fastapi.testclient import TestClient
345
+
346
+ with TestClient(self.app) as client:
347
+ yield client
348
+
349
+ @contextlib.contextmanager
350
+ def normal_user(self):
351
+ from diracx.core.properties import NORMAL_USER
352
+ from diracx.routers.auth.token import create_token
353
+
354
+ with self.unauthenticated() as client:
355
+ payload = {
356
+ "sub": "testingVO:yellow-sub",
357
+ "exp": datetime.now(tz=timezone.utc)
358
+ + timedelta(self.test_auth_settings.access_token_expire_minutes),
359
+ "iss": ISSUER,
360
+ "dirac_properties": [NORMAL_USER],
361
+ "jti": str(uuid4()),
362
+ "preferred_username": "preferred_username",
363
+ "dirac_group": "test_group",
364
+ "vo": "lhcb",
365
+ }
366
+ token = create_token(payload, self.test_auth_settings)
367
+
368
+ client.headers["Authorization"] = f"Bearer {token}"
369
+ client.dirac_token_payload = payload
370
+ yield client
371
+
372
+ @contextlib.contextmanager
373
+ def admin_user(self):
374
+ from diracx.core.properties import JOB_ADMINISTRATOR
375
+ from diracx.routers.auth.token import create_token
376
+
377
+ with self.unauthenticated() as client:
378
+ payload = {
379
+ "sub": "testingVO:yellow-sub",
380
+ "iss": ISSUER,
381
+ "dirac_properties": [JOB_ADMINISTRATOR],
382
+ "jti": str(uuid4()),
383
+ "preferred_username": "preferred_username",
384
+ "dirac_group": "test_group",
385
+ "vo": "lhcb",
386
+ }
387
+ token = create_token(payload, self.test_auth_settings)
388
+ client.headers["Authorization"] = f"Bearer {token}"
389
+ client.dirac_token_payload = payload
390
+ yield client
391
+
392
+
393
+ @pytest.fixture(scope="session")
394
+ def session_client_factory(
395
+ test_auth_settings,
396
+ test_sandbox_settings,
397
+ with_config_repo,
398
+ tmp_path_factory,
399
+ test_dev_settings,
400
+ ):
401
+ """TODO.
402
+ ----
403
+
404
+ """
405
+ yield ClientFactory(
406
+ tmp_path_factory,
407
+ with_config_repo,
408
+ test_auth_settings,
409
+ test_sandbox_settings,
410
+ test_dev_settings,
411
+ )
412
+
413
+
414
+ @pytest.fixture
415
+ def client_factory(session_client_factory, request):
416
+ marker = request.node.get_closest_marker("enabled_dependencies")
417
+ if marker is None:
418
+ raise RuntimeError("This test requires the enabled_dependencies marker")
419
+ (enabled_dependencies,) = marker.args
420
+ with session_client_factory.configure(enabled_dependencies=enabled_dependencies):
421
+ yield session_client_factory
422
+
423
+
424
+ @pytest.fixture(scope="session")
425
+ def with_config_repo(tmp_path_factory):
426
+ from git import Repo
427
+
428
+ from diracx.core.config import Config
429
+
430
+ tmp_path = tmp_path_factory.mktemp("cs-repo")
431
+
432
+ repo = Repo.init(tmp_path, initial_branch="master")
433
+ cs_file = tmp_path / "default.yml"
434
+ example_cs = Config.model_validate(
435
+ {
436
+ "DIRAC": {},
437
+ "Registry": {
438
+ "lhcb": {
439
+ "DefaultGroup": "lhcb_user",
440
+ "DefaultProxyLifeTime": 432000,
441
+ "DefaultStorageQuota": 2000,
442
+ "IdP": {
443
+ "URL": "https://idp-server.invalid",
444
+ "ClientID": "test-idp",
445
+ },
446
+ "Users": {
447
+ "b824d4dc-1f9d-4ee8-8df5-c0ae55d46041": {
448
+ "PreferedUsername": "chaen",
449
+ "Email": None,
450
+ },
451
+ "c935e5ed-2g0e-5ff9-9eg6-d1bf66e57152": {
452
+ "PreferedUsername": "albdr",
453
+ "Email": None,
454
+ },
455
+ },
456
+ "Groups": {
457
+ "lhcb_user": {
458
+ "Properties": ["NormalUser", "PrivateLimitedDelegation"],
459
+ "Users": [
460
+ "b824d4dc-1f9d-4ee8-8df5-c0ae55d46041",
461
+ "c935e5ed-2g0e-5ff9-9eg6-d1bf66e57152",
462
+ ],
463
+ },
464
+ "lhcb_prmgr": {
465
+ "Properties": ["NormalUser", "ProductionManagement"],
466
+ "Users": ["b824d4dc-1f9d-4ee8-8df5-c0ae55d46041"],
467
+ },
468
+ "lhcb_tokenmgr": {
469
+ "Properties": ["NormalUser", "ProxyManagement"],
470
+ "Users": ["c935e5ed-2g0e-5ff9-9eg6-d1bf66e57152"],
471
+ },
472
+ },
473
+ }
474
+ },
475
+ "Operations": {"Defaults": {}},
476
+ "Systems": {
477
+ "WorkloadManagement": {
478
+ "Production": {
479
+ "Databases": {
480
+ "JobDB": {
481
+ "DBName": "xyz",
482
+ "Host": "xyz",
483
+ "Port": 9999,
484
+ "MaxRescheduling": 3,
485
+ },
486
+ "JobLoggingDB": {
487
+ "DBName": "xyz",
488
+ "Host": "xyz",
489
+ "Port": 9999,
490
+ },
491
+ "PilotAgentsDB": {
492
+ "DBName": "xyz",
493
+ "Host": "xyz",
494
+ "Port": 9999,
495
+ },
496
+ "SandboxMetadataDB": {
497
+ "DBName": "xyz",
498
+ "Host": "xyz",
499
+ "Port": 9999,
500
+ },
501
+ "TaskQueueDB": {
502
+ "DBName": "xyz",
503
+ "Host": "xyz",
504
+ "Port": 9999,
505
+ },
506
+ "ElasticJobParametersDB": {
507
+ "DBName": "xyz",
508
+ "Host": "xyz",
509
+ "Port": 9999,
510
+ },
511
+ "VirtualMachineDB": {
512
+ "DBName": "xyz",
513
+ "Host": "xyz",
514
+ "Port": 9999,
515
+ },
516
+ },
517
+ },
518
+ },
519
+ },
520
+ }
521
+ )
522
+ cs_file.write_text(example_cs.model_dump_json())
523
+ repo.index.add([cs_file]) # add it to the index
524
+ repo.index.commit("Added a new file")
525
+ yield tmp_path
526
+
527
+
528
+ @pytest.fixture(scope="session")
529
+ def demo_dir(request) -> Path:
530
+ demo_dir = request.config.getoption("--demo-dir")
531
+ if demo_dir is None:
532
+ pytest.skip("Requires a running instance of the DiracX demo")
533
+ demo_dir = (demo_dir / ".demo").resolve()
534
+ yield demo_dir
535
+
536
+
537
+ @pytest.fixture(scope="session")
538
+ def demo_urls(demo_dir):
539
+ import yaml
540
+
541
+ helm_values = yaml.safe_load((demo_dir / "values.yaml").read_text())
542
+ yield helm_values["developer"]["urls"]
543
+
544
+
545
+ @pytest.fixture(scope="session")
546
+ def demo_kubectl_env(demo_dir):
547
+ """Get the dictionary of environment variables for kubectl to control the demo."""
548
+ kube_conf = demo_dir / "kube.conf"
549
+ if not kube_conf.exists():
550
+ raise RuntimeError(f"Could not find {kube_conf}, is the demo running?")
551
+
552
+ env = {
553
+ **os.environ,
554
+ "KUBECONFIG": str(kube_conf),
555
+ "PATH": f"{demo_dir}:{os.environ['PATH']}",
556
+ }
557
+
558
+ # Check that we can run kubectl
559
+ pods_result = subprocess.check_output(
560
+ ["kubectl", "get", "pods"], env=env, text=True
561
+ )
562
+ assert "diracx" in pods_result
563
+
564
+ yield env
565
+
566
+
567
+ @pytest.fixture
568
+ def cli_env(monkeypatch, tmp_path, demo_urls, demo_dir):
569
+ """Set up the environment for the CLI."""
570
+ import httpx
571
+
572
+ from diracx.core.preferences import get_diracx_preferences
573
+
574
+ diracx_url = demo_urls["diracx"]
575
+ ca_path = demo_dir / "demo-ca.pem"
576
+ if not ca_path.exists():
577
+ raise RuntimeError(f"Could not find {ca_path}, is the demo running?")
578
+
579
+ # Ensure the demo is working
580
+
581
+ r = httpx.get(
582
+ f"{diracx_url}/api/openapi.json",
583
+ verify=ssl.create_default_context(cafile=ca_path),
584
+ )
585
+ r.raise_for_status()
586
+ assert r.json()["info"]["title"] == "Dirac"
587
+
588
+ env = {
589
+ "DIRACX_URL": diracx_url,
590
+ "DIRACX_CA_PATH": str(ca_path),
591
+ "HOME": str(tmp_path),
592
+ }
593
+ for key, value in env.items():
594
+ monkeypatch.setenv(key, value)
595
+ yield env
596
+
597
+ # The DiracX preferences are cached however when testing this cache is invalid
598
+ get_diracx_preferences.cache_clear()
599
+
600
+
601
+ @pytest.fixture
602
+ async def with_cli_login(monkeypatch, capfd, cli_env, tmp_path):
603
+ try:
604
+ credentials = await test_login(monkeypatch, capfd, cli_env)
605
+ except Exception as e:
606
+ pytest.skip(f"Login failed, fix test_login to re-enable this test: {e!r}")
607
+
608
+ credentials_path = tmp_path / "credentials.json"
609
+ credentials_path.write_text(credentials)
610
+ monkeypatch.setenv("DIRACX_CREDENTIALS_PATH", str(credentials_path))
611
+ yield
612
+
613
+
614
+ async def test_login(monkeypatch, capfd, cli_env):
615
+ from diracx import cli
616
+
617
+ poll_attempts = 0
618
+
619
+ def fake_sleep(*args, **kwargs):
620
+ nonlocal poll_attempts
621
+
622
+ # Keep track of the number of times this is called
623
+ poll_attempts += 1
624
+
625
+ # After polling 5 times, do the actual login
626
+ if poll_attempts == 5:
627
+ # The login URL should have been printed to stdout
628
+ captured = capfd.readouterr()
629
+ match = re.search(rf"{cli_env['DIRACX_URL']}[^\n]+", captured.out)
630
+ assert match, captured
631
+
632
+ do_device_flow_with_dex(match.group(), cli_env["DIRACX_CA_PATH"])
633
+
634
+ # Ensure we don't poll forever
635
+ assert poll_attempts <= 100
636
+
637
+ # Reduce the sleep duration to zero to speed up the test
638
+ return unpatched_sleep(0)
639
+
640
+ # We monkeypatch asyncio.sleep to provide a hook to run the actions that
641
+ # would normally be done by a user. This includes capturing the login URL
642
+ # and doing the actual device flow with dex.
643
+ unpatched_sleep = asyncio.sleep
644
+
645
+ expected_credentials_path = Path(
646
+ cli_env["HOME"], ".cache", "diracx", "credentials.json"
647
+ )
648
+ # Ensure the credentials file does not exist before logging in
649
+ assert not expected_credentials_path.exists()
650
+
651
+ # Run the login command
652
+ with monkeypatch.context() as m:
653
+ m.setattr("asyncio.sleep", fake_sleep)
654
+ await cli.auth.login(vo="diracAdmin", group=None, property=None)
655
+ captured = capfd.readouterr()
656
+ assert "Login successful!" in captured.out
657
+ assert captured.err == ""
658
+
659
+ # Ensure the credentials file exists after logging in
660
+ assert expected_credentials_path.exists()
661
+
662
+ # Return the credentials so this test can also be used by the
663
+ # "with_cli_login" fixture
664
+ return expected_credentials_path.read_text()
665
+
666
+
667
+ def do_device_flow_with_dex(url: str, ca_path: str) -> None:
668
+ """Do the device flow with dex."""
669
+
670
+ class DexLoginFormParser(HTMLParser):
671
+ def handle_starttag(self, tag, attrs):
672
+ nonlocal action_url
673
+ if "form" in str(tag):
674
+ assert action_url is None
675
+ action_url = urljoin(login_page_url, dict(attrs)["action"])
676
+
677
+ # Get the login page
678
+ r = requests.get(url, verify=ca_path)
679
+ r.raise_for_status()
680
+ login_page_url = r.url # This is not the same as URL as we redirect to dex
681
+ login_page_body = r.text
682
+
683
+ # Search the page for the login form so we know where to post the credentials
684
+ action_url = None
685
+ DexLoginFormParser().feed(login_page_body)
686
+ assert action_url is not None, login_page_body
687
+
688
+ # Do the actual login
689
+ r = requests.post(
690
+ action_url,
691
+ data={"login": "admin@example.com", "password": "password"},
692
+ verify=ca_path,
693
+ )
694
+ r.raise_for_status()
695
+ approval_url = r.url # This is not the same as URL as we redirect to dex
696
+ # Do the actual approval
697
+ r = requests.post(
698
+ approval_url,
699
+ {"approval": "approve", "req": parse_qs(urlparse(r.url).query)["req"][0]},
700
+ verify=ca_path,
701
+ )
702
+
703
+ # This should have redirected to the DiracX page that shows the login is complete
704
+ assert "Please close the window" in r.text