diracx-testing 0.0.1a17__py3-none-any.whl → 0.0.1a18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diracx/testing/__init__.py +63 -28
- diracx/testing/dummy_osdb.py +30 -0
- diracx/testing/mock_osdb.py +163 -0
- diracx/testing/osdb.py +18 -27
- diracx/testing/routers.py +36 -0
- {diracx_testing-0.0.1a17.dist-info → diracx_testing-0.0.1a18.dist-info}/METADATA +3 -3
- diracx_testing-0.0.1a18.dist-info/RECORD +9 -0
- {diracx_testing-0.0.1a17.dist-info → diracx_testing-0.0.1a18.dist-info}/WHEEL +1 -1
- diracx_testing-0.0.1a17.dist-info/RECORD +0 -6
- {diracx_testing-0.0.1a17.dist-info → diracx_testing-0.0.1a18.dist-info}/top_level.txt +0 -0
    
        diracx/testing/__init__.py
    CHANGED
    
    | @@ -8,6 +8,7 @@ import os | |
| 8 8 | 
             
            import re
         | 
| 9 9 | 
             
            import subprocess
         | 
| 10 10 | 
             
            from datetime import datetime, timedelta, timezone
         | 
| 11 | 
            +
            from functools import partial
         | 
| 11 12 | 
             
            from html.parser import HTMLParser
         | 
| 12 13 | 
             
            from pathlib import Path
         | 
| 13 14 | 
             
            from typing import TYPE_CHECKING
         | 
| @@ -18,6 +19,7 @@ import pytest | |
| 18 19 | 
             
            import requests
         | 
| 19 20 |  | 
| 20 21 | 
             
            if TYPE_CHECKING:
         | 
| 22 | 
            +
                from diracx.core.settings import DevelopmentSettings
         | 
| 21 23 | 
             
                from diracx.routers.job_manager.sandboxes import SandboxStoreSettings
         | 
| 22 24 | 
             
                from diracx.routers.utils.users import AuthorizedUserInfo, AuthSettings
         | 
| 23 25 |  | 
| @@ -46,9 +48,7 @@ def pytest_addoption(parser): | |
| 46 48 |  | 
| 47 49 |  | 
| 48 50 | 
             
            def pytest_collection_modifyitems(config, items):
         | 
| 49 | 
            -
                """
         | 
| 50 | 
            -
                Disable the test_regenerate_client if not explicitly asked for
         | 
| 51 | 
            -
                """
         | 
| 51 | 
            +
                """Disable the test_regenerate_client if not explicitly asked for."""
         | 
| 52 52 | 
             
                if config.getoption("--regenerate-client"):
         | 
| 53 53 | 
             
                    # --regenerate-client given in cli: allow client re-generation
         | 
| 54 54 | 
             
                    return
         | 
| @@ -59,11 +59,11 @@ def pytest_collection_modifyitems(config, items): | |
| 59 59 |  | 
| 60 60 |  | 
| 61 61 | 
             
            @pytest.fixture(scope="session")
         | 
| 62 | 
            -
            def  | 
| 62 | 
            +
            def private_key_pem() -> str:
         | 
| 63 63 | 
             
                from cryptography.hazmat.primitives import serialization
         | 
| 64 | 
            -
                from cryptography.hazmat.primitives.asymmetric import  | 
| 64 | 
            +
                from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey
         | 
| 65 65 |  | 
| 66 | 
            -
                private_key =  | 
| 66 | 
            +
                private_key = Ed25519PrivateKey.generate()
         | 
| 67 67 | 
             
                return private_key.private_bytes(
         | 
| 68 68 | 
             
                    encoding=serialization.Encoding.PEM,
         | 
| 69 69 | 
             
                    format=serialization.PrivateFormat.PKCS8,
         | 
| @@ -79,11 +79,19 @@ def fernet_key() -> str: | |
| 79 79 |  | 
| 80 80 |  | 
| 81 81 | 
             
            @pytest.fixture(scope="session")
         | 
| 82 | 
            -
            def  | 
| 82 | 
            +
            def test_dev_settings() -> DevelopmentSettings:
         | 
| 83 | 
            +
                from diracx.core.settings import DevelopmentSettings
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                yield DevelopmentSettings()
         | 
| 86 | 
            +
             | 
| 87 | 
            +
             | 
| 88 | 
            +
            @pytest.fixture(scope="session")
         | 
| 89 | 
            +
            def test_auth_settings(private_key_pem, fernet_key) -> AuthSettings:
         | 
| 83 90 | 
             
                from diracx.routers.utils.users import AuthSettings
         | 
| 84 91 |  | 
| 85 92 | 
             
                yield AuthSettings(
         | 
| 86 | 
            -
                     | 
| 93 | 
            +
                    token_algorithm="EdDSA",
         | 
| 94 | 
            +
                    token_key=private_key_pem,
         | 
| 87 95 | 
             
                    state_key=fernet_key,
         | 
| 88 96 | 
             
                    allowed_redirects=[
         | 
| 89 97 | 
             
                        "http://diracx.test.invalid:8000/api/docs/oauth2-redirect",
         | 
| @@ -93,7 +101,7 @@ def test_auth_settings(rsa_private_key_pem, fernet_key) -> AuthSettings: | |
| 93 101 |  | 
| 94 102 | 
             
            @pytest.fixture(scope="session")
         | 
| 95 103 | 
             
            def aio_moto(worker_id):
         | 
| 96 | 
            -
                """Start the moto server in a separate thread and return the base URL
         | 
| 104 | 
            +
                """Start the moto server in a separate thread and return the base URL.
         | 
| 97 105 |  | 
| 98 106 | 
             
                The mocking provided by moto doesn't play nicely with aiobotocore so we use
         | 
| 99 107 | 
             
                the server directly. See https://github.com/aio-libs/aiobotocore/issues/755
         | 
| @@ -142,18 +150,20 @@ class ClientFactory: | |
| 142 150 | 
             
                    with_config_repo,
         | 
| 143 151 | 
             
                    test_auth_settings,
         | 
| 144 152 | 
             
                    test_sandbox_settings,
         | 
| 153 | 
            +
                    test_dev_settings,
         | 
| 145 154 | 
             
                ):
         | 
| 146 155 | 
             
                    from diracx.core.config import ConfigSource
         | 
| 147 156 | 
             
                    from diracx.core.extensions import select_from_extension
         | 
| 148 157 | 
             
                    from diracx.core.settings import ServiceSettingsBase
         | 
| 158 | 
            +
                    from diracx.db.os.utils import BaseOSDB
         | 
| 149 159 | 
             
                    from diracx.db.sql.utils import BaseSQLDB
         | 
| 150 160 | 
             
                    from diracx.routers import create_app_inner
         | 
| 151 161 | 
             
                    from diracx.routers.access_policies import BaseAccessPolicy
         | 
| 152 162 |  | 
| 163 | 
            +
                    from .mock_osdb import fake_available_osdb_implementations
         | 
| 164 | 
            +
             | 
| 153 165 | 
             
                    class AlwaysAllowAccessPolicy(BaseAccessPolicy):
         | 
| 154 | 
            -
                        """
         | 
| 155 | 
            -
                        Dummy access policy
         | 
| 156 | 
            -
                        """
         | 
| 166 | 
            +
                        """Dummy access policy."""
         | 
| 157 167 |  | 
| 158 168 | 
             
                        async def policy(
         | 
| 159 169 | 
             
                            policy_name: str, user_info: AuthorizedUserInfo, /, **kwargs
         | 
| @@ -171,9 +181,21 @@ class ClientFactory: | |
| 171 181 | 
             
                        e.name: "sqlite+aiosqlite:///:memory:"
         | 
| 172 182 | 
             
                        for e in select_from_extension(group="diracx.db.sql")
         | 
| 173 183 | 
             
                    }
         | 
| 184 | 
            +
                    # TODO: Monkeypatch this in a less stupid way
         | 
| 185 | 
            +
                    # TODO: Only use this if opensearch isn't available
         | 
| 186 | 
            +
                    os_database_conn_kwargs = {
         | 
| 187 | 
            +
                        e.name: {"sqlalchemy_dsn": "sqlite+aiosqlite:///:memory:"}
         | 
| 188 | 
            +
                        for e in select_from_extension(group="diracx.db.os")
         | 
| 189 | 
            +
                    }
         | 
| 190 | 
            +
                    BaseOSDB.available_implementations = partial(
         | 
| 191 | 
            +
                        fake_available_osdb_implementations,
         | 
| 192 | 
            +
                        real_available_implementations=BaseOSDB.available_implementations,
         | 
| 193 | 
            +
                    )
         | 
| 194 | 
            +
             | 
| 174 195 | 
             
                    self._cache_dir = tmp_path_factory.mktemp("empty-dbs")
         | 
| 175 196 |  | 
| 176 197 | 
             
                    self.test_auth_settings = test_auth_settings
         | 
| 198 | 
            +
                    self.test_dev_settings = test_dev_settings
         | 
| 177 199 |  | 
| 178 200 | 
             
                    all_access_policies = {
         | 
| 179 201 | 
             
                        e.name: [AlwaysAllowAccessPolicy]
         | 
| @@ -186,11 +208,10 @@ class ClientFactory: | |
| 186 208 | 
             
                        all_service_settings=[
         | 
| 187 209 | 
             
                            test_auth_settings,
         | 
| 188 210 | 
             
                            test_sandbox_settings,
         | 
| 211 | 
            +
                            test_dev_settings,
         | 
| 189 212 | 
             
                        ],
         | 
| 190 213 | 
             
                        database_urls=database_urls,
         | 
| 191 | 
            -
                        os_database_conn_kwargs= | 
| 192 | 
            -
                            # TODO: JobParametersDB
         | 
| 193 | 
            -
                        },
         | 
| 214 | 
            +
                        os_database_conn_kwargs=os_database_conn_kwargs,
         | 
| 194 215 | 
             
                        config_source=ConfigSource.create_from_url(
         | 
| 195 216 | 
             
                            backend_url=f"git+file://{with_config_repo}"
         | 
| 196 217 | 
             
                        ),
         | 
| @@ -202,14 +223,20 @@ class ClientFactory: | |
| 202 223 | 
             
                    for obj in self.all_dependency_overrides:
         | 
| 203 224 | 
             
                        assert issubclass(
         | 
| 204 225 | 
             
                            obj.__self__,
         | 
| 205 | 
            -
                            ( | 
| 226 | 
            +
                            (
         | 
| 227 | 
            +
                                ServiceSettingsBase,
         | 
| 228 | 
            +
                                BaseSQLDB,
         | 
| 229 | 
            +
                                BaseOSDB,
         | 
| 230 | 
            +
                                ConfigSource,
         | 
| 231 | 
            +
                                BaseAccessPolicy,
         | 
| 232 | 
            +
                            ),
         | 
| 206 233 | 
             
                        ), obj
         | 
| 207 234 |  | 
| 208 235 | 
             
                    self.all_lifetime_functions = self.app.lifetime_functions[:]
         | 
| 209 236 | 
             
                    self.app.lifetime_functions = []
         | 
| 210 237 | 
             
                    for obj in self.all_lifetime_functions:
         | 
| 211 238 | 
             
                        assert isinstance(
         | 
| 212 | 
            -
                            obj.__self__, (ServiceSettingsBase, BaseSQLDB, ConfigSource)
         | 
| 239 | 
            +
                            obj.__self__, (ServiceSettingsBase, BaseSQLDB, BaseOSDB, ConfigSource)
         | 
| 213 240 | 
             
                        ), obj
         | 
| 214 241 |  | 
| 215 242 | 
             
                @contextlib.contextmanager
         | 
| @@ -227,6 +254,7 @@ class ClientFactory: | |
| 227 254 | 
             
                            self.app.dependency_overrides[k] = UnavailableDependency(class_name)
         | 
| 228 255 |  | 
| 229 256 | 
             
                    for obj in self.all_lifetime_functions:
         | 
| 257 | 
            +
                        # TODO: We should use the name of the entry point instead of the class name
         | 
| 230 258 | 
             
                        if obj.__self__.__class__.__name__ in enabled_dependencies:
         | 
| 231 259 | 
             
                            self.app.lifetime_functions.append(obj)
         | 
| 232 260 |  | 
| @@ -235,14 +263,15 @@ class ClientFactory: | |
| 235 263 | 
             
                    # already been ran
         | 
| 236 264 | 
             
                    self.app.lifetime_functions.append(self.create_db_schemas)
         | 
| 237 265 |  | 
| 238 | 
            -
                     | 
| 239 | 
            -
             | 
| 240 | 
            -
                     | 
| 241 | 
            -
             | 
| 266 | 
            +
                    try:
         | 
| 267 | 
            +
                        yield
         | 
| 268 | 
            +
                    finally:
         | 
| 269 | 
            +
                        self.app.dependency_overrides = {}
         | 
| 270 | 
            +
                        self.app.lifetime_functions = []
         | 
| 242 271 |  | 
| 243 272 | 
             
                @contextlib.asynccontextmanager
         | 
| 244 273 | 
             
                async def create_db_schemas(self):
         | 
| 245 | 
            -
                    """Create DB schema's based on the DBs available in app.dependency_overrides"""
         | 
| 274 | 
            +
                    """Create DB schema's based on the DBs available in app.dependency_overrides."""
         | 
| 246 275 | 
             
                    import aiosqlite
         | 
| 247 276 | 
             
                    import sqlalchemy
         | 
| 248 277 | 
             
                    from sqlalchemy.util.concurrency import greenlet_spawn
         | 
| @@ -349,12 +378,18 @@ def session_client_factory( | |
| 349 378 | 
             
                test_sandbox_settings,
         | 
| 350 379 | 
             
                with_config_repo,
         | 
| 351 380 | 
             
                tmp_path_factory,
         | 
| 381 | 
            +
                test_dev_settings,
         | 
| 352 382 | 
             
            ):
         | 
| 353 | 
            -
                """
         | 
| 354 | 
            -
                 | 
| 383 | 
            +
                """TODO.
         | 
| 384 | 
            +
                ----
         | 
| 385 | 
            +
             | 
| 355 386 | 
             
                """
         | 
| 356 387 | 
             
                yield ClientFactory(
         | 
| 357 | 
            -
                    tmp_path_factory, | 
| 388 | 
            +
                    tmp_path_factory,
         | 
| 389 | 
            +
                    with_config_repo,
         | 
| 390 | 
            +
                    test_auth_settings,
         | 
| 391 | 
            +
                    test_sandbox_settings,
         | 
| 392 | 
            +
                    test_dev_settings,
         | 
| 358 393 | 
             
                )
         | 
| 359 394 |  | 
| 360 395 |  | 
| @@ -443,7 +478,7 @@ def demo_urls(demo_dir): | |
| 443 478 |  | 
| 444 479 | 
             
            @pytest.fixture(scope="session")
         | 
| 445 480 | 
             
            def demo_kubectl_env(demo_dir):
         | 
| 446 | 
            -
                """Get the dictionary of environment variables for kubectl to control the demo"""
         | 
| 481 | 
            +
                """Get the dictionary of environment variables for kubectl to control the demo."""
         | 
| 447 482 | 
             
                kube_conf = demo_dir / "kube.conf"
         | 
| 448 483 | 
             
                if not kube_conf.exists():
         | 
| 449 484 | 
             
                    raise RuntimeError(f"Could not find {kube_conf}, is the demo running?")
         | 
| @@ -465,7 +500,7 @@ def demo_kubectl_env(demo_dir): | |
| 465 500 |  | 
| 466 501 | 
             
            @pytest.fixture
         | 
| 467 502 | 
             
            def cli_env(monkeypatch, tmp_path, demo_urls, demo_dir):
         | 
| 468 | 
            -
                """Set up the environment for the CLI"""
         | 
| 503 | 
            +
                """Set up the environment for the CLI."""
         | 
| 469 504 | 
             
                import httpx
         | 
| 470 505 |  | 
| 471 506 | 
             
                from diracx.core.preferences import get_diracx_preferences
         | 
| @@ -561,7 +596,7 @@ async def test_login(monkeypatch, capfd, cli_env): | |
| 561 596 |  | 
| 562 597 |  | 
| 563 598 | 
             
            def do_device_flow_with_dex(url: str, ca_path: str) -> None:
         | 
| 564 | 
            -
                """Do the device flow with dex"""
         | 
| 599 | 
            +
                """Do the device flow with dex."""
         | 
| 565 600 |  | 
| 566 601 | 
             
                class DexLoginFormParser(HTMLParser):
         | 
| 567 602 | 
             
                    def handle_starttag(self, tag, attrs):
         | 
| @@ -0,0 +1,30 @@ | |
| 1 | 
            +
            from __future__ import annotations
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import secrets
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            from diracx.db.os.utils import BaseOSDB
         | 
| 6 | 
            +
             | 
| 7 | 
            +
             | 
| 8 | 
            +
            class DummyOSDB(BaseOSDB):
         | 
| 9 | 
            +
                """Example DiracX OpenSearch database class for testing.
         | 
| 10 | 
            +
             | 
| 11 | 
            +
                A new random prefix is created each time the class is defined to ensure
         | 
| 12 | 
            +
                test runs are independent of each other.
         | 
| 13 | 
            +
                """
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                fields = {
         | 
| 16 | 
            +
                    "DateField": {"type": "date"},
         | 
| 17 | 
            +
                    "IntField": {"type": "long"},
         | 
| 18 | 
            +
                    "KeywordField0": {"type": "keyword"},
         | 
| 19 | 
            +
                    "KeywordField1": {"type": "keyword"},
         | 
| 20 | 
            +
                    "KeywordField2": {"type": "keyword"},
         | 
| 21 | 
            +
                    "TextField": {"type": "text"},
         | 
| 22 | 
            +
                }
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                def __init__(self, *args, **kwargs):
         | 
| 25 | 
            +
                    # Randomize the index prefix to ensure tests are independent
         | 
| 26 | 
            +
                    self.index_prefix = f"dummy_{secrets.token_hex(8)}"
         | 
| 27 | 
            +
                    super().__init__(*args, **kwargs)
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                def index_name(self, doc_id: int) -> str:
         | 
| 30 | 
            +
                    return f"{self.index_prefix}-{doc_id // 1e6:.0f}m"
         | 
| @@ -0,0 +1,163 @@ | |
| 1 | 
            +
            from __future__ import annotations
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            __all__ = (
         | 
| 4 | 
            +
                "MockOSDBMixin",
         | 
| 5 | 
            +
                "fake_available_osdb_implementations",
         | 
| 6 | 
            +
            )
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            import contextlib
         | 
| 9 | 
            +
            from datetime import datetime, timezone
         | 
| 10 | 
            +
            from functools import partial
         | 
| 11 | 
            +
            from typing import Any, AsyncIterator
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            from sqlalchemy import select
         | 
| 14 | 
            +
            from sqlalchemy.dialects.sqlite import insert as sqlite_insert
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            from diracx.core.models import SearchSpec, SortSpec
         | 
| 17 | 
            +
            from diracx.db.sql import utils as sql_utils
         | 
| 18 | 
            +
             | 
| 19 | 
            +
             | 
| 20 | 
            +
            class MockOSDBMixin:
         | 
| 21 | 
            +
                """A subclass of DummyOSDB that hacks it to use sqlite as a backed.
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                This is only intended for testing and development purposes to avoid the
         | 
| 24 | 
            +
                need to run a full OpenSearch instance. This class is used by defining a
         | 
| 25 | 
            +
                new class that inherits from this mixin as well the real DB class, i.e.
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                .. code-block:: python
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                    class JobParametersDB(MockOSDBMixin, JobParametersDB):
         | 
| 30 | 
            +
                        pass
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                or
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                .. code-block:: python
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                    JobParametersDB = type("JobParametersDB", (MockOSDBMixin, JobParametersDB), {})
         | 
| 37 | 
            +
                """
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                def __init__(self, connection_kwargs: dict[str, Any]) -> None:
         | 
| 40 | 
            +
                    from sqlalchemy import JSON, Column, Integer, MetaData, String, Table
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                    from diracx.db.sql.utils import DateNowColumn
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                    # Dynamically create a subclass of BaseSQLDB so we get clearer errors
         | 
| 45 | 
            +
                    MockedDB = type(f"Mocked{self.__class__.__name__}", (sql_utils.BaseSQLDB,), {})
         | 
| 46 | 
            +
                    self._sql_db = MockedDB(connection_kwargs["sqlalchemy_dsn"])
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                    # Dynamically create the table definition based on the fields
         | 
| 49 | 
            +
                    columns = [
         | 
| 50 | 
            +
                        Column("doc_id", Integer, primary_key=True),
         | 
| 51 | 
            +
                        Column("extra", JSON, default={}, nullable=False),
         | 
| 52 | 
            +
                    ]
         | 
| 53 | 
            +
                    for field, field_type in self.fields.items():
         | 
| 54 | 
            +
                        match field_type["type"]:
         | 
| 55 | 
            +
                            case "date":
         | 
| 56 | 
            +
                                ColumnType = DateNowColumn
         | 
| 57 | 
            +
                            case "long":
         | 
| 58 | 
            +
                                ColumnType = partial(Column, type_=Integer)
         | 
| 59 | 
            +
                            case "keyword":
         | 
| 60 | 
            +
                                ColumnType = partial(Column, type_=String(255))
         | 
| 61 | 
            +
                            case "text":
         | 
| 62 | 
            +
                                ColumnType = partial(Column, type_=String(64 * 1024))
         | 
| 63 | 
            +
                            case _:
         | 
| 64 | 
            +
                                raise NotImplementedError(f"Unknown field type: {field_type=}")
         | 
| 65 | 
            +
                        columns.append(ColumnType(field, default=None))
         | 
| 66 | 
            +
                    self._sql_db.metadata = MetaData()
         | 
| 67 | 
            +
                    self._table = Table("dummy", self._sql_db.metadata, *columns)
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                @contextlib.asynccontextmanager
         | 
| 70 | 
            +
                async def client_context(self) -> AsyncIterator[None]:
         | 
| 71 | 
            +
                    async with self._sql_db.engine_context():
         | 
| 72 | 
            +
                        yield
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                async def __aenter__(self):
         | 
| 75 | 
            +
                    await self._sql_db.__aenter__()
         | 
| 76 | 
            +
                    return self
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                async def __aexit__(self, exc_type, exc_value, traceback):
         | 
| 79 | 
            +
                    await self._sql_db.__aexit__(exc_type, exc_value, traceback)
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                async def create_index_template(self) -> None:
         | 
| 82 | 
            +
                    async with self._sql_db.engine.begin() as conn:
         | 
| 83 | 
            +
                        await conn.run_sync(self._sql_db.metadata.create_all)
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                async def upsert(self, doc_id, document) -> None:
         | 
| 86 | 
            +
                    async with self:
         | 
| 87 | 
            +
                        values = {}
         | 
| 88 | 
            +
                        for key, value in document.items():
         | 
| 89 | 
            +
                            if key in self.fields:
         | 
| 90 | 
            +
                                values[key] = value
         | 
| 91 | 
            +
                            else:
         | 
| 92 | 
            +
                                values.setdefault("extra", {})[key] = value
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                        stmt = sqlite_insert(self._table).values(doc_id=doc_id, **values)
         | 
| 95 | 
            +
                        # TODO: Upsert the JSON blob properly
         | 
| 96 | 
            +
                        stmt = stmt.on_conflict_do_update(index_elements=["doc_id"], set_=values)
         | 
| 97 | 
            +
                        await self._sql_db.conn.execute(stmt)
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                async def search(
         | 
| 100 | 
            +
                    self,
         | 
| 101 | 
            +
                    parameters: list[str] | None,
         | 
| 102 | 
            +
                    search: list[SearchSpec],
         | 
| 103 | 
            +
                    sorts: list[SortSpec],
         | 
| 104 | 
            +
                    *,
         | 
| 105 | 
            +
                    distinct: bool = False,
         | 
| 106 | 
            +
                    per_page: int = 100,
         | 
| 107 | 
            +
                    page: int | None = None,
         | 
| 108 | 
            +
                ) -> tuple[int, list[dict[Any, Any]]]:
         | 
| 109 | 
            +
                    async with self:
         | 
| 110 | 
            +
                        # Apply selection
         | 
| 111 | 
            +
                        if parameters:
         | 
| 112 | 
            +
                            columns = []
         | 
| 113 | 
            +
                            for p in parameters:
         | 
| 114 | 
            +
                                if p in self.fields:
         | 
| 115 | 
            +
                                    columns.append(self._table.columns[p])
         | 
| 116 | 
            +
                                else:
         | 
| 117 | 
            +
                                    columns.append(self._table.columns["extra"][p].label(p))
         | 
| 118 | 
            +
                        else:
         | 
| 119 | 
            +
                            columns = self._table.columns
         | 
| 120 | 
            +
                        stmt = select(*columns)
         | 
| 121 | 
            +
                        if distinct:
         | 
| 122 | 
            +
                            stmt = stmt.distinct()
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                        # Apply filtering
         | 
| 125 | 
            +
                        stmt = sql_utils.apply_search_filters(
         | 
| 126 | 
            +
                            self._table.columns.__getitem__, stmt, search
         | 
| 127 | 
            +
                        )
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                        # Apply sorting
         | 
| 130 | 
            +
                        stmt = sql_utils.apply_sort_constraints(
         | 
| 131 | 
            +
                            self._table.columns.__getitem__, stmt, sorts
         | 
| 132 | 
            +
                        )
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                        # Apply pagination
         | 
| 135 | 
            +
                        if page is not None:
         | 
| 136 | 
            +
                            stmt = stmt.offset((page - 1) * per_page).limit(per_page)
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                        results = []
         | 
| 139 | 
            +
                        async for row in await self._sql_db.conn.stream(stmt):
         | 
| 140 | 
            +
                            result = dict(row._mapping)
         | 
| 141 | 
            +
                            result.pop("doc_id", None)
         | 
| 142 | 
            +
                            if "extra" in result:
         | 
| 143 | 
            +
                                result.update(result.pop("extra"))
         | 
| 144 | 
            +
                            for k, v in list(result.items()):
         | 
| 145 | 
            +
                                if isinstance(v, datetime) and v.tzinfo is None:
         | 
| 146 | 
            +
                                    result[k] = v.replace(tzinfo=timezone.utc)
         | 
| 147 | 
            +
                                if v is None:
         | 
| 148 | 
            +
                                    result.pop(k)
         | 
| 149 | 
            +
                            results.append(result)
         | 
| 150 | 
            +
                    return results
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                async def ping(self):
         | 
| 153 | 
            +
                    return await self._sql_db.ping()
         | 
| 154 | 
            +
             | 
| 155 | 
            +
             | 
| 156 | 
            +
            def fake_available_osdb_implementations(name, *, real_available_implementations):
         | 
| 157 | 
            +
                implementations = real_available_implementations(name)
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                # Dynamically generate a class that inherits from the first implementation
         | 
| 160 | 
            +
                # but that also has the MockOSDBMixin
         | 
| 161 | 
            +
                MockParameterDB = type(name, (MockOSDBMixin, implementations[0]), {})
         | 
| 162 | 
            +
             | 
| 163 | 
            +
                return [MockParameterDB] + implementations
         | 
    
        diracx/testing/osdb.py
    CHANGED
    
    | @@ -1,12 +1,12 @@ | |
| 1 1 | 
             
            from __future__ import annotations
         | 
| 2 2 |  | 
| 3 | 
            -
            import secrets
         | 
| 4 3 | 
             
            import socket
         | 
| 5 4 | 
             
            from subprocess import PIPE, Popen, check_output
         | 
| 6 5 |  | 
| 7 6 | 
             
            import pytest
         | 
| 8 7 |  | 
| 9 | 
            -
            from  | 
| 8 | 
            +
            from .dummy_osdb import DummyOSDB
         | 
| 9 | 
            +
            from .mock_osdb import MockOSDBMixin
         | 
| 10 10 |  | 
| 11 11 | 
             
            OPENSEARCH_PORT = 28000
         | 
| 12 12 |  | 
| @@ -18,31 +18,6 @@ def require_port_availability(port: int) -> bool: | |
| 18 18 | 
             
                        raise RuntimeError(f"This test requires port {port} to be available")
         | 
| 19 19 |  | 
| 20 20 |  | 
| 21 | 
            -
            class DummyOSDB(BaseOSDB):
         | 
| 22 | 
            -
                """Example DiracX OpenSearch database class for testing.
         | 
| 23 | 
            -
             | 
| 24 | 
            -
                A new random prefix is created each time the class is defined to ensure
         | 
| 25 | 
            -
                test runs are independent of each other.
         | 
| 26 | 
            -
                """
         | 
| 27 | 
            -
             | 
| 28 | 
            -
                fields = {
         | 
| 29 | 
            -
                    "DateField": {"type": "date"},
         | 
| 30 | 
            -
                    "IntField": {"type": "long"},
         | 
| 31 | 
            -
                    "KeywordField0": {"type": "keyword"},
         | 
| 32 | 
            -
                    "KeywordField1": {"type": "keyword"},
         | 
| 33 | 
            -
                    "KeywordField2": {"type": "keyword"},
         | 
| 34 | 
            -
                    "TextField": {"type": "text"},
         | 
| 35 | 
            -
                }
         | 
| 36 | 
            -
             | 
| 37 | 
            -
                def __init__(self, *args, **kwargs):
         | 
| 38 | 
            -
                    # Randomize the index prefix to ensure tests are independent
         | 
| 39 | 
            -
                    self.index_prefix = f"dummy_{secrets.token_hex(8)}"
         | 
| 40 | 
            -
                    super().__init__(*args, **kwargs)
         | 
| 41 | 
            -
             | 
| 42 | 
            -
                def index_name(self, doc_id: int) -> str:
         | 
| 43 | 
            -
                    return f"{self.index_prefix}-{doc_id // 1e6:.0f}m"
         | 
| 44 | 
            -
             | 
| 45 | 
            -
             | 
| 46 21 | 
             
            @pytest.fixture(scope="session")
         | 
| 47 22 | 
             
            def opensearch_conn_kwargs(demo_kubectl_env):
         | 
| 48 23 | 
             
                """Fixture to get the OpenSearch connection kwargs.
         | 
| @@ -108,3 +83,19 @@ async def dummy_opensearch_db(dummy_opensearch_db_without_template): | |
| 108 83 | 
             
                await db.create_index_template()
         | 
| 109 84 | 
             
                yield db
         | 
| 110 85 | 
             
                await db.client.indices.delete_index_template(name=db.index_prefix)
         | 
| 86 | 
            +
             | 
| 87 | 
            +
             | 
| 88 | 
            +
            @pytest.fixture
         | 
| 89 | 
            +
            async def sql_opensearch_db():
         | 
| 90 | 
            +
                """Fixture which returns a SQLOSDB object."""
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                class MockDummyOSDB(MockOSDBMixin, DummyOSDB):
         | 
| 93 | 
            +
                    pass
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                db = MockDummyOSDB(
         | 
| 96 | 
            +
                    connection_kwargs={"sqlalchemy_dsn": "sqlite+aiosqlite:///:memory:"}
         | 
| 97 | 
            +
                )
         | 
| 98 | 
            +
                async with db.client_context():
         | 
| 99 | 
            +
                    await db.create_index_template()
         | 
| 100 | 
            +
                    yield db
         | 
| 101 | 
            +
                    # No need to cleanup as this uses an in-memory sqlite database
         | 
| @@ -0,0 +1,36 @@ | |
| 1 | 
            +
            from __future__ import annotations
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            from contextlib import asynccontextmanager
         | 
| 4 | 
            +
            from functools import partial
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            from diracx.testing.mock_osdb import fake_available_osdb_implementations
         | 
| 7 | 
            +
             | 
| 8 | 
            +
             | 
| 9 | 
            +
            @asynccontextmanager
         | 
| 10 | 
            +
            async def ensure_dbs_exist():
         | 
| 11 | 
            +
                from diracx.db.__main__ import init_os, init_sql
         | 
| 12 | 
            +
             | 
| 13 | 
            +
                await init_sql()
         | 
| 14 | 
            +
                await init_os()
         | 
| 15 | 
            +
                yield
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            def create_app():
         | 
| 19 | 
            +
                """Create a FastAPI application for testing purposes.
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                This is a wrapper around diracx.routers.create_app that:
         | 
| 22 | 
            +
                 * adds a lifetime function to ensure the DB schemas are initialized
         | 
| 23 | 
            +
                 * replaces the parameter DBs with sqlite-backed versions
         | 
| 24 | 
            +
                """
         | 
| 25 | 
            +
                from diracx.db.os.utils import BaseOSDB
         | 
| 26 | 
            +
                from diracx.routers import create_app
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                BaseOSDB.available_implementations = partial(
         | 
| 29 | 
            +
                    fake_available_osdb_implementations,
         | 
| 30 | 
            +
                    real_available_implementations=BaseOSDB.available_implementations,
         | 
| 31 | 
            +
                )
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                app = create_app()
         | 
| 34 | 
            +
                app.lifetime_functions.append(ensure_dbs_exist)
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                return app
         | 
| @@ -1,6 +1,6 @@ | |
| 1 1 | 
             
            Metadata-Version: 2.1
         | 
| 2 2 | 
             
            Name: diracx-testing
         | 
| 3 | 
            -
            Version: 0.0. | 
| 3 | 
            +
            Version: 0.0.1a18
         | 
| 4 4 | 
             
            Summary: TODO
         | 
| 5 5 | 
             
            License: GPL-3.0-only
         | 
| 6 6 | 
             
            Classifier: Intended Audience :: Science/Research
         | 
| @@ -8,12 +8,12 @@ Classifier: License :: OSI Approved :: GNU General Public License v3 (GPLv3) | |
| 8 8 | 
             
            Classifier: Programming Language :: Python :: 3
         | 
| 9 9 | 
             
            Classifier: Topic :: Scientific/Engineering
         | 
| 10 10 | 
             
            Classifier: Topic :: System :: Distributed Computing
         | 
| 11 | 
            -
            Requires-Python: >=3. | 
| 11 | 
            +
            Requires-Python: >=3.11
         | 
| 12 12 | 
             
            Description-Content-Type: text/markdown
         | 
| 13 13 | 
             
            Requires-Dist: pytest
         | 
| 14 14 | 
             
            Requires-Dist: pytest-asyncio
         | 
| 15 15 | 
             
            Requires-Dist: pytest-cov
         | 
| 16 16 | 
             
            Requires-Dist: pytest-xdist
         | 
| 17 17 | 
             
            Provides-Extra: testing
         | 
| 18 | 
            -
            Requires-Dist: diracx-testing | 
| 18 | 
            +
            Requires-Dist: diracx-testing; extra == "testing"
         | 
| 19 19 |  | 
| @@ -0,0 +1,9 @@ | |
| 1 | 
            +
            diracx/testing/__init__.py,sha256=nC59kW0jD8na0LFf73m7xmRB5uBP0cuvO4N2qGPmRyM,21420
         | 
| 2 | 
            +
            diracx/testing/dummy_osdb.py,sha256=bNk3LF8KgMuQx3RVFNYuw4hMmpG2A80sZ58rEZqHo7M,907
         | 
| 3 | 
            +
            diracx/testing/mock_osdb.py,sha256=1TFb3b0xDb2vIy4Q4V23VtrsWoT3RE5kOZmOs8n541g,5862
         | 
| 4 | 
            +
            diracx/testing/osdb.py,sha256=m6mUBLnGOoQLTCIBie9P2GhmLMybrgzIrlIYfhF1_Ss,3230
         | 
| 5 | 
            +
            diracx/testing/routers.py,sha256=UW-TnikMQgcNxF5sUZD5DWoucGiCpP6s8mYmuahDiSc,979
         | 
| 6 | 
            +
            diracx_testing-0.0.1a18.dist-info/METADATA,sha256=C5zZ-VjMY64BqNPZArdorlO8tGAIaWwi5L46RfpmrhU,614
         | 
| 7 | 
            +
            diracx_testing-0.0.1a18.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
         | 
| 8 | 
            +
            diracx_testing-0.0.1a18.dist-info/top_level.txt,sha256=vJx10tdRlBX3rF2Psgk5jlwVGZNcL3m_7iQWwgPXt-U,7
         | 
| 9 | 
            +
            diracx_testing-0.0.1a18.dist-info/RECORD,,
         | 
| @@ -1,6 +0,0 @@ | |
| 1 | 
            -
            diracx/testing/__init__.py,sha256=JFoMvc7VwOImm0CuGiJU9Lb-xwKxNx2WZdV2FuLu1N4,20145
         | 
| 2 | 
            -
            diracx/testing/osdb.py,sha256=-EFZNyEY07Zq7HdQGZxS3H808Y94aaUhmo0x-Y8xo3Q,3592
         | 
| 3 | 
            -
            diracx_testing-0.0.1a17.dist-info/METADATA,sha256=uesAYJrA2K61yXPBgUran91pIbmlBOK4gyTVKbIlung,615
         | 
| 4 | 
            -
            diracx_testing-0.0.1a17.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
         | 
| 5 | 
            -
            diracx_testing-0.0.1a17.dist-info/top_level.txt,sha256=vJx10tdRlBX3rF2Psgk5jlwVGZNcL3m_7iQWwgPXt-U,7
         | 
| 6 | 
            -
            diracx_testing-0.0.1a17.dist-info/RECORD,,
         | 
| 
            File without changes
         |