kernia-test-utils 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kernia_test_utils/__init__.py +56 -0
- kernia_test_utils/adapter_fixtures.py +143 -0
- kernia_test_utils/asgi_driver.py +107 -0
- kernia_test_utils/containers.py +124 -0
- kernia_test_utils/mock_idp.py +191 -0
- kernia_test_utils/mock_saml_idp.py +268 -0
- kernia_test_utils/mock_sms.py +41 -0
- kernia_test_utils/mock_smtp.py +60 -0
- kernia_test_utils/mock_stripe.py +637 -0
- kernia_test_utils/py.typed +0 -0
- kernia_test_utils/soft_webauthn.py +210 -0
- kernia_test_utils-0.1.0.dist-info/METADATA +72 -0
- kernia_test_utils-0.1.0.dist-info/RECORD +15 -0
- kernia_test_utils-0.1.0.dist-info/WHEEL +4 -0
- kernia_test_utils-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
"""Shared test fixtures.
|
|
2
|
+
|
|
3
|
+
Mirrors `reference/packages/test-utils/`. Exposes:
|
|
4
|
+
|
|
5
|
+
* `ASGIDriver` / `ASGIResponse` — call an ASGI app like a client.
|
|
6
|
+
* `MockIdP` — in-process OIDC IdP with signed id_tokens.
|
|
7
|
+
* `MockSMTP` / `SentEmail` — capture outgoing emails.
|
|
8
|
+
* `MockSMS` / `SentSMS` — capture outgoing SMS.
|
|
9
|
+
* `MockStripe` — Stripe REST mock + signed-webhook helper.
|
|
10
|
+
* `MockSAMLIdP` — minimal signed SAML 2.0 IdP fixture.
|
|
11
|
+
* Container helpers — lazy testcontainers fixtures behind `requires_docker`.
|
|
12
|
+
* `all_adapters_param` — pytest parametrize value covering every backend.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from kernia_test_utils.adapter_fixtures import (
|
|
16
|
+
AdapterFactory,
|
|
17
|
+
adapter_cleanup,
|
|
18
|
+
all_adapters_param,
|
|
19
|
+
)
|
|
20
|
+
from kernia_test_utils.asgi_driver import ASGIDriver, ASGIResponse
|
|
21
|
+
from kernia_test_utils.containers import (
|
|
22
|
+
docker_available,
|
|
23
|
+
mongodb_container,
|
|
24
|
+
mysql_container,
|
|
25
|
+
postgres_container,
|
|
26
|
+
redis_container,
|
|
27
|
+
requires_docker,
|
|
28
|
+
)
|
|
29
|
+
from kernia_test_utils.mock_idp import MockIdP
|
|
30
|
+
from kernia_test_utils.mock_saml_idp import MockSAMLIdP
|
|
31
|
+
from kernia_test_utils.mock_sms import MockSMS, SentSMS
|
|
32
|
+
from kernia_test_utils.mock_smtp import MockSMTP, SentEmail
|
|
33
|
+
from kernia_test_utils.mock_stripe import MockStripe
|
|
34
|
+
from kernia_test_utils.soft_webauthn import SoftAuthenticator
|
|
35
|
+
|
|
36
|
+
__all__ = [
|
|
37
|
+
"ASGIDriver",
|
|
38
|
+
"ASGIResponse",
|
|
39
|
+
"AdapterFactory",
|
|
40
|
+
"MockIdP",
|
|
41
|
+
"MockSAMLIdP",
|
|
42
|
+
"MockSMS",
|
|
43
|
+
"MockSMTP",
|
|
44
|
+
"MockStripe",
|
|
45
|
+
"SentEmail",
|
|
46
|
+
"SentSMS",
|
|
47
|
+
"SoftAuthenticator",
|
|
48
|
+
"adapter_cleanup",
|
|
49
|
+
"all_adapters_param",
|
|
50
|
+
"docker_available",
|
|
51
|
+
"mongodb_container",
|
|
52
|
+
"mysql_container",
|
|
53
|
+
"postgres_container",
|
|
54
|
+
"redis_container",
|
|
55
|
+
"requires_docker",
|
|
56
|
+
]
|
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
"""Cross-adapter parametrization helper.
|
|
2
|
+
|
|
3
|
+
Plugin integration tests do:
|
|
4
|
+
|
|
5
|
+
@pytest.mark.parametrize(*all_adapters_param())
|
|
6
|
+
async def test_thing(adapter_factory):
|
|
7
|
+
adapter = await adapter_factory()
|
|
8
|
+
...
|
|
9
|
+
|
|
10
|
+
The factory returns a fresh adapter with a fresh DB. Schema is created
|
|
11
|
+
per-test; for SQLAlchemy the engine is disposed after the test runs.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
from collections.abc import Awaitable, Callable
|
|
17
|
+
from typing import Any
|
|
18
|
+
|
|
19
|
+
import pytest
|
|
20
|
+
|
|
21
|
+
from kernia_test_utils.containers import docker_available, postgres_container
|
|
22
|
+
|
|
23
|
+
AdapterFactory = Callable[[], Awaitable[Any]]
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
async def _memory_factory() -> Any:
|
|
27
|
+
from kernia_memory_adapter import memory_adapter
|
|
28
|
+
|
|
29
|
+
return memory_adapter()
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
async def _sqlite_factory() -> Any:
|
|
33
|
+
# Each call gets its own in-memory database — they're isolated even when
|
|
34
|
+
# several tests run concurrently because the URL contains a fresh secret.
|
|
35
|
+
import secrets
|
|
36
|
+
|
|
37
|
+
from kernia_sqlalchemy import sqlalchemy_adapter
|
|
38
|
+
|
|
39
|
+
url = f"sqlite+aiosqlite:///file:{secrets.token_hex(8)}?mode=memory&cache=shared&uri=true"
|
|
40
|
+
return await sqlalchemy_adapter(url=url)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _postgres_url_factory() -> Callable[[], Awaitable[Any]]:
|
|
44
|
+
"""Return an adapter-factory bound to a single Postgres container.
|
|
45
|
+
|
|
46
|
+
The container is started lazily on first use and stopped via an atexit
|
|
47
|
+
handler. Each adapter call gets a fresh database name on that container.
|
|
48
|
+
"""
|
|
49
|
+
state: dict[str, Any] = {}
|
|
50
|
+
|
|
51
|
+
async def factory() -> Any:
|
|
52
|
+
from kernia_sqlalchemy import sqlalchemy_adapter
|
|
53
|
+
|
|
54
|
+
if "url" not in state:
|
|
55
|
+
import atexit
|
|
56
|
+
|
|
57
|
+
ctx = postgres_container()
|
|
58
|
+
url = ctx.__enter__()
|
|
59
|
+
state["ctx"] = ctx
|
|
60
|
+
state["url"] = url
|
|
61
|
+
atexit.register(lambda: ctx.__exit__(None, None, None))
|
|
62
|
+
# NOTE: tests share the same database; rely on per-test transactional
|
|
63
|
+
# rollback at the adapter layer. For now we just hand out adapters
|
|
64
|
+
# against the shared URL — schema is idempotent (CREATE IF NOT EXISTS).
|
|
65
|
+
return await sqlalchemy_adapter(url=state["url"])
|
|
66
|
+
|
|
67
|
+
return factory
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def all_adapters_param() -> tuple[str, list[Any]]:
|
|
71
|
+
"""Return `("adapter_factory", [...])` for pytest.mark.parametrize.
|
|
72
|
+
|
|
73
|
+
Each entry is a `pytest.param(factory, id=..., marks=...)`. Containers
|
|
74
|
+
that need Docker are wrapped in `pytest.mark.skipif` when Docker is
|
|
75
|
+
unavailable.
|
|
76
|
+
"""
|
|
77
|
+
has_docker = docker_available()
|
|
78
|
+
|
|
79
|
+
return (
|
|
80
|
+
"adapter_factory",
|
|
81
|
+
[
|
|
82
|
+
pytest.param(_memory_factory, id="memory"),
|
|
83
|
+
pytest.param(_sqlite_factory, id="sqlalchemy-sqlite"),
|
|
84
|
+
pytest.param(
|
|
85
|
+
_postgres_url_factory(),
|
|
86
|
+
id="sqlalchemy-postgres",
|
|
87
|
+
marks=pytest.mark.skipif(not has_docker, reason="Docker required"),
|
|
88
|
+
),
|
|
89
|
+
pytest.param(
|
|
90
|
+
_mongo_url_factory(),
|
|
91
|
+
id="mongo",
|
|
92
|
+
marks=pytest.mark.skipif(not has_docker, reason="Docker required for mongo"),
|
|
93
|
+
),
|
|
94
|
+
],
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def _mongo_url_factory() -> Callable[[], Awaitable[Any]]:
|
|
99
|
+
"""Return an adapter-factory bound to a single MongoDB container.
|
|
100
|
+
|
|
101
|
+
Mirrors `_postgres_url_factory`: the container starts lazily on first use
|
|
102
|
+
and is stopped via an atexit handler, so it outlives each test body (a
|
|
103
|
+
`with` block here would tear the container down before the test runs).
|
|
104
|
+
Each call gets a fresh database name on the shared container, keeping
|
|
105
|
+
parametrized tests isolated.
|
|
106
|
+
"""
|
|
107
|
+
state: dict[str, Any] = {}
|
|
108
|
+
|
|
109
|
+
async def factory() -> Any:
|
|
110
|
+
try:
|
|
111
|
+
from kernia_mongo import mongo_adapter
|
|
112
|
+
except ImportError:
|
|
113
|
+
pytest.skip("kernia_mongo is not installed")
|
|
114
|
+
|
|
115
|
+
if "url" not in state:
|
|
116
|
+
import atexit
|
|
117
|
+
|
|
118
|
+
from kernia_test_utils.containers import mongodb_container
|
|
119
|
+
|
|
120
|
+
ctx = mongodb_container()
|
|
121
|
+
state["ctx"] = ctx
|
|
122
|
+
state["url"] = ctx.__enter__()
|
|
123
|
+
atexit.register(lambda: ctx.__exit__(None, None, None))
|
|
124
|
+
|
|
125
|
+
import secrets
|
|
126
|
+
|
|
127
|
+
return await mongo_adapter(url=state["url"], db_name=f"kernia_test_{secrets.token_hex(4)}")
|
|
128
|
+
|
|
129
|
+
return factory
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
@pytest.fixture(autouse=False)
|
|
133
|
+
async def adapter_cleanup() -> Any:
|
|
134
|
+
"""Per-test cleanup hook.
|
|
135
|
+
|
|
136
|
+
Tests that use the parametrized factory can opt in by adding this fixture.
|
|
137
|
+
It currently just yields — adapter teardown is the factory's
|
|
138
|
+
responsibility — but it's the seam where future global cleanup will live.
|
|
139
|
+
"""
|
|
140
|
+
return
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
__all__ = ["AdapterFactory", "adapter_cleanup", "all_adapters_param"]
|
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
"""ASGIDriver — calls an ASGI app like a client without a real server.
|
|
2
|
+
|
|
3
|
+
Maintains a cookie jar between calls so tests can chain sign-up → get-session.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import json
|
|
9
|
+
from collections.abc import Awaitable, Callable, Mapping
|
|
10
|
+
from dataclasses import dataclass, field
|
|
11
|
+
from typing import Any
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass(frozen=True, slots=True)
|
|
15
|
+
class ASGIResponse:
|
|
16
|
+
status: int
|
|
17
|
+
headers: tuple[tuple[str, str], ...]
|
|
18
|
+
body: bytes
|
|
19
|
+
|
|
20
|
+
def json(self) -> Any:
|
|
21
|
+
if not self.body:
|
|
22
|
+
return None
|
|
23
|
+
return json.loads(self.body.decode("utf-8"))
|
|
24
|
+
|
|
25
|
+
def set_cookies(self) -> dict[str, str]:
|
|
26
|
+
out: dict[str, str] = {}
|
|
27
|
+
for k, v in self.headers:
|
|
28
|
+
if k.lower() != "set-cookie":
|
|
29
|
+
continue
|
|
30
|
+
name, _, rest = v.partition("=")
|
|
31
|
+
value, _, _attrs = rest.partition(";")
|
|
32
|
+
out[name.strip()] = value
|
|
33
|
+
return out
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@dataclass
|
|
37
|
+
class ASGIDriver:
|
|
38
|
+
app: Callable[..., Awaitable[None]]
|
|
39
|
+
cookies: dict[str, str] = field(default_factory=dict)
|
|
40
|
+
|
|
41
|
+
async def request(
|
|
42
|
+
self,
|
|
43
|
+
method: str,
|
|
44
|
+
path: str,
|
|
45
|
+
*,
|
|
46
|
+
json_body: Any = None,
|
|
47
|
+
headers: Mapping[str, str] | None = None,
|
|
48
|
+
query: str = "",
|
|
49
|
+
) -> ASGIResponse:
|
|
50
|
+
body_bytes = b""
|
|
51
|
+
req_headers: list[tuple[bytes, bytes]] = []
|
|
52
|
+
for k, v in (headers or {}).items():
|
|
53
|
+
req_headers.append((k.lower().encode("latin-1"), v.encode("latin-1")))
|
|
54
|
+
if json_body is not None:
|
|
55
|
+
body_bytes = json.dumps(json_body).encode("utf-8")
|
|
56
|
+
if not any(k == b"content-type" for k, _ in req_headers):
|
|
57
|
+
req_headers.append((b"content-type", b"application/json"))
|
|
58
|
+
if self.cookies:
|
|
59
|
+
cookie_header = "; ".join(f"{k}={v}" for k, v in self.cookies.items())
|
|
60
|
+
req_headers.append((b"cookie", cookie_header.encode("latin-1")))
|
|
61
|
+
|
|
62
|
+
scope = {
|
|
63
|
+
"type": "http",
|
|
64
|
+
"method": method,
|
|
65
|
+
"path": path,
|
|
66
|
+
"query_string": query.encode("utf-8"),
|
|
67
|
+
"headers": req_headers,
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
sent_body = b""
|
|
71
|
+
more = True
|
|
72
|
+
|
|
73
|
+
async def receive() -> dict:
|
|
74
|
+
nonlocal more
|
|
75
|
+
if more:
|
|
76
|
+
more = False
|
|
77
|
+
return {"type": "http.request", "body": body_bytes, "more_body": False}
|
|
78
|
+
return {"type": "http.disconnect"}
|
|
79
|
+
|
|
80
|
+
captured: dict[str, Any] = {"status": None, "headers": [], "body": b""}
|
|
81
|
+
|
|
82
|
+
async def send(msg: dict) -> None:
|
|
83
|
+
nonlocal sent_body
|
|
84
|
+
if msg["type"] == "http.response.start":
|
|
85
|
+
captured["status"] = msg["status"]
|
|
86
|
+
captured["headers"] = msg.get("headers", [])
|
|
87
|
+
elif msg["type"] == "http.response.body":
|
|
88
|
+
sent_body += msg.get("body") or b""
|
|
89
|
+
|
|
90
|
+
await self.app(scope, receive, send)
|
|
91
|
+
|
|
92
|
+
captured["body"] = sent_body
|
|
93
|
+
decoded_headers = tuple(
|
|
94
|
+
(k.decode("latin-1"), v.decode("latin-1")) for k, v in captured["headers"]
|
|
95
|
+
)
|
|
96
|
+
response = ASGIResponse(
|
|
97
|
+
status=captured["status"],
|
|
98
|
+
headers=decoded_headers,
|
|
99
|
+
body=sent_body,
|
|
100
|
+
)
|
|
101
|
+
# Update cookie jar
|
|
102
|
+
for name, value in response.set_cookies().items():
|
|
103
|
+
if value:
|
|
104
|
+
self.cookies[name] = value
|
|
105
|
+
elif name in self.cookies:
|
|
106
|
+
del self.cookies[name]
|
|
107
|
+
return response
|
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
"""Lazy testcontainers fixtures.
|
|
2
|
+
|
|
3
|
+
Each helper imports `testcontainers` at call-time so the dep stays optional.
|
|
4
|
+
If Docker is not reachable, the call raises an ImportError / RuntimeError; use
|
|
5
|
+
`requires_docker()` to skip cleanly at the test layer.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import contextlib
|
|
11
|
+
import shutil
|
|
12
|
+
import subprocess
|
|
13
|
+
from collections.abc import Iterator
|
|
14
|
+
|
|
15
|
+
import pytest
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _testcontainers_installed() -> bool:
|
|
19
|
+
"""True if the optional `testcontainers` package is importable."""
|
|
20
|
+
import importlib.util
|
|
21
|
+
|
|
22
|
+
return importlib.util.find_spec("testcontainers") is not None
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def docker_available() -> bool:
|
|
26
|
+
"""Best-effort check that container-backed tests can actually run.
|
|
27
|
+
|
|
28
|
+
Requires BOTH a reachable Docker daemon AND the `testcontainers` package.
|
|
29
|
+
If either is missing, the gated suites skip cleanly instead of erroring —
|
|
30
|
+
a Docker daemon with no `testcontainers` install is a common local setup.
|
|
31
|
+
"""
|
|
32
|
+
if not _testcontainers_installed():
|
|
33
|
+
return False
|
|
34
|
+
if shutil.which("docker") is None:
|
|
35
|
+
return False
|
|
36
|
+
try:
|
|
37
|
+
result = subprocess.run(
|
|
38
|
+
["docker", "info"],
|
|
39
|
+
check=False,
|
|
40
|
+
capture_output=True,
|
|
41
|
+
timeout=5,
|
|
42
|
+
)
|
|
43
|
+
except (OSError, subprocess.TimeoutExpired):
|
|
44
|
+
return False
|
|
45
|
+
return result.returncode == 0
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def requires_docker() -> pytest.MarkDecorator:
|
|
49
|
+
"""Return a pytest mark that skips when Docker isn't reachable."""
|
|
50
|
+
return pytest.mark.skipif(not docker_available(), reason="Docker is not available on this host")
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@contextlib.contextmanager
|
|
54
|
+
def postgres_container(image: str = "postgres:16-alpine") -> Iterator[str]:
|
|
55
|
+
"""Yield a Postgres connection URL (asyncpg-compatible)."""
|
|
56
|
+
try:
|
|
57
|
+
from testcontainers.postgres import PostgresContainer # type: ignore[import-not-found]
|
|
58
|
+
except ImportError as e: # pragma: no cover
|
|
59
|
+
raise RuntimeError("testcontainers extra is not installed") from e
|
|
60
|
+
container = PostgresContainer(image)
|
|
61
|
+
container.start()
|
|
62
|
+
try:
|
|
63
|
+
# Convert default psycopg2 URL to asyncpg-friendly form.
|
|
64
|
+
url = container.get_connection_url().replace(
|
|
65
|
+
"postgresql+psycopg2://", "postgresql+asyncpg://"
|
|
66
|
+
)
|
|
67
|
+
yield url
|
|
68
|
+
finally:
|
|
69
|
+
container.stop()
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
@contextlib.contextmanager
|
|
73
|
+
def mysql_container(image: str = "mysql:8") -> Iterator[str]:
|
|
74
|
+
try:
|
|
75
|
+
from testcontainers.mysql import MySqlContainer # type: ignore[import-not-found]
|
|
76
|
+
except ImportError as e: # pragma: no cover
|
|
77
|
+
raise RuntimeError("testcontainers extra is not installed") from e
|
|
78
|
+
container = MySqlContainer(image)
|
|
79
|
+
container.start()
|
|
80
|
+
try:
|
|
81
|
+
url = container.get_connection_url().replace("mysql+pymysql://", "mysql+aiomysql://")
|
|
82
|
+
yield url
|
|
83
|
+
finally:
|
|
84
|
+
container.stop()
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
@contextlib.contextmanager
|
|
88
|
+
def mongodb_container(image: str = "mongo:7") -> Iterator[str]:
|
|
89
|
+
try:
|
|
90
|
+
from testcontainers.mongodb import MongoDbContainer # type: ignore[import-not-found]
|
|
91
|
+
except ImportError as e: # pragma: no cover
|
|
92
|
+
raise RuntimeError("testcontainers extra is not installed") from e
|
|
93
|
+
container = MongoDbContainer(image)
|
|
94
|
+
container.start()
|
|
95
|
+
try:
|
|
96
|
+
yield container.get_connection_url()
|
|
97
|
+
finally:
|
|
98
|
+
container.stop()
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
@contextlib.contextmanager
|
|
102
|
+
def redis_container(image: str = "redis:7-alpine") -> Iterator[str]:
|
|
103
|
+
try:
|
|
104
|
+
from testcontainers.redis import RedisContainer # type: ignore[import-not-found]
|
|
105
|
+
except ImportError as e: # pragma: no cover
|
|
106
|
+
raise RuntimeError("testcontainers extra is not installed") from e
|
|
107
|
+
container = RedisContainer(image)
|
|
108
|
+
container.start()
|
|
109
|
+
try:
|
|
110
|
+
host = container.get_container_host_ip()
|
|
111
|
+
port = container.get_exposed_port(6379)
|
|
112
|
+
yield f"redis://{host}:{port}/0"
|
|
113
|
+
finally:
|
|
114
|
+
container.stop()
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
__all__ = [
|
|
118
|
+
"docker_available",
|
|
119
|
+
"mongodb_container",
|
|
120
|
+
"mysql_container",
|
|
121
|
+
"postgres_container",
|
|
122
|
+
"redis_container",
|
|
123
|
+
"requires_docker",
|
|
124
|
+
]
|
|
@@ -0,0 +1,191 @@
|
|
|
1
|
+
"""In-process OIDC IdP for OAuth/OIDC tests.
|
|
2
|
+
|
|
3
|
+
Generates an RSA keypair (via `cryptography`), signs id_tokens, and serves a
|
|
4
|
+
JWKS / discovery / token / userinfo surface via `httpx.MockTransport`. Tests can
|
|
5
|
+
plug the transport into an `httpx.AsyncClient` and exercise the full OIDC
|
|
6
|
+
authorization-code flow without a live IdP.
|
|
7
|
+
|
|
8
|
+
Mirrors what real providers (Google, Microsoft, Apple) expose so the core
|
|
9
|
+
`oauth2.verify_id_token` path is exercised end-to-end.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
import base64
|
|
15
|
+
import json
|
|
16
|
+
import secrets
|
|
17
|
+
import time
|
|
18
|
+
from collections import deque
|
|
19
|
+
from dataclasses import dataclass, field
|
|
20
|
+
from typing import Any
|
|
21
|
+
|
|
22
|
+
import httpx
|
|
23
|
+
from cryptography.hazmat.primitives import hashes
|
|
24
|
+
from cryptography.hazmat.primitives.asymmetric import padding, rsa
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _b64url(b: bytes) -> str:
|
|
28
|
+
return base64.urlsafe_b64encode(b).rstrip(b"=").decode("ascii")
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _b64url_int(i: int) -> str:
|
|
32
|
+
length = (i.bit_length() + 7) // 8 or 1
|
|
33
|
+
return _b64url(i.to_bytes(length, "big"))
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@dataclass(slots=True)
|
|
37
|
+
class _UserProfile:
|
|
38
|
+
sub: str
|
|
39
|
+
email: str | None = None
|
|
40
|
+
name: str | None = None
|
|
41
|
+
extra: dict[str, Any] = field(default_factory=dict)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@dataclass
|
|
45
|
+
class MockIdP:
|
|
46
|
+
"""In-memory OpenID Connect provider for tests.
|
|
47
|
+
|
|
48
|
+
Usage:
|
|
49
|
+
idp = MockIdP(issuer="https://test-idp", audience="client-1")
|
|
50
|
+
idp.create_user(sub="u1", email="a@b.c", name="A")
|
|
51
|
+
transport = idp.mock_transport()
|
|
52
|
+
async with httpx.AsyncClient(transport=transport) as client:
|
|
53
|
+
...
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
issuer: str = "https://test-idp"
|
|
57
|
+
audience: str = "client-1"
|
|
58
|
+
kid: str = "test-key-1"
|
|
59
|
+
token_ttl: int = 3600
|
|
60
|
+
|
|
61
|
+
_key: rsa.RSAPrivateKey = field(init=False)
|
|
62
|
+
_queue: deque[_UserProfile] = field(init=False, default_factory=deque)
|
|
63
|
+
_access_tokens: dict[str, _UserProfile] = field(init=False, default_factory=dict)
|
|
64
|
+
|
|
65
|
+
def __post_init__(self) -> None:
|
|
66
|
+
self._key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
|
|
67
|
+
|
|
68
|
+
# ----- public API -----
|
|
69
|
+
|
|
70
|
+
@property
|
|
71
|
+
def jwks(self) -> dict[str, Any]:
|
|
72
|
+
pub = self._key.public_key().public_numbers()
|
|
73
|
+
return {
|
|
74
|
+
"keys": [
|
|
75
|
+
{
|
|
76
|
+
"kid": self.kid,
|
|
77
|
+
"kty": "RSA",
|
|
78
|
+
"alg": "RS256",
|
|
79
|
+
"use": "sig",
|
|
80
|
+
"n": _b64url_int(pub.n),
|
|
81
|
+
"e": _b64url_int(pub.e),
|
|
82
|
+
}
|
|
83
|
+
]
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
@property
|
|
87
|
+
def discovery(self) -> dict[str, Any]:
|
|
88
|
+
return {
|
|
89
|
+
"issuer": self.issuer,
|
|
90
|
+
"authorization_endpoint": f"{self.issuer}/authorize",
|
|
91
|
+
"token_endpoint": f"{self.issuer}/token",
|
|
92
|
+
"userinfo_endpoint": f"{self.issuer}/userinfo",
|
|
93
|
+
"jwks_uri": f"{self.issuer}/.well-known/jwks.json",
|
|
94
|
+
"response_types_supported": ["code"],
|
|
95
|
+
"subject_types_supported": ["public"],
|
|
96
|
+
"id_token_signing_alg_values_supported": ["RS256"],
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
def create_user(
|
|
100
|
+
self,
|
|
101
|
+
sub: str,
|
|
102
|
+
email: str | None = None,
|
|
103
|
+
name: str | None = None,
|
|
104
|
+
**extra: Any,
|
|
105
|
+
) -> None:
|
|
106
|
+
"""Enqueue a profile for the next sign-in (token exchange)."""
|
|
107
|
+
self._queue.append(_UserProfile(sub=sub, email=email, name=name, extra=dict(extra)))
|
|
108
|
+
|
|
109
|
+
def id_token_for(self, sub: str, **claims: Any) -> str:
|
|
110
|
+
"""Return a signed id_token for direct verification tests."""
|
|
111
|
+
now = int(time.time())
|
|
112
|
+
payload: dict[str, Any] = {
|
|
113
|
+
"iss": self.issuer,
|
|
114
|
+
"aud": self.audience,
|
|
115
|
+
"sub": sub,
|
|
116
|
+
"iat": now,
|
|
117
|
+
"exp": now + self.token_ttl,
|
|
118
|
+
}
|
|
119
|
+
payload.update(claims)
|
|
120
|
+
return self._sign_jwt(payload)
|
|
121
|
+
|
|
122
|
+
def mock_transport(self) -> httpx.MockTransport:
|
|
123
|
+
return httpx.MockTransport(self._handle)
|
|
124
|
+
|
|
125
|
+
# ----- internals -----
|
|
126
|
+
|
|
127
|
+
def _sign_jwt(self, claims: dict[str, Any]) -> str:
|
|
128
|
+
header = {"alg": "RS256", "kid": self.kid, "typ": "JWT"}
|
|
129
|
+
signing_input = (
|
|
130
|
+
_b64url(json.dumps(header, separators=(",", ":")).encode())
|
|
131
|
+
+ "."
|
|
132
|
+
+ _b64url(json.dumps(claims, separators=(",", ":")).encode())
|
|
133
|
+
)
|
|
134
|
+
sig = self._key.sign(
|
|
135
|
+
signing_input.encode("ascii"),
|
|
136
|
+
padding.PKCS1v15(),
|
|
137
|
+
hashes.SHA256(),
|
|
138
|
+
)
|
|
139
|
+
return signing_input + "." + _b64url(sig)
|
|
140
|
+
|
|
141
|
+
def _next_profile(self) -> _UserProfile:
|
|
142
|
+
if self._queue:
|
|
143
|
+
return self._queue.popleft()
|
|
144
|
+
# Fallback so the IdP never errors out — tests that don't pre-enqueue
|
|
145
|
+
# still get a deterministic profile.
|
|
146
|
+
return _UserProfile(sub="anonymous", email="anonymous@test", name="anonymous")
|
|
147
|
+
|
|
148
|
+
def _claims_for(self, profile: _UserProfile) -> dict[str, Any]:
|
|
149
|
+
claims: dict[str, Any] = {"sub": profile.sub}
|
|
150
|
+
if profile.email is not None:
|
|
151
|
+
claims["email"] = profile.email
|
|
152
|
+
claims["email_verified"] = True
|
|
153
|
+
if profile.name is not None:
|
|
154
|
+
claims["name"] = profile.name
|
|
155
|
+
claims.update(profile.extra)
|
|
156
|
+
return claims
|
|
157
|
+
|
|
158
|
+
def _handle(self, request: httpx.Request) -> httpx.Response:
|
|
159
|
+
path = request.url.path
|
|
160
|
+
if path.endswith("/.well-known/jwks.json"):
|
|
161
|
+
return httpx.Response(200, json=self.jwks)
|
|
162
|
+
if path.endswith("/.well-known/openid-configuration"):
|
|
163
|
+
return httpx.Response(200, json=self.discovery)
|
|
164
|
+
if path.endswith("/token") and request.method == "POST":
|
|
165
|
+
profile = self._next_profile()
|
|
166
|
+
access_token = "at_" + secrets.token_urlsafe(16)
|
|
167
|
+
self._access_tokens[access_token] = profile
|
|
168
|
+
extra_claims = self._claims_for(profile)
|
|
169
|
+
extra_claims.pop("sub", None)
|
|
170
|
+
id_token = self.id_token_for(profile.sub, **extra_claims)
|
|
171
|
+
return httpx.Response(
|
|
172
|
+
200,
|
|
173
|
+
json={
|
|
174
|
+
"access_token": access_token,
|
|
175
|
+
"token_type": "Bearer",
|
|
176
|
+
"expires_in": self.token_ttl,
|
|
177
|
+
"id_token": id_token,
|
|
178
|
+
"scope": "openid email profile",
|
|
179
|
+
},
|
|
180
|
+
)
|
|
181
|
+
if path.endswith("/userinfo"):
|
|
182
|
+
auth = request.headers.get("authorization", "")
|
|
183
|
+
token = auth.removeprefix("Bearer ").strip()
|
|
184
|
+
profile = self._access_tokens.get(token)
|
|
185
|
+
if profile is None:
|
|
186
|
+
return httpx.Response(401, json={"error": "invalid_token"})
|
|
187
|
+
return httpx.Response(200, json=self._claims_for(profile))
|
|
188
|
+
return httpx.Response(404, json={"error": "not_found", "path": path})
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
__all__ = ["MockIdP"]
|