byoi 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
byoi/providers.py ADDED
@@ -0,0 +1,434 @@
1
+ """OIDC Provider management and discovery.
2
+
3
+ Handles OIDC discovery, JWKS fetching, and provider configuration.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ import asyncio
9
+ import logging
10
+ from dataclasses import dataclass
11
+ from typing import TYPE_CHECKING, Any, Self
12
+ from urllib.parse import urljoin
13
+
14
+ import httpx
15
+
16
+ from byoi.cache import CacheProtocol, NullCache
17
+ from byoi.errors import JWKSFetchError, ProviderDiscoveryError, ProviderNotFoundError
18
+ from byoi.models import OIDCProviderConfig, OIDCProviderInfo
19
+
20
+ if TYPE_CHECKING:
21
+ from types import TracebackType
22
+
23
+ logger = logging.getLogger("byoi.providers")
24
+
25
+ __all__ = (
26
+ "OIDCDiscoveryDocument",
27
+ "ProviderManager",
28
+ )
29
+
30
+
31
+ @dataclass(slots=True)
32
+ class OIDCDiscoveryDocument:
33
+ """OIDC Discovery document (OpenID Connect Discovery 1.0)."""
34
+
35
+ issuer: str
36
+ authorization_endpoint: str
37
+ token_endpoint: str
38
+ userinfo_endpoint: str | None
39
+ jwks_uri: str
40
+ scopes_supported: list[str]
41
+ response_types_supported: list[str]
42
+ subject_types_supported: list[str]
43
+ id_token_signing_alg_values_supported: list[str]
44
+ claims_supported: list[str]
45
+ raw: dict[str, Any]
46
+
47
+ @classmethod
48
+ def from_dict(cls, data: dict[str, Any]) -> OIDCDiscoveryDocument:
49
+ """Create a discovery document from a dictionary."""
50
+ return cls(
51
+ issuer=data["issuer"],
52
+ authorization_endpoint=data["authorization_endpoint"],
53
+ token_endpoint=data["token_endpoint"],
54
+ userinfo_endpoint=data.get("userinfo_endpoint"),
55
+ jwks_uri=data["jwks_uri"],
56
+ scopes_supported=data.get("scopes_supported", ["openid"]),
57
+ response_types_supported=data.get("response_types_supported", ["code"]),
58
+ subject_types_supported=data.get("subject_types_supported", ["public"]),
59
+ id_token_signing_alg_values_supported=data.get(
60
+ "id_token_signing_alg_values_supported", ["RS256"]
61
+ ),
62
+ claims_supported=data.get("claims_supported", []),
63
+ raw=data,
64
+ )
65
+
66
+
67
+ class ProviderManager:
68
+ """Manages OIDC providers, discovery, and JWKS caching."""
69
+
70
+ # Cache key prefixes
71
+ DISCOVERY_CACHE_PREFIX = "discovery:"
72
+ JWKS_CACHE_PREFIX = "jwks:"
73
+
74
+ # Default TTLs
75
+ DEFAULT_DISCOVERY_TTL = 3600 # 1 hour
76
+ DEFAULT_JWKS_TTL = 3600 # 1 hour
77
+
78
+ def __init__(
79
+ self,
80
+ cache: CacheProtocol | None = None,
81
+ http_client: httpx.AsyncClient | None = None,
82
+ discovery_ttl_seconds: int = DEFAULT_DISCOVERY_TTL,
83
+ jwks_ttl_seconds: int = DEFAULT_JWKS_TTL,
84
+ http_timeout_seconds: float = 30.0,
85
+ ) -> None:
86
+ """Initialize the provider manager.
87
+
88
+ Args:
89
+ cache: Cache implementation for discovery docs and JWKS.
90
+ http_client: HTTP client for making requests.
91
+ discovery_ttl_seconds: TTL for cached discovery documents.
92
+ jwks_ttl_seconds: TTL for cached JWKS.
93
+ http_timeout_seconds: Timeout for HTTP requests in seconds.
94
+ """
95
+ self._providers: dict[str, OIDCProviderConfig] = {}
96
+ self._cache = cache or NullCache()
97
+ self._http_client = http_client
98
+ self._owns_http_client = http_client is None
99
+ self._discovery_ttl = discovery_ttl_seconds
100
+ self._jwks_ttl = jwks_ttl_seconds
101
+ self._http_timeout_seconds = http_timeout_seconds
102
+ # Locks for request deduplication to prevent concurrent fetches for the same provider
103
+ self._discovery_locks: dict[str, asyncio.Lock] = {}
104
+ self._jwks_locks: dict[str, asyncio.Lock] = {}
105
+ self._locks_lock = asyncio.Lock()
106
+
107
+ async def __aenter__(self) -> Self:
108
+ """Enter the async context manager."""
109
+ return self
110
+
111
+ async def __aexit__(
112
+ self,
113
+ exc_type: type[BaseException] | None,
114
+ exc_val: BaseException | None,
115
+ exc_tb: TracebackType | None,
116
+ ) -> None:
117
+ """Exit the async context manager and close the provider manager."""
118
+ await self.close()
119
+
120
+ async def _get_lock(self, locks_dict: dict[str, asyncio.Lock], key: str) -> asyncio.Lock:
121
+ """Get or create a lock for request deduplication."""
122
+ async with self._locks_lock:
123
+ if key not in locks_dict:
124
+ locks_dict[key] = asyncio.Lock()
125
+ return locks_dict[key]
126
+
127
+ async def _get_http_client(self) -> httpx.AsyncClient:
128
+ """Get or create the HTTP client."""
129
+ if self._http_client is None:
130
+ self._http_client = httpx.AsyncClient(
131
+ timeout=httpx.Timeout(self._http_timeout_seconds),
132
+ follow_redirects=True,
133
+ )
134
+ return self._http_client
135
+
136
+ def register_provider(self, config: OIDCProviderConfig) -> None:
137
+ """Register an OIDC provider.
138
+
139
+ Args:
140
+ config: The provider configuration.
141
+
142
+ Raises:
143
+ ValueError: If a provider with the same name is already registered.
144
+ """
145
+ if config.name in self._providers:
146
+ raise ValueError(f"Provider '{config.name}' is already registered")
147
+ self._providers[config.name] = config
148
+
149
+ def unregister_provider(self, name: str) -> bool:
150
+ """Unregister an OIDC provider.
151
+
152
+ Args:
153
+ name: The provider name.
154
+
155
+ Returns:
156
+ True if the provider was unregistered, False if not found.
157
+ """
158
+ if name in self._providers:
159
+ del self._providers[name]
160
+ return True
161
+ return False
162
+
163
+ def get_provider(self, name: str) -> OIDCProviderConfig:
164
+ """Get a provider configuration by name.
165
+
166
+ Args:
167
+ name: The provider name.
168
+
169
+ Returns:
170
+ The provider configuration.
171
+
172
+ Raises:
173
+ ProviderNotFoundError: If the provider is not registered.
174
+ """
175
+ if name not in self._providers:
176
+ raise ProviderNotFoundError(name)
177
+ return self._providers[name]
178
+
179
+ def get_provider_info(self, name: str) -> OIDCProviderInfo:
180
+ """Get public information about a provider.
181
+
182
+ Args:
183
+ name: The provider name.
184
+
185
+ Returns:
186
+ Public provider information.
187
+
188
+ Raises:
189
+ ProviderNotFoundError: If the provider is not registered.
190
+ """
191
+ config = self.get_provider(name)
192
+ return OIDCProviderInfo(
193
+ name=config.name,
194
+ display_name=config.display_name,
195
+ issuer=config.issuer,
196
+ )
197
+
198
+ def list_providers(self) -> list[OIDCProviderInfo]:
199
+ """List all registered providers.
200
+
201
+ Returns:
202
+ List of public provider information.
203
+ """
204
+ return [
205
+ OIDCProviderInfo(
206
+ name=config.name,
207
+ display_name=config.display_name,
208
+ issuer=config.issuer,
209
+ )
210
+ for config in self._providers.values()
211
+ ]
212
+
213
+ async def discover(self, provider_name: str) -> OIDCDiscoveryDocument:
214
+ """Discover OIDC endpoints for a provider.
215
+
216
+ Args:
217
+ provider_name: The provider name.
218
+
219
+ Returns:
220
+ The OIDC discovery document.
221
+
222
+ Raises:
223
+ ProviderNotFoundError: If the provider is not registered.
224
+ ProviderDiscoveryError: If discovery fails.
225
+ """
226
+ config = self.get_provider(provider_name)
227
+
228
+ # Check cache first (before acquiring lock for better performance)
229
+ cache_key = f"{self.DISCOVERY_CACHE_PREFIX}{provider_name}"
230
+ cached = await self._cache.get(cache_key)
231
+ if cached is not None:
232
+ logger.debug("Using cached discovery document for provider=%s", provider_name)
233
+ return OIDCDiscoveryDocument.from_dict(cached)
234
+
235
+ # Acquire per-provider lock to deduplicate concurrent requests
236
+ lock = await self._get_lock(self._discovery_locks, provider_name)
237
+ async with lock:
238
+ # Check cache again after acquiring lock (another request may have populated it)
239
+ cached = await self._cache.get(cache_key)
240
+ if cached is not None:
241
+ logger.debug("Using cached discovery document for provider=%s", provider_name)
242
+ return OIDCDiscoveryDocument.from_dict(cached)
243
+
244
+ # Fetch discovery document
245
+ discovery_url = urljoin(
246
+ config.issuer.rstrip("/") + "/",
247
+ ".well-known/openid-configuration",
248
+ )
249
+
250
+ logger.debug("Fetching discovery document from %s", discovery_url)
251
+
252
+ try:
253
+ client = await self._get_http_client()
254
+ response = await client.get(discovery_url)
255
+ response.raise_for_status()
256
+ data = response.json()
257
+ except httpx.HTTPError as e:
258
+ logger.error(
259
+ "Failed to fetch discovery document for provider=%s: %s",
260
+ provider_name,
261
+ str(e),
262
+ )
263
+ raise ProviderDiscoveryError(provider_name, str(e)) from e
264
+ except Exception as e:
265
+ logger.error(
266
+ "Unexpected error fetching discovery document for provider=%s: %s",
267
+ provider_name,
268
+ str(e),
269
+ )
270
+ raise ProviderDiscoveryError(provider_name, str(e)) from e
271
+
272
+ # Validate required fields
273
+ required_fields = [
274
+ "issuer",
275
+ "authorization_endpoint",
276
+ "token_endpoint",
277
+ "jwks_uri",
278
+ ]
279
+ for field in required_fields:
280
+ if field not in data:
281
+ logger.error(
282
+ "Discovery document missing required field '%s' for provider=%s",
283
+ field,
284
+ provider_name,
285
+ )
286
+ raise ProviderDiscoveryError(
287
+ provider_name, f"Missing required field: {field}"
288
+ )
289
+
290
+ # Cache the discovery document
291
+ await self._cache.set(cache_key, data, self._discovery_ttl)
292
+
293
+ logger.info("Discovery document fetched and cached for provider=%s", provider_name)
294
+
295
+ return OIDCDiscoveryDocument.from_dict(data)
296
+
297
+ async def get_jwks(self, provider_name: str, *, force_refresh: bool = False) -> dict[str, Any]:
298
+ """Get the JWKS for a provider.
299
+
300
+ Args:
301
+ provider_name: The provider name.
302
+ force_refresh: If True, bypass the cache and fetch fresh JWKS.
303
+ Useful for handling key rotation when a kid is not found.
304
+
305
+ Returns:
306
+ The JWKS as a dictionary.
307
+
308
+ Raises:
309
+ ProviderNotFoundError: If the provider is not registered.
310
+ JWKSFetchError: If fetching JWKS fails.
311
+ """
312
+ config = self.get_provider(provider_name)
313
+
314
+ cache_key = f"{self.JWKS_CACHE_PREFIX}{provider_name}"
315
+
316
+ # Check cache first (unless force_refresh is requested)
317
+ if not force_refresh:
318
+ cached = await self._cache.get(cache_key)
319
+ if cached is not None:
320
+ logger.debug("Using cached JWKS for provider=%s", provider_name)
321
+ return cached
322
+
323
+ # Acquire per-provider lock to deduplicate concurrent requests
324
+ lock = await self._get_lock(self._jwks_locks, provider_name)
325
+ async with lock:
326
+ # Check cache again after acquiring lock (unless force_refresh)
327
+ if not force_refresh:
328
+ cached = await self._cache.get(cache_key)
329
+ if cached is not None:
330
+ logger.debug("Using cached JWKS for provider=%s", provider_name)
331
+ return cached
332
+
333
+ # Get JWKS URI from config override or discovery
334
+ if config.jwks_uri_override:
335
+ jwks_uri = config.jwks_uri_override
336
+ else:
337
+ discovery = await self.discover(provider_name)
338
+ jwks_uri = discovery.jwks_uri
339
+
340
+ logger.debug("Fetching JWKS from %s for provider=%s", jwks_uri, provider_name)
341
+
342
+ # Fetch JWKS
343
+ try:
344
+ client = await self._get_http_client()
345
+ response = await client.get(jwks_uri)
346
+ response.raise_for_status()
347
+ jwks = response.json()
348
+ except httpx.HTTPError as e:
349
+ logger.error("Failed to fetch JWKS for provider=%s: %s", provider_name, str(e))
350
+ raise JWKSFetchError(provider_name, str(e)) from e
351
+ except Exception as e:
352
+ logger.error("Unexpected error fetching JWKS for provider=%s: %s", provider_name, str(e))
353
+ raise JWKSFetchError(provider_name, str(e)) from e
354
+
355
+ # Validate JWKS structure
356
+ if "keys" not in jwks:
357
+ logger.error("Invalid JWKS structure for provider=%s: missing 'keys' field", provider_name)
358
+ raise JWKSFetchError(provider_name, "Invalid JWKS: missing 'keys' field")
359
+
360
+ # Cache the JWKS
361
+ await self._cache.set(cache_key, jwks, self._jwks_ttl)
362
+
363
+ logger.info("JWKS fetched and cached for provider=%s", provider_name)
364
+
365
+ return jwks
366
+
367
+ async def get_authorization_endpoint(self, provider_name: str) -> str:
368
+ """Get the authorization endpoint for a provider.
369
+
370
+ Args:
371
+ provider_name: The provider name.
372
+
373
+ Returns:
374
+ The authorization endpoint URL.
375
+ """
376
+ config = self.get_provider(provider_name)
377
+ if config.authorization_endpoint_override:
378
+ return config.authorization_endpoint_override
379
+
380
+ discovery = await self.discover(provider_name)
381
+ return discovery.authorization_endpoint
382
+
383
+ async def get_token_endpoint(self, provider_name: str) -> str:
384
+ """Get the token endpoint for a provider.
385
+
386
+ Args:
387
+ provider_name: The provider name.
388
+
389
+ Returns:
390
+ The token endpoint URL.
391
+ """
392
+ config = self.get_provider(provider_name)
393
+ if config.token_endpoint_override:
394
+ return config.token_endpoint_override
395
+
396
+ discovery = await self.discover(provider_name)
397
+ return discovery.token_endpoint
398
+
399
+ async def get_userinfo_endpoint(self, provider_name: str) -> str | None:
400
+ """Get the userinfo endpoint for a provider.
401
+
402
+ Args:
403
+ provider_name: The provider name.
404
+
405
+ Returns:
406
+ The userinfo endpoint URL, or None if not available.
407
+ """
408
+ config = self.get_provider(provider_name)
409
+ if config.userinfo_endpoint_override:
410
+ return config.userinfo_endpoint_override
411
+
412
+ discovery = await self.discover(provider_name)
413
+ return discovery.userinfo_endpoint
414
+
415
+ async def invalidate_cache(self, provider_name: str | None = None) -> None:
416
+ """Invalidate cached data for a provider or all providers.
417
+
418
+ Args:
419
+ provider_name: The provider name, or None to invalidate all.
420
+ """
421
+ if provider_name:
422
+ await self._cache.delete(f"{self.DISCOVERY_CACHE_PREFIX}{provider_name}")
423
+ await self._cache.delete(f"{self.JWKS_CACHE_PREFIX}{provider_name}")
424
+ else:
425
+ # Invalidate all providers
426
+ for name in self._providers:
427
+ await self._cache.delete(f"{self.DISCOVERY_CACHE_PREFIX}{name}")
428
+ await self._cache.delete(f"{self.JWKS_CACHE_PREFIX}{name}")
429
+
430
+ async def close(self) -> None:
431
+ """Close the provider manager and release resources."""
432
+ if self._owns_http_client and self._http_client is not None:
433
+ await self._http_client.aclose()
434
+ self._http_client = None
byoi/py.typed ADDED
File without changes
byoi/repositories.py ADDED
@@ -0,0 +1,252 @@
1
+ """Repository protocols for the BYOI library.
2
+
3
+ These protocols define the interfaces that implementing applications must satisfy
4
+ for data persistence. The implementing application is responsible for providing
5
+ concrete implementations backed by their database of choice.
6
+ """
7
+
8
+ from datetime import datetime
9
+ from typing import Any, Protocol, runtime_checkable
10
+
11
+ from byoi.types import AuthStateProtocol, LinkedIdentityProtocol, UserProtocol
12
+
13
+ __all__ = (
14
+ "UserRepositoryProtocol",
15
+ "LinkedIdentityRepositoryProtocol",
16
+ "AuthStateRepositoryProtocol",
17
+ )
18
+
19
+
20
+ @runtime_checkable
21
+ class UserRepositoryProtocol(Protocol):
22
+ """Protocol for user data persistence.
23
+
24
+ Implementing applications must provide a repository that satisfies this protocol
25
+ for managing user entities.
26
+ """
27
+
28
+ async def get_by_id(self, user_id: str) -> UserProtocol | None:
29
+ """Retrieve a user by their unique identifier.
30
+
31
+ Args:
32
+ user_id: The unique identifier of the user.
33
+
34
+ Returns:
35
+ The user if found, None otherwise.
36
+ """
37
+ ...
38
+
39
+ async def get_by_email(self, email: str) -> UserProtocol | None:
40
+ """Retrieve a user by their email address.
41
+
42
+ Args:
43
+ email: The email address to search for.
44
+
45
+ Returns:
46
+ The user if found, None otherwise.
47
+ """
48
+ ...
49
+
50
+ async def create(
51
+ self,
52
+ email: str | None = None,
53
+ is_active: bool = True,
54
+ **extra_data: Any,
55
+ ) -> UserProtocol:
56
+ """Create a new user.
57
+
58
+ Args:
59
+ email: The user's email address.
60
+ is_active: Whether the user account is active.
61
+ **extra_data: Additional data to store with the user.
62
+
63
+ Returns:
64
+ The newly created user.
65
+ """
66
+ ...
67
+
68
+ async def update(self, user_id: str, **data: Any) -> UserProtocol | None:
69
+ """Update a user's data.
70
+
71
+ Args:
72
+ user_id: The unique identifier of the user to update.
73
+ **data: The data to update.
74
+
75
+ Returns:
76
+ The updated user if found, None otherwise.
77
+ """
78
+ ...
79
+
80
+
81
+ @runtime_checkable
82
+ class LinkedIdentityRepositoryProtocol(Protocol):
83
+ """Protocol for linked identity data persistence.
84
+
85
+ Implementing applications must provide a repository that satisfies this protocol
86
+ for managing linked identity entities.
87
+ """
88
+
89
+ async def get_by_id(self, identity_id: str) -> LinkedIdentityProtocol | None:
90
+ """Retrieve a linked identity by its unique identifier.
91
+
92
+ Args:
93
+ identity_id: The unique identifier of the linked identity.
94
+
95
+ Returns:
96
+ The linked identity if found, None otherwise.
97
+ """
98
+ ...
99
+
100
+ async def get_by_provider(
101
+ self,
102
+ provider_name: str,
103
+ provider_subject: str,
104
+ ) -> LinkedIdentityProtocol | None:
105
+ """Retrieve a linked identity by provider name and subject.
106
+
107
+ Args:
108
+ provider_name: The name of the identity provider.
109
+ provider_subject: The subject identifier from the provider.
110
+
111
+ Returns:
112
+ The linked identity if found, None otherwise.
113
+ """
114
+ ...
115
+
116
+ async def get_by_user_id(self, user_id: str) -> list[LinkedIdentityProtocol]:
117
+ """Retrieve all linked identities for a user.
118
+
119
+ Args:
120
+ user_id: The unique identifier of the user.
121
+
122
+ Returns:
123
+ A list of linked identities for the user.
124
+ """
125
+ ...
126
+
127
+ async def create(
128
+ self,
129
+ user_id: str,
130
+ provider_name: str,
131
+ provider_subject: str,
132
+ email: str | None = None,
133
+ **extra_data: Any,
134
+ ) -> LinkedIdentityProtocol:
135
+ """Create a new linked identity.
136
+
137
+ Args:
138
+ user_id: The ID of the user to link.
139
+ provider_name: The name of the identity provider.
140
+ provider_subject: The subject identifier from the provider.
141
+ email: The email associated with this identity.
142
+ **extra_data: Additional data to store with the identity.
143
+
144
+ Returns:
145
+ The newly created linked identity.
146
+ """
147
+ ...
148
+
149
+ async def update_last_used(
150
+ self,
151
+ identity_id: str,
152
+ last_used_at: datetime,
153
+ ) -> LinkedIdentityProtocol | None:
154
+ """Update the last used timestamp for a linked identity.
155
+
156
+ Args:
157
+ identity_id: The unique identifier of the linked identity.
158
+ last_used_at: The timestamp to set.
159
+
160
+ Returns:
161
+ The updated linked identity if found, None otherwise.
162
+ """
163
+ ...
164
+
165
+ async def delete(self, identity_id: str) -> bool:
166
+ """Delete a linked identity.
167
+
168
+ Args:
169
+ identity_id: The unique identifier of the linked identity to delete.
170
+
171
+ Returns:
172
+ True if the identity was deleted, False if not found.
173
+ """
174
+ ...
175
+
176
+ async def delete_by_user_id(self, user_id: str) -> int:
177
+ """Delete all linked identities for a user.
178
+
179
+ Args:
180
+ user_id: The unique identifier of the user.
181
+
182
+ Returns:
183
+ The number of identities deleted.
184
+ """
185
+ ...
186
+
187
+
188
+ @runtime_checkable
189
+ class AuthStateRepositoryProtocol(Protocol):
190
+ """Protocol for auth state data persistence.
191
+
192
+ Implementing applications must provide a repository that satisfies this protocol
193
+ for managing OAuth flow state (PKCE, nonce, etc.).
194
+ """
195
+
196
+ async def create(
197
+ self,
198
+ state: str,
199
+ code_verifier: str,
200
+ nonce: str,
201
+ provider_name: str,
202
+ redirect_uri: str,
203
+ expires_at: datetime,
204
+ client_type: str = "web",
205
+ extra_data: dict[str, Any] | None = None,
206
+ ) -> AuthStateProtocol:
207
+ """Create a new auth state.
208
+
209
+ Args:
210
+ state: The OAuth state parameter.
211
+ code_verifier: The PKCE code verifier.
212
+ nonce: The nonce for ID token validation.
213
+ provider_name: The name of the provider.
214
+ redirect_uri: The redirect URI for this auth flow.
215
+ expires_at: When this auth state expires.
216
+ client_type: Type of client (web, desktop, mobile).
217
+ extra_data: Additional data to store.
218
+
219
+ Returns:
220
+ The newly created auth state.
221
+ """
222
+ ...
223
+
224
+ async def get_by_state(self, state: str) -> AuthStateProtocol | None:
225
+ """Retrieve an auth state by its state parameter.
226
+
227
+ Args:
228
+ state: The OAuth state parameter.
229
+
230
+ Returns:
231
+ The auth state if found, None otherwise.
232
+ """
233
+ ...
234
+
235
+ async def delete(self, state: str) -> bool:
236
+ """Delete an auth state.
237
+
238
+ Args:
239
+ state: The OAuth state parameter.
240
+
241
+ Returns:
242
+ True if the state was deleted, False if not found.
243
+ """
244
+ ...
245
+
246
+ async def delete_expired(self) -> int:
247
+ """Delete all expired auth states.
248
+
249
+ Returns:
250
+ The number of states deleted.
251
+ """
252
+ ...