conduit-auth 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,14 @@
1
+ __pycache__/
2
+ *.pyc
3
+ .env
4
+ .vscode/settings.json
5
+ Resources/BCP/*.xlsx
6
+ scratchpad/*.sql
7
+ .env.*
8
+ .vscode/settings.json
9
+
10
+ # Dependencies and build output
11
+ node_modules/
12
+ dist/
13
+
14
+ .claude/settings.local.json
@@ -0,0 +1,22 @@
1
+ Metadata-Version: 2.4
2
+ Name: conduit-auth
3
+ Version: 0.1.0
4
+ Summary: Shared Entra ID authentication for the Conduit suite
5
+ Requires-Python: >=3.11
6
+ Requires-Dist: cryptography>=41.0
7
+ Requires-Dist: httpx>=0.25
8
+ Requires-Dist: pydantic-settings>=2.0
9
+ Requires-Dist: pyjwt[crypto]>=2.8
10
+ Provides-Extra: dev
11
+ Requires-Dist: fastapi>=0.100; extra == 'dev'
12
+ Requires-Dist: flask>=3.0; extra == 'dev'
13
+ Requires-Dist: httpx>=0.25; extra == 'dev'
14
+ Requires-Dist: pytest-asyncio>=0.21; extra == 'dev'
15
+ Requires-Dist: pytest>=7.0; extra == 'dev'
16
+ Requires-Dist: respx>=0.20; extra == 'dev'
17
+ Requires-Dist: starlette[full]>=0.27; extra == 'dev'
18
+ Requires-Dist: uvicorn>=0.20; extra == 'dev'
19
+ Provides-Extra: fastapi
20
+ Requires-Dist: fastapi>=0.100; extra == 'fastapi'
21
+ Provides-Extra: flask
22
+ Requires-Dist: flask>=3.0; extra == 'flask'
@@ -0,0 +1,36 @@
1
+ [project]
2
+ name = "conduit-auth"
3
+ version = "0.1.0"
4
+ description = "Shared Entra ID authentication for the Conduit suite"
5
+ requires-python = ">=3.11"
6
+ dependencies = [
7
+ "PyJWT[crypto]>=2.8",
8
+ "cryptography>=41.0",
9
+ "httpx>=0.25",
10
+ "pydantic-settings>=2.0",
11
+ ]
12
+
13
+ [project.optional-dependencies]
14
+ fastapi = ["fastapi>=0.100"]
15
+ flask = ["flask>=3.0"]
16
+ dev = [
17
+ "pytest>=7.0",
18
+ "pytest-asyncio>=0.21",
19
+ "respx>=0.20",
20
+ "httpx>=0.25",
21
+ "fastapi>=0.100",
22
+ "flask>=3.0",
23
+ "uvicorn>=0.20",
24
+ "starlette[full]>=0.27",
25
+ ]
26
+
27
+ [build-system]
28
+ requires = ["hatchling"]
29
+ build-backend = "hatchling.build"
30
+
31
+ [tool.hatch.build.targets.wheel]
32
+ packages = ["src/conduit_auth"]
33
+
34
+ [tool.pytest.ini_options]
35
+ testpaths = ["tests"]
36
+ asyncio_mode = "auto"
@@ -0,0 +1,27 @@
1
+ """Conduit Auth — Shared Entra ID authentication for the Conduit suite.
2
+
3
+ Usage (FastAPI):
4
+ from conduit_auth import AuthSettings
5
+ from conduit_auth.fastapi import AuthMiddleware, get_current_user
6
+
7
+ settings = AuthSettings()
8
+ app.add_middleware(AuthMiddleware, settings=settings)
9
+
10
+ Usage (Flask):
11
+ from conduit_auth import AuthSettings
12
+ from conduit_auth.flask import init_auth, get_current_user
13
+
14
+ settings = AuthSettings()
15
+ init_auth(app, settings)
16
+ """
17
+
18
+ from .config import AuthSettings
19
+ from .exceptions import AuthenticationError, TokenExpiredError
20
+ from .models import AuthenticatedUser
21
+
22
+ __all__ = [
23
+ "AuthSettings",
24
+ "AuthenticatedUser",
25
+ "AuthenticationError",
26
+ "TokenExpiredError",
27
+ ]
@@ -0,0 +1,32 @@
1
+ """Authentication configuration via Pydantic settings."""
2
+
3
+ from pydantic_settings import BaseSettings
4
+
5
+
6
+ class AuthSettings(BaseSettings):
7
+ """Conduit auth configuration, loaded from environment variables.
8
+
9
+ Required:
10
+ ENTRA_CLIENT_ID: App registration client ID from Azure portal.
11
+
12
+ Optional:
13
+ ENTRA_TENANT_ID: Set to "common" for multi-tenant (default).
14
+ Use a specific tenant GUID to restrict to a single tenant.
15
+ ENTRA_AUDIENCE: Token audience to validate. Defaults to ENTRA_CLIENT_ID.
16
+ JWKS_URI: Microsoft's JWKS key endpoint. Override only for testing.
17
+ JWKS_CACHE_TTL: How long to cache JWKS keys (seconds). Default 1 hour.
18
+ AUTH_DISABLED: Set True for local dev to bypass validation entirely.
19
+ AUTH_EXCLUDE_PATHS: Paths that skip auth (health checks, etc).
20
+ """
21
+
22
+ ENTRA_CLIENT_ID: str
23
+ ENTRA_TENANT_ID: str = "common"
24
+ ENTRA_AUDIENCE: str | None = None
25
+ JWKS_URI: str = (
26
+ "https://login.microsoftonline.com/common/discovery/v2.0/keys"
27
+ )
28
+ JWKS_CACHE_TTL: int = 3600
29
+ AUTH_DISABLED: bool = False
30
+ AUTH_EXCLUDE_PATHS: list[str] = ["/health", "/api/health"]
31
+
32
+ model_config = {"env_prefix": "", "env_file": ".env"}
@@ -0,0 +1,9 @@
1
+ """Authentication exceptions for the Conduit auth module."""
2
+
3
+
4
+ class AuthenticationError(Exception):
5
+ """Raised when token validation fails."""
6
+
7
+
8
+ class TokenExpiredError(AuthenticationError):
9
+ """Raised when the JWT has expired."""
@@ -0,0 +1,6 @@
1
+ """FastAPI integration for Conduit auth."""
2
+
3
+ from .middleware import AuthMiddleware
4
+ from .dependencies import get_current_user
5
+
6
+ __all__ = ["AuthMiddleware", "get_current_user"]
@@ -0,0 +1,23 @@
1
+ """FastAPI dependency injection for authenticated user."""
2
+
3
+ from fastapi import HTTPException, Request
4
+
5
+ from ..models import AuthenticatedUser
6
+
7
+
8
+ async def get_current_user(request: Request) -> AuthenticatedUser:
9
+ """FastAPI dependency that extracts the authenticated user.
10
+
11
+ Usage:
12
+ from conduit_auth.fastapi import get_current_user
13
+ from conduit_auth import AuthenticatedUser
14
+
15
+ @app.get("/api/projects")
16
+ async def list_projects(user: AuthenticatedUser = Depends(get_current_user)):
17
+ # user.tenant_id, user.email, etc.
18
+ ...
19
+ """
20
+ user = getattr(request.state, "user", None)
21
+ if user is None:
22
+ raise HTTPException(status_code=401, detail="Not authenticated")
23
+ return user
@@ -0,0 +1,69 @@
1
+ """FastAPI/Starlette middleware for JWT authentication."""
2
+
3
+ from starlette.middleware.base import BaseHTTPMiddleware
4
+ from starlette.requests import Request
5
+ from starlette.responses import JSONResponse
6
+
7
+ from ..config import AuthSettings
8
+ from ..exceptions import AuthenticationError
9
+ from ..models import AuthenticatedUser
10
+ from ..token import get_key_cache, validate_token
11
+
12
+
13
+ class AuthMiddleware(BaseHTTPMiddleware):
14
+ """Validates Bearer JWT tokens on /api/* routes.
15
+
16
+ Usage:
17
+ from conduit_auth.fastapi import AuthMiddleware
18
+ from conduit_auth import AuthSettings
19
+
20
+ settings = AuthSettings()
21
+ app.add_middleware(AuthMiddleware, settings=settings)
22
+ """
23
+
24
+ def __init__(self, app, settings: AuthSettings) -> None:
25
+ super().__init__(app)
26
+ self.settings = settings
27
+ self._key_cache = get_key_cache()
28
+
29
+ async def dispatch(self, request: Request, call_next):
30
+ path = request.url.path
31
+
32
+ # Skip excluded paths (health checks, etc.)
33
+ if path in self.settings.AUTH_EXCLUDE_PATHS:
34
+ return await call_next(request)
35
+
36
+ # Only protect /api/* routes — static assets are served by nginx
37
+ if not path.startswith("/api/"):
38
+ return await call_next(request)
39
+
40
+ # Dev bypass — injects a mock user
41
+ if self.settings.AUTH_DISABLED:
42
+ request.state.user = AuthenticatedUser(
43
+ tenant_id="dev-tenant-00000000",
44
+ user_id="dev-user-00000000",
45
+ email="dev@localhost",
46
+ display_name="Dev User",
47
+ roles=[],
48
+ raw_claims={},
49
+ )
50
+ return await call_next(request)
51
+
52
+ # Extract Bearer token
53
+ auth_header = request.headers.get("Authorization", "")
54
+ if not auth_header.startswith("Bearer "):
55
+ return JSONResponse(
56
+ {"error": "Missing or invalid Authorization header"},
57
+ status_code=401,
58
+ )
59
+ token = auth_header[7:]
60
+
61
+ # Validate
62
+ try:
63
+ jwks_keys = await self._key_cache.get_keys_async(self.settings)
64
+ user = validate_token(token, jwks_keys, self.settings)
65
+ except AuthenticationError as e:
66
+ return JSONResponse({"error": str(e)}, status_code=401)
67
+
68
+ request.state.user = user
69
+ return await call_next(request)
@@ -0,0 +1,6 @@
1
+ """Flask integration for Conduit auth."""
2
+
3
+ from .middleware import init_auth, get_current_user
4
+ from .decorators import require_auth
5
+
6
+ __all__ = ["init_auth", "get_current_user", "require_auth"]
@@ -0,0 +1,29 @@
1
+ """Flask route decorators for authentication."""
2
+
3
+ from functools import wraps
4
+
5
+ from flask import g, jsonify
6
+
7
+
8
+ def require_auth(f):
9
+ """Decorator for routes that require an authenticated user.
10
+
11
+ Use this when init_auth() is not applied globally, or as an
12
+ explicit marker on routes that need auth.
13
+
14
+ Usage:
15
+ @app.route("/api/projects")
16
+ @require_auth
17
+ def list_projects():
18
+ user = get_current_user()
19
+ ...
20
+ """
21
+
22
+ @wraps(f)
23
+ def decorated(*args, **kwargs):
24
+ user = getattr(g, "user", None)
25
+ if user is None:
26
+ return jsonify({"error": "Not authenticated"}), 401
27
+ return f(*args, **kwargs)
28
+
29
+ return decorated
@@ -0,0 +1,86 @@
1
+ """Flask middleware for JWT authentication."""
2
+
3
+ from flask import Flask, g, jsonify, request
4
+
5
+ from ..config import AuthSettings
6
+ from ..exceptions import AuthenticationError
7
+ from ..models import AuthenticatedUser
8
+ from ..token import get_key_cache, validate_token
9
+
10
+
11
+ def init_auth(app: Flask, settings: AuthSettings) -> None:
12
+ """Register JWT auth middleware on a Flask app.
13
+
14
+ Validates Bearer tokens on all /api/* routes. Stores the authenticated
15
+ user on Flask's g.user for access in route handlers.
16
+
17
+ Usage:
18
+ from conduit_auth.flask import init_auth
19
+ from conduit_auth import AuthSettings
20
+
21
+ settings = AuthSettings()
22
+ init_auth(app, settings)
23
+
24
+ # In routes:
25
+ from conduit_auth.flask import get_current_user
26
+ user = get_current_user()
27
+ """
28
+ @app.before_request
29
+ def _validate_auth():
30
+ path = request.path
31
+
32
+ # Skip excluded paths
33
+ if path in settings.AUTH_EXCLUDE_PATHS:
34
+ return None
35
+
36
+ # Only protect /api/* routes
37
+ if not path.startswith("/api/"):
38
+ return None
39
+
40
+ # Dev bypass
41
+ if settings.AUTH_DISABLED:
42
+ g.user = AuthenticatedUser(
43
+ tenant_id="dev-tenant-00000000",
44
+ user_id="dev-user-00000000",
45
+ email="dev@localhost",
46
+ display_name="Dev User",
47
+ roles=[],
48
+ raw_claims={},
49
+ )
50
+ return None
51
+
52
+ # Extract Bearer token
53
+ auth_header = request.headers.get("Authorization", "")
54
+ if not auth_header.startswith("Bearer "):
55
+ return (
56
+ jsonify({"error": "Missing or invalid Authorization header"}),
57
+ 401,
58
+ )
59
+ token = auth_header[7:]
60
+
61
+ # Validate
62
+ try:
63
+ jwks_keys = get_key_cache().get_keys(settings)
64
+ user = validate_token(token, jwks_keys, settings)
65
+ except AuthenticationError as e:
66
+ return jsonify({"error": str(e)}), 401
67
+
68
+ g.user = user
69
+ return None
70
+
71
+
72
+ def get_current_user() -> AuthenticatedUser:
73
+ """Get the authenticated user from Flask's request context.
74
+
75
+ Returns:
76
+ The AuthenticatedUser set by init_auth middleware.
77
+
78
+ Raises:
79
+ 401 abort if no user is set (request wasn't authenticated).
80
+ """
81
+ user = getattr(g, "user", None)
82
+ if user is None:
83
+ from flask import abort
84
+
85
+ abort(401)
86
+ return user
@@ -0,0 +1,24 @@
1
+ """Core data models for the Conduit auth module."""
2
+
3
+ from dataclasses import dataclass, field
4
+
5
+
6
+ @dataclass
7
+ class AuthenticatedUser:
8
+ """Represents a user authenticated via Entra ID.
9
+
10
+ Fields map to Entra ID token claims:
11
+ - tenant_id: tid claim (which Azure AD tenant the user belongs to)
12
+ - user_id: oid claim (user's object ID within their tenant)
13
+ - email: preferred_username claim
14
+ - display_name: name claim
15
+ - roles: app roles assigned in Entra ID (if configured)
16
+ - raw_claims: full decoded JWT for app-specific use
17
+ """
18
+
19
+ tenant_id: str
20
+ user_id: str
21
+ email: str
22
+ display_name: str
23
+ roles: list[str] = field(default_factory=list)
24
+ raw_claims: dict = field(default_factory=dict)
@@ -0,0 +1,155 @@
1
+ """JWT token validation against Microsoft Entra ID JWKS keys."""
2
+
3
+ import re
4
+ import time
5
+ import threading
6
+
7
+ import httpx
8
+ import jwt
9
+ from jwt.algorithms import RSAAlgorithm
10
+
11
+ from .config import AuthSettings
12
+ from .exceptions import AuthenticationError, TokenExpiredError
13
+ from .models import AuthenticatedUser
14
+
15
+ # Multi-tenant issuer patterns (v2.0 and v1.0 endpoints)
16
+ _ISSUER_V2 = re.compile(
17
+ r"https://login\.microsoftonline\.com/"
18
+ r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}/v2\.0"
19
+ )
20
+ _ISSUER_V1 = re.compile(
21
+ r"https://sts\.windows\.net/"
22
+ r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}/"
23
+ )
24
+
25
+
26
+ class JWKSKeyCache:
27
+ """Thread-safe cache for Microsoft's JWKS signing keys."""
28
+
29
+ def __init__(self) -> None:
30
+ self._keys: list[dict] = []
31
+ self._fetched_at: float = 0
32
+ self._lock = threading.Lock()
33
+
34
+ def get_keys(self, settings: AuthSettings) -> list[dict]:
35
+ """Get JWKS keys, fetching from Microsoft if cache has expired."""
36
+ now = time.time()
37
+ if now - self._fetched_at < settings.JWKS_CACHE_TTL and self._keys:
38
+ return self._keys
39
+
40
+ with self._lock:
41
+ # Double-check after acquiring lock
42
+ if now - self._fetched_at < settings.JWKS_CACHE_TTL and self._keys:
43
+ return self._keys
44
+ return self._fetch_sync(settings)
45
+
46
+ async def get_keys_async(self, settings: AuthSettings) -> list[dict]:
47
+ """Get JWKS keys asynchronously."""
48
+ now = time.time()
49
+ if now - self._fetched_at < settings.JWKS_CACHE_TTL and self._keys:
50
+ return self._keys
51
+
52
+ return await self._fetch_async(settings)
53
+
54
+ def _fetch_sync(self, settings: AuthSettings) -> list[dict]:
55
+ resp = httpx.get(settings.JWKS_URI, timeout=10)
56
+ resp.raise_for_status()
57
+ self._keys = resp.json()["keys"]
58
+ self._fetched_at = time.time()
59
+ return self._keys
60
+
61
+ async def _fetch_async(self, settings: AuthSettings) -> list[dict]:
62
+ async with httpx.AsyncClient() as client:
63
+ resp = await client.get(settings.JWKS_URI, timeout=10)
64
+ resp.raise_for_status()
65
+ self._keys = resp.json()["keys"]
66
+ self._fetched_at = time.time()
67
+ return self._keys
68
+
69
+ def clear(self) -> None:
70
+ """Clear the cache (useful for testing)."""
71
+ self._keys = []
72
+ self._fetched_at = 0
73
+
74
+
75
+ # Module-level cache instance
76
+ _key_cache = JWKSKeyCache()
77
+
78
+
79
+ def get_key_cache() -> JWKSKeyCache:
80
+ """Get the module-level JWKS key cache."""
81
+ return _key_cache
82
+
83
+
84
+ def validate_token(
85
+ token: str,
86
+ jwks_keys: list[dict],
87
+ settings: AuthSettings,
88
+ ) -> AuthenticatedUser:
89
+ """Validate a JWT and return an AuthenticatedUser.
90
+
91
+ Args:
92
+ token: The raw JWT string (without "Bearer " prefix).
93
+ jwks_keys: List of JWKS key dicts from Microsoft's endpoint.
94
+ settings: Auth configuration.
95
+
96
+ Returns:
97
+ AuthenticatedUser with claims extracted from the token.
98
+
99
+ Raises:
100
+ AuthenticationError: If the token is invalid.
101
+ TokenExpiredError: If the token has expired.
102
+ """
103
+ audience = settings.ENTRA_AUDIENCE or settings.ENTRA_CLIENT_ID
104
+
105
+ # Decode header to find the signing key
106
+ try:
107
+ unverified_header = jwt.get_unverified_header(token)
108
+ except jwt.exceptions.DecodeError as e:
109
+ raise AuthenticationError(f"Malformed token header: {e}")
110
+
111
+ kid = unverified_header.get("kid")
112
+ if not kid:
113
+ raise AuthenticationError("Token header missing 'kid' field")
114
+
115
+ matching_key = next((k for k in jwks_keys if k["kid"] == kid), None)
116
+ if not matching_key:
117
+ raise AuthenticationError("Token signing key not found in JWKS")
118
+
119
+ # Build the public key from the JWK
120
+ try:
121
+ public_key = RSAAlgorithm.from_jwk(matching_key)
122
+ except Exception as e:
123
+ raise AuthenticationError(f"Failed to construct public key: {e}")
124
+
125
+ # Decode and validate the token
126
+ try:
127
+ claims = jwt.decode(
128
+ token,
129
+ public_key,
130
+ algorithms=["RS256"],
131
+ audience=audience,
132
+ options={"verify_iss": False}, # We validate issuer manually below
133
+ )
134
+ except jwt.ExpiredSignatureError:
135
+ raise TokenExpiredError("Token has expired")
136
+ except jwt.InvalidAudienceError:
137
+ raise AuthenticationError(
138
+ f"Token audience mismatch: expected '{audience}'"
139
+ )
140
+ except jwt.InvalidTokenError as e:
141
+ raise AuthenticationError(f"Invalid token: {e}")
142
+
143
+ # Multi-tenant issuer validation: accept any valid Azure AD tenant
144
+ issuer = claims.get("iss", "")
145
+ if not (_ISSUER_V2.match(issuer) or _ISSUER_V1.match(issuer)):
146
+ raise AuthenticationError(f"Unexpected token issuer: {issuer}")
147
+
148
+ return AuthenticatedUser(
149
+ tenant_id=claims.get("tid", ""),
150
+ user_id=claims.get("oid", ""),
151
+ email=claims.get("preferred_username", ""),
152
+ display_name=claims.get("name", ""),
153
+ roles=claims.get("roles", []),
154
+ raw_claims=claims,
155
+ )
File without changes
@@ -0,0 +1,139 @@
1
+ """Shared test fixtures: RSA keys, test JWTs, mock JWKS endpoint."""
2
+
3
+ import json
4
+ import time
5
+ import uuid
6
+
7
+ import jwt as pyjwt
8
+ import pytest
9
+ from cryptography.hazmat.primitives import serialization
10
+ from cryptography.hazmat.primitives.asymmetric import rsa
11
+
12
+ from conduit_auth import AuthSettings
13
+ from conduit_auth.token import get_key_cache
14
+
15
+
16
+ @pytest.fixture(autouse=True)
17
+ def _clear_key_cache():
18
+ """Clear the JWKS key cache between tests."""
19
+ get_key_cache().clear()
20
+
21
+
22
+ @pytest.fixture
23
+ def rsa_keypair():
24
+ """Generate a test RSA key pair."""
25
+ private_key = rsa.generate_private_key(
26
+ public_exponent=65537,
27
+ key_size=2048,
28
+ )
29
+ return private_key
30
+
31
+
32
+ @pytest.fixture
33
+ def kid():
34
+ """A fixed key ID for test tokens."""
35
+ return "test-kid-12345"
36
+
37
+
38
+ @pytest.fixture
39
+ def jwks_keys(rsa_keypair, kid):
40
+ """JWKS key list matching the test RSA key."""
41
+ public_key = rsa_keypair.public_key()
42
+ public_numbers = public_key.public_numbers()
43
+
44
+ # Encode as base64url (matching JWKS format)
45
+ import base64
46
+
47
+ def _b64url(num: int, length: int) -> str:
48
+ return (
49
+ base64.urlsafe_b64encode(num.to_bytes(length, "big"))
50
+ .rstrip(b"=")
51
+ .decode()
52
+ )
53
+
54
+ n_bytes = (public_numbers.n.bit_length() + 7) // 8
55
+
56
+ return [
57
+ {
58
+ "kty": "RSA",
59
+ "use": "sig",
60
+ "kid": kid,
61
+ "n": _b64url(public_numbers.n, n_bytes),
62
+ "e": _b64url(public_numbers.e, 3),
63
+ "alg": "RS256",
64
+ }
65
+ ]
66
+
67
+
68
+ @pytest.fixture
69
+ def test_tenant_id():
70
+ return "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
71
+
72
+
73
+ @pytest.fixture
74
+ def test_user_id():
75
+ return "11111111-2222-3333-4444-555555555555"
76
+
77
+
78
+ @pytest.fixture
79
+ def test_client_id():
80
+ return "cccccccc-dddd-eeee-ffff-000000000000"
81
+
82
+
83
+ @pytest.fixture
84
+ def auth_settings(test_client_id):
85
+ """AuthSettings for testing."""
86
+ return AuthSettings(
87
+ ENTRA_CLIENT_ID=test_client_id,
88
+ AUTH_DISABLED=False,
89
+ )
90
+
91
+
92
+ @pytest.fixture
93
+ def make_token(rsa_keypair, kid, test_tenant_id, test_user_id, test_client_id):
94
+ """Factory for creating signed test JWTs."""
95
+
96
+ def _make(
97
+ tenant_id: str | None = None,
98
+ user_id: str | None = None,
99
+ email: str = "user@example.com",
100
+ name: str = "Test User",
101
+ audience: str | None = None,
102
+ issuer: str | None = None,
103
+ expired: bool = False,
104
+ extra_claims: dict | None = None,
105
+ use_kid: str | None = None,
106
+ ) -> str:
107
+ tid = tenant_id or test_tenant_id
108
+ now = time.time()
109
+
110
+ claims = {
111
+ "aud": audience or test_client_id,
112
+ "iss": issuer
113
+ or f"https://login.microsoftonline.com/{tid}/v2.0",
114
+ "iat": int(now) - 60,
115
+ "nbf": int(now) - 60,
116
+ "exp": int(now) - 300 if expired else int(now) + 3600,
117
+ "tid": tid,
118
+ "oid": user_id or test_user_id,
119
+ "preferred_username": email,
120
+ "name": name,
121
+ "sub": str(uuid.uuid4()),
122
+ }
123
+ if extra_claims:
124
+ claims.update(extra_claims)
125
+
126
+ private_pem = rsa_keypair.private_bytes(
127
+ encoding=serialization.Encoding.PEM,
128
+ format=serialization.PrivateFormat.PKCS8,
129
+ encryption_algorithm=serialization.NoEncryption(),
130
+ )
131
+
132
+ return pyjwt.encode(
133
+ claims,
134
+ private_pem,
135
+ algorithm="RS256",
136
+ headers={"kid": use_kid or kid},
137
+ )
138
+
139
+ return _make
@@ -0,0 +1,117 @@
1
+ """Tests for FastAPI middleware and dependencies."""
2
+
3
+ from unittest.mock import patch
4
+
5
+ import pytest
6
+ from fastapi import Depends, FastAPI
7
+ from fastapi.testclient import TestClient
8
+
9
+ from conduit_auth import AuthSettings, AuthenticatedUser
10
+ from conduit_auth.fastapi import AuthMiddleware, get_current_user
11
+
12
+
13
+ def _create_app(settings: AuthSettings) -> FastAPI:
14
+ """Create a test FastAPI app with auth middleware."""
15
+ app = FastAPI()
16
+ app.add_middleware(AuthMiddleware, settings=settings)
17
+
18
+ @app.get("/api/protected")
19
+ async def protected(user: AuthenticatedUser = Depends(get_current_user)):
20
+ return {
21
+ "tenant_id": user.tenant_id,
22
+ "user_id": user.user_id,
23
+ "email": user.email,
24
+ "display_name": user.display_name,
25
+ }
26
+
27
+ @app.get("/health")
28
+ async def health():
29
+ return {"status": "ok"}
30
+
31
+ @app.get("/not-api")
32
+ async def not_api():
33
+ return {"message": "public"}
34
+
35
+ return app
36
+
37
+
38
+ class TestAuthMiddleware:
39
+ def test_missing_auth_header(self, auth_settings):
40
+ client = TestClient(_create_app(auth_settings))
41
+ resp = client.get("/api/protected")
42
+
43
+ assert resp.status_code == 401
44
+ assert "Authorization" in resp.json()["error"]
45
+
46
+ def test_invalid_auth_header(self, auth_settings):
47
+ client = TestClient(_create_app(auth_settings))
48
+ resp = client.get(
49
+ "/api/protected", headers={"Authorization": "Basic abc"}
50
+ )
51
+
52
+ assert resp.status_code == 401
53
+
54
+ def test_valid_token(self, auth_settings, make_token, jwks_keys):
55
+ client = TestClient(_create_app(auth_settings))
56
+ token = make_token()
57
+
58
+ with patch(
59
+ "conduit_auth.fastapi.middleware.get_key_cache"
60
+ ) as mock_cache:
61
+ mock_cache.return_value.get_keys_async = (
62
+ lambda *a, **kw: _async_return(jwks_keys)
63
+ )
64
+ resp = client.get(
65
+ "/api/protected",
66
+ headers={"Authorization": f"Bearer {token}"},
67
+ )
68
+
69
+ assert resp.status_code == 200
70
+ data = resp.json()
71
+ assert data["email"] == "user@example.com"
72
+ assert data["tenant_id"] == "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
73
+
74
+ def test_expired_token(self, auth_settings, make_token, jwks_keys):
75
+ client = TestClient(_create_app(auth_settings))
76
+ token = make_token(expired=True)
77
+
78
+ with patch(
79
+ "conduit_auth.fastapi.middleware.get_key_cache"
80
+ ) as mock_cache:
81
+ mock_cache.return_value.get_keys_async = (
82
+ lambda *a, **kw: _async_return(jwks_keys)
83
+ )
84
+ resp = client.get(
85
+ "/api/protected",
86
+ headers={"Authorization": f"Bearer {token}"},
87
+ )
88
+
89
+ assert resp.status_code == 401
90
+ assert "expired" in resp.json()["error"]
91
+
92
+ def test_health_excluded(self, auth_settings):
93
+ client = TestClient(_create_app(auth_settings))
94
+ resp = client.get("/health")
95
+
96
+ assert resp.status_code == 200
97
+
98
+ def test_non_api_path_excluded(self, auth_settings):
99
+ client = TestClient(_create_app(auth_settings))
100
+ resp = client.get("/not-api")
101
+
102
+ assert resp.status_code == 200
103
+
104
+ def test_auth_disabled(self, test_client_id):
105
+ settings = AuthSettings(
106
+ ENTRA_CLIENT_ID=test_client_id, AUTH_DISABLED=True
107
+ )
108
+ client = TestClient(_create_app(settings))
109
+ resp = client.get("/api/protected")
110
+
111
+ assert resp.status_code == 200
112
+ assert resp.json()["email"] == "dev@localhost"
113
+ assert resp.json()["tenant_id"] == "dev-tenant-00000000"
114
+
115
+
116
+ async def _async_return(value):
117
+ return value
@@ -0,0 +1,123 @@
1
+ """Tests for Flask middleware and decorators."""
2
+
3
+ from unittest.mock import patch
4
+
5
+ import pytest
6
+ from flask import Flask
7
+
8
+ from conduit_auth import AuthSettings
9
+ from conduit_auth.flask import get_current_user, init_auth, require_auth
10
+
11
+
12
+ def _create_app(settings: AuthSettings) -> Flask:
13
+ """Create a test Flask app with auth middleware."""
14
+ app = Flask(__name__)
15
+ app.config["TESTING"] = True
16
+ init_auth(app, settings)
17
+
18
+ @app.route("/api/protected")
19
+ def protected():
20
+ user = get_current_user()
21
+ return {
22
+ "tenant_id": user.tenant_id,
23
+ "user_id": user.user_id,
24
+ "email": user.email,
25
+ "display_name": user.display_name,
26
+ }
27
+
28
+ @app.route("/api/decorated")
29
+ @require_auth
30
+ def decorated():
31
+ user = get_current_user()
32
+ return {"email": user.email}
33
+
34
+ @app.route("/health")
35
+ def health():
36
+ return {"status": "ok"}
37
+
38
+ @app.route("/not-api")
39
+ def not_api():
40
+ return {"message": "public"}
41
+
42
+ return app
43
+
44
+
45
+ class TestFlaskMiddleware:
46
+ def test_missing_auth_header(self, auth_settings):
47
+ app = _create_app(auth_settings)
48
+ with app.test_client() as client:
49
+ resp = client.get("/api/protected")
50
+
51
+ assert resp.status_code == 401
52
+ assert "Authorization" in resp.json["error"]
53
+
54
+ def test_valid_token(self, auth_settings, make_token, jwks_keys):
55
+ app = _create_app(auth_settings)
56
+ token = make_token()
57
+
58
+ with patch("conduit_auth.flask.middleware.get_key_cache") as mock_cache:
59
+ mock_cache.return_value.get_keys.return_value = jwks_keys
60
+ with app.test_client() as client:
61
+ resp = client.get(
62
+ "/api/protected",
63
+ headers={"Authorization": f"Bearer {token}"},
64
+ )
65
+
66
+ assert resp.status_code == 200
67
+ assert resp.json["email"] == "user@example.com"
68
+ assert resp.json["tenant_id"] == "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
69
+
70
+ def test_expired_token(self, auth_settings, make_token, jwks_keys):
71
+ app = _create_app(auth_settings)
72
+ token = make_token(expired=True)
73
+
74
+ with patch("conduit_auth.flask.middleware.get_key_cache") as mock_cache:
75
+ mock_cache.return_value.get_keys.return_value = jwks_keys
76
+ with app.test_client() as client:
77
+ resp = client.get(
78
+ "/api/protected",
79
+ headers={"Authorization": f"Bearer {token}"},
80
+ )
81
+
82
+ assert resp.status_code == 401
83
+ assert "expired" in resp.json["error"]
84
+
85
+ def test_health_excluded(self, auth_settings):
86
+ app = _create_app(auth_settings)
87
+ with app.test_client() as client:
88
+ resp = client.get("/health")
89
+
90
+ assert resp.status_code == 200
91
+
92
+ def test_non_api_path_excluded(self, auth_settings):
93
+ app = _create_app(auth_settings)
94
+ with app.test_client() as client:
95
+ resp = client.get("/not-api")
96
+
97
+ assert resp.status_code == 200
98
+
99
+ def test_auth_disabled(self, test_client_id):
100
+ settings = AuthSettings(
101
+ ENTRA_CLIENT_ID=test_client_id, AUTH_DISABLED=True
102
+ )
103
+ app = _create_app(settings)
104
+ with app.test_client() as client:
105
+ resp = client.get("/api/protected")
106
+
107
+ assert resp.status_code == 200
108
+ assert resp.json["email"] == "dev@localhost"
109
+
110
+ def test_require_auth_decorator(self, auth_settings, make_token, jwks_keys):
111
+ app = _create_app(auth_settings)
112
+ token = make_token()
113
+
114
+ with patch("conduit_auth.flask.middleware.get_key_cache") as mock_cache:
115
+ mock_cache.return_value.get_keys.return_value = jwks_keys
116
+ with app.test_client() as client:
117
+ resp = client.get(
118
+ "/api/decorated",
119
+ headers={"Authorization": f"Bearer {token}"},
120
+ )
121
+
122
+ assert resp.status_code == 200
123
+ assert resp.json["email"] == "user@example.com"
@@ -0,0 +1,92 @@
1
+ """Tests for JWT token validation."""
2
+
3
+ import pytest
4
+
5
+ from conduit_auth.exceptions import AuthenticationError, TokenExpiredError
6
+ from conduit_auth.token import validate_token
7
+
8
+
9
+ class TestValidateToken:
10
+ """Tests for validate_token()."""
11
+
12
+ def test_valid_token(self, make_token, jwks_keys, auth_settings):
13
+ token = make_token()
14
+ user = validate_token(token, jwks_keys, auth_settings)
15
+
16
+ assert user.tenant_id == "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
17
+ assert user.user_id == "11111111-2222-3333-4444-555555555555"
18
+ assert user.email == "user@example.com"
19
+ assert user.display_name == "Test User"
20
+ assert user.roles == []
21
+ assert "tid" in user.raw_claims
22
+
23
+ def test_valid_token_with_roles(self, make_token, jwks_keys, auth_settings):
24
+ token = make_token(extra_claims={"roles": ["Admin", "Reader"]})
25
+ user = validate_token(token, jwks_keys, auth_settings)
26
+
27
+ assert user.roles == ["Admin", "Reader"]
28
+
29
+ def test_expired_token(self, make_token, jwks_keys, auth_settings):
30
+ token = make_token(expired=True)
31
+
32
+ with pytest.raises(TokenExpiredError, match="expired"):
33
+ validate_token(token, jwks_keys, auth_settings)
34
+
35
+ def test_wrong_audience(self, make_token, jwks_keys, auth_settings):
36
+ token = make_token(audience="wrong-audience")
37
+
38
+ with pytest.raises(AuthenticationError, match="audience"):
39
+ validate_token(token, jwks_keys, auth_settings)
40
+
41
+ def test_invalid_issuer(self, make_token, jwks_keys, auth_settings):
42
+ token = make_token(issuer="https://evil.example.com/v2.0")
43
+
44
+ with pytest.raises(AuthenticationError, match="issuer"):
45
+ validate_token(token, jwks_keys, auth_settings)
46
+
47
+ def test_v1_issuer_accepted(self, make_token, jwks_keys, auth_settings):
48
+ tid = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
49
+ token = make_token(issuer=f"https://sts.windows.net/{tid}/")
50
+ user = validate_token(token, jwks_keys, auth_settings)
51
+
52
+ assert user.tenant_id == tid
53
+
54
+ def test_multi_tenant_different_tenants(
55
+ self, make_token, jwks_keys, auth_settings
56
+ ):
57
+ """Tokens from different tenants should all be accepted."""
58
+ for tid in [
59
+ "11111111-1111-1111-1111-111111111111",
60
+ "22222222-2222-2222-2222-222222222222",
61
+ "33333333-3333-3333-3333-333333333333",
62
+ ]:
63
+ token = make_token(tenant_id=tid)
64
+ user = validate_token(token, jwks_keys, auth_settings)
65
+ assert user.tenant_id == tid
66
+
67
+ def test_unknown_kid(self, make_token, jwks_keys, auth_settings):
68
+ token = make_token(use_kid="unknown-kid")
69
+
70
+ with pytest.raises(AuthenticationError, match="not found"):
71
+ validate_token(token, jwks_keys, auth_settings)
72
+
73
+ def test_malformed_token(self, jwks_keys, auth_settings):
74
+ with pytest.raises(AuthenticationError, match="Malformed"):
75
+ validate_token("not-a-jwt", jwks_keys, auth_settings)
76
+
77
+ def test_custom_audience_setting(
78
+ self, make_token, jwks_keys, test_client_id
79
+ ):
80
+ custom_audience = "api://my-custom-audience"
81
+ settings = type(
82
+ "Settings",
83
+ (),
84
+ {
85
+ "ENTRA_CLIENT_ID": test_client_id,
86
+ "ENTRA_AUDIENCE": custom_audience,
87
+ },
88
+ )()
89
+ token = make_token(audience=custom_audience)
90
+ user = validate_token(token, jwks_keys, settings)
91
+
92
+ assert user.email == "user@example.com"