byoi 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- byoi/__init__.py +233 -0
- byoi/__main__.py +228 -0
- byoi/cache.py +346 -0
- byoi/config.py +349 -0
- byoi/dependencies.py +144 -0
- byoi/errors.py +360 -0
- byoi/models.py +451 -0
- byoi/pkce.py +144 -0
- byoi/providers.py +434 -0
- byoi/py.typed +0 -0
- byoi/repositories.py +252 -0
- byoi/service.py +723 -0
- byoi/telemetry.py +352 -0
- byoi/tokens.py +340 -0
- byoi/types.py +130 -0
- byoi-0.1.0a1.dist-info/METADATA +504 -0
- byoi-0.1.0a1.dist-info/RECORD +20 -0
- byoi-0.1.0a1.dist-info/WHEEL +4 -0
- byoi-0.1.0a1.dist-info/entry_points.txt +3 -0
- byoi-0.1.0a1.dist-info/licenses/LICENSE +21 -0
byoi/telemetry.py
ADDED
|
@@ -0,0 +1,352 @@
|
|
|
1
|
+
"""OpenTelemetry tracing support for BYOI.
|
|
2
|
+
|
|
3
|
+
This module provides optional OpenTelemetry instrumentation for BYOI operations.
|
|
4
|
+
Tracing can help with debugging, monitoring, and understanding the performance
|
|
5
|
+
of authentication flows.
|
|
6
|
+
|
|
7
|
+
Installation:
|
|
8
|
+
pip install byoi[telemetry]
|
|
9
|
+
|
|
10
|
+
Usage:
|
|
11
|
+
from byoi.telemetry import configure_tracing, get_tracer
|
|
12
|
+
|
|
13
|
+
# Configure tracing (call once at startup)
|
|
14
|
+
configure_tracing(service_name="my-auth-service")
|
|
15
|
+
|
|
16
|
+
# Or use with custom TracerProvider
|
|
17
|
+
from opentelemetry.sdk.trace import TracerProvider
|
|
18
|
+
provider = TracerProvider()
|
|
19
|
+
configure_tracing(tracer_provider=provider)
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
from contextlib import contextmanager
|
|
23
|
+
from functools import wraps
|
|
24
|
+
from typing import TYPE_CHECKING, Any, Callable, TypeVar
|
|
25
|
+
|
|
26
|
+
# Type variable for generic function decoration
|
|
27
|
+
F = TypeVar("F", bound=Callable[..., Any])
|
|
28
|
+
|
|
29
|
+
# Flag to track if OpenTelemetry is available
|
|
30
|
+
_OTEL_AVAILABLE = False
|
|
31
|
+
_tracer = None
|
|
32
|
+
|
|
33
|
+
try:
|
|
34
|
+
from opentelemetry import trace
|
|
35
|
+
from opentelemetry.trace import Span, SpanKind, Status, StatusCode, Tracer
|
|
36
|
+
from opentelemetry.sdk.trace import TracerProvider
|
|
37
|
+
from opentelemetry.sdk.resources import Resource
|
|
38
|
+
from opentelemetry.semconv.resource import ResourceAttributes
|
|
39
|
+
|
|
40
|
+
_OTEL_AVAILABLE = True
|
|
41
|
+
except ImportError:
|
|
42
|
+
# OpenTelemetry not installed - provide stub implementations
|
|
43
|
+
trace = None # type: ignore[assignment]
|
|
44
|
+
Span = None # type: ignore[assignment, misc]
|
|
45
|
+
SpanKind = None # type: ignore[assignment, misc]
|
|
46
|
+
Status = None # type: ignore[assignment, misc]
|
|
47
|
+
StatusCode = None # type: ignore[assignment, misc]
|
|
48
|
+
Tracer = None # type: ignore[assignment, misc]
|
|
49
|
+
TracerProvider = None # type: ignore[assignment, misc]
|
|
50
|
+
Resource = None # type: ignore[assignment, misc]
|
|
51
|
+
ResourceAttributes = None # type: ignore[assignment, misc]
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
__all__ = (
|
|
55
|
+
"is_tracing_available",
|
|
56
|
+
"configure_tracing",
|
|
57
|
+
"get_tracer",
|
|
58
|
+
"traced",
|
|
59
|
+
"traced_async",
|
|
60
|
+
"span_context",
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
# BYOI-specific span names
|
|
65
|
+
SPAN_CREATE_AUTH_URL = "byoi.create_authorization_url"
|
|
66
|
+
SPAN_EXCHANGE_CODE = "byoi.exchange_code"
|
|
67
|
+
SPAN_REFRESH_TOKEN = "byoi.refresh_token"
|
|
68
|
+
SPAN_AUTHENTICATE = "byoi.authenticate"
|
|
69
|
+
SPAN_LINK_IDENTITY = "byoi.link_identity"
|
|
70
|
+
SPAN_UNLINK_IDENTITY = "byoi.unlink_identity"
|
|
71
|
+
SPAN_VALIDATE_TOKEN = "byoi.validate_token"
|
|
72
|
+
SPAN_DISCOVER_PROVIDER = "byoi.discover_provider"
|
|
73
|
+
SPAN_FETCH_JWKS = "byoi.fetch_jwks"
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def is_tracing_available() -> bool:
|
|
77
|
+
"""Check if OpenTelemetry tracing is available.
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
True if OpenTelemetry is installed and configured, False otherwise.
|
|
81
|
+
"""
|
|
82
|
+
return _OTEL_AVAILABLE
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def configure_tracing(
|
|
86
|
+
service_name: str = "byoi",
|
|
87
|
+
tracer_provider: "TracerProvider | None" = None,
|
|
88
|
+
**resource_attributes: str,
|
|
89
|
+
) -> "Tracer | None":
|
|
90
|
+
"""Configure OpenTelemetry tracing for BYOI.
|
|
91
|
+
|
|
92
|
+
This function should be called once during application startup to
|
|
93
|
+
enable tracing for BYOI operations.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
service_name: The name of the service for tracing.
|
|
97
|
+
tracer_provider: Optional custom TracerProvider. If not provided,
|
|
98
|
+
a new one will be created.
|
|
99
|
+
**resource_attributes: Additional resource attributes to include.
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
The configured Tracer, or None if OpenTelemetry is not available.
|
|
103
|
+
|
|
104
|
+
Example:
|
|
105
|
+
```python
|
|
106
|
+
from byoi.telemetry import configure_tracing
|
|
107
|
+
|
|
108
|
+
# Simple configuration
|
|
109
|
+
configure_tracing(service_name="my-auth-service")
|
|
110
|
+
|
|
111
|
+
# With custom attributes
|
|
112
|
+
configure_tracing(
|
|
113
|
+
service_name="my-auth-service",
|
|
114
|
+
environment="production",
|
|
115
|
+
version="1.0.0",
|
|
116
|
+
)
|
|
117
|
+
```
|
|
118
|
+
"""
|
|
119
|
+
global _tracer
|
|
120
|
+
|
|
121
|
+
if not _OTEL_AVAILABLE:
|
|
122
|
+
return None
|
|
123
|
+
|
|
124
|
+
if tracer_provider is None:
|
|
125
|
+
# Create a new TracerProvider with resource attributes
|
|
126
|
+
attributes = {
|
|
127
|
+
ResourceAttributes.SERVICE_NAME: service_name,
|
|
128
|
+
**resource_attributes,
|
|
129
|
+
}
|
|
130
|
+
resource = Resource.create(attributes)
|
|
131
|
+
tracer_provider = TracerProvider(resource=resource)
|
|
132
|
+
trace.set_tracer_provider(tracer_provider)
|
|
133
|
+
|
|
134
|
+
_tracer = trace.get_tracer(
|
|
135
|
+
instrumenting_module_name="byoi",
|
|
136
|
+
instrumenting_library_version="0.1.0a1",
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
return _tracer
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def get_tracer() -> "Tracer | None":
|
|
143
|
+
"""Get the configured BYOI tracer.
|
|
144
|
+
|
|
145
|
+
Returns:
|
|
146
|
+
The configured Tracer, or None if tracing is not configured.
|
|
147
|
+
"""
|
|
148
|
+
global _tracer
|
|
149
|
+
|
|
150
|
+
if not _OTEL_AVAILABLE:
|
|
151
|
+
return None
|
|
152
|
+
|
|
153
|
+
if _tracer is None:
|
|
154
|
+
# Try to get from global TracerProvider
|
|
155
|
+
_tracer = trace.get_tracer(
|
|
156
|
+
instrumenting_module_name="byoi",
|
|
157
|
+
instrumenting_library_version="0.1.0a1",
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
return _tracer
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
@contextmanager
|
|
164
|
+
def span_context(
|
|
165
|
+
name: str,
|
|
166
|
+
kind: "SpanKind | None" = None,
|
|
167
|
+
attributes: dict[str, Any] | None = None,
|
|
168
|
+
):
|
|
169
|
+
"""Context manager for creating spans.
|
|
170
|
+
|
|
171
|
+
This is a convenience wrapper that handles the case where tracing
|
|
172
|
+
is not available.
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
name: The name of the span.
|
|
176
|
+
kind: The kind of span (CLIENT, SERVER, INTERNAL, etc.).
|
|
177
|
+
attributes: Optional attributes to add to the span.
|
|
178
|
+
|
|
179
|
+
Yields:
|
|
180
|
+
The created Span, or a no-op context if tracing is not available.
|
|
181
|
+
|
|
182
|
+
Example:
|
|
183
|
+
```python
|
|
184
|
+
from byoi.telemetry import span_context
|
|
185
|
+
|
|
186
|
+
with span_context("my_operation", attributes={"user_id": "123"}) as span:
|
|
187
|
+
# Do work
|
|
188
|
+
span.set_attribute("result", "success")
|
|
189
|
+
```
|
|
190
|
+
"""
|
|
191
|
+
if not _OTEL_AVAILABLE:
|
|
192
|
+
yield _NoOpSpan()
|
|
193
|
+
return
|
|
194
|
+
|
|
195
|
+
tracer = get_tracer()
|
|
196
|
+
if tracer is None:
|
|
197
|
+
yield _NoOpSpan()
|
|
198
|
+
return
|
|
199
|
+
|
|
200
|
+
span_kind = kind or SpanKind.INTERNAL
|
|
201
|
+
with tracer.start_as_current_span(
|
|
202
|
+
name,
|
|
203
|
+
kind=span_kind,
|
|
204
|
+
attributes=attributes,
|
|
205
|
+
) as span:
|
|
206
|
+
yield span
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
class _NoOpSpan:
|
|
210
|
+
"""A no-op span for when tracing is not available."""
|
|
211
|
+
|
|
212
|
+
def set_attribute(self, key: str, value: Any) -> None:
|
|
213
|
+
"""No-op set_attribute."""
|
|
214
|
+
pass
|
|
215
|
+
|
|
216
|
+
def set_attributes(self, attributes: dict[str, Any]) -> None:
|
|
217
|
+
"""No-op set_attributes."""
|
|
218
|
+
pass
|
|
219
|
+
|
|
220
|
+
def add_event(self, name: str, attributes: dict[str, Any] | None = None) -> None:
|
|
221
|
+
"""No-op add_event."""
|
|
222
|
+
pass
|
|
223
|
+
|
|
224
|
+
def record_exception(self, exception: BaseException) -> None:
|
|
225
|
+
"""No-op record_exception."""
|
|
226
|
+
pass
|
|
227
|
+
|
|
228
|
+
def set_status(self, status: Any) -> None:
|
|
229
|
+
"""No-op set_status."""
|
|
230
|
+
pass
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
def traced(
|
|
234
|
+
span_name: str | None = None,
|
|
235
|
+
kind: "SpanKind | None" = None,
|
|
236
|
+
record_exception: bool = True,
|
|
237
|
+
) -> Callable[[F], F]:
|
|
238
|
+
"""Decorator to trace a synchronous function.
|
|
239
|
+
|
|
240
|
+
Args:
|
|
241
|
+
span_name: The name of the span. Defaults to the function name.
|
|
242
|
+
kind: The kind of span.
|
|
243
|
+
record_exception: Whether to record exceptions in the span.
|
|
244
|
+
|
|
245
|
+
Returns:
|
|
246
|
+
A decorator function.
|
|
247
|
+
|
|
248
|
+
Example:
|
|
249
|
+
```python
|
|
250
|
+
from byoi.telemetry import traced
|
|
251
|
+
|
|
252
|
+
@traced("my_operation")
|
|
253
|
+
def my_function():
|
|
254
|
+
# Do work
|
|
255
|
+
pass
|
|
256
|
+
```
|
|
257
|
+
"""
|
|
258
|
+
|
|
259
|
+
def decorator(func: F) -> F:
|
|
260
|
+
if not _OTEL_AVAILABLE:
|
|
261
|
+
return func
|
|
262
|
+
|
|
263
|
+
@wraps(func)
|
|
264
|
+
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
265
|
+
name = span_name or func.__name__
|
|
266
|
+
with span_context(name, kind=kind) as span:
|
|
267
|
+
try:
|
|
268
|
+
result = func(*args, **kwargs)
|
|
269
|
+
return result
|
|
270
|
+
except Exception as e:
|
|
271
|
+
if record_exception:
|
|
272
|
+
span.record_exception(e)
|
|
273
|
+
span.set_status(Status(StatusCode.ERROR, str(e)))
|
|
274
|
+
raise
|
|
275
|
+
|
|
276
|
+
return wrapper # type: ignore[return-value]
|
|
277
|
+
|
|
278
|
+
return decorator
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
def traced_async(
|
|
282
|
+
span_name: str | None = None,
|
|
283
|
+
kind: "SpanKind | None" = None,
|
|
284
|
+
record_exception: bool = True,
|
|
285
|
+
) -> Callable[[F], F]:
|
|
286
|
+
"""Decorator to trace an asynchronous function.
|
|
287
|
+
|
|
288
|
+
Args:
|
|
289
|
+
span_name: The name of the span. Defaults to the function name.
|
|
290
|
+
kind: The kind of span.
|
|
291
|
+
record_exception: Whether to record exceptions in the span.
|
|
292
|
+
|
|
293
|
+
Returns:
|
|
294
|
+
A decorator function.
|
|
295
|
+
|
|
296
|
+
Example:
|
|
297
|
+
```python
|
|
298
|
+
from byoi.telemetry import traced_async
|
|
299
|
+
|
|
300
|
+
@traced_async("my_async_operation")
|
|
301
|
+
async def my_async_function():
|
|
302
|
+
# Do async work
|
|
303
|
+
pass
|
|
304
|
+
```
|
|
305
|
+
"""
|
|
306
|
+
|
|
307
|
+
def decorator(func: F) -> F:
|
|
308
|
+
if not _OTEL_AVAILABLE:
|
|
309
|
+
return func
|
|
310
|
+
|
|
311
|
+
@wraps(func)
|
|
312
|
+
async def wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
313
|
+
name = span_name or func.__name__
|
|
314
|
+
with span_context(name, kind=kind) as span:
|
|
315
|
+
try:
|
|
316
|
+
result = await func(*args, **kwargs)
|
|
317
|
+
return result
|
|
318
|
+
except Exception as e:
|
|
319
|
+
if record_exception:
|
|
320
|
+
span.record_exception(e)
|
|
321
|
+
span.set_status(Status(StatusCode.ERROR, str(e)))
|
|
322
|
+
raise
|
|
323
|
+
|
|
324
|
+
return wrapper # type: ignore[return-value]
|
|
325
|
+
|
|
326
|
+
return decorator
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
def add_auth_attributes(
|
|
330
|
+
span: "Span | _NoOpSpan",
|
|
331
|
+
provider_name: str | None = None,
|
|
332
|
+
user_id: str | None = None,
|
|
333
|
+
identity_id: str | None = None,
|
|
334
|
+
is_new_user: bool | None = None,
|
|
335
|
+
) -> None:
|
|
336
|
+
"""Add common authentication attributes to a span.
|
|
337
|
+
|
|
338
|
+
Args:
|
|
339
|
+
span: The span to add attributes to.
|
|
340
|
+
provider_name: The name of the OIDC provider.
|
|
341
|
+
user_id: The user ID.
|
|
342
|
+
identity_id: The linked identity ID.
|
|
343
|
+
is_new_user: Whether this is a new user.
|
|
344
|
+
"""
|
|
345
|
+
if provider_name is not None:
|
|
346
|
+
span.set_attribute("byoi.provider_name", provider_name)
|
|
347
|
+
if user_id is not None:
|
|
348
|
+
span.set_attribute("byoi.user_id", user_id)
|
|
349
|
+
if identity_id is not None:
|
|
350
|
+
span.set_attribute("byoi.identity_id", identity_id)
|
|
351
|
+
if is_new_user is not None:
|
|
352
|
+
span.set_attribute("byoi.is_new_user", is_new_user)
|
byoi/tokens.py
ADDED
|
@@ -0,0 +1,340 @@
|
|
|
1
|
+
"""ID Token validation utilities.
|
|
2
|
+
|
|
3
|
+
Handles JWT validation including signature verification, claims validation,
|
|
4
|
+
and nonce checking.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import base64
|
|
8
|
+
import hashlib
|
|
9
|
+
import logging
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
from jose import JWTError, jwt
|
|
13
|
+
from jose.exceptions import ExpiredSignatureError, JWTClaimsError
|
|
14
|
+
|
|
15
|
+
from byoi.errors import (
|
|
16
|
+
InvalidAudienceError,
|
|
17
|
+
InvalidIssuerError,
|
|
18
|
+
InvalidNonceError,
|
|
19
|
+
TokenExpiredError,
|
|
20
|
+
TokenValidationError,
|
|
21
|
+
)
|
|
22
|
+
from byoi.models import IdentityInfo, OIDCProviderConfig
|
|
23
|
+
from byoi.providers import ProviderManager
|
|
24
|
+
|
|
25
|
+
logger = logging.getLogger("byoi.tokens")
|
|
26
|
+
|
|
27
|
+
__all__ = (
|
|
28
|
+
"TokenValidator",
|
|
29
|
+
"validate_id_token",
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class TokenValidator:
|
|
34
|
+
"""Validates ID tokens from OIDC providers."""
|
|
35
|
+
|
|
36
|
+
# Standard OIDC claims to extract
|
|
37
|
+
STANDARD_CLAIMS = {
|
|
38
|
+
"sub",
|
|
39
|
+
"email",
|
|
40
|
+
"email_verified",
|
|
41
|
+
"name",
|
|
42
|
+
"given_name",
|
|
43
|
+
"family_name",
|
|
44
|
+
"picture",
|
|
45
|
+
"locale",
|
|
46
|
+
"iss",
|
|
47
|
+
"aud",
|
|
48
|
+
"exp",
|
|
49
|
+
"iat",
|
|
50
|
+
"nonce",
|
|
51
|
+
"at_hash",
|
|
52
|
+
"c_hash",
|
|
53
|
+
"auth_time",
|
|
54
|
+
"acr",
|
|
55
|
+
"amr",
|
|
56
|
+
"azp",
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
def __init__(self, provider_manager: ProviderManager) -> None:
|
|
60
|
+
"""Initialize the token validator.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
provider_manager: The provider manager for fetching JWKS.
|
|
64
|
+
"""
|
|
65
|
+
self._provider_manager = provider_manager
|
|
66
|
+
|
|
67
|
+
async def validate(
|
|
68
|
+
self,
|
|
69
|
+
id_token: str,
|
|
70
|
+
provider_name: str,
|
|
71
|
+
expected_nonce: str | None = None,
|
|
72
|
+
verify_at_hash: bool = False,
|
|
73
|
+
access_token: str | None = None,
|
|
74
|
+
) -> IdentityInfo:
|
|
75
|
+
"""Validate an ID token and extract identity information.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
id_token: The raw ID token (JWT).
|
|
79
|
+
provider_name: The provider name.
|
|
80
|
+
expected_nonce: The expected nonce value (for replay protection).
|
|
81
|
+
verify_at_hash: Whether to verify the at_hash claim.
|
|
82
|
+
access_token: The access token (required if verify_at_hash is True).
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
Extracted identity information.
|
|
86
|
+
|
|
87
|
+
Raises:
|
|
88
|
+
TokenValidationError: If validation fails.
|
|
89
|
+
ProviderNotFoundError: If the provider is not registered.
|
|
90
|
+
"""
|
|
91
|
+
logger.debug("Validating ID token for provider=%s", provider_name)
|
|
92
|
+
|
|
93
|
+
config = self._provider_manager.get_provider(provider_name)
|
|
94
|
+
jwks = await self._provider_manager.get_jwks(provider_name)
|
|
95
|
+
|
|
96
|
+
# Decode and validate the token, with JWKS refresh retry on key not found
|
|
97
|
+
try:
|
|
98
|
+
claims = await self._decode_and_validate(id_token, config, jwks)
|
|
99
|
+
except TokenValidationError as e:
|
|
100
|
+
# Check if this might be a key rotation issue (kid not found in JWKS)
|
|
101
|
+
error_msg = str(e).lower()
|
|
102
|
+
if "key" in error_msg or "kid" in error_msg or "signature" in error_msg:
|
|
103
|
+
logger.info(
|
|
104
|
+
"Token validation failed, retrying with fresh JWKS for provider=%s",
|
|
105
|
+
provider_name,
|
|
106
|
+
)
|
|
107
|
+
# Fetch fresh JWKS bypassing cache
|
|
108
|
+
jwks = await self._provider_manager.get_jwks(provider_name, force_refresh=True)
|
|
109
|
+
# Retry validation with fresh JWKS
|
|
110
|
+
claims = await self._decode_and_validate(id_token, config, jwks)
|
|
111
|
+
else:
|
|
112
|
+
raise
|
|
113
|
+
|
|
114
|
+
# Validate nonce if provided
|
|
115
|
+
if expected_nonce is not None:
|
|
116
|
+
token_nonce = claims.get("nonce")
|
|
117
|
+
if token_nonce != expected_nonce:
|
|
118
|
+
logger.warning(
|
|
119
|
+
"Nonce mismatch for provider=%s: expected=%s, got=%s",
|
|
120
|
+
provider_name,
|
|
121
|
+
expected_nonce,
|
|
122
|
+
token_nonce,
|
|
123
|
+
)
|
|
124
|
+
raise InvalidNonceError()
|
|
125
|
+
|
|
126
|
+
# Validate at_hash if required
|
|
127
|
+
if verify_at_hash and access_token:
|
|
128
|
+
await self._verify_at_hash(claims, access_token, id_token)
|
|
129
|
+
|
|
130
|
+
# Extract identity information
|
|
131
|
+
identity = self._extract_identity(claims, provider_name)
|
|
132
|
+
|
|
133
|
+
logger.debug(
|
|
134
|
+
"ID token validated successfully: provider=%s, subject=%s",
|
|
135
|
+
provider_name,
|
|
136
|
+
identity.subject,
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
return identity
|
|
140
|
+
|
|
141
|
+
async def _decode_and_validate(
|
|
142
|
+
self,
|
|
143
|
+
id_token: str,
|
|
144
|
+
config: OIDCProviderConfig,
|
|
145
|
+
jwks: dict[str, Any],
|
|
146
|
+
) -> dict[str, Any]:
|
|
147
|
+
"""Decode and validate the JWT.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
id_token: The raw ID token.
|
|
151
|
+
config: The provider configuration.
|
|
152
|
+
jwks: The provider's JWKS.
|
|
153
|
+
|
|
154
|
+
Returns:
|
|
155
|
+
The validated claims.
|
|
156
|
+
|
|
157
|
+
Raises:
|
|
158
|
+
TokenValidationError: If validation fails.
|
|
159
|
+
"""
|
|
160
|
+
# Determine expected audience
|
|
161
|
+
audience = config.audience or config.client_id
|
|
162
|
+
|
|
163
|
+
try:
|
|
164
|
+
# First, decode without verification to get the header
|
|
165
|
+
unverified = jwt.get_unverified_header(id_token)
|
|
166
|
+
algorithm = unverified.get("alg", "RS256")
|
|
167
|
+
|
|
168
|
+
# Decode and verify the token
|
|
169
|
+
claims = jwt.decode(
|
|
170
|
+
id_token,
|
|
171
|
+
jwks,
|
|
172
|
+
algorithms=[algorithm],
|
|
173
|
+
audience=audience,
|
|
174
|
+
issuer=config.issuer,
|
|
175
|
+
options={
|
|
176
|
+
"verify_aud": True,
|
|
177
|
+
"verify_iss": True,
|
|
178
|
+
"verify_exp": True,
|
|
179
|
+
"verify_iat": True,
|
|
180
|
+
"verify_nbf": False, # Not all providers include nbf
|
|
181
|
+
"leeway": config.clock_skew_seconds,
|
|
182
|
+
},
|
|
183
|
+
)
|
|
184
|
+
except ExpiredSignatureError as e:
|
|
185
|
+
logger.warning("Token expired for provider=%s", config.name)
|
|
186
|
+
raise TokenExpiredError() from e
|
|
187
|
+
except JWTClaimsError as e:
|
|
188
|
+
error_msg = str(e).lower()
|
|
189
|
+
if "audience" in error_msg:
|
|
190
|
+
# Try to extract actual audience from token
|
|
191
|
+
try:
|
|
192
|
+
unverified_claims = jwt.get_unverified_claims(id_token)
|
|
193
|
+
actual_aud = unverified_claims.get("aud", "unknown")
|
|
194
|
+
except Exception:
|
|
195
|
+
actual_aud = "unknown"
|
|
196
|
+
logger.warning(
|
|
197
|
+
"Audience mismatch for provider=%s: expected=%s, got=%s",
|
|
198
|
+
config.name,
|
|
199
|
+
audience,
|
|
200
|
+
actual_aud,
|
|
201
|
+
)
|
|
202
|
+
raise InvalidAudienceError(audience, actual_aud) from e
|
|
203
|
+
if "issuer" in error_msg:
|
|
204
|
+
try:
|
|
205
|
+
unverified_claims = jwt.get_unverified_claims(id_token)
|
|
206
|
+
actual_iss = unverified_claims.get("iss", "unknown")
|
|
207
|
+
except Exception:
|
|
208
|
+
actual_iss = "unknown"
|
|
209
|
+
logger.warning(
|
|
210
|
+
"Issuer mismatch for provider=%s: expected=%s, got=%s",
|
|
211
|
+
config.name,
|
|
212
|
+
config.issuer,
|
|
213
|
+
actual_iss,
|
|
214
|
+
)
|
|
215
|
+
raise InvalidIssuerError(config.issuer, actual_iss) from e
|
|
216
|
+
logger.warning("Token claims validation failed for provider=%s: %s", config.name, str(e))
|
|
217
|
+
raise TokenValidationError(f"Token claims validation failed: {e}") from e
|
|
218
|
+
except JWTError as e:
|
|
219
|
+
logger.error("JWT validation failed for provider=%s: %s", config.name, str(e))
|
|
220
|
+
raise TokenValidationError(f"Token validation failed: {e}") from e
|
|
221
|
+
|
|
222
|
+
# Validate required claims
|
|
223
|
+
if "sub" not in claims:
|
|
224
|
+
logger.error("Token missing required 'sub' claim for provider=%s", config.name)
|
|
225
|
+
raise TokenValidationError("Token missing required 'sub' claim")
|
|
226
|
+
|
|
227
|
+
return claims
|
|
228
|
+
|
|
229
|
+
async def _verify_at_hash(
|
|
230
|
+
self,
|
|
231
|
+
claims: dict[str, Any],
|
|
232
|
+
access_token: str,
|
|
233
|
+
id_token: str,
|
|
234
|
+
) -> None:
|
|
235
|
+
"""Verify the at_hash claim matches the access token.
|
|
236
|
+
|
|
237
|
+
Args:
|
|
238
|
+
claims: The ID token claims.
|
|
239
|
+
access_token: The access token.
|
|
240
|
+
id_token: The ID token (to determine algorithm).
|
|
241
|
+
|
|
242
|
+
Raises:
|
|
243
|
+
TokenValidationError: If at_hash verification fails.
|
|
244
|
+
"""
|
|
245
|
+
|
|
246
|
+
at_hash = claims.get("at_hash")
|
|
247
|
+
if at_hash is None:
|
|
248
|
+
# at_hash is optional, but if verify_at_hash is True and it's missing,
|
|
249
|
+
# we should raise an error
|
|
250
|
+
raise TokenValidationError("Token missing 'at_hash' claim")
|
|
251
|
+
|
|
252
|
+
# Get the algorithm from the token header
|
|
253
|
+
header = jwt.get_unverified_header(id_token)
|
|
254
|
+
algorithm = header.get("alg", "RS256")
|
|
255
|
+
|
|
256
|
+
# Determine hash algorithm based on JWT algorithm
|
|
257
|
+
if algorithm.startswith("RS") or algorithm.startswith("PS"):
|
|
258
|
+
hash_alg = f"sha{algorithm[2:]}"
|
|
259
|
+
elif algorithm.startswith("ES"):
|
|
260
|
+
hash_alg = f"sha{algorithm[2:]}"
|
|
261
|
+
elif algorithm.startswith("HS"):
|
|
262
|
+
hash_alg = f"sha{algorithm[2:]}"
|
|
263
|
+
else:
|
|
264
|
+
hash_alg = "sha256"
|
|
265
|
+
|
|
266
|
+
# Calculate expected at_hash
|
|
267
|
+
try:
|
|
268
|
+
hash_func = hashlib.new(hash_alg)
|
|
269
|
+
hash_func.update(access_token.encode("ascii"))
|
|
270
|
+
digest = hash_func.digest()
|
|
271
|
+
expected_at_hash = base64.urlsafe_b64encode(digest[: len(digest) // 2]).rstrip(b"=")
|
|
272
|
+
expected_at_hash_str = expected_at_hash.decode("ascii")
|
|
273
|
+
except Exception as e:
|
|
274
|
+
raise TokenValidationError(f"Failed to calculate at_hash: {e}") from e
|
|
275
|
+
|
|
276
|
+
if at_hash != expected_at_hash_str:
|
|
277
|
+
raise TokenValidationError("at_hash claim does not match access token")
|
|
278
|
+
|
|
279
|
+
def _extract_identity(
|
|
280
|
+
self,
|
|
281
|
+
claims: dict[str, Any],
|
|
282
|
+
provider_name: str,
|
|
283
|
+
) -> IdentityInfo:
|
|
284
|
+
"""Extract identity information from claims.
|
|
285
|
+
|
|
286
|
+
Args:
|
|
287
|
+
claims: The validated token claims.
|
|
288
|
+
provider_name: The provider name.
|
|
289
|
+
|
|
290
|
+
Returns:
|
|
291
|
+
The extracted identity information.
|
|
292
|
+
"""
|
|
293
|
+
# Extract extra claims (non-standard)
|
|
294
|
+
extra_claims = {
|
|
295
|
+
k: v for k, v in claims.items() if k not in self.STANDARD_CLAIMS
|
|
296
|
+
}
|
|
297
|
+
|
|
298
|
+
return IdentityInfo(
|
|
299
|
+
provider_name=provider_name,
|
|
300
|
+
subject=claims["sub"],
|
|
301
|
+
email=claims.get("email"),
|
|
302
|
+
email_verified=claims.get("email_verified", False),
|
|
303
|
+
name=claims.get("name"),
|
|
304
|
+
given_name=claims.get("given_name"),
|
|
305
|
+
family_name=claims.get("family_name"),
|
|
306
|
+
picture=claims.get("picture"),
|
|
307
|
+
locale=claims.get("locale"),
|
|
308
|
+
extra_claims=extra_claims,
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
async def validate_id_token(
|
|
313
|
+
id_token: str,
|
|
314
|
+
provider_manager: ProviderManager,
|
|
315
|
+
provider_name: str,
|
|
316
|
+
expected_nonce: str | None = None,
|
|
317
|
+
verify_at_hash: bool = False,
|
|
318
|
+
access_token: str | None = None,
|
|
319
|
+
) -> IdentityInfo:
|
|
320
|
+
"""Convenience function to validate an ID token.
|
|
321
|
+
|
|
322
|
+
Args:
|
|
323
|
+
id_token: The raw ID token (JWT).
|
|
324
|
+
provider_manager: The provider manager.
|
|
325
|
+
provider_name: The provider name.
|
|
326
|
+
expected_nonce: The expected nonce value.
|
|
327
|
+
verify_at_hash: Whether to verify the at_hash claim.
|
|
328
|
+
access_token: The access token (required if verify_at_hash is True).
|
|
329
|
+
|
|
330
|
+
Returns:
|
|
331
|
+
Extracted identity information.
|
|
332
|
+
"""
|
|
333
|
+
validator = TokenValidator(provider_manager)
|
|
334
|
+
return await validator.validate(
|
|
335
|
+
id_token,
|
|
336
|
+
provider_name,
|
|
337
|
+
expected_nonce,
|
|
338
|
+
verify_at_hash,
|
|
339
|
+
access_token,
|
|
340
|
+
)
|