simple-module-auth 0.0.16__tar.gz → 0.0.18__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {simple_module_auth-0.0.16 → simple_module_auth-0.0.18}/PKG-INFO +3 -3
- simple_module_auth-0.0.18/auth/__init__.py +7 -0
- {simple_module_auth-0.0.16 → simple_module_auth-0.0.18}/auth/contracts/__init__.py +2 -1
- simple_module_auth-0.0.18/auth/contracts/provider.py +34 -0
- simple_module_auth-0.0.18/auth/middleware.py +113 -0
- simple_module_auth-0.0.18/auth/module.py +55 -0
- {simple_module_auth-0.0.16 → simple_module_auth-0.0.18}/auth/state.py +9 -4
- {simple_module_auth-0.0.16 → simple_module_auth-0.0.18}/pyproject.toml +3 -3
- simple_module_auth-0.0.18/tests/test_auth_middleware.py +205 -0
- simple_module_auth-0.0.18/tests/test_auth_provider_protocol.py +56 -0
- {simple_module_auth-0.0.16 → simple_module_auth-0.0.18}/tests/test_resolver_registry.py +63 -0
- simple_module_auth-0.0.16/auth/__init__.py +0 -6
- simple_module_auth-0.0.16/auth/module.py +0 -39
- {simple_module_auth-0.0.16 → simple_module_auth-0.0.18}/.gitignore +0 -0
- {simple_module_auth-0.0.16 → simple_module_auth-0.0.18}/LICENSE +0 -0
- {simple_module_auth-0.0.16 → simple_module_auth-0.0.18}/README.md +0 -0
- {simple_module_auth-0.0.16 → simple_module_auth-0.0.18}/auth/contracts/resolver.py +0 -0
- {simple_module_auth-0.0.16 → simple_module_auth-0.0.18}/auth/contracts/schemas.py +0 -0
- {simple_module_auth-0.0.16 → simple_module_auth-0.0.18}/auth/deps.py +0 -0
- {simple_module_auth-0.0.16 → simple_module_auth-0.0.18}/auth/locales/en.json +0 -0
- {simple_module_auth-0.0.16 → simple_module_auth-0.0.18}/auth/locales/es.json +0 -0
- {simple_module_auth-0.0.16 → simple_module_auth-0.0.18}/auth/py.typed +0 -0
- {simple_module_auth-0.0.16 → simple_module_auth-0.0.18}/package.json +0 -0
- {simple_module_auth-0.0.16 → simple_module_auth-0.0.18}/tests/test_deps.py +0 -0
- {simple_module_auth-0.0.16 → simple_module_auth-0.0.18}/tests/test_module.py +0 -0
- {simple_module_auth-0.0.16 → simple_module_auth-0.0.18}/tests/test_user_context.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: simple_module_auth
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.18
|
|
4
4
|
Summary: Session-cookie authentication primitives — middleware, login/logout, redirect helpers for simple_module
|
|
5
5
|
Project-URL: Homepage, https://github.com/antosubash/simple_module_python
|
|
6
6
|
Project-URL: Repository, https://github.com/antosubash/simple_module_python
|
|
@@ -22,8 +22,8 @@ Classifier: Topic :: Software Development :: Libraries :: Application Frameworks
|
|
|
22
22
|
Classifier: Typing :: Typed
|
|
23
23
|
Requires-Python: >=3.12
|
|
24
24
|
Requires-Dist: itsdangerous>=2.2
|
|
25
|
-
Requires-Dist: simple-module-core==0.0.
|
|
26
|
-
Requires-Dist: simple-module-db==0.0.
|
|
25
|
+
Requires-Dist: simple-module-core==0.0.18
|
|
26
|
+
Requires-Dist: simple-module-db==0.0.18
|
|
27
27
|
Description-Content-Type: text/markdown
|
|
28
28
|
|
|
29
29
|
# simple_module_auth
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
"""Auth module — shared contracts (UserContext, AuthProvider, PrincipalResolver, deps)."""
|
|
2
|
+
|
|
3
|
+
from auth.contracts.provider import AuthProvider
|
|
4
|
+
from auth.contracts.resolver import PrincipalResolver
|
|
5
|
+
from auth.contracts.schemas import UserContext
|
|
6
|
+
|
|
7
|
+
__all__ = ["AuthProvider", "PrincipalResolver", "UserContext"]
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
"""AuthProvider protocol — the contract both users and keycloak modules implement."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Protocol, runtime_checkable
|
|
6
|
+
|
|
7
|
+
from starlette.requests import Request
|
|
8
|
+
|
|
9
|
+
from auth.contracts.schemas import UserContext
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@runtime_checkable
|
|
13
|
+
class AuthProvider(Protocol):
|
|
14
|
+
"""Extension point for swappable authentication backends.
|
|
15
|
+
|
|
16
|
+
Exactly one module (``users`` or ``keycloak``) registers an implementation
|
|
17
|
+
on ``app.state.auth.auth_provider`` during ``register_settings``.
|
|
18
|
+
The ``AuthMiddleware`` delegates to it on every request.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
name: str
|
|
22
|
+
|
|
23
|
+
async def resolve_user(self, request: Request) -> UserContext | None: ...
|
|
24
|
+
|
|
25
|
+
def get_login_url(self, request: Request, next_url: str | None = None) -> str: ...
|
|
26
|
+
|
|
27
|
+
def get_logout_url(self, request: Request) -> str: ...
|
|
28
|
+
|
|
29
|
+
def get_public_paths(self) -> tuple[tuple[str, ...], tuple[str, ...]]: ...
|
|
30
|
+
|
|
31
|
+
def is_bearer_request(self, request: Request) -> bool: ...
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
__all__ = ["AuthProvider"]
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
"""Provider-agnostic authentication middleware.
|
|
2
|
+
|
|
3
|
+
Delegates user resolution to the ``AuthProvider`` registered on
|
|
4
|
+
``app.state.auth.auth_provider``, then falls through to the
|
|
5
|
+
principal-resolver chain. Sets ``request.state.user`` and the
|
|
6
|
+
``current_user_id`` ContextVar for audit listeners.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
import logging
|
|
12
|
+
|
|
13
|
+
from simple_module_db.listeners import current_user_id
|
|
14
|
+
from starlette.requests import Request
|
|
15
|
+
from starlette.responses import JSONResponse, RedirectResponse
|
|
16
|
+
from starlette.types import ASGIApp, Receive, Scope, Send
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
_FRAMEWORK_PUBLIC_PREFIXES = (
|
|
21
|
+
"/health",
|
|
22
|
+
"/static/",
|
|
23
|
+
"/api/docs",
|
|
24
|
+
"/api/redoc",
|
|
25
|
+
"/openapi.json",
|
|
26
|
+
"/i18n/",
|
|
27
|
+
)
|
|
28
|
+
_FRAMEWORK_PUBLIC_EXACT = ("/",)
|
|
29
|
+
_SESSION_NEXT_KEY = "next"
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class AuthMiddleware:
|
|
33
|
+
"""Authenticate requests via the registered AuthProvider.
|
|
34
|
+
|
|
35
|
+
On cache miss (provider returns None), falls through to the
|
|
36
|
+
principal-resolver chain. Unauthenticated API requests get 401 JSON;
|
|
37
|
+
unauthenticated browser requests get a redirect to the provider's
|
|
38
|
+
login URL.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(self, app: ASGIApp) -> None:
|
|
42
|
+
self.app = app
|
|
43
|
+
|
|
44
|
+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
|
45
|
+
if scope["type"] != "http":
|
|
46
|
+
await self.app(scope, receive, send)
|
|
47
|
+
return
|
|
48
|
+
|
|
49
|
+
path: str = scope["path"]
|
|
50
|
+
method: str = scope.get("method", "GET")
|
|
51
|
+
app_state = scope["app"].state
|
|
52
|
+
auth_state = app_state.auth
|
|
53
|
+
provider = auth_state.auth_provider
|
|
54
|
+
|
|
55
|
+
if provider is None:
|
|
56
|
+
await self.app(scope, receive, send)
|
|
57
|
+
return
|
|
58
|
+
|
|
59
|
+
is_public = (
|
|
60
|
+
any(path.startswith(p) for p in _FRAMEWORK_PUBLIC_PREFIXES)
|
|
61
|
+
or path in _FRAMEWORK_PUBLIC_EXACT
|
|
62
|
+
)
|
|
63
|
+
# Module-contributed public routes (register_public_routes hook). Method
|
|
64
|
+
# -aware, so a GET read route can be exempted without opening sibling
|
|
65
|
+
# POST/PATCH mutations under the same prefix.
|
|
66
|
+
if not is_public:
|
|
67
|
+
public_routes = getattr(app_state, "public_routes", None)
|
|
68
|
+
is_public = public_routes is not None and public_routes.matches(method, path)
|
|
69
|
+
# Legacy provider-declared paths (prefix-only, method-agnostic). Kept for
|
|
70
|
+
# back-compat with AuthProvider implementations.
|
|
71
|
+
if not is_public:
|
|
72
|
+
prefix_paths, exact_paths = provider.get_public_paths()
|
|
73
|
+
is_public = any(path.startswith(p) for p in prefix_paths) or path in exact_paths
|
|
74
|
+
|
|
75
|
+
request = Request(scope)
|
|
76
|
+
user_ctx = await provider.resolve_user(request)
|
|
77
|
+
|
|
78
|
+
if user_ctx is None:
|
|
79
|
+
for resolver in auth_state.principal_resolvers:
|
|
80
|
+
try:
|
|
81
|
+
user_ctx = await resolver(request)
|
|
82
|
+
except Exception:
|
|
83
|
+
logger.exception(
|
|
84
|
+
"Principal resolver %r raised; treating as no-match",
|
|
85
|
+
resolver,
|
|
86
|
+
)
|
|
87
|
+
continue
|
|
88
|
+
if user_ctx is not None:
|
|
89
|
+
break
|
|
90
|
+
|
|
91
|
+
if user_ctx is None and not is_public:
|
|
92
|
+
if path.startswith("/api/") or provider.is_bearer_request(request):
|
|
93
|
+
response = JSONResponse({"detail": "Not authenticated"}, status_code=401)
|
|
94
|
+
else:
|
|
95
|
+
session = scope.get("session", {})
|
|
96
|
+
session[_SESSION_NEXT_KEY] = str(request.url)
|
|
97
|
+
response = RedirectResponse(provider.get_login_url(request), status_code=302)
|
|
98
|
+
await response(scope, receive, send)
|
|
99
|
+
return
|
|
100
|
+
|
|
101
|
+
if user_ctx is not None:
|
|
102
|
+
request.state.user = user_ctx
|
|
103
|
+
token = current_user_id.set(user_ctx.id)
|
|
104
|
+
try:
|
|
105
|
+
await self.app(scope, receive, send)
|
|
106
|
+
finally:
|
|
107
|
+
current_user_id.reset(token)
|
|
108
|
+
return
|
|
109
|
+
|
|
110
|
+
await self.app(scope, receive, send)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
__all__ = ["AuthMiddleware"]
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
"""Auth module — shared contracts (UserContext, AuthProvider, deps).
|
|
2
|
+
|
|
3
|
+
Intentionally minimal: this module owns the PUBLIC interface (UserContext,
|
|
4
|
+
AuthProvider, PrincipalResolver, get_current_user, CurrentUser, require_permission)
|
|
5
|
+
that every other module imports. Keeping it stable prevents churn when auth
|
|
6
|
+
internals change.
|
|
7
|
+
|
|
8
|
+
The ``auth_provider`` slot on ``app.state.auth`` is the extension point
|
|
9
|
+
auth-provider modules (``users``, ``keycloak``) use to register themselves.
|
|
10
|
+
The ``principal_resolvers`` registry lets downstream modules add extra
|
|
11
|
+
credential sources (PAT bearer tokens, API keys, etc.).
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
import importlib.resources
|
|
17
|
+
from pathlib import Path
|
|
18
|
+
from typing import TYPE_CHECKING
|
|
19
|
+
|
|
20
|
+
from simple_module_core.module import ModuleBase, ModuleMeta
|
|
21
|
+
|
|
22
|
+
if TYPE_CHECKING:
|
|
23
|
+
from fastapi import FastAPI
|
|
24
|
+
|
|
25
|
+
from auth.contracts.schemas import UserContext
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _serialize_principal(user: UserContext) -> dict:
|
|
29
|
+
return {
|
|
30
|
+
"id": user.id,
|
|
31
|
+
"name": user.name,
|
|
32
|
+
"email": user.email,
|
|
33
|
+
"roles": user.roles,
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class AuthModule(ModuleBase):
|
|
38
|
+
meta = ModuleMeta(
|
|
39
|
+
name="Auth",
|
|
40
|
+
route_prefix="/auth",
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
def register_settings(self, app: FastAPI) -> None:
|
|
44
|
+
from auth.state import AuthState
|
|
45
|
+
|
|
46
|
+
app.state.auth = AuthState()
|
|
47
|
+
app.state.principal_serializer = _serialize_principal
|
|
48
|
+
|
|
49
|
+
def register_middleware(self, app: FastAPI) -> None:
|
|
50
|
+
from auth.middleware import AuthMiddleware
|
|
51
|
+
|
|
52
|
+
app.add_middleware(AuthMiddleware)
|
|
53
|
+
|
|
54
|
+
def locale_dirs(self) -> dict[str, Path]:
|
|
55
|
+
return {"auth": Path(str(importlib.resources.files(__package__) / "locales"))}
|
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
"""Module-owned state attached to ``app.state.auth`` by ``AuthModule.register_settings``.
|
|
2
2
|
|
|
3
|
-
Holds the
|
|
4
|
-
|
|
5
|
-
|
|
3
|
+
Holds the auth provider (set by one of ``users`` or ``keycloak``) and the
|
|
4
|
+
principal-resolver registry. Apps register additional resolvers from their
|
|
5
|
+
``on_startup`` hook::
|
|
6
6
|
|
|
7
7
|
app.state.auth.principal_resolvers.append(my_pat_resolver)
|
|
8
8
|
"""
|
|
@@ -10,14 +10,19 @@ resolvers from their ``on_startup`` hook::
|
|
|
10
10
|
from __future__ import annotations
|
|
11
11
|
|
|
12
12
|
from dataclasses import dataclass, field
|
|
13
|
+
from typing import TYPE_CHECKING
|
|
13
14
|
|
|
14
15
|
from auth.contracts.resolver import PrincipalResolver
|
|
15
16
|
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
from auth.contracts.provider import AuthProvider
|
|
19
|
+
|
|
16
20
|
|
|
17
21
|
@dataclass
|
|
18
22
|
class AuthState:
|
|
19
|
-
"""Per-app auth registry. Initialized empty; modules
|
|
23
|
+
"""Per-app auth registry. Initialized empty; provider modules populate at boot."""
|
|
20
24
|
|
|
25
|
+
auth_provider: AuthProvider | None = None
|
|
21
26
|
principal_resolvers: list[PrincipalResolver] = field(default_factory=list)
|
|
22
27
|
|
|
23
28
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "simple_module_auth"
|
|
3
|
-
version = "0.0.
|
|
3
|
+
version = "0.0.18"
|
|
4
4
|
description = "Session-cookie authentication primitives — middleware, login/logout, redirect helpers for simple_module"
|
|
5
5
|
readme = "README.md"
|
|
6
6
|
license = "MIT"
|
|
@@ -22,8 +22,8 @@ classifiers = [
|
|
|
22
22
|
]
|
|
23
23
|
dependencies = [
|
|
24
24
|
"itsdangerous>=2.2",
|
|
25
|
-
"simple_module_core==0.0.
|
|
26
|
-
"simple_module_db==0.0.
|
|
25
|
+
"simple_module_core==0.0.18",
|
|
26
|
+
"simple_module_db==0.0.18",
|
|
27
27
|
]
|
|
28
28
|
|
|
29
29
|
[project.entry-points.simple_module]
|
|
@@ -0,0 +1,205 @@
|
|
|
1
|
+
"""Tests for the provider-agnostic AuthMiddleware."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import httpx
|
|
6
|
+
import pytest
|
|
7
|
+
from auth.contracts.schemas import UserContext
|
|
8
|
+
from auth.middleware import AuthMiddleware
|
|
9
|
+
from auth.state import AuthState
|
|
10
|
+
from fastapi import FastAPI, Request
|
|
11
|
+
from starlette.middleware.sessions import SessionMiddleware
|
|
12
|
+
from starlette.responses import JSONResponse
|
|
13
|
+
|
|
14
|
+
SECRET = "test-middleware-secret"
|
|
15
|
+
|
|
16
|
+
_TEST_USER = UserContext(
|
|
17
|
+
id="aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee",
|
|
18
|
+
email="test@example.com",
|
|
19
|
+
name="Test User",
|
|
20
|
+
roles=["admin"],
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class _StubProvider:
|
|
25
|
+
name = "stub"
|
|
26
|
+
|
|
27
|
+
def __init__(self, *, user: UserContext | None = None):
|
|
28
|
+
self._user = user
|
|
29
|
+
|
|
30
|
+
async def resolve_user(self, request):
|
|
31
|
+
return self._user
|
|
32
|
+
|
|
33
|
+
def get_login_url(self, request, next_url=None):
|
|
34
|
+
return "/stub/login"
|
|
35
|
+
|
|
36
|
+
def get_logout_url(self, request):
|
|
37
|
+
return "/stub/logout"
|
|
38
|
+
|
|
39
|
+
def get_public_paths(self):
|
|
40
|
+
return (("/stub/login", "/stub/public/"), ())
|
|
41
|
+
|
|
42
|
+
def is_bearer_request(self, request):
|
|
43
|
+
auth = request.headers.get("authorization", "")
|
|
44
|
+
return auth.startswith("Bearer ")
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _build_app(provider, *, principal_resolvers=None, public_routes=None):
|
|
48
|
+
app = FastAPI()
|
|
49
|
+
app.state.auth = AuthState(
|
|
50
|
+
auth_provider=provider,
|
|
51
|
+
principal_resolvers=list(principal_resolvers or []),
|
|
52
|
+
)
|
|
53
|
+
if public_routes is not None:
|
|
54
|
+
app.state.public_routes = public_routes
|
|
55
|
+
|
|
56
|
+
async def _handler(request: Request, path: str = ""):
|
|
57
|
+
user = getattr(request.state, "user", None)
|
|
58
|
+
return JSONResponse(
|
|
59
|
+
{
|
|
60
|
+
"user": user.to_session_dict() if user else None,
|
|
61
|
+
}
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
app.add_api_route("/{path:path}", _handler, methods=["GET", "POST", "PATCH"])
|
|
65
|
+
|
|
66
|
+
app.add_middleware(AuthMiddleware)
|
|
67
|
+
app.add_middleware(SessionMiddleware, secret_key=SECRET)
|
|
68
|
+
return app
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@pytest.fixture
|
|
72
|
+
def authenticated_app():
|
|
73
|
+
return _build_app(_StubProvider(user=_TEST_USER))
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
@pytest.fixture
|
|
77
|
+
def unauthenticated_app():
|
|
78
|
+
return _build_app(_StubProvider(user=None))
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
async def test_authenticated_request_sets_user(authenticated_app):
|
|
82
|
+
transport = httpx.ASGITransport(app=authenticated_app)
|
|
83
|
+
async with httpx.AsyncClient(transport=transport, base_url="http://test") as c:
|
|
84
|
+
resp = await c.get("/some/page")
|
|
85
|
+
assert resp.status_code == 200
|
|
86
|
+
assert resp.json()["user"]["email"] == "test@example.com"
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
async def test_unauthenticated_browser_redirects_to_login(unauthenticated_app):
|
|
90
|
+
transport = httpx.ASGITransport(app=unauthenticated_app)
|
|
91
|
+
async with httpx.AsyncClient(
|
|
92
|
+
transport=transport, base_url="http://test", follow_redirects=False
|
|
93
|
+
) as c:
|
|
94
|
+
resp = await c.get("/protected/page")
|
|
95
|
+
assert resp.status_code == 302
|
|
96
|
+
assert resp.headers["location"] == "/stub/login"
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
async def test_unauthenticated_api_returns_401(unauthenticated_app):
|
|
100
|
+
transport = httpx.ASGITransport(app=unauthenticated_app)
|
|
101
|
+
async with httpx.AsyncClient(transport=transport, base_url="http://test") as c:
|
|
102
|
+
resp = await c.get("/api/protected")
|
|
103
|
+
assert resp.status_code == 401
|
|
104
|
+
assert resp.json()["detail"] == "Not authenticated"
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
async def test_unauthenticated_bearer_returns_401(unauthenticated_app):
|
|
108
|
+
transport = httpx.ASGITransport(app=unauthenticated_app)
|
|
109
|
+
async with httpx.AsyncClient(transport=transport, base_url="http://test") as c:
|
|
110
|
+
resp = await c.get("/some/page", headers={"Authorization": "Bearer bad"})
|
|
111
|
+
assert resp.status_code == 401
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
async def test_public_paths_skip_auth(unauthenticated_app):
|
|
115
|
+
transport = httpx.ASGITransport(app=unauthenticated_app)
|
|
116
|
+
async with httpx.AsyncClient(transport=transport, base_url="http://test") as c:
|
|
117
|
+
resp = await c.get("/stub/login")
|
|
118
|
+
assert resp.status_code == 200
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
async def test_framework_public_paths_skip_auth(unauthenticated_app):
|
|
122
|
+
transport = httpx.ASGITransport(app=unauthenticated_app)
|
|
123
|
+
async with httpx.AsyncClient(transport=transport, base_url="http://test") as c:
|
|
124
|
+
resp = await c.get("/health")
|
|
125
|
+
assert resp.status_code == 200
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
async def test_root_is_public(unauthenticated_app):
|
|
129
|
+
transport = httpx.ASGITransport(app=unauthenticated_app)
|
|
130
|
+
async with httpx.AsyncClient(transport=transport, base_url="http://test") as c:
|
|
131
|
+
resp = await c.get("/")
|
|
132
|
+
assert resp.status_code == 200
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
async def test_registry_public_route_skips_auth():
|
|
136
|
+
"""A module-contributed public route lets an unauthenticated GET through."""
|
|
137
|
+
from simple_module_core.public_routes import PublicRouteRegistry
|
|
138
|
+
|
|
139
|
+
registry = PublicRouteRegistry()
|
|
140
|
+
registry.add_prefix("/api/gis/stac")
|
|
141
|
+
app = _build_app(_StubProvider(user=None), public_routes=registry)
|
|
142
|
+
|
|
143
|
+
transport = httpx.ASGITransport(app=app)
|
|
144
|
+
async with httpx.AsyncClient(transport=transport, base_url="http://test") as c:
|
|
145
|
+
resp = await c.get("/api/gis/stac/collections")
|
|
146
|
+
assert resp.status_code == 200
|
|
147
|
+
assert resp.json()["user"] is None
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
async def test_registry_method_scoping_gates_other_verbs():
|
|
151
|
+
"""A GET-scoped public rule exempts GET but still gates PATCH on the same path."""
|
|
152
|
+
from simple_module_core.public_routes import PublicRouteRegistry
|
|
153
|
+
|
|
154
|
+
registry = PublicRouteRegistry()
|
|
155
|
+
registry.add_regex(r"/api/gis/datasets/[^/]+/tilejson$", methods={"GET"})
|
|
156
|
+
app = _build_app(_StubProvider(user=None), public_routes=registry)
|
|
157
|
+
|
|
158
|
+
transport = httpx.ASGITransport(app=app)
|
|
159
|
+
async with httpx.AsyncClient(transport=transport, base_url="http://test") as c:
|
|
160
|
+
ok = await c.get("/api/gis/datasets/42/tilejson")
|
|
161
|
+
gated = await c.patch("/api/gis/datasets/42/tilejson")
|
|
162
|
+
assert ok.status_code == 200
|
|
163
|
+
assert gated.status_code == 401
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
async def test_no_registry_falls_back_to_provider_paths(unauthenticated_app):
|
|
167
|
+
"""Apps built without a public-routes registry still honor provider paths."""
|
|
168
|
+
transport = httpx.ASGITransport(app=unauthenticated_app)
|
|
169
|
+
async with httpx.AsyncClient(transport=transport, base_url="http://test") as c:
|
|
170
|
+
public = await c.get("/stub/public/data")
|
|
171
|
+
gated = await c.get("/api/protected")
|
|
172
|
+
assert public.status_code == 200
|
|
173
|
+
assert gated.status_code == 401
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
async def test_resolver_chain_fallback():
|
|
177
|
+
"""When provider returns None, fall through to principal resolvers."""
|
|
178
|
+
|
|
179
|
+
async def fake_resolver(request):
|
|
180
|
+
auth = request.headers.get("authorization", "")
|
|
181
|
+
if auth == "Bearer good-token":
|
|
182
|
+
return _TEST_USER
|
|
183
|
+
return None
|
|
184
|
+
|
|
185
|
+
app = _build_app(_StubProvider(user=None), principal_resolvers=[fake_resolver])
|
|
186
|
+
transport = httpx.ASGITransport(app=app)
|
|
187
|
+
async with httpx.AsyncClient(transport=transport, base_url="http://test") as c:
|
|
188
|
+
resp = await c.get("/protected", headers={"Authorization": "Bearer good-token"})
|
|
189
|
+
assert resp.status_code == 200
|
|
190
|
+
assert resp.json()["user"]["email"] == "test@example.com"
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
async def test_resolver_exception_is_logged_and_skipped():
|
|
194
|
+
"""A resolver that raises should be caught; middleware continues."""
|
|
195
|
+
|
|
196
|
+
async def bad_resolver(request):
|
|
197
|
+
raise RuntimeError("boom")
|
|
198
|
+
|
|
199
|
+
app = _build_app(_StubProvider(user=None), principal_resolvers=[bad_resolver])
|
|
200
|
+
transport = httpx.ASGITransport(app=app)
|
|
201
|
+
async with httpx.AsyncClient(
|
|
202
|
+
transport=transport, base_url="http://test", follow_redirects=False
|
|
203
|
+
) as c:
|
|
204
|
+
resp = await c.get("/protected/page")
|
|
205
|
+
assert resp.status_code == 302
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
"""Tests for the AuthProvider protocol."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from auth.contracts.provider import AuthProvider
|
|
6
|
+
from auth.contracts.schemas import UserContext
|
|
7
|
+
from starlette.requests import Request
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class _FakeProvider:
|
|
11
|
+
"""Minimal implementation to verify protocol conformance."""
|
|
12
|
+
|
|
13
|
+
name = "fake"
|
|
14
|
+
|
|
15
|
+
async def resolve_user(self, request: Request) -> UserContext | None:
|
|
16
|
+
return None
|
|
17
|
+
|
|
18
|
+
def get_login_url(self, request: Request, next_url: str | None = None) -> str:
|
|
19
|
+
return "/fake/login"
|
|
20
|
+
|
|
21
|
+
def get_logout_url(self, request: Request) -> str:
|
|
22
|
+
return "/fake/logout"
|
|
23
|
+
|
|
24
|
+
def get_public_paths(self) -> tuple[tuple[str, ...], tuple[str, ...]]:
|
|
25
|
+
return (("/fake/login",), ())
|
|
26
|
+
|
|
27
|
+
def is_bearer_request(self, request: Request) -> bool:
|
|
28
|
+
return False
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def test_fake_provider_satisfies_protocol():
|
|
32
|
+
provider = _FakeProvider()
|
|
33
|
+
assert isinstance(provider, AuthProvider)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def test_protocol_rejects_incomplete_implementation():
|
|
37
|
+
class _Incomplete:
|
|
38
|
+
name = "broken"
|
|
39
|
+
|
|
40
|
+
assert not isinstance(_Incomplete(), AuthProvider)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def test_auth_package_reexports_auth_provider():
|
|
44
|
+
import auth
|
|
45
|
+
|
|
46
|
+
assert hasattr(auth, "AuthProvider")
|
|
47
|
+
assert "AuthProvider" in auth.__all__
|
|
48
|
+
from auth.contracts.provider import AuthProvider as Canonical
|
|
49
|
+
|
|
50
|
+
assert auth.AuthProvider is Canonical
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def test_contracts_package_reexports_auth_provider():
|
|
54
|
+
from auth.contracts import AuthProvider
|
|
55
|
+
|
|
56
|
+
assert AuthProvider is not None
|
|
@@ -79,3 +79,66 @@ def test_auth_package_reexports_public_surface():
|
|
|
79
79
|
|
|
80
80
|
assert auth.PrincipalResolver is PrincipalResolver
|
|
81
81
|
assert auth.UserContext is UserContext
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def test_auth_module_registers_middleware():
|
|
85
|
+
"""AuthModule.register_middleware should add AuthMiddleware."""
|
|
86
|
+
from auth.module import AuthModule
|
|
87
|
+
from fastapi import FastAPI
|
|
88
|
+
|
|
89
|
+
app = FastAPI()
|
|
90
|
+
AuthModule().register_middleware(app)
|
|
91
|
+
middleware_classes = [m.cls.__name__ for m in app.user_middleware]
|
|
92
|
+
assert "AuthMiddleware" in middleware_classes
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def test_auth_module_registers_principal_serializer():
|
|
96
|
+
"""AuthModule.register_settings should set principal_serializer on app.state."""
|
|
97
|
+
from auth.module import AuthModule
|
|
98
|
+
from fastapi import FastAPI
|
|
99
|
+
|
|
100
|
+
app = FastAPI()
|
|
101
|
+
AuthModule().register_settings(app)
|
|
102
|
+
serializer = getattr(app.state, "principal_serializer", None)
|
|
103
|
+
assert serializer is not None
|
|
104
|
+
|
|
105
|
+
from auth.contracts.schemas import UserContext
|
|
106
|
+
|
|
107
|
+
ctx = UserContext(id="123", email="a@b.com", name="Test", roles=["admin"])
|
|
108
|
+
result = serializer(ctx)
|
|
109
|
+
assert result == {"id": "123", "name": "Test", "email": "a@b.com", "roles": ["admin"]}
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def test_auth_state_has_auth_provider_field():
|
|
113
|
+
from auth.state import AuthState
|
|
114
|
+
|
|
115
|
+
state = AuthState()
|
|
116
|
+
assert state.auth_provider is None
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def test_auth_state_accepts_auth_provider():
|
|
120
|
+
from auth.contracts.provider import AuthProvider
|
|
121
|
+
from auth.state import AuthState
|
|
122
|
+
|
|
123
|
+
class FakeProvider:
|
|
124
|
+
name = "fake"
|
|
125
|
+
|
|
126
|
+
async def resolve_user(self, request):
|
|
127
|
+
return None
|
|
128
|
+
|
|
129
|
+
def get_login_url(self, request, next_url=None):
|
|
130
|
+
return "/login"
|
|
131
|
+
|
|
132
|
+
def get_logout_url(self, request):
|
|
133
|
+
return "/logout"
|
|
134
|
+
|
|
135
|
+
def get_public_paths(self):
|
|
136
|
+
return ((), ())
|
|
137
|
+
|
|
138
|
+
def is_bearer_request(self, request):
|
|
139
|
+
return False
|
|
140
|
+
|
|
141
|
+
provider = FakeProvider()
|
|
142
|
+
state = AuthState(auth_provider=provider)
|
|
143
|
+
assert state.auth_provider is provider
|
|
144
|
+
assert isinstance(state.auth_provider, AuthProvider)
|
|
@@ -1,39 +0,0 @@
|
|
|
1
|
-
"""Auth module — shared contracts (UserContext, deps).
|
|
2
|
-
|
|
3
|
-
Intentionally minimal: this module owns the PUBLIC interface (UserContext,
|
|
4
|
-
PrincipalResolver, get_current_user, CurrentUser, require_permission) that
|
|
5
|
-
every other module imports. Keeping it stable prevents churn when auth
|
|
6
|
-
internals change.
|
|
7
|
-
|
|
8
|
-
All authentication logic (middleware, login, signup, OAuth) lives in the
|
|
9
|
-
users module. The ``principal_resolvers`` registry on ``app.state.auth`` is
|
|
10
|
-
the extension point downstream modules use to plug in additional credential
|
|
11
|
-
sources (PAT bearer tokens, API keys, etc.) — see
|
|
12
|
-
``docs/framework/principal-resolvers.md`` for the worked example.
|
|
13
|
-
"""
|
|
14
|
-
|
|
15
|
-
from __future__ import annotations
|
|
16
|
-
|
|
17
|
-
import importlib.resources
|
|
18
|
-
from pathlib import Path
|
|
19
|
-
from typing import TYPE_CHECKING
|
|
20
|
-
|
|
21
|
-
from simple_module_core.module import ModuleBase, ModuleMeta
|
|
22
|
-
|
|
23
|
-
if TYPE_CHECKING:
|
|
24
|
-
from fastapi import FastAPI
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
class AuthModule(ModuleBase):
|
|
28
|
-
meta = ModuleMeta(
|
|
29
|
-
name="Auth",
|
|
30
|
-
route_prefix="/auth",
|
|
31
|
-
)
|
|
32
|
-
|
|
33
|
-
def register_settings(self, app: FastAPI) -> None:
|
|
34
|
-
from auth.state import AuthState
|
|
35
|
-
|
|
36
|
-
app.state.auth = AuthState()
|
|
37
|
-
|
|
38
|
-
def locale_dirs(self) -> dict[str, Path]:
|
|
39
|
-
return {"auth": Path(str(importlib.resources.files(__package__) / "locales"))}
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|