byoi 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
byoi/telemetry.py ADDED
@@ -0,0 +1,352 @@
1
+ """OpenTelemetry tracing support for BYOI.
2
+
3
+ This module provides optional OpenTelemetry instrumentation for BYOI operations.
4
+ Tracing can help with debugging, monitoring, and understanding the performance
5
+ of authentication flows.
6
+
7
+ Installation:
8
+ pip install byoi[telemetry]
9
+
10
+ Usage:
11
+ from byoi.telemetry import configure_tracing, get_tracer
12
+
13
+ # Configure tracing (call once at startup)
14
+ configure_tracing(service_name="my-auth-service")
15
+
16
+ # Or use with custom TracerProvider
17
+ from opentelemetry.sdk.trace import TracerProvider
18
+ provider = TracerProvider()
19
+ configure_tracing(tracer_provider=provider)
20
+ """
21
+
22
+ from contextlib import contextmanager
23
+ from functools import wraps
24
+ from typing import TYPE_CHECKING, Any, Callable, TypeVar
25
+
26
+ # Type variable for generic function decoration
27
+ F = TypeVar("F", bound=Callable[..., Any])
28
+
29
+ # Flag to track if OpenTelemetry is available
30
+ _OTEL_AVAILABLE = False
31
+ _tracer = None
32
+
33
+ try:
34
+ from opentelemetry import trace
35
+ from opentelemetry.trace import Span, SpanKind, Status, StatusCode, Tracer
36
+ from opentelemetry.sdk.trace import TracerProvider
37
+ from opentelemetry.sdk.resources import Resource
38
+ from opentelemetry.semconv.resource import ResourceAttributes
39
+
40
+ _OTEL_AVAILABLE = True
41
+ except ImportError:
42
+ # OpenTelemetry not installed - provide stub implementations
43
+ trace = None # type: ignore[assignment]
44
+ Span = None # type: ignore[assignment, misc]
45
+ SpanKind = None # type: ignore[assignment, misc]
46
+ Status = None # type: ignore[assignment, misc]
47
+ StatusCode = None # type: ignore[assignment, misc]
48
+ Tracer = None # type: ignore[assignment, misc]
49
+ TracerProvider = None # type: ignore[assignment, misc]
50
+ Resource = None # type: ignore[assignment, misc]
51
+ ResourceAttributes = None # type: ignore[assignment, misc]
52
+
53
+
54
+ __all__ = (
55
+ "is_tracing_available",
56
+ "configure_tracing",
57
+ "get_tracer",
58
+ "traced",
59
+ "traced_async",
60
+ "span_context",
61
+ )
62
+
63
+
64
+ # BYOI-specific span names
65
+ SPAN_CREATE_AUTH_URL = "byoi.create_authorization_url"
66
+ SPAN_EXCHANGE_CODE = "byoi.exchange_code"
67
+ SPAN_REFRESH_TOKEN = "byoi.refresh_token"
68
+ SPAN_AUTHENTICATE = "byoi.authenticate"
69
+ SPAN_LINK_IDENTITY = "byoi.link_identity"
70
+ SPAN_UNLINK_IDENTITY = "byoi.unlink_identity"
71
+ SPAN_VALIDATE_TOKEN = "byoi.validate_token"
72
+ SPAN_DISCOVER_PROVIDER = "byoi.discover_provider"
73
+ SPAN_FETCH_JWKS = "byoi.fetch_jwks"
74
+
75
+
76
+ def is_tracing_available() -> bool:
77
+ """Check if OpenTelemetry tracing is available.
78
+
79
+ Returns:
80
+ True if OpenTelemetry is installed and configured, False otherwise.
81
+ """
82
+ return _OTEL_AVAILABLE
83
+
84
+
85
+ def configure_tracing(
86
+ service_name: str = "byoi",
87
+ tracer_provider: "TracerProvider | None" = None,
88
+ **resource_attributes: str,
89
+ ) -> "Tracer | None":
90
+ """Configure OpenTelemetry tracing for BYOI.
91
+
92
+ This function should be called once during application startup to
93
+ enable tracing for BYOI operations.
94
+
95
+ Args:
96
+ service_name: The name of the service for tracing.
97
+ tracer_provider: Optional custom TracerProvider. If not provided,
98
+ a new one will be created.
99
+ **resource_attributes: Additional resource attributes to include.
100
+
101
+ Returns:
102
+ The configured Tracer, or None if OpenTelemetry is not available.
103
+
104
+ Example:
105
+ ```python
106
+ from byoi.telemetry import configure_tracing
107
+
108
+ # Simple configuration
109
+ configure_tracing(service_name="my-auth-service")
110
+
111
+ # With custom attributes
112
+ configure_tracing(
113
+ service_name="my-auth-service",
114
+ environment="production",
115
+ version="1.0.0",
116
+ )
117
+ ```
118
+ """
119
+ global _tracer
120
+
121
+ if not _OTEL_AVAILABLE:
122
+ return None
123
+
124
+ if tracer_provider is None:
125
+ # Create a new TracerProvider with resource attributes
126
+ attributes = {
127
+ ResourceAttributes.SERVICE_NAME: service_name,
128
+ **resource_attributes,
129
+ }
130
+ resource = Resource.create(attributes)
131
+ tracer_provider = TracerProvider(resource=resource)
132
+ trace.set_tracer_provider(tracer_provider)
133
+
134
+ _tracer = trace.get_tracer(
135
+ instrumenting_module_name="byoi",
136
+ instrumenting_library_version="0.1.0a1",
137
+ )
138
+
139
+ return _tracer
140
+
141
+
142
+ def get_tracer() -> "Tracer | None":
143
+ """Get the configured BYOI tracer.
144
+
145
+ Returns:
146
+ The configured Tracer, or None if tracing is not configured.
147
+ """
148
+ global _tracer
149
+
150
+ if not _OTEL_AVAILABLE:
151
+ return None
152
+
153
+ if _tracer is None:
154
+ # Try to get from global TracerProvider
155
+ _tracer = trace.get_tracer(
156
+ instrumenting_module_name="byoi",
157
+ instrumenting_library_version="0.1.0a1",
158
+ )
159
+
160
+ return _tracer
161
+
162
+
163
+ @contextmanager
164
+ def span_context(
165
+ name: str,
166
+ kind: "SpanKind | None" = None,
167
+ attributes: dict[str, Any] | None = None,
168
+ ):
169
+ """Context manager for creating spans.
170
+
171
+ This is a convenience wrapper that handles the case where tracing
172
+ is not available.
173
+
174
+ Args:
175
+ name: The name of the span.
176
+ kind: The kind of span (CLIENT, SERVER, INTERNAL, etc.).
177
+ attributes: Optional attributes to add to the span.
178
+
179
+ Yields:
180
+ The created Span, or a no-op context if tracing is not available.
181
+
182
+ Example:
183
+ ```python
184
+ from byoi.telemetry import span_context
185
+
186
+ with span_context("my_operation", attributes={"user_id": "123"}) as span:
187
+ # Do work
188
+ span.set_attribute("result", "success")
189
+ ```
190
+ """
191
+ if not _OTEL_AVAILABLE:
192
+ yield _NoOpSpan()
193
+ return
194
+
195
+ tracer = get_tracer()
196
+ if tracer is None:
197
+ yield _NoOpSpan()
198
+ return
199
+
200
+ span_kind = kind or SpanKind.INTERNAL
201
+ with tracer.start_as_current_span(
202
+ name,
203
+ kind=span_kind,
204
+ attributes=attributes,
205
+ ) as span:
206
+ yield span
207
+
208
+
209
+ class _NoOpSpan:
210
+ """A no-op span for when tracing is not available."""
211
+
212
+ def set_attribute(self, key: str, value: Any) -> None:
213
+ """No-op set_attribute."""
214
+ pass
215
+
216
+ def set_attributes(self, attributes: dict[str, Any]) -> None:
217
+ """No-op set_attributes."""
218
+ pass
219
+
220
+ def add_event(self, name: str, attributes: dict[str, Any] | None = None) -> None:
221
+ """No-op add_event."""
222
+ pass
223
+
224
+ def record_exception(self, exception: BaseException) -> None:
225
+ """No-op record_exception."""
226
+ pass
227
+
228
+ def set_status(self, status: Any) -> None:
229
+ """No-op set_status."""
230
+ pass
231
+
232
+
233
+ def traced(
234
+ span_name: str | None = None,
235
+ kind: "SpanKind | None" = None,
236
+ record_exception: bool = True,
237
+ ) -> Callable[[F], F]:
238
+ """Decorator to trace a synchronous function.
239
+
240
+ Args:
241
+ span_name: The name of the span. Defaults to the function name.
242
+ kind: The kind of span.
243
+ record_exception: Whether to record exceptions in the span.
244
+
245
+ Returns:
246
+ A decorator function.
247
+
248
+ Example:
249
+ ```python
250
+ from byoi.telemetry import traced
251
+
252
+ @traced("my_operation")
253
+ def my_function():
254
+ # Do work
255
+ pass
256
+ ```
257
+ """
258
+
259
+ def decorator(func: F) -> F:
260
+ if not _OTEL_AVAILABLE:
261
+ return func
262
+
263
+ @wraps(func)
264
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
265
+ name = span_name or func.__name__
266
+ with span_context(name, kind=kind) as span:
267
+ try:
268
+ result = func(*args, **kwargs)
269
+ return result
270
+ except Exception as e:
271
+ if record_exception:
272
+ span.record_exception(e)
273
+ span.set_status(Status(StatusCode.ERROR, str(e)))
274
+ raise
275
+
276
+ return wrapper # type: ignore[return-value]
277
+
278
+ return decorator
279
+
280
+
281
+ def traced_async(
282
+ span_name: str | None = None,
283
+ kind: "SpanKind | None" = None,
284
+ record_exception: bool = True,
285
+ ) -> Callable[[F], F]:
286
+ """Decorator to trace an asynchronous function.
287
+
288
+ Args:
289
+ span_name: The name of the span. Defaults to the function name.
290
+ kind: The kind of span.
291
+ record_exception: Whether to record exceptions in the span.
292
+
293
+ Returns:
294
+ A decorator function.
295
+
296
+ Example:
297
+ ```python
298
+ from byoi.telemetry import traced_async
299
+
300
+ @traced_async("my_async_operation")
301
+ async def my_async_function():
302
+ # Do async work
303
+ pass
304
+ ```
305
+ """
306
+
307
+ def decorator(func: F) -> F:
308
+ if not _OTEL_AVAILABLE:
309
+ return func
310
+
311
+ @wraps(func)
312
+ async def wrapper(*args: Any, **kwargs: Any) -> Any:
313
+ name = span_name or func.__name__
314
+ with span_context(name, kind=kind) as span:
315
+ try:
316
+ result = await func(*args, **kwargs)
317
+ return result
318
+ except Exception as e:
319
+ if record_exception:
320
+ span.record_exception(e)
321
+ span.set_status(Status(StatusCode.ERROR, str(e)))
322
+ raise
323
+
324
+ return wrapper # type: ignore[return-value]
325
+
326
+ return decorator
327
+
328
+
329
+ def add_auth_attributes(
330
+ span: "Span | _NoOpSpan",
331
+ provider_name: str | None = None,
332
+ user_id: str | None = None,
333
+ identity_id: str | None = None,
334
+ is_new_user: bool | None = None,
335
+ ) -> None:
336
+ """Add common authentication attributes to a span.
337
+
338
+ Args:
339
+ span: The span to add attributes to.
340
+ provider_name: The name of the OIDC provider.
341
+ user_id: The user ID.
342
+ identity_id: The linked identity ID.
343
+ is_new_user: Whether this is a new user.
344
+ """
345
+ if provider_name is not None:
346
+ span.set_attribute("byoi.provider_name", provider_name)
347
+ if user_id is not None:
348
+ span.set_attribute("byoi.user_id", user_id)
349
+ if identity_id is not None:
350
+ span.set_attribute("byoi.identity_id", identity_id)
351
+ if is_new_user is not None:
352
+ span.set_attribute("byoi.is_new_user", is_new_user)
byoi/tokens.py ADDED
@@ -0,0 +1,340 @@
1
+ """ID Token validation utilities.
2
+
3
+ Handles JWT validation including signature verification, claims validation,
4
+ and nonce checking.
5
+ """
6
+
7
+ import base64
8
+ import hashlib
9
+ import logging
10
+ from typing import Any
11
+
12
+ from jose import JWTError, jwt
13
+ from jose.exceptions import ExpiredSignatureError, JWTClaimsError
14
+
15
+ from byoi.errors import (
16
+ InvalidAudienceError,
17
+ InvalidIssuerError,
18
+ InvalidNonceError,
19
+ TokenExpiredError,
20
+ TokenValidationError,
21
+ )
22
+ from byoi.models import IdentityInfo, OIDCProviderConfig
23
+ from byoi.providers import ProviderManager
24
+
25
+ logger = logging.getLogger("byoi.tokens")
26
+
27
+ __all__ = (
28
+ "TokenValidator",
29
+ "validate_id_token",
30
+ )
31
+
32
+
33
+ class TokenValidator:
34
+ """Validates ID tokens from OIDC providers."""
35
+
36
+ # Standard OIDC claims to extract
37
+ STANDARD_CLAIMS = {
38
+ "sub",
39
+ "email",
40
+ "email_verified",
41
+ "name",
42
+ "given_name",
43
+ "family_name",
44
+ "picture",
45
+ "locale",
46
+ "iss",
47
+ "aud",
48
+ "exp",
49
+ "iat",
50
+ "nonce",
51
+ "at_hash",
52
+ "c_hash",
53
+ "auth_time",
54
+ "acr",
55
+ "amr",
56
+ "azp",
57
+ }
58
+
59
+ def __init__(self, provider_manager: ProviderManager) -> None:
60
+ """Initialize the token validator.
61
+
62
+ Args:
63
+ provider_manager: The provider manager for fetching JWKS.
64
+ """
65
+ self._provider_manager = provider_manager
66
+
67
+ async def validate(
68
+ self,
69
+ id_token: str,
70
+ provider_name: str,
71
+ expected_nonce: str | None = None,
72
+ verify_at_hash: bool = False,
73
+ access_token: str | None = None,
74
+ ) -> IdentityInfo:
75
+ """Validate an ID token and extract identity information.
76
+
77
+ Args:
78
+ id_token: The raw ID token (JWT).
79
+ provider_name: The provider name.
80
+ expected_nonce: The expected nonce value (for replay protection).
81
+ verify_at_hash: Whether to verify the at_hash claim.
82
+ access_token: The access token (required if verify_at_hash is True).
83
+
84
+ Returns:
85
+ Extracted identity information.
86
+
87
+ Raises:
88
+ TokenValidationError: If validation fails.
89
+ ProviderNotFoundError: If the provider is not registered.
90
+ """
91
+ logger.debug("Validating ID token for provider=%s", provider_name)
92
+
93
+ config = self._provider_manager.get_provider(provider_name)
94
+ jwks = await self._provider_manager.get_jwks(provider_name)
95
+
96
+ # Decode and validate the token, with JWKS refresh retry on key not found
97
+ try:
98
+ claims = await self._decode_and_validate(id_token, config, jwks)
99
+ except TokenValidationError as e:
100
+ # Check if this might be a key rotation issue (kid not found in JWKS)
101
+ error_msg = str(e).lower()
102
+ if "key" in error_msg or "kid" in error_msg or "signature" in error_msg:
103
+ logger.info(
104
+ "Token validation failed, retrying with fresh JWKS for provider=%s",
105
+ provider_name,
106
+ )
107
+ # Fetch fresh JWKS bypassing cache
108
+ jwks = await self._provider_manager.get_jwks(provider_name, force_refresh=True)
109
+ # Retry validation with fresh JWKS
110
+ claims = await self._decode_and_validate(id_token, config, jwks)
111
+ else:
112
+ raise
113
+
114
+ # Validate nonce if provided
115
+ if expected_nonce is not None:
116
+ token_nonce = claims.get("nonce")
117
+ if token_nonce != expected_nonce:
118
+ logger.warning(
119
+ "Nonce mismatch for provider=%s: expected=%s, got=%s",
120
+ provider_name,
121
+ expected_nonce,
122
+ token_nonce,
123
+ )
124
+ raise InvalidNonceError()
125
+
126
+ # Validate at_hash if required
127
+ if verify_at_hash and access_token:
128
+ await self._verify_at_hash(claims, access_token, id_token)
129
+
130
+ # Extract identity information
131
+ identity = self._extract_identity(claims, provider_name)
132
+
133
+ logger.debug(
134
+ "ID token validated successfully: provider=%s, subject=%s",
135
+ provider_name,
136
+ identity.subject,
137
+ )
138
+
139
+ return identity
140
+
141
+ async def _decode_and_validate(
142
+ self,
143
+ id_token: str,
144
+ config: OIDCProviderConfig,
145
+ jwks: dict[str, Any],
146
+ ) -> dict[str, Any]:
147
+ """Decode and validate the JWT.
148
+
149
+ Args:
150
+ id_token: The raw ID token.
151
+ config: The provider configuration.
152
+ jwks: The provider's JWKS.
153
+
154
+ Returns:
155
+ The validated claims.
156
+
157
+ Raises:
158
+ TokenValidationError: If validation fails.
159
+ """
160
+ # Determine expected audience
161
+ audience = config.audience or config.client_id
162
+
163
+ try:
164
+ # First, decode without verification to get the header
165
+ unverified = jwt.get_unverified_header(id_token)
166
+ algorithm = unverified.get("alg", "RS256")
167
+
168
+ # Decode and verify the token
169
+ claims = jwt.decode(
170
+ id_token,
171
+ jwks,
172
+ algorithms=[algorithm],
173
+ audience=audience,
174
+ issuer=config.issuer,
175
+ options={
176
+ "verify_aud": True,
177
+ "verify_iss": True,
178
+ "verify_exp": True,
179
+ "verify_iat": True,
180
+ "verify_nbf": False, # Not all providers include nbf
181
+ "leeway": config.clock_skew_seconds,
182
+ },
183
+ )
184
+ except ExpiredSignatureError as e:
185
+ logger.warning("Token expired for provider=%s", config.name)
186
+ raise TokenExpiredError() from e
187
+ except JWTClaimsError as e:
188
+ error_msg = str(e).lower()
189
+ if "audience" in error_msg:
190
+ # Try to extract actual audience from token
191
+ try:
192
+ unverified_claims = jwt.get_unverified_claims(id_token)
193
+ actual_aud = unverified_claims.get("aud", "unknown")
194
+ except Exception:
195
+ actual_aud = "unknown"
196
+ logger.warning(
197
+ "Audience mismatch for provider=%s: expected=%s, got=%s",
198
+ config.name,
199
+ audience,
200
+ actual_aud,
201
+ )
202
+ raise InvalidAudienceError(audience, actual_aud) from e
203
+ if "issuer" in error_msg:
204
+ try:
205
+ unverified_claims = jwt.get_unverified_claims(id_token)
206
+ actual_iss = unverified_claims.get("iss", "unknown")
207
+ except Exception:
208
+ actual_iss = "unknown"
209
+ logger.warning(
210
+ "Issuer mismatch for provider=%s: expected=%s, got=%s",
211
+ config.name,
212
+ config.issuer,
213
+ actual_iss,
214
+ )
215
+ raise InvalidIssuerError(config.issuer, actual_iss) from e
216
+ logger.warning("Token claims validation failed for provider=%s: %s", config.name, str(e))
217
+ raise TokenValidationError(f"Token claims validation failed: {e}") from e
218
+ except JWTError as e:
219
+ logger.error("JWT validation failed for provider=%s: %s", config.name, str(e))
220
+ raise TokenValidationError(f"Token validation failed: {e}") from e
221
+
222
+ # Validate required claims
223
+ if "sub" not in claims:
224
+ logger.error("Token missing required 'sub' claim for provider=%s", config.name)
225
+ raise TokenValidationError("Token missing required 'sub' claim")
226
+
227
+ return claims
228
+
229
+ async def _verify_at_hash(
230
+ self,
231
+ claims: dict[str, Any],
232
+ access_token: str,
233
+ id_token: str,
234
+ ) -> None:
235
+ """Verify the at_hash claim matches the access token.
236
+
237
+ Args:
238
+ claims: The ID token claims.
239
+ access_token: The access token.
240
+ id_token: The ID token (to determine algorithm).
241
+
242
+ Raises:
243
+ TokenValidationError: If at_hash verification fails.
244
+ """
245
+
246
+ at_hash = claims.get("at_hash")
247
+ if at_hash is None:
248
+ # at_hash is optional, but if verify_at_hash is True and it's missing,
249
+ # we should raise an error
250
+ raise TokenValidationError("Token missing 'at_hash' claim")
251
+
252
+ # Get the algorithm from the token header
253
+ header = jwt.get_unverified_header(id_token)
254
+ algorithm = header.get("alg", "RS256")
255
+
256
+ # Determine hash algorithm based on JWT algorithm
257
+ if algorithm.startswith("RS") or algorithm.startswith("PS"):
258
+ hash_alg = f"sha{algorithm[2:]}"
259
+ elif algorithm.startswith("ES"):
260
+ hash_alg = f"sha{algorithm[2:]}"
261
+ elif algorithm.startswith("HS"):
262
+ hash_alg = f"sha{algorithm[2:]}"
263
+ else:
264
+ hash_alg = "sha256"
265
+
266
+ # Calculate expected at_hash
267
+ try:
268
+ hash_func = hashlib.new(hash_alg)
269
+ hash_func.update(access_token.encode("ascii"))
270
+ digest = hash_func.digest()
271
+ expected_at_hash = base64.urlsafe_b64encode(digest[: len(digest) // 2]).rstrip(b"=")
272
+ expected_at_hash_str = expected_at_hash.decode("ascii")
273
+ except Exception as e:
274
+ raise TokenValidationError(f"Failed to calculate at_hash: {e}") from e
275
+
276
+ if at_hash != expected_at_hash_str:
277
+ raise TokenValidationError("at_hash claim does not match access token")
278
+
279
+ def _extract_identity(
280
+ self,
281
+ claims: dict[str, Any],
282
+ provider_name: str,
283
+ ) -> IdentityInfo:
284
+ """Extract identity information from claims.
285
+
286
+ Args:
287
+ claims: The validated token claims.
288
+ provider_name: The provider name.
289
+
290
+ Returns:
291
+ The extracted identity information.
292
+ """
293
+ # Extract extra claims (non-standard)
294
+ extra_claims = {
295
+ k: v for k, v in claims.items() if k not in self.STANDARD_CLAIMS
296
+ }
297
+
298
+ return IdentityInfo(
299
+ provider_name=provider_name,
300
+ subject=claims["sub"],
301
+ email=claims.get("email"),
302
+ email_verified=claims.get("email_verified", False),
303
+ name=claims.get("name"),
304
+ given_name=claims.get("given_name"),
305
+ family_name=claims.get("family_name"),
306
+ picture=claims.get("picture"),
307
+ locale=claims.get("locale"),
308
+ extra_claims=extra_claims,
309
+ )
310
+
311
+
312
+ async def validate_id_token(
313
+ id_token: str,
314
+ provider_manager: ProviderManager,
315
+ provider_name: str,
316
+ expected_nonce: str | None = None,
317
+ verify_at_hash: bool = False,
318
+ access_token: str | None = None,
319
+ ) -> IdentityInfo:
320
+ """Convenience function to validate an ID token.
321
+
322
+ Args:
323
+ id_token: The raw ID token (JWT).
324
+ provider_manager: The provider manager.
325
+ provider_name: The provider name.
326
+ expected_nonce: The expected nonce value.
327
+ verify_at_hash: Whether to verify the at_hash claim.
328
+ access_token: The access token (required if verify_at_hash is True).
329
+
330
+ Returns:
331
+ Extracted identity information.
332
+ """
333
+ validator = TokenValidator(provider_manager)
334
+ return await validator.validate(
335
+ id_token,
336
+ provider_name,
337
+ expected_nonce,
338
+ verify_at_hash,
339
+ access_token,
340
+ )