conduit-auth 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- conduit_auth-0.1.0/.gitignore +14 -0
- conduit_auth-0.1.0/PKG-INFO +22 -0
- conduit_auth-0.1.0/pyproject.toml +36 -0
- conduit_auth-0.1.0/src/conduit_auth/__init__.py +27 -0
- conduit_auth-0.1.0/src/conduit_auth/config.py +32 -0
- conduit_auth-0.1.0/src/conduit_auth/exceptions.py +9 -0
- conduit_auth-0.1.0/src/conduit_auth/fastapi/__init__.py +6 -0
- conduit_auth-0.1.0/src/conduit_auth/fastapi/dependencies.py +23 -0
- conduit_auth-0.1.0/src/conduit_auth/fastapi/middleware.py +69 -0
- conduit_auth-0.1.0/src/conduit_auth/flask/__init__.py +6 -0
- conduit_auth-0.1.0/src/conduit_auth/flask/decorators.py +29 -0
- conduit_auth-0.1.0/src/conduit_auth/flask/middleware.py +86 -0
- conduit_auth-0.1.0/src/conduit_auth/models.py +24 -0
- conduit_auth-0.1.0/src/conduit_auth/token.py +155 -0
- conduit_auth-0.1.0/tests/__init__.py +0 -0
- conduit_auth-0.1.0/tests/conftest.py +139 -0
- conduit_auth-0.1.0/tests/test_fastapi.py +117 -0
- conduit_auth-0.1.0/tests/test_flask.py +123 -0
- conduit_auth-0.1.0/tests/test_token.py +92 -0
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: conduit-auth
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Shared Entra ID authentication for the Conduit suite
|
|
5
|
+
Requires-Python: >=3.11
|
|
6
|
+
Requires-Dist: cryptography>=41.0
|
|
7
|
+
Requires-Dist: httpx>=0.25
|
|
8
|
+
Requires-Dist: pydantic-settings>=2.0
|
|
9
|
+
Requires-Dist: pyjwt[crypto]>=2.8
|
|
10
|
+
Provides-Extra: dev
|
|
11
|
+
Requires-Dist: fastapi>=0.100; extra == 'dev'
|
|
12
|
+
Requires-Dist: flask>=3.0; extra == 'dev'
|
|
13
|
+
Requires-Dist: httpx>=0.25; extra == 'dev'
|
|
14
|
+
Requires-Dist: pytest-asyncio>=0.21; extra == 'dev'
|
|
15
|
+
Requires-Dist: pytest>=7.0; extra == 'dev'
|
|
16
|
+
Requires-Dist: respx>=0.20; extra == 'dev'
|
|
17
|
+
Requires-Dist: starlette[full]>=0.27; extra == 'dev'
|
|
18
|
+
Requires-Dist: uvicorn>=0.20; extra == 'dev'
|
|
19
|
+
Provides-Extra: fastapi
|
|
20
|
+
Requires-Dist: fastapi>=0.100; extra == 'fastapi'
|
|
21
|
+
Provides-Extra: flask
|
|
22
|
+
Requires-Dist: flask>=3.0; extra == 'flask'
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "conduit-auth"
|
|
3
|
+
version = "0.1.0"
|
|
4
|
+
description = "Shared Entra ID authentication for the Conduit suite"
|
|
5
|
+
requires-python = ">=3.11"
|
|
6
|
+
dependencies = [
|
|
7
|
+
"PyJWT[crypto]>=2.8",
|
|
8
|
+
"cryptography>=41.0",
|
|
9
|
+
"httpx>=0.25",
|
|
10
|
+
"pydantic-settings>=2.0",
|
|
11
|
+
]
|
|
12
|
+
|
|
13
|
+
[project.optional-dependencies]
|
|
14
|
+
fastapi = ["fastapi>=0.100"]
|
|
15
|
+
flask = ["flask>=3.0"]
|
|
16
|
+
dev = [
|
|
17
|
+
"pytest>=7.0",
|
|
18
|
+
"pytest-asyncio>=0.21",
|
|
19
|
+
"respx>=0.20",
|
|
20
|
+
"httpx>=0.25",
|
|
21
|
+
"fastapi>=0.100",
|
|
22
|
+
"flask>=3.0",
|
|
23
|
+
"uvicorn>=0.20",
|
|
24
|
+
"starlette[full]>=0.27",
|
|
25
|
+
]
|
|
26
|
+
|
|
27
|
+
[build-system]
|
|
28
|
+
requires = ["hatchling"]
|
|
29
|
+
build-backend = "hatchling.build"
|
|
30
|
+
|
|
31
|
+
[tool.hatch.build.targets.wheel]
|
|
32
|
+
packages = ["src/conduit_auth"]
|
|
33
|
+
|
|
34
|
+
[tool.pytest.ini_options]
|
|
35
|
+
testpaths = ["tests"]
|
|
36
|
+
asyncio_mode = "auto"
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
"""Conduit Auth — Shared Entra ID authentication for the Conduit suite.
|
|
2
|
+
|
|
3
|
+
Usage (FastAPI):
|
|
4
|
+
from conduit_auth import AuthSettings
|
|
5
|
+
from conduit_auth.fastapi import AuthMiddleware, get_current_user
|
|
6
|
+
|
|
7
|
+
settings = AuthSettings()
|
|
8
|
+
app.add_middleware(AuthMiddleware, settings=settings)
|
|
9
|
+
|
|
10
|
+
Usage (Flask):
|
|
11
|
+
from conduit_auth import AuthSettings
|
|
12
|
+
from conduit_auth.flask import init_auth, get_current_user
|
|
13
|
+
|
|
14
|
+
settings = AuthSettings()
|
|
15
|
+
init_auth(app, settings)
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
from .config import AuthSettings
|
|
19
|
+
from .exceptions import AuthenticationError, TokenExpiredError
|
|
20
|
+
from .models import AuthenticatedUser
|
|
21
|
+
|
|
22
|
+
__all__ = [
|
|
23
|
+
"AuthSettings",
|
|
24
|
+
"AuthenticatedUser",
|
|
25
|
+
"AuthenticationError",
|
|
26
|
+
"TokenExpiredError",
|
|
27
|
+
]
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
"""Authentication configuration via Pydantic settings."""
|
|
2
|
+
|
|
3
|
+
from pydantic_settings import BaseSettings
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class AuthSettings(BaseSettings):
|
|
7
|
+
"""Conduit auth configuration, loaded from environment variables.
|
|
8
|
+
|
|
9
|
+
Required:
|
|
10
|
+
ENTRA_CLIENT_ID: App registration client ID from Azure portal.
|
|
11
|
+
|
|
12
|
+
Optional:
|
|
13
|
+
ENTRA_TENANT_ID: Set to "common" for multi-tenant (default).
|
|
14
|
+
Use a specific tenant GUID to restrict to a single tenant.
|
|
15
|
+
ENTRA_AUDIENCE: Token audience to validate. Defaults to ENTRA_CLIENT_ID.
|
|
16
|
+
JWKS_URI: Microsoft's JWKS key endpoint. Override only for testing.
|
|
17
|
+
JWKS_CACHE_TTL: How long to cache JWKS keys (seconds). Default 1 hour.
|
|
18
|
+
AUTH_DISABLED: Set True for local dev to bypass validation entirely.
|
|
19
|
+
AUTH_EXCLUDE_PATHS: Paths that skip auth (health checks, etc).
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
ENTRA_CLIENT_ID: str
|
|
23
|
+
ENTRA_TENANT_ID: str = "common"
|
|
24
|
+
ENTRA_AUDIENCE: str | None = None
|
|
25
|
+
JWKS_URI: str = (
|
|
26
|
+
"https://login.microsoftonline.com/common/discovery/v2.0/keys"
|
|
27
|
+
)
|
|
28
|
+
JWKS_CACHE_TTL: int = 3600
|
|
29
|
+
AUTH_DISABLED: bool = False
|
|
30
|
+
AUTH_EXCLUDE_PATHS: list[str] = ["/health", "/api/health"]
|
|
31
|
+
|
|
32
|
+
model_config = {"env_prefix": "", "env_file": ".env"}
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
"""FastAPI dependency injection for authenticated user."""
|
|
2
|
+
|
|
3
|
+
from fastapi import HTTPException, Request
|
|
4
|
+
|
|
5
|
+
from ..models import AuthenticatedUser
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
async def get_current_user(request: Request) -> AuthenticatedUser:
|
|
9
|
+
"""FastAPI dependency that extracts the authenticated user.
|
|
10
|
+
|
|
11
|
+
Usage:
|
|
12
|
+
from conduit_auth.fastapi import get_current_user
|
|
13
|
+
from conduit_auth import AuthenticatedUser
|
|
14
|
+
|
|
15
|
+
@app.get("/api/projects")
|
|
16
|
+
async def list_projects(user: AuthenticatedUser = Depends(get_current_user)):
|
|
17
|
+
# user.tenant_id, user.email, etc.
|
|
18
|
+
...
|
|
19
|
+
"""
|
|
20
|
+
user = getattr(request.state, "user", None)
|
|
21
|
+
if user is None:
|
|
22
|
+
raise HTTPException(status_code=401, detail="Not authenticated")
|
|
23
|
+
return user
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
"""FastAPI/Starlette middleware for JWT authentication."""
|
|
2
|
+
|
|
3
|
+
from starlette.middleware.base import BaseHTTPMiddleware
|
|
4
|
+
from starlette.requests import Request
|
|
5
|
+
from starlette.responses import JSONResponse
|
|
6
|
+
|
|
7
|
+
from ..config import AuthSettings
|
|
8
|
+
from ..exceptions import AuthenticationError
|
|
9
|
+
from ..models import AuthenticatedUser
|
|
10
|
+
from ..token import get_key_cache, validate_token
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class AuthMiddleware(BaseHTTPMiddleware):
|
|
14
|
+
"""Validates Bearer JWT tokens on /api/* routes.
|
|
15
|
+
|
|
16
|
+
Usage:
|
|
17
|
+
from conduit_auth.fastapi import AuthMiddleware
|
|
18
|
+
from conduit_auth import AuthSettings
|
|
19
|
+
|
|
20
|
+
settings = AuthSettings()
|
|
21
|
+
app.add_middleware(AuthMiddleware, settings=settings)
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(self, app, settings: AuthSettings) -> None:
|
|
25
|
+
super().__init__(app)
|
|
26
|
+
self.settings = settings
|
|
27
|
+
self._key_cache = get_key_cache()
|
|
28
|
+
|
|
29
|
+
async def dispatch(self, request: Request, call_next):
|
|
30
|
+
path = request.url.path
|
|
31
|
+
|
|
32
|
+
# Skip excluded paths (health checks, etc.)
|
|
33
|
+
if path in self.settings.AUTH_EXCLUDE_PATHS:
|
|
34
|
+
return await call_next(request)
|
|
35
|
+
|
|
36
|
+
# Only protect /api/* routes — static assets are served by nginx
|
|
37
|
+
if not path.startswith("/api/"):
|
|
38
|
+
return await call_next(request)
|
|
39
|
+
|
|
40
|
+
# Dev bypass — injects a mock user
|
|
41
|
+
if self.settings.AUTH_DISABLED:
|
|
42
|
+
request.state.user = AuthenticatedUser(
|
|
43
|
+
tenant_id="dev-tenant-00000000",
|
|
44
|
+
user_id="dev-user-00000000",
|
|
45
|
+
email="dev@localhost",
|
|
46
|
+
display_name="Dev User",
|
|
47
|
+
roles=[],
|
|
48
|
+
raw_claims={},
|
|
49
|
+
)
|
|
50
|
+
return await call_next(request)
|
|
51
|
+
|
|
52
|
+
# Extract Bearer token
|
|
53
|
+
auth_header = request.headers.get("Authorization", "")
|
|
54
|
+
if not auth_header.startswith("Bearer "):
|
|
55
|
+
return JSONResponse(
|
|
56
|
+
{"error": "Missing or invalid Authorization header"},
|
|
57
|
+
status_code=401,
|
|
58
|
+
)
|
|
59
|
+
token = auth_header[7:]
|
|
60
|
+
|
|
61
|
+
# Validate
|
|
62
|
+
try:
|
|
63
|
+
jwks_keys = await self._key_cache.get_keys_async(self.settings)
|
|
64
|
+
user = validate_token(token, jwks_keys, self.settings)
|
|
65
|
+
except AuthenticationError as e:
|
|
66
|
+
return JSONResponse({"error": str(e)}, status_code=401)
|
|
67
|
+
|
|
68
|
+
request.state.user = user
|
|
69
|
+
return await call_next(request)
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
"""Flask route decorators for authentication."""
|
|
2
|
+
|
|
3
|
+
from functools import wraps
|
|
4
|
+
|
|
5
|
+
from flask import g, jsonify
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def require_auth(f):
|
|
9
|
+
"""Decorator for routes that require an authenticated user.
|
|
10
|
+
|
|
11
|
+
Use this when init_auth() is not applied globally, or as an
|
|
12
|
+
explicit marker on routes that need auth.
|
|
13
|
+
|
|
14
|
+
Usage:
|
|
15
|
+
@app.route("/api/projects")
|
|
16
|
+
@require_auth
|
|
17
|
+
def list_projects():
|
|
18
|
+
user = get_current_user()
|
|
19
|
+
...
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
@wraps(f)
|
|
23
|
+
def decorated(*args, **kwargs):
|
|
24
|
+
user = getattr(g, "user", None)
|
|
25
|
+
if user is None:
|
|
26
|
+
return jsonify({"error": "Not authenticated"}), 401
|
|
27
|
+
return f(*args, **kwargs)
|
|
28
|
+
|
|
29
|
+
return decorated
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
"""Flask middleware for JWT authentication."""
|
|
2
|
+
|
|
3
|
+
from flask import Flask, g, jsonify, request
|
|
4
|
+
|
|
5
|
+
from ..config import AuthSettings
|
|
6
|
+
from ..exceptions import AuthenticationError
|
|
7
|
+
from ..models import AuthenticatedUser
|
|
8
|
+
from ..token import get_key_cache, validate_token
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def init_auth(app: Flask, settings: AuthSettings) -> None:
|
|
12
|
+
"""Register JWT auth middleware on a Flask app.
|
|
13
|
+
|
|
14
|
+
Validates Bearer tokens on all /api/* routes. Stores the authenticated
|
|
15
|
+
user on Flask's g.user for access in route handlers.
|
|
16
|
+
|
|
17
|
+
Usage:
|
|
18
|
+
from conduit_auth.flask import init_auth
|
|
19
|
+
from conduit_auth import AuthSettings
|
|
20
|
+
|
|
21
|
+
settings = AuthSettings()
|
|
22
|
+
init_auth(app, settings)
|
|
23
|
+
|
|
24
|
+
# In routes:
|
|
25
|
+
from conduit_auth.flask import get_current_user
|
|
26
|
+
user = get_current_user()
|
|
27
|
+
"""
|
|
28
|
+
@app.before_request
|
|
29
|
+
def _validate_auth():
|
|
30
|
+
path = request.path
|
|
31
|
+
|
|
32
|
+
# Skip excluded paths
|
|
33
|
+
if path in settings.AUTH_EXCLUDE_PATHS:
|
|
34
|
+
return None
|
|
35
|
+
|
|
36
|
+
# Only protect /api/* routes
|
|
37
|
+
if not path.startswith("/api/"):
|
|
38
|
+
return None
|
|
39
|
+
|
|
40
|
+
# Dev bypass
|
|
41
|
+
if settings.AUTH_DISABLED:
|
|
42
|
+
g.user = AuthenticatedUser(
|
|
43
|
+
tenant_id="dev-tenant-00000000",
|
|
44
|
+
user_id="dev-user-00000000",
|
|
45
|
+
email="dev@localhost",
|
|
46
|
+
display_name="Dev User",
|
|
47
|
+
roles=[],
|
|
48
|
+
raw_claims={},
|
|
49
|
+
)
|
|
50
|
+
return None
|
|
51
|
+
|
|
52
|
+
# Extract Bearer token
|
|
53
|
+
auth_header = request.headers.get("Authorization", "")
|
|
54
|
+
if not auth_header.startswith("Bearer "):
|
|
55
|
+
return (
|
|
56
|
+
jsonify({"error": "Missing or invalid Authorization header"}),
|
|
57
|
+
401,
|
|
58
|
+
)
|
|
59
|
+
token = auth_header[7:]
|
|
60
|
+
|
|
61
|
+
# Validate
|
|
62
|
+
try:
|
|
63
|
+
jwks_keys = get_key_cache().get_keys(settings)
|
|
64
|
+
user = validate_token(token, jwks_keys, settings)
|
|
65
|
+
except AuthenticationError as e:
|
|
66
|
+
return jsonify({"error": str(e)}), 401
|
|
67
|
+
|
|
68
|
+
g.user = user
|
|
69
|
+
return None
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def get_current_user() -> AuthenticatedUser:
|
|
73
|
+
"""Get the authenticated user from Flask's request context.
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
The AuthenticatedUser set by init_auth middleware.
|
|
77
|
+
|
|
78
|
+
Raises:
|
|
79
|
+
401 abort if no user is set (request wasn't authenticated).
|
|
80
|
+
"""
|
|
81
|
+
user = getattr(g, "user", None)
|
|
82
|
+
if user is None:
|
|
83
|
+
from flask import abort
|
|
84
|
+
|
|
85
|
+
abort(401)
|
|
86
|
+
return user
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
"""Core data models for the Conduit auth module."""
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@dataclass
|
|
7
|
+
class AuthenticatedUser:
|
|
8
|
+
"""Represents a user authenticated via Entra ID.
|
|
9
|
+
|
|
10
|
+
Fields map to Entra ID token claims:
|
|
11
|
+
- tenant_id: tid claim (which Azure AD tenant the user belongs to)
|
|
12
|
+
- user_id: oid claim (user's object ID within their tenant)
|
|
13
|
+
- email: preferred_username claim
|
|
14
|
+
- display_name: name claim
|
|
15
|
+
- roles: app roles assigned in Entra ID (if configured)
|
|
16
|
+
- raw_claims: full decoded JWT for app-specific use
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
tenant_id: str
|
|
20
|
+
user_id: str
|
|
21
|
+
email: str
|
|
22
|
+
display_name: str
|
|
23
|
+
roles: list[str] = field(default_factory=list)
|
|
24
|
+
raw_claims: dict = field(default_factory=dict)
|
|
@@ -0,0 +1,155 @@
|
|
|
1
|
+
"""JWT token validation against Microsoft Entra ID JWKS keys."""
|
|
2
|
+
|
|
3
|
+
import re
|
|
4
|
+
import time
|
|
5
|
+
import threading
|
|
6
|
+
|
|
7
|
+
import httpx
|
|
8
|
+
import jwt
|
|
9
|
+
from jwt.algorithms import RSAAlgorithm
|
|
10
|
+
|
|
11
|
+
from .config import AuthSettings
|
|
12
|
+
from .exceptions import AuthenticationError, TokenExpiredError
|
|
13
|
+
from .models import AuthenticatedUser
|
|
14
|
+
|
|
15
|
+
# Multi-tenant issuer patterns (v2.0 and v1.0 endpoints)
|
|
16
|
+
_ISSUER_V2 = re.compile(
|
|
17
|
+
r"https://login\.microsoftonline\.com/"
|
|
18
|
+
r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}/v2\.0"
|
|
19
|
+
)
|
|
20
|
+
_ISSUER_V1 = re.compile(
|
|
21
|
+
r"https://sts\.windows\.net/"
|
|
22
|
+
r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}/"
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class JWKSKeyCache:
|
|
27
|
+
"""Thread-safe cache for Microsoft's JWKS signing keys."""
|
|
28
|
+
|
|
29
|
+
def __init__(self) -> None:
|
|
30
|
+
self._keys: list[dict] = []
|
|
31
|
+
self._fetched_at: float = 0
|
|
32
|
+
self._lock = threading.Lock()
|
|
33
|
+
|
|
34
|
+
def get_keys(self, settings: AuthSettings) -> list[dict]:
|
|
35
|
+
"""Get JWKS keys, fetching from Microsoft if cache has expired."""
|
|
36
|
+
now = time.time()
|
|
37
|
+
if now - self._fetched_at < settings.JWKS_CACHE_TTL and self._keys:
|
|
38
|
+
return self._keys
|
|
39
|
+
|
|
40
|
+
with self._lock:
|
|
41
|
+
# Double-check after acquiring lock
|
|
42
|
+
if now - self._fetched_at < settings.JWKS_CACHE_TTL and self._keys:
|
|
43
|
+
return self._keys
|
|
44
|
+
return self._fetch_sync(settings)
|
|
45
|
+
|
|
46
|
+
async def get_keys_async(self, settings: AuthSettings) -> list[dict]:
|
|
47
|
+
"""Get JWKS keys asynchronously."""
|
|
48
|
+
now = time.time()
|
|
49
|
+
if now - self._fetched_at < settings.JWKS_CACHE_TTL and self._keys:
|
|
50
|
+
return self._keys
|
|
51
|
+
|
|
52
|
+
return await self._fetch_async(settings)
|
|
53
|
+
|
|
54
|
+
def _fetch_sync(self, settings: AuthSettings) -> list[dict]:
|
|
55
|
+
resp = httpx.get(settings.JWKS_URI, timeout=10)
|
|
56
|
+
resp.raise_for_status()
|
|
57
|
+
self._keys = resp.json()["keys"]
|
|
58
|
+
self._fetched_at = time.time()
|
|
59
|
+
return self._keys
|
|
60
|
+
|
|
61
|
+
async def _fetch_async(self, settings: AuthSettings) -> list[dict]:
|
|
62
|
+
async with httpx.AsyncClient() as client:
|
|
63
|
+
resp = await client.get(settings.JWKS_URI, timeout=10)
|
|
64
|
+
resp.raise_for_status()
|
|
65
|
+
self._keys = resp.json()["keys"]
|
|
66
|
+
self._fetched_at = time.time()
|
|
67
|
+
return self._keys
|
|
68
|
+
|
|
69
|
+
def clear(self) -> None:
|
|
70
|
+
"""Clear the cache (useful for testing)."""
|
|
71
|
+
self._keys = []
|
|
72
|
+
self._fetched_at = 0
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
# Module-level cache instance
|
|
76
|
+
_key_cache = JWKSKeyCache()
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def get_key_cache() -> JWKSKeyCache:
|
|
80
|
+
"""Get the module-level JWKS key cache."""
|
|
81
|
+
return _key_cache
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def validate_token(
|
|
85
|
+
token: str,
|
|
86
|
+
jwks_keys: list[dict],
|
|
87
|
+
settings: AuthSettings,
|
|
88
|
+
) -> AuthenticatedUser:
|
|
89
|
+
"""Validate a JWT and return an AuthenticatedUser.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
token: The raw JWT string (without "Bearer " prefix).
|
|
93
|
+
jwks_keys: List of JWKS key dicts from Microsoft's endpoint.
|
|
94
|
+
settings: Auth configuration.
|
|
95
|
+
|
|
96
|
+
Returns:
|
|
97
|
+
AuthenticatedUser with claims extracted from the token.
|
|
98
|
+
|
|
99
|
+
Raises:
|
|
100
|
+
AuthenticationError: If the token is invalid.
|
|
101
|
+
TokenExpiredError: If the token has expired.
|
|
102
|
+
"""
|
|
103
|
+
audience = settings.ENTRA_AUDIENCE or settings.ENTRA_CLIENT_ID
|
|
104
|
+
|
|
105
|
+
# Decode header to find the signing key
|
|
106
|
+
try:
|
|
107
|
+
unverified_header = jwt.get_unverified_header(token)
|
|
108
|
+
except jwt.exceptions.DecodeError as e:
|
|
109
|
+
raise AuthenticationError(f"Malformed token header: {e}")
|
|
110
|
+
|
|
111
|
+
kid = unverified_header.get("kid")
|
|
112
|
+
if not kid:
|
|
113
|
+
raise AuthenticationError("Token header missing 'kid' field")
|
|
114
|
+
|
|
115
|
+
matching_key = next((k for k in jwks_keys if k["kid"] == kid), None)
|
|
116
|
+
if not matching_key:
|
|
117
|
+
raise AuthenticationError("Token signing key not found in JWKS")
|
|
118
|
+
|
|
119
|
+
# Build the public key from the JWK
|
|
120
|
+
try:
|
|
121
|
+
public_key = RSAAlgorithm.from_jwk(matching_key)
|
|
122
|
+
except Exception as e:
|
|
123
|
+
raise AuthenticationError(f"Failed to construct public key: {e}")
|
|
124
|
+
|
|
125
|
+
# Decode and validate the token
|
|
126
|
+
try:
|
|
127
|
+
claims = jwt.decode(
|
|
128
|
+
token,
|
|
129
|
+
public_key,
|
|
130
|
+
algorithms=["RS256"],
|
|
131
|
+
audience=audience,
|
|
132
|
+
options={"verify_iss": False}, # We validate issuer manually below
|
|
133
|
+
)
|
|
134
|
+
except jwt.ExpiredSignatureError:
|
|
135
|
+
raise TokenExpiredError("Token has expired")
|
|
136
|
+
except jwt.InvalidAudienceError:
|
|
137
|
+
raise AuthenticationError(
|
|
138
|
+
f"Token audience mismatch: expected '{audience}'"
|
|
139
|
+
)
|
|
140
|
+
except jwt.InvalidTokenError as e:
|
|
141
|
+
raise AuthenticationError(f"Invalid token: {e}")
|
|
142
|
+
|
|
143
|
+
# Multi-tenant issuer validation: accept any valid Azure AD tenant
|
|
144
|
+
issuer = claims.get("iss", "")
|
|
145
|
+
if not (_ISSUER_V2.match(issuer) or _ISSUER_V1.match(issuer)):
|
|
146
|
+
raise AuthenticationError(f"Unexpected token issuer: {issuer}")
|
|
147
|
+
|
|
148
|
+
return AuthenticatedUser(
|
|
149
|
+
tenant_id=claims.get("tid", ""),
|
|
150
|
+
user_id=claims.get("oid", ""),
|
|
151
|
+
email=claims.get("preferred_username", ""),
|
|
152
|
+
display_name=claims.get("name", ""),
|
|
153
|
+
roles=claims.get("roles", []),
|
|
154
|
+
raw_claims=claims,
|
|
155
|
+
)
|
|
File without changes
|
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
"""Shared test fixtures: RSA keys, test JWTs, mock JWKS endpoint."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import time
|
|
5
|
+
import uuid
|
|
6
|
+
|
|
7
|
+
import jwt as pyjwt
|
|
8
|
+
import pytest
|
|
9
|
+
from cryptography.hazmat.primitives import serialization
|
|
10
|
+
from cryptography.hazmat.primitives.asymmetric import rsa
|
|
11
|
+
|
|
12
|
+
from conduit_auth import AuthSettings
|
|
13
|
+
from conduit_auth.token import get_key_cache
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@pytest.fixture(autouse=True)
|
|
17
|
+
def _clear_key_cache():
|
|
18
|
+
"""Clear the JWKS key cache between tests."""
|
|
19
|
+
get_key_cache().clear()
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@pytest.fixture
|
|
23
|
+
def rsa_keypair():
|
|
24
|
+
"""Generate a test RSA key pair."""
|
|
25
|
+
private_key = rsa.generate_private_key(
|
|
26
|
+
public_exponent=65537,
|
|
27
|
+
key_size=2048,
|
|
28
|
+
)
|
|
29
|
+
return private_key
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@pytest.fixture
|
|
33
|
+
def kid():
|
|
34
|
+
"""A fixed key ID for test tokens."""
|
|
35
|
+
return "test-kid-12345"
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@pytest.fixture
|
|
39
|
+
def jwks_keys(rsa_keypair, kid):
|
|
40
|
+
"""JWKS key list matching the test RSA key."""
|
|
41
|
+
public_key = rsa_keypair.public_key()
|
|
42
|
+
public_numbers = public_key.public_numbers()
|
|
43
|
+
|
|
44
|
+
# Encode as base64url (matching JWKS format)
|
|
45
|
+
import base64
|
|
46
|
+
|
|
47
|
+
def _b64url(num: int, length: int) -> str:
|
|
48
|
+
return (
|
|
49
|
+
base64.urlsafe_b64encode(num.to_bytes(length, "big"))
|
|
50
|
+
.rstrip(b"=")
|
|
51
|
+
.decode()
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
n_bytes = (public_numbers.n.bit_length() + 7) // 8
|
|
55
|
+
|
|
56
|
+
return [
|
|
57
|
+
{
|
|
58
|
+
"kty": "RSA",
|
|
59
|
+
"use": "sig",
|
|
60
|
+
"kid": kid,
|
|
61
|
+
"n": _b64url(public_numbers.n, n_bytes),
|
|
62
|
+
"e": _b64url(public_numbers.e, 3),
|
|
63
|
+
"alg": "RS256",
|
|
64
|
+
}
|
|
65
|
+
]
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@pytest.fixture
|
|
69
|
+
def test_tenant_id():
|
|
70
|
+
return "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
@pytest.fixture
|
|
74
|
+
def test_user_id():
|
|
75
|
+
return "11111111-2222-3333-4444-555555555555"
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
@pytest.fixture
|
|
79
|
+
def test_client_id():
|
|
80
|
+
return "cccccccc-dddd-eeee-ffff-000000000000"
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
@pytest.fixture
|
|
84
|
+
def auth_settings(test_client_id):
|
|
85
|
+
"""AuthSettings for testing."""
|
|
86
|
+
return AuthSettings(
|
|
87
|
+
ENTRA_CLIENT_ID=test_client_id,
|
|
88
|
+
AUTH_DISABLED=False,
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
@pytest.fixture
|
|
93
|
+
def make_token(rsa_keypair, kid, test_tenant_id, test_user_id, test_client_id):
|
|
94
|
+
"""Factory for creating signed test JWTs."""
|
|
95
|
+
|
|
96
|
+
def _make(
|
|
97
|
+
tenant_id: str | None = None,
|
|
98
|
+
user_id: str | None = None,
|
|
99
|
+
email: str = "user@example.com",
|
|
100
|
+
name: str = "Test User",
|
|
101
|
+
audience: str | None = None,
|
|
102
|
+
issuer: str | None = None,
|
|
103
|
+
expired: bool = False,
|
|
104
|
+
extra_claims: dict | None = None,
|
|
105
|
+
use_kid: str | None = None,
|
|
106
|
+
) -> str:
|
|
107
|
+
tid = tenant_id or test_tenant_id
|
|
108
|
+
now = time.time()
|
|
109
|
+
|
|
110
|
+
claims = {
|
|
111
|
+
"aud": audience or test_client_id,
|
|
112
|
+
"iss": issuer
|
|
113
|
+
or f"https://login.microsoftonline.com/{tid}/v2.0",
|
|
114
|
+
"iat": int(now) - 60,
|
|
115
|
+
"nbf": int(now) - 60,
|
|
116
|
+
"exp": int(now) - 300 if expired else int(now) + 3600,
|
|
117
|
+
"tid": tid,
|
|
118
|
+
"oid": user_id or test_user_id,
|
|
119
|
+
"preferred_username": email,
|
|
120
|
+
"name": name,
|
|
121
|
+
"sub": str(uuid.uuid4()),
|
|
122
|
+
}
|
|
123
|
+
if extra_claims:
|
|
124
|
+
claims.update(extra_claims)
|
|
125
|
+
|
|
126
|
+
private_pem = rsa_keypair.private_bytes(
|
|
127
|
+
encoding=serialization.Encoding.PEM,
|
|
128
|
+
format=serialization.PrivateFormat.PKCS8,
|
|
129
|
+
encryption_algorithm=serialization.NoEncryption(),
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
return pyjwt.encode(
|
|
133
|
+
claims,
|
|
134
|
+
private_pem,
|
|
135
|
+
algorithm="RS256",
|
|
136
|
+
headers={"kid": use_kid or kid},
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
return _make
|
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
"""Tests for FastAPI middleware and dependencies."""
|
|
2
|
+
|
|
3
|
+
from unittest.mock import patch
|
|
4
|
+
|
|
5
|
+
import pytest
|
|
6
|
+
from fastapi import Depends, FastAPI
|
|
7
|
+
from fastapi.testclient import TestClient
|
|
8
|
+
|
|
9
|
+
from conduit_auth import AuthSettings, AuthenticatedUser
|
|
10
|
+
from conduit_auth.fastapi import AuthMiddleware, get_current_user
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _create_app(settings: AuthSettings) -> FastAPI:
|
|
14
|
+
"""Create a test FastAPI app with auth middleware."""
|
|
15
|
+
app = FastAPI()
|
|
16
|
+
app.add_middleware(AuthMiddleware, settings=settings)
|
|
17
|
+
|
|
18
|
+
@app.get("/api/protected")
|
|
19
|
+
async def protected(user: AuthenticatedUser = Depends(get_current_user)):
|
|
20
|
+
return {
|
|
21
|
+
"tenant_id": user.tenant_id,
|
|
22
|
+
"user_id": user.user_id,
|
|
23
|
+
"email": user.email,
|
|
24
|
+
"display_name": user.display_name,
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
@app.get("/health")
|
|
28
|
+
async def health():
|
|
29
|
+
return {"status": "ok"}
|
|
30
|
+
|
|
31
|
+
@app.get("/not-api")
|
|
32
|
+
async def not_api():
|
|
33
|
+
return {"message": "public"}
|
|
34
|
+
|
|
35
|
+
return app
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class TestAuthMiddleware:
|
|
39
|
+
def test_missing_auth_header(self, auth_settings):
|
|
40
|
+
client = TestClient(_create_app(auth_settings))
|
|
41
|
+
resp = client.get("/api/protected")
|
|
42
|
+
|
|
43
|
+
assert resp.status_code == 401
|
|
44
|
+
assert "Authorization" in resp.json()["error"]
|
|
45
|
+
|
|
46
|
+
def test_invalid_auth_header(self, auth_settings):
|
|
47
|
+
client = TestClient(_create_app(auth_settings))
|
|
48
|
+
resp = client.get(
|
|
49
|
+
"/api/protected", headers={"Authorization": "Basic abc"}
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
assert resp.status_code == 401
|
|
53
|
+
|
|
54
|
+
def test_valid_token(self, auth_settings, make_token, jwks_keys):
|
|
55
|
+
client = TestClient(_create_app(auth_settings))
|
|
56
|
+
token = make_token()
|
|
57
|
+
|
|
58
|
+
with patch(
|
|
59
|
+
"conduit_auth.fastapi.middleware.get_key_cache"
|
|
60
|
+
) as mock_cache:
|
|
61
|
+
mock_cache.return_value.get_keys_async = (
|
|
62
|
+
lambda *a, **kw: _async_return(jwks_keys)
|
|
63
|
+
)
|
|
64
|
+
resp = client.get(
|
|
65
|
+
"/api/protected",
|
|
66
|
+
headers={"Authorization": f"Bearer {token}"},
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
assert resp.status_code == 200
|
|
70
|
+
data = resp.json()
|
|
71
|
+
assert data["email"] == "user@example.com"
|
|
72
|
+
assert data["tenant_id"] == "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
|
|
73
|
+
|
|
74
|
+
def test_expired_token(self, auth_settings, make_token, jwks_keys):
|
|
75
|
+
client = TestClient(_create_app(auth_settings))
|
|
76
|
+
token = make_token(expired=True)
|
|
77
|
+
|
|
78
|
+
with patch(
|
|
79
|
+
"conduit_auth.fastapi.middleware.get_key_cache"
|
|
80
|
+
) as mock_cache:
|
|
81
|
+
mock_cache.return_value.get_keys_async = (
|
|
82
|
+
lambda *a, **kw: _async_return(jwks_keys)
|
|
83
|
+
)
|
|
84
|
+
resp = client.get(
|
|
85
|
+
"/api/protected",
|
|
86
|
+
headers={"Authorization": f"Bearer {token}"},
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
assert resp.status_code == 401
|
|
90
|
+
assert "expired" in resp.json()["error"]
|
|
91
|
+
|
|
92
|
+
def test_health_excluded(self, auth_settings):
|
|
93
|
+
client = TestClient(_create_app(auth_settings))
|
|
94
|
+
resp = client.get("/health")
|
|
95
|
+
|
|
96
|
+
assert resp.status_code == 200
|
|
97
|
+
|
|
98
|
+
def test_non_api_path_excluded(self, auth_settings):
|
|
99
|
+
client = TestClient(_create_app(auth_settings))
|
|
100
|
+
resp = client.get("/not-api")
|
|
101
|
+
|
|
102
|
+
assert resp.status_code == 200
|
|
103
|
+
|
|
104
|
+
def test_auth_disabled(self, test_client_id):
|
|
105
|
+
settings = AuthSettings(
|
|
106
|
+
ENTRA_CLIENT_ID=test_client_id, AUTH_DISABLED=True
|
|
107
|
+
)
|
|
108
|
+
client = TestClient(_create_app(settings))
|
|
109
|
+
resp = client.get("/api/protected")
|
|
110
|
+
|
|
111
|
+
assert resp.status_code == 200
|
|
112
|
+
assert resp.json()["email"] == "dev@localhost"
|
|
113
|
+
assert resp.json()["tenant_id"] == "dev-tenant-00000000"
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
async def _async_return(value):
|
|
117
|
+
return value
|
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
"""Tests for Flask middleware and decorators."""
|
|
2
|
+
|
|
3
|
+
from unittest.mock import patch
|
|
4
|
+
|
|
5
|
+
import pytest
|
|
6
|
+
from flask import Flask
|
|
7
|
+
|
|
8
|
+
from conduit_auth import AuthSettings
|
|
9
|
+
from conduit_auth.flask import get_current_user, init_auth, require_auth
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def _create_app(settings: AuthSettings) -> Flask:
|
|
13
|
+
"""Create a test Flask app with auth middleware."""
|
|
14
|
+
app = Flask(__name__)
|
|
15
|
+
app.config["TESTING"] = True
|
|
16
|
+
init_auth(app, settings)
|
|
17
|
+
|
|
18
|
+
@app.route("/api/protected")
|
|
19
|
+
def protected():
|
|
20
|
+
user = get_current_user()
|
|
21
|
+
return {
|
|
22
|
+
"tenant_id": user.tenant_id,
|
|
23
|
+
"user_id": user.user_id,
|
|
24
|
+
"email": user.email,
|
|
25
|
+
"display_name": user.display_name,
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
@app.route("/api/decorated")
|
|
29
|
+
@require_auth
|
|
30
|
+
def decorated():
|
|
31
|
+
user = get_current_user()
|
|
32
|
+
return {"email": user.email}
|
|
33
|
+
|
|
34
|
+
@app.route("/health")
|
|
35
|
+
def health():
|
|
36
|
+
return {"status": "ok"}
|
|
37
|
+
|
|
38
|
+
@app.route("/not-api")
|
|
39
|
+
def not_api():
|
|
40
|
+
return {"message": "public"}
|
|
41
|
+
|
|
42
|
+
return app
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class TestFlaskMiddleware:
|
|
46
|
+
def test_missing_auth_header(self, auth_settings):
|
|
47
|
+
app = _create_app(auth_settings)
|
|
48
|
+
with app.test_client() as client:
|
|
49
|
+
resp = client.get("/api/protected")
|
|
50
|
+
|
|
51
|
+
assert resp.status_code == 401
|
|
52
|
+
assert "Authorization" in resp.json["error"]
|
|
53
|
+
|
|
54
|
+
def test_valid_token(self, auth_settings, make_token, jwks_keys):
|
|
55
|
+
app = _create_app(auth_settings)
|
|
56
|
+
token = make_token()
|
|
57
|
+
|
|
58
|
+
with patch("conduit_auth.flask.middleware.get_key_cache") as mock_cache:
|
|
59
|
+
mock_cache.return_value.get_keys.return_value = jwks_keys
|
|
60
|
+
with app.test_client() as client:
|
|
61
|
+
resp = client.get(
|
|
62
|
+
"/api/protected",
|
|
63
|
+
headers={"Authorization": f"Bearer {token}"},
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
assert resp.status_code == 200
|
|
67
|
+
assert resp.json["email"] == "user@example.com"
|
|
68
|
+
assert resp.json["tenant_id"] == "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
|
|
69
|
+
|
|
70
|
+
def test_expired_token(self, auth_settings, make_token, jwks_keys):
|
|
71
|
+
app = _create_app(auth_settings)
|
|
72
|
+
token = make_token(expired=True)
|
|
73
|
+
|
|
74
|
+
with patch("conduit_auth.flask.middleware.get_key_cache") as mock_cache:
|
|
75
|
+
mock_cache.return_value.get_keys.return_value = jwks_keys
|
|
76
|
+
with app.test_client() as client:
|
|
77
|
+
resp = client.get(
|
|
78
|
+
"/api/protected",
|
|
79
|
+
headers={"Authorization": f"Bearer {token}"},
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
assert resp.status_code == 401
|
|
83
|
+
assert "expired" in resp.json["error"]
|
|
84
|
+
|
|
85
|
+
def test_health_excluded(self, auth_settings):
|
|
86
|
+
app = _create_app(auth_settings)
|
|
87
|
+
with app.test_client() as client:
|
|
88
|
+
resp = client.get("/health")
|
|
89
|
+
|
|
90
|
+
assert resp.status_code == 200
|
|
91
|
+
|
|
92
|
+
def test_non_api_path_excluded(self, auth_settings):
|
|
93
|
+
app = _create_app(auth_settings)
|
|
94
|
+
with app.test_client() as client:
|
|
95
|
+
resp = client.get("/not-api")
|
|
96
|
+
|
|
97
|
+
assert resp.status_code == 200
|
|
98
|
+
|
|
99
|
+
def test_auth_disabled(self, test_client_id):
|
|
100
|
+
settings = AuthSettings(
|
|
101
|
+
ENTRA_CLIENT_ID=test_client_id, AUTH_DISABLED=True
|
|
102
|
+
)
|
|
103
|
+
app = _create_app(settings)
|
|
104
|
+
with app.test_client() as client:
|
|
105
|
+
resp = client.get("/api/protected")
|
|
106
|
+
|
|
107
|
+
assert resp.status_code == 200
|
|
108
|
+
assert resp.json["email"] == "dev@localhost"
|
|
109
|
+
|
|
110
|
+
def test_require_auth_decorator(self, auth_settings, make_token, jwks_keys):
|
|
111
|
+
app = _create_app(auth_settings)
|
|
112
|
+
token = make_token()
|
|
113
|
+
|
|
114
|
+
with patch("conduit_auth.flask.middleware.get_key_cache") as mock_cache:
|
|
115
|
+
mock_cache.return_value.get_keys.return_value = jwks_keys
|
|
116
|
+
with app.test_client() as client:
|
|
117
|
+
resp = client.get(
|
|
118
|
+
"/api/decorated",
|
|
119
|
+
headers={"Authorization": f"Bearer {token}"},
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
assert resp.status_code == 200
|
|
123
|
+
assert resp.json["email"] == "user@example.com"
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
"""Tests for JWT token validation."""
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
|
|
5
|
+
from conduit_auth.exceptions import AuthenticationError, TokenExpiredError
|
|
6
|
+
from conduit_auth.token import validate_token
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class TestValidateToken:
|
|
10
|
+
"""Tests for validate_token()."""
|
|
11
|
+
|
|
12
|
+
def test_valid_token(self, make_token, jwks_keys, auth_settings):
|
|
13
|
+
token = make_token()
|
|
14
|
+
user = validate_token(token, jwks_keys, auth_settings)
|
|
15
|
+
|
|
16
|
+
assert user.tenant_id == "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
|
|
17
|
+
assert user.user_id == "11111111-2222-3333-4444-555555555555"
|
|
18
|
+
assert user.email == "user@example.com"
|
|
19
|
+
assert user.display_name == "Test User"
|
|
20
|
+
assert user.roles == []
|
|
21
|
+
assert "tid" in user.raw_claims
|
|
22
|
+
|
|
23
|
+
def test_valid_token_with_roles(self, make_token, jwks_keys, auth_settings):
|
|
24
|
+
token = make_token(extra_claims={"roles": ["Admin", "Reader"]})
|
|
25
|
+
user = validate_token(token, jwks_keys, auth_settings)
|
|
26
|
+
|
|
27
|
+
assert user.roles == ["Admin", "Reader"]
|
|
28
|
+
|
|
29
|
+
def test_expired_token(self, make_token, jwks_keys, auth_settings):
|
|
30
|
+
token = make_token(expired=True)
|
|
31
|
+
|
|
32
|
+
with pytest.raises(TokenExpiredError, match="expired"):
|
|
33
|
+
validate_token(token, jwks_keys, auth_settings)
|
|
34
|
+
|
|
35
|
+
def test_wrong_audience(self, make_token, jwks_keys, auth_settings):
|
|
36
|
+
token = make_token(audience="wrong-audience")
|
|
37
|
+
|
|
38
|
+
with pytest.raises(AuthenticationError, match="audience"):
|
|
39
|
+
validate_token(token, jwks_keys, auth_settings)
|
|
40
|
+
|
|
41
|
+
def test_invalid_issuer(self, make_token, jwks_keys, auth_settings):
|
|
42
|
+
token = make_token(issuer="https://evil.example.com/v2.0")
|
|
43
|
+
|
|
44
|
+
with pytest.raises(AuthenticationError, match="issuer"):
|
|
45
|
+
validate_token(token, jwks_keys, auth_settings)
|
|
46
|
+
|
|
47
|
+
def test_v1_issuer_accepted(self, make_token, jwks_keys, auth_settings):
|
|
48
|
+
tid = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
|
|
49
|
+
token = make_token(issuer=f"https://sts.windows.net/{tid}/")
|
|
50
|
+
user = validate_token(token, jwks_keys, auth_settings)
|
|
51
|
+
|
|
52
|
+
assert user.tenant_id == tid
|
|
53
|
+
|
|
54
|
+
def test_multi_tenant_different_tenants(
|
|
55
|
+
self, make_token, jwks_keys, auth_settings
|
|
56
|
+
):
|
|
57
|
+
"""Tokens from different tenants should all be accepted."""
|
|
58
|
+
for tid in [
|
|
59
|
+
"11111111-1111-1111-1111-111111111111",
|
|
60
|
+
"22222222-2222-2222-2222-222222222222",
|
|
61
|
+
"33333333-3333-3333-3333-333333333333",
|
|
62
|
+
]:
|
|
63
|
+
token = make_token(tenant_id=tid)
|
|
64
|
+
user = validate_token(token, jwks_keys, auth_settings)
|
|
65
|
+
assert user.tenant_id == tid
|
|
66
|
+
|
|
67
|
+
def test_unknown_kid(self, make_token, jwks_keys, auth_settings):
|
|
68
|
+
token = make_token(use_kid="unknown-kid")
|
|
69
|
+
|
|
70
|
+
with pytest.raises(AuthenticationError, match="not found"):
|
|
71
|
+
validate_token(token, jwks_keys, auth_settings)
|
|
72
|
+
|
|
73
|
+
def test_malformed_token(self, jwks_keys, auth_settings):
|
|
74
|
+
with pytest.raises(AuthenticationError, match="Malformed"):
|
|
75
|
+
validate_token("not-a-jwt", jwks_keys, auth_settings)
|
|
76
|
+
|
|
77
|
+
def test_custom_audience_setting(
|
|
78
|
+
self, make_token, jwks_keys, test_client_id
|
|
79
|
+
):
|
|
80
|
+
custom_audience = "api://my-custom-audience"
|
|
81
|
+
settings = type(
|
|
82
|
+
"Settings",
|
|
83
|
+
(),
|
|
84
|
+
{
|
|
85
|
+
"ENTRA_CLIENT_ID": test_client_id,
|
|
86
|
+
"ENTRA_AUDIENCE": custom_audience,
|
|
87
|
+
},
|
|
88
|
+
)()
|
|
89
|
+
token = make_token(audience=custom_audience)
|
|
90
|
+
user = validate_token(token, jwks_keys, settings)
|
|
91
|
+
|
|
92
|
+
assert user.email == "user@example.com"
|