byoi 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- byoi/__init__.py +233 -0
- byoi/__main__.py +228 -0
- byoi/cache.py +346 -0
- byoi/config.py +349 -0
- byoi/dependencies.py +144 -0
- byoi/errors.py +360 -0
- byoi/models.py +451 -0
- byoi/pkce.py +144 -0
- byoi/providers.py +434 -0
- byoi/py.typed +0 -0
- byoi/repositories.py +252 -0
- byoi/service.py +723 -0
- byoi/telemetry.py +352 -0
- byoi/tokens.py +340 -0
- byoi/types.py +130 -0
- byoi-0.1.0a1.dist-info/METADATA +504 -0
- byoi-0.1.0a1.dist-info/RECORD +20 -0
- byoi-0.1.0a1.dist-info/WHEEL +4 -0
- byoi-0.1.0a1.dist-info/entry_points.txt +3 -0
- byoi-0.1.0a1.dist-info/licenses/LICENSE +21 -0
byoi/providers.py
ADDED
|
@@ -0,0 +1,434 @@
|
|
|
1
|
+
"""OIDC Provider management and discovery.
|
|
2
|
+
|
|
3
|
+
Handles OIDC discovery, JWKS fetching, and provider configuration.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import asyncio
|
|
9
|
+
import logging
|
|
10
|
+
from dataclasses import dataclass
|
|
11
|
+
from typing import TYPE_CHECKING, Any, Self
|
|
12
|
+
from urllib.parse import urljoin
|
|
13
|
+
|
|
14
|
+
import httpx
|
|
15
|
+
|
|
16
|
+
from byoi.cache import CacheProtocol, NullCache
|
|
17
|
+
from byoi.errors import JWKSFetchError, ProviderDiscoveryError, ProviderNotFoundError
|
|
18
|
+
from byoi.models import OIDCProviderConfig, OIDCProviderInfo
|
|
19
|
+
|
|
20
|
+
if TYPE_CHECKING:
|
|
21
|
+
from types import TracebackType
|
|
22
|
+
|
|
23
|
+
logger = logging.getLogger("byoi.providers")
|
|
24
|
+
|
|
25
|
+
__all__ = (
|
|
26
|
+
"OIDCDiscoveryDocument",
|
|
27
|
+
"ProviderManager",
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass(slots=True)
|
|
32
|
+
class OIDCDiscoveryDocument:
|
|
33
|
+
"""OIDC Discovery document (OpenID Connect Discovery 1.0)."""
|
|
34
|
+
|
|
35
|
+
issuer: str
|
|
36
|
+
authorization_endpoint: str
|
|
37
|
+
token_endpoint: str
|
|
38
|
+
userinfo_endpoint: str | None
|
|
39
|
+
jwks_uri: str
|
|
40
|
+
scopes_supported: list[str]
|
|
41
|
+
response_types_supported: list[str]
|
|
42
|
+
subject_types_supported: list[str]
|
|
43
|
+
id_token_signing_alg_values_supported: list[str]
|
|
44
|
+
claims_supported: list[str]
|
|
45
|
+
raw: dict[str, Any]
|
|
46
|
+
|
|
47
|
+
@classmethod
|
|
48
|
+
def from_dict(cls, data: dict[str, Any]) -> OIDCDiscoveryDocument:
|
|
49
|
+
"""Create a discovery document from a dictionary."""
|
|
50
|
+
return cls(
|
|
51
|
+
issuer=data["issuer"],
|
|
52
|
+
authorization_endpoint=data["authorization_endpoint"],
|
|
53
|
+
token_endpoint=data["token_endpoint"],
|
|
54
|
+
userinfo_endpoint=data.get("userinfo_endpoint"),
|
|
55
|
+
jwks_uri=data["jwks_uri"],
|
|
56
|
+
scopes_supported=data.get("scopes_supported", ["openid"]),
|
|
57
|
+
response_types_supported=data.get("response_types_supported", ["code"]),
|
|
58
|
+
subject_types_supported=data.get("subject_types_supported", ["public"]),
|
|
59
|
+
id_token_signing_alg_values_supported=data.get(
|
|
60
|
+
"id_token_signing_alg_values_supported", ["RS256"]
|
|
61
|
+
),
|
|
62
|
+
claims_supported=data.get("claims_supported", []),
|
|
63
|
+
raw=data,
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class ProviderManager:
|
|
68
|
+
"""Manages OIDC providers, discovery, and JWKS caching."""
|
|
69
|
+
|
|
70
|
+
# Cache key prefixes
|
|
71
|
+
DISCOVERY_CACHE_PREFIX = "discovery:"
|
|
72
|
+
JWKS_CACHE_PREFIX = "jwks:"
|
|
73
|
+
|
|
74
|
+
# Default TTLs
|
|
75
|
+
DEFAULT_DISCOVERY_TTL = 3600 # 1 hour
|
|
76
|
+
DEFAULT_JWKS_TTL = 3600 # 1 hour
|
|
77
|
+
|
|
78
|
+
def __init__(
|
|
79
|
+
self,
|
|
80
|
+
cache: CacheProtocol | None = None,
|
|
81
|
+
http_client: httpx.AsyncClient | None = None,
|
|
82
|
+
discovery_ttl_seconds: int = DEFAULT_DISCOVERY_TTL,
|
|
83
|
+
jwks_ttl_seconds: int = DEFAULT_JWKS_TTL,
|
|
84
|
+
http_timeout_seconds: float = 30.0,
|
|
85
|
+
) -> None:
|
|
86
|
+
"""Initialize the provider manager.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
cache: Cache implementation for discovery docs and JWKS.
|
|
90
|
+
http_client: HTTP client for making requests.
|
|
91
|
+
discovery_ttl_seconds: TTL for cached discovery documents.
|
|
92
|
+
jwks_ttl_seconds: TTL for cached JWKS.
|
|
93
|
+
http_timeout_seconds: Timeout for HTTP requests in seconds.
|
|
94
|
+
"""
|
|
95
|
+
self._providers: dict[str, OIDCProviderConfig] = {}
|
|
96
|
+
self._cache = cache or NullCache()
|
|
97
|
+
self._http_client = http_client
|
|
98
|
+
self._owns_http_client = http_client is None
|
|
99
|
+
self._discovery_ttl = discovery_ttl_seconds
|
|
100
|
+
self._jwks_ttl = jwks_ttl_seconds
|
|
101
|
+
self._http_timeout_seconds = http_timeout_seconds
|
|
102
|
+
# Locks for request deduplication to prevent concurrent fetches for the same provider
|
|
103
|
+
self._discovery_locks: dict[str, asyncio.Lock] = {}
|
|
104
|
+
self._jwks_locks: dict[str, asyncio.Lock] = {}
|
|
105
|
+
self._locks_lock = asyncio.Lock()
|
|
106
|
+
|
|
107
|
+
async def __aenter__(self) -> Self:
|
|
108
|
+
"""Enter the async context manager."""
|
|
109
|
+
return self
|
|
110
|
+
|
|
111
|
+
async def __aexit__(
|
|
112
|
+
self,
|
|
113
|
+
exc_type: type[BaseException] | None,
|
|
114
|
+
exc_val: BaseException | None,
|
|
115
|
+
exc_tb: TracebackType | None,
|
|
116
|
+
) -> None:
|
|
117
|
+
"""Exit the async context manager and close the provider manager."""
|
|
118
|
+
await self.close()
|
|
119
|
+
|
|
120
|
+
async def _get_lock(self, locks_dict: dict[str, asyncio.Lock], key: str) -> asyncio.Lock:
|
|
121
|
+
"""Get or create a lock for request deduplication."""
|
|
122
|
+
async with self._locks_lock:
|
|
123
|
+
if key not in locks_dict:
|
|
124
|
+
locks_dict[key] = asyncio.Lock()
|
|
125
|
+
return locks_dict[key]
|
|
126
|
+
|
|
127
|
+
async def _get_http_client(self) -> httpx.AsyncClient:
|
|
128
|
+
"""Get or create the HTTP client."""
|
|
129
|
+
if self._http_client is None:
|
|
130
|
+
self._http_client = httpx.AsyncClient(
|
|
131
|
+
timeout=httpx.Timeout(self._http_timeout_seconds),
|
|
132
|
+
follow_redirects=True,
|
|
133
|
+
)
|
|
134
|
+
return self._http_client
|
|
135
|
+
|
|
136
|
+
def register_provider(self, config: OIDCProviderConfig) -> None:
|
|
137
|
+
"""Register an OIDC provider.
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
config: The provider configuration.
|
|
141
|
+
|
|
142
|
+
Raises:
|
|
143
|
+
ValueError: If a provider with the same name is already registered.
|
|
144
|
+
"""
|
|
145
|
+
if config.name in self._providers:
|
|
146
|
+
raise ValueError(f"Provider '{config.name}' is already registered")
|
|
147
|
+
self._providers[config.name] = config
|
|
148
|
+
|
|
149
|
+
def unregister_provider(self, name: str) -> bool:
|
|
150
|
+
"""Unregister an OIDC provider.
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
name: The provider name.
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
True if the provider was unregistered, False if not found.
|
|
157
|
+
"""
|
|
158
|
+
if name in self._providers:
|
|
159
|
+
del self._providers[name]
|
|
160
|
+
return True
|
|
161
|
+
return False
|
|
162
|
+
|
|
163
|
+
def get_provider(self, name: str) -> OIDCProviderConfig:
|
|
164
|
+
"""Get a provider configuration by name.
|
|
165
|
+
|
|
166
|
+
Args:
|
|
167
|
+
name: The provider name.
|
|
168
|
+
|
|
169
|
+
Returns:
|
|
170
|
+
The provider configuration.
|
|
171
|
+
|
|
172
|
+
Raises:
|
|
173
|
+
ProviderNotFoundError: If the provider is not registered.
|
|
174
|
+
"""
|
|
175
|
+
if name not in self._providers:
|
|
176
|
+
raise ProviderNotFoundError(name)
|
|
177
|
+
return self._providers[name]
|
|
178
|
+
|
|
179
|
+
def get_provider_info(self, name: str) -> OIDCProviderInfo:
|
|
180
|
+
"""Get public information about a provider.
|
|
181
|
+
|
|
182
|
+
Args:
|
|
183
|
+
name: The provider name.
|
|
184
|
+
|
|
185
|
+
Returns:
|
|
186
|
+
Public provider information.
|
|
187
|
+
|
|
188
|
+
Raises:
|
|
189
|
+
ProviderNotFoundError: If the provider is not registered.
|
|
190
|
+
"""
|
|
191
|
+
config = self.get_provider(name)
|
|
192
|
+
return OIDCProviderInfo(
|
|
193
|
+
name=config.name,
|
|
194
|
+
display_name=config.display_name,
|
|
195
|
+
issuer=config.issuer,
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
def list_providers(self) -> list[OIDCProviderInfo]:
|
|
199
|
+
"""List all registered providers.
|
|
200
|
+
|
|
201
|
+
Returns:
|
|
202
|
+
List of public provider information.
|
|
203
|
+
"""
|
|
204
|
+
return [
|
|
205
|
+
OIDCProviderInfo(
|
|
206
|
+
name=config.name,
|
|
207
|
+
display_name=config.display_name,
|
|
208
|
+
issuer=config.issuer,
|
|
209
|
+
)
|
|
210
|
+
for config in self._providers.values()
|
|
211
|
+
]
|
|
212
|
+
|
|
213
|
+
async def discover(self, provider_name: str) -> OIDCDiscoveryDocument:
|
|
214
|
+
"""Discover OIDC endpoints for a provider.
|
|
215
|
+
|
|
216
|
+
Args:
|
|
217
|
+
provider_name: The provider name.
|
|
218
|
+
|
|
219
|
+
Returns:
|
|
220
|
+
The OIDC discovery document.
|
|
221
|
+
|
|
222
|
+
Raises:
|
|
223
|
+
ProviderNotFoundError: If the provider is not registered.
|
|
224
|
+
ProviderDiscoveryError: If discovery fails.
|
|
225
|
+
"""
|
|
226
|
+
config = self.get_provider(provider_name)
|
|
227
|
+
|
|
228
|
+
# Check cache first (before acquiring lock for better performance)
|
|
229
|
+
cache_key = f"{self.DISCOVERY_CACHE_PREFIX}{provider_name}"
|
|
230
|
+
cached = await self._cache.get(cache_key)
|
|
231
|
+
if cached is not None:
|
|
232
|
+
logger.debug("Using cached discovery document for provider=%s", provider_name)
|
|
233
|
+
return OIDCDiscoveryDocument.from_dict(cached)
|
|
234
|
+
|
|
235
|
+
# Acquire per-provider lock to deduplicate concurrent requests
|
|
236
|
+
lock = await self._get_lock(self._discovery_locks, provider_name)
|
|
237
|
+
async with lock:
|
|
238
|
+
# Check cache again after acquiring lock (another request may have populated it)
|
|
239
|
+
cached = await self._cache.get(cache_key)
|
|
240
|
+
if cached is not None:
|
|
241
|
+
logger.debug("Using cached discovery document for provider=%s", provider_name)
|
|
242
|
+
return OIDCDiscoveryDocument.from_dict(cached)
|
|
243
|
+
|
|
244
|
+
# Fetch discovery document
|
|
245
|
+
discovery_url = urljoin(
|
|
246
|
+
config.issuer.rstrip("/") + "/",
|
|
247
|
+
".well-known/openid-configuration",
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
logger.debug("Fetching discovery document from %s", discovery_url)
|
|
251
|
+
|
|
252
|
+
try:
|
|
253
|
+
client = await self._get_http_client()
|
|
254
|
+
response = await client.get(discovery_url)
|
|
255
|
+
response.raise_for_status()
|
|
256
|
+
data = response.json()
|
|
257
|
+
except httpx.HTTPError as e:
|
|
258
|
+
logger.error(
|
|
259
|
+
"Failed to fetch discovery document for provider=%s: %s",
|
|
260
|
+
provider_name,
|
|
261
|
+
str(e),
|
|
262
|
+
)
|
|
263
|
+
raise ProviderDiscoveryError(provider_name, str(e)) from e
|
|
264
|
+
except Exception as e:
|
|
265
|
+
logger.error(
|
|
266
|
+
"Unexpected error fetching discovery document for provider=%s: %s",
|
|
267
|
+
provider_name,
|
|
268
|
+
str(e),
|
|
269
|
+
)
|
|
270
|
+
raise ProviderDiscoveryError(provider_name, str(e)) from e
|
|
271
|
+
|
|
272
|
+
# Validate required fields
|
|
273
|
+
required_fields = [
|
|
274
|
+
"issuer",
|
|
275
|
+
"authorization_endpoint",
|
|
276
|
+
"token_endpoint",
|
|
277
|
+
"jwks_uri",
|
|
278
|
+
]
|
|
279
|
+
for field in required_fields:
|
|
280
|
+
if field not in data:
|
|
281
|
+
logger.error(
|
|
282
|
+
"Discovery document missing required field '%s' for provider=%s",
|
|
283
|
+
field,
|
|
284
|
+
provider_name,
|
|
285
|
+
)
|
|
286
|
+
raise ProviderDiscoveryError(
|
|
287
|
+
provider_name, f"Missing required field: {field}"
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
# Cache the discovery document
|
|
291
|
+
await self._cache.set(cache_key, data, self._discovery_ttl)
|
|
292
|
+
|
|
293
|
+
logger.info("Discovery document fetched and cached for provider=%s", provider_name)
|
|
294
|
+
|
|
295
|
+
return OIDCDiscoveryDocument.from_dict(data)
|
|
296
|
+
|
|
297
|
+
async def get_jwks(self, provider_name: str, *, force_refresh: bool = False) -> dict[str, Any]:
|
|
298
|
+
"""Get the JWKS for a provider.
|
|
299
|
+
|
|
300
|
+
Args:
|
|
301
|
+
provider_name: The provider name.
|
|
302
|
+
force_refresh: If True, bypass the cache and fetch fresh JWKS.
|
|
303
|
+
Useful for handling key rotation when a kid is not found.
|
|
304
|
+
|
|
305
|
+
Returns:
|
|
306
|
+
The JWKS as a dictionary.
|
|
307
|
+
|
|
308
|
+
Raises:
|
|
309
|
+
ProviderNotFoundError: If the provider is not registered.
|
|
310
|
+
JWKSFetchError: If fetching JWKS fails.
|
|
311
|
+
"""
|
|
312
|
+
config = self.get_provider(provider_name)
|
|
313
|
+
|
|
314
|
+
cache_key = f"{self.JWKS_CACHE_PREFIX}{provider_name}"
|
|
315
|
+
|
|
316
|
+
# Check cache first (unless force_refresh is requested)
|
|
317
|
+
if not force_refresh:
|
|
318
|
+
cached = await self._cache.get(cache_key)
|
|
319
|
+
if cached is not None:
|
|
320
|
+
logger.debug("Using cached JWKS for provider=%s", provider_name)
|
|
321
|
+
return cached
|
|
322
|
+
|
|
323
|
+
# Acquire per-provider lock to deduplicate concurrent requests
|
|
324
|
+
lock = await self._get_lock(self._jwks_locks, provider_name)
|
|
325
|
+
async with lock:
|
|
326
|
+
# Check cache again after acquiring lock (unless force_refresh)
|
|
327
|
+
if not force_refresh:
|
|
328
|
+
cached = await self._cache.get(cache_key)
|
|
329
|
+
if cached is not None:
|
|
330
|
+
logger.debug("Using cached JWKS for provider=%s", provider_name)
|
|
331
|
+
return cached
|
|
332
|
+
|
|
333
|
+
# Get JWKS URI from config override or discovery
|
|
334
|
+
if config.jwks_uri_override:
|
|
335
|
+
jwks_uri = config.jwks_uri_override
|
|
336
|
+
else:
|
|
337
|
+
discovery = await self.discover(provider_name)
|
|
338
|
+
jwks_uri = discovery.jwks_uri
|
|
339
|
+
|
|
340
|
+
logger.debug("Fetching JWKS from %s for provider=%s", jwks_uri, provider_name)
|
|
341
|
+
|
|
342
|
+
# Fetch JWKS
|
|
343
|
+
try:
|
|
344
|
+
client = await self._get_http_client()
|
|
345
|
+
response = await client.get(jwks_uri)
|
|
346
|
+
response.raise_for_status()
|
|
347
|
+
jwks = response.json()
|
|
348
|
+
except httpx.HTTPError as e:
|
|
349
|
+
logger.error("Failed to fetch JWKS for provider=%s: %s", provider_name, str(e))
|
|
350
|
+
raise JWKSFetchError(provider_name, str(e)) from e
|
|
351
|
+
except Exception as e:
|
|
352
|
+
logger.error("Unexpected error fetching JWKS for provider=%s: %s", provider_name, str(e))
|
|
353
|
+
raise JWKSFetchError(provider_name, str(e)) from e
|
|
354
|
+
|
|
355
|
+
# Validate JWKS structure
|
|
356
|
+
if "keys" not in jwks:
|
|
357
|
+
logger.error("Invalid JWKS structure for provider=%s: missing 'keys' field", provider_name)
|
|
358
|
+
raise JWKSFetchError(provider_name, "Invalid JWKS: missing 'keys' field")
|
|
359
|
+
|
|
360
|
+
# Cache the JWKS
|
|
361
|
+
await self._cache.set(cache_key, jwks, self._jwks_ttl)
|
|
362
|
+
|
|
363
|
+
logger.info("JWKS fetched and cached for provider=%s", provider_name)
|
|
364
|
+
|
|
365
|
+
return jwks
|
|
366
|
+
|
|
367
|
+
async def get_authorization_endpoint(self, provider_name: str) -> str:
|
|
368
|
+
"""Get the authorization endpoint for a provider.
|
|
369
|
+
|
|
370
|
+
Args:
|
|
371
|
+
provider_name: The provider name.
|
|
372
|
+
|
|
373
|
+
Returns:
|
|
374
|
+
The authorization endpoint URL.
|
|
375
|
+
"""
|
|
376
|
+
config = self.get_provider(provider_name)
|
|
377
|
+
if config.authorization_endpoint_override:
|
|
378
|
+
return config.authorization_endpoint_override
|
|
379
|
+
|
|
380
|
+
discovery = await self.discover(provider_name)
|
|
381
|
+
return discovery.authorization_endpoint
|
|
382
|
+
|
|
383
|
+
async def get_token_endpoint(self, provider_name: str) -> str:
|
|
384
|
+
"""Get the token endpoint for a provider.
|
|
385
|
+
|
|
386
|
+
Args:
|
|
387
|
+
provider_name: The provider name.
|
|
388
|
+
|
|
389
|
+
Returns:
|
|
390
|
+
The token endpoint URL.
|
|
391
|
+
"""
|
|
392
|
+
config = self.get_provider(provider_name)
|
|
393
|
+
if config.token_endpoint_override:
|
|
394
|
+
return config.token_endpoint_override
|
|
395
|
+
|
|
396
|
+
discovery = await self.discover(provider_name)
|
|
397
|
+
return discovery.token_endpoint
|
|
398
|
+
|
|
399
|
+
async def get_userinfo_endpoint(self, provider_name: str) -> str | None:
|
|
400
|
+
"""Get the userinfo endpoint for a provider.
|
|
401
|
+
|
|
402
|
+
Args:
|
|
403
|
+
provider_name: The provider name.
|
|
404
|
+
|
|
405
|
+
Returns:
|
|
406
|
+
The userinfo endpoint URL, or None if not available.
|
|
407
|
+
"""
|
|
408
|
+
config = self.get_provider(provider_name)
|
|
409
|
+
if config.userinfo_endpoint_override:
|
|
410
|
+
return config.userinfo_endpoint_override
|
|
411
|
+
|
|
412
|
+
discovery = await self.discover(provider_name)
|
|
413
|
+
return discovery.userinfo_endpoint
|
|
414
|
+
|
|
415
|
+
async def invalidate_cache(self, provider_name: str | None = None) -> None:
|
|
416
|
+
"""Invalidate cached data for a provider or all providers.
|
|
417
|
+
|
|
418
|
+
Args:
|
|
419
|
+
provider_name: The provider name, or None to invalidate all.
|
|
420
|
+
"""
|
|
421
|
+
if provider_name:
|
|
422
|
+
await self._cache.delete(f"{self.DISCOVERY_CACHE_PREFIX}{provider_name}")
|
|
423
|
+
await self._cache.delete(f"{self.JWKS_CACHE_PREFIX}{provider_name}")
|
|
424
|
+
else:
|
|
425
|
+
# Invalidate all providers
|
|
426
|
+
for name in self._providers:
|
|
427
|
+
await self._cache.delete(f"{self.DISCOVERY_CACHE_PREFIX}{name}")
|
|
428
|
+
await self._cache.delete(f"{self.JWKS_CACHE_PREFIX}{name}")
|
|
429
|
+
|
|
430
|
+
async def close(self) -> None:
|
|
431
|
+
"""Close the provider manager and release resources."""
|
|
432
|
+
if self._owns_http_client and self._http_client is not None:
|
|
433
|
+
await self._http_client.aclose()
|
|
434
|
+
self._http_client = None
|
byoi/py.typed
ADDED
|
File without changes
|
byoi/repositories.py
ADDED
|
@@ -0,0 +1,252 @@
|
|
|
1
|
+
"""Repository protocols for the BYOI library.
|
|
2
|
+
|
|
3
|
+
These protocols define the interfaces that implementing applications must satisfy
|
|
4
|
+
for data persistence. The implementing application is responsible for providing
|
|
5
|
+
concrete implementations backed by their database of choice.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from datetime import datetime
|
|
9
|
+
from typing import Any, Protocol, runtime_checkable
|
|
10
|
+
|
|
11
|
+
from byoi.types import AuthStateProtocol, LinkedIdentityProtocol, UserProtocol
|
|
12
|
+
|
|
13
|
+
__all__ = (
|
|
14
|
+
"UserRepositoryProtocol",
|
|
15
|
+
"LinkedIdentityRepositoryProtocol",
|
|
16
|
+
"AuthStateRepositoryProtocol",
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@runtime_checkable
|
|
21
|
+
class UserRepositoryProtocol(Protocol):
|
|
22
|
+
"""Protocol for user data persistence.
|
|
23
|
+
|
|
24
|
+
Implementing applications must provide a repository that satisfies this protocol
|
|
25
|
+
for managing user entities.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
async def get_by_id(self, user_id: str) -> UserProtocol | None:
|
|
29
|
+
"""Retrieve a user by their unique identifier.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
user_id: The unique identifier of the user.
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
The user if found, None otherwise.
|
|
36
|
+
"""
|
|
37
|
+
...
|
|
38
|
+
|
|
39
|
+
async def get_by_email(self, email: str) -> UserProtocol | None:
|
|
40
|
+
"""Retrieve a user by their email address.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
email: The email address to search for.
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
The user if found, None otherwise.
|
|
47
|
+
"""
|
|
48
|
+
...
|
|
49
|
+
|
|
50
|
+
async def create(
|
|
51
|
+
self,
|
|
52
|
+
email: str | None = None,
|
|
53
|
+
is_active: bool = True,
|
|
54
|
+
**extra_data: Any,
|
|
55
|
+
) -> UserProtocol:
|
|
56
|
+
"""Create a new user.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
email: The user's email address.
|
|
60
|
+
is_active: Whether the user account is active.
|
|
61
|
+
**extra_data: Additional data to store with the user.
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
The newly created user.
|
|
65
|
+
"""
|
|
66
|
+
...
|
|
67
|
+
|
|
68
|
+
async def update(self, user_id: str, **data: Any) -> UserProtocol | None:
|
|
69
|
+
"""Update a user's data.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
user_id: The unique identifier of the user to update.
|
|
73
|
+
**data: The data to update.
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
The updated user if found, None otherwise.
|
|
77
|
+
"""
|
|
78
|
+
...
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
@runtime_checkable
|
|
82
|
+
class LinkedIdentityRepositoryProtocol(Protocol):
|
|
83
|
+
"""Protocol for linked identity data persistence.
|
|
84
|
+
|
|
85
|
+
Implementing applications must provide a repository that satisfies this protocol
|
|
86
|
+
for managing linked identity entities.
|
|
87
|
+
"""
|
|
88
|
+
|
|
89
|
+
async def get_by_id(self, identity_id: str) -> LinkedIdentityProtocol | None:
|
|
90
|
+
"""Retrieve a linked identity by its unique identifier.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
identity_id: The unique identifier of the linked identity.
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
The linked identity if found, None otherwise.
|
|
97
|
+
"""
|
|
98
|
+
...
|
|
99
|
+
|
|
100
|
+
async def get_by_provider(
|
|
101
|
+
self,
|
|
102
|
+
provider_name: str,
|
|
103
|
+
provider_subject: str,
|
|
104
|
+
) -> LinkedIdentityProtocol | None:
|
|
105
|
+
"""Retrieve a linked identity by provider name and subject.
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
provider_name: The name of the identity provider.
|
|
109
|
+
provider_subject: The subject identifier from the provider.
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
The linked identity if found, None otherwise.
|
|
113
|
+
"""
|
|
114
|
+
...
|
|
115
|
+
|
|
116
|
+
async def get_by_user_id(self, user_id: str) -> list[LinkedIdentityProtocol]:
|
|
117
|
+
"""Retrieve all linked identities for a user.
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
user_id: The unique identifier of the user.
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
A list of linked identities for the user.
|
|
124
|
+
"""
|
|
125
|
+
...
|
|
126
|
+
|
|
127
|
+
async def create(
|
|
128
|
+
self,
|
|
129
|
+
user_id: str,
|
|
130
|
+
provider_name: str,
|
|
131
|
+
provider_subject: str,
|
|
132
|
+
email: str | None = None,
|
|
133
|
+
**extra_data: Any,
|
|
134
|
+
) -> LinkedIdentityProtocol:
|
|
135
|
+
"""Create a new linked identity.
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
user_id: The ID of the user to link.
|
|
139
|
+
provider_name: The name of the identity provider.
|
|
140
|
+
provider_subject: The subject identifier from the provider.
|
|
141
|
+
email: The email associated with this identity.
|
|
142
|
+
**extra_data: Additional data to store with the identity.
|
|
143
|
+
|
|
144
|
+
Returns:
|
|
145
|
+
The newly created linked identity.
|
|
146
|
+
"""
|
|
147
|
+
...
|
|
148
|
+
|
|
149
|
+
async def update_last_used(
|
|
150
|
+
self,
|
|
151
|
+
identity_id: str,
|
|
152
|
+
last_used_at: datetime,
|
|
153
|
+
) -> LinkedIdentityProtocol | None:
|
|
154
|
+
"""Update the last used timestamp for a linked identity.
|
|
155
|
+
|
|
156
|
+
Args:
|
|
157
|
+
identity_id: The unique identifier of the linked identity.
|
|
158
|
+
last_used_at: The timestamp to set.
|
|
159
|
+
|
|
160
|
+
Returns:
|
|
161
|
+
The updated linked identity if found, None otherwise.
|
|
162
|
+
"""
|
|
163
|
+
...
|
|
164
|
+
|
|
165
|
+
async def delete(self, identity_id: str) -> bool:
|
|
166
|
+
"""Delete a linked identity.
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
identity_id: The unique identifier of the linked identity to delete.
|
|
170
|
+
|
|
171
|
+
Returns:
|
|
172
|
+
True if the identity was deleted, False if not found.
|
|
173
|
+
"""
|
|
174
|
+
...
|
|
175
|
+
|
|
176
|
+
async def delete_by_user_id(self, user_id: str) -> int:
|
|
177
|
+
"""Delete all linked identities for a user.
|
|
178
|
+
|
|
179
|
+
Args:
|
|
180
|
+
user_id: The unique identifier of the user.
|
|
181
|
+
|
|
182
|
+
Returns:
|
|
183
|
+
The number of identities deleted.
|
|
184
|
+
"""
|
|
185
|
+
...
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
@runtime_checkable
|
|
189
|
+
class AuthStateRepositoryProtocol(Protocol):
|
|
190
|
+
"""Protocol for auth state data persistence.
|
|
191
|
+
|
|
192
|
+
Implementing applications must provide a repository that satisfies this protocol
|
|
193
|
+
for managing OAuth flow state (PKCE, nonce, etc.).
|
|
194
|
+
"""
|
|
195
|
+
|
|
196
|
+
async def create(
|
|
197
|
+
self,
|
|
198
|
+
state: str,
|
|
199
|
+
code_verifier: str,
|
|
200
|
+
nonce: str,
|
|
201
|
+
provider_name: str,
|
|
202
|
+
redirect_uri: str,
|
|
203
|
+
expires_at: datetime,
|
|
204
|
+
client_type: str = "web",
|
|
205
|
+
extra_data: dict[str, Any] | None = None,
|
|
206
|
+
) -> AuthStateProtocol:
|
|
207
|
+
"""Create a new auth state.
|
|
208
|
+
|
|
209
|
+
Args:
|
|
210
|
+
state: The OAuth state parameter.
|
|
211
|
+
code_verifier: The PKCE code verifier.
|
|
212
|
+
nonce: The nonce for ID token validation.
|
|
213
|
+
provider_name: The name of the provider.
|
|
214
|
+
redirect_uri: The redirect URI for this auth flow.
|
|
215
|
+
expires_at: When this auth state expires.
|
|
216
|
+
client_type: Type of client (web, desktop, mobile).
|
|
217
|
+
extra_data: Additional data to store.
|
|
218
|
+
|
|
219
|
+
Returns:
|
|
220
|
+
The newly created auth state.
|
|
221
|
+
"""
|
|
222
|
+
...
|
|
223
|
+
|
|
224
|
+
async def get_by_state(self, state: str) -> AuthStateProtocol | None:
|
|
225
|
+
"""Retrieve an auth state by its state parameter.
|
|
226
|
+
|
|
227
|
+
Args:
|
|
228
|
+
state: The OAuth state parameter.
|
|
229
|
+
|
|
230
|
+
Returns:
|
|
231
|
+
The auth state if found, None otherwise.
|
|
232
|
+
"""
|
|
233
|
+
...
|
|
234
|
+
|
|
235
|
+
async def delete(self, state: str) -> bool:
|
|
236
|
+
"""Delete an auth state.
|
|
237
|
+
|
|
238
|
+
Args:
|
|
239
|
+
state: The OAuth state parameter.
|
|
240
|
+
|
|
241
|
+
Returns:
|
|
242
|
+
True if the state was deleted, False if not found.
|
|
243
|
+
"""
|
|
244
|
+
...
|
|
245
|
+
|
|
246
|
+
async def delete_expired(self) -> int:
|
|
247
|
+
"""Delete all expired auth states.
|
|
248
|
+
|
|
249
|
+
Returns:
|
|
250
|
+
The number of states deleted.
|
|
251
|
+
"""
|
|
252
|
+
...
|