byoi 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- byoi/__init__.py +233 -0
- byoi/__main__.py +228 -0
- byoi/cache.py +346 -0
- byoi/config.py +349 -0
- byoi/dependencies.py +144 -0
- byoi/errors.py +360 -0
- byoi/models.py +451 -0
- byoi/pkce.py +144 -0
- byoi/providers.py +434 -0
- byoi/py.typed +0 -0
- byoi/repositories.py +252 -0
- byoi/service.py +723 -0
- byoi/telemetry.py +352 -0
- byoi/tokens.py +340 -0
- byoi/types.py +130 -0
- byoi-0.1.0a1.dist-info/METADATA +504 -0
- byoi-0.1.0a1.dist-info/RECORD +20 -0
- byoi-0.1.0a1.dist-info/WHEEL +4 -0
- byoi-0.1.0a1.dist-info/entry_points.txt +3 -0
- byoi-0.1.0a1.dist-info/licenses/LICENSE +21 -0
byoi/cache.py
ADDED
|
@@ -0,0 +1,346 @@
|
|
|
1
|
+
"""Cache implementations for the BYOI library.
|
|
2
|
+
|
|
3
|
+
Provides caching for JWKS and OIDC discovery documents.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import asyncio
|
|
9
|
+
import json
|
|
10
|
+
import logging
|
|
11
|
+
import time
|
|
12
|
+
import weakref
|
|
13
|
+
from abc import ABC, abstractmethod
|
|
14
|
+
from typing import TYPE_CHECKING, Any, Self
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from types import TracebackType
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger("byoi.cache")
|
|
20
|
+
|
|
21
|
+
# Track all InMemoryCache instances for cleanup on interpreter shutdown
|
|
22
|
+
_active_caches: weakref.WeakSet[InMemoryCache] = weakref.WeakSet()
|
|
23
|
+
|
|
24
|
+
__all__ = (
|
|
25
|
+
"CacheProtocol",
|
|
26
|
+
"InMemoryCache",
|
|
27
|
+
"RedisCache",
|
|
28
|
+
"NullCache",
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class CacheProtocol(ABC):
|
|
33
|
+
"""Abstract base class for cache implementations."""
|
|
34
|
+
|
|
35
|
+
@abstractmethod
|
|
36
|
+
async def get(self, key: str) -> Any | None:
|
|
37
|
+
"""Get a value from the cache.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
key: The cache key.
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
The cached value, or None if not found or expired.
|
|
44
|
+
"""
|
|
45
|
+
...
|
|
46
|
+
|
|
47
|
+
@abstractmethod
|
|
48
|
+
async def set(self, key: str, value: Any, ttl_seconds: int | None = None) -> None:
|
|
49
|
+
"""Set a value in the cache.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
key: The cache key.
|
|
53
|
+
value: The value to cache (must be JSON-serializable).
|
|
54
|
+
ttl_seconds: Time-to-live in seconds. None means no expiration.
|
|
55
|
+
"""
|
|
56
|
+
...
|
|
57
|
+
|
|
58
|
+
@abstractmethod
|
|
59
|
+
async def delete(self, key: str) -> bool:
|
|
60
|
+
"""Delete a value from the cache.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
key: The cache key.
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
True if the key was deleted, False if not found.
|
|
67
|
+
"""
|
|
68
|
+
...
|
|
69
|
+
|
|
70
|
+
@abstractmethod
|
|
71
|
+
async def clear(self) -> None:
|
|
72
|
+
"""Clear all values from the cache."""
|
|
73
|
+
...
|
|
74
|
+
|
|
75
|
+
@abstractmethod
|
|
76
|
+
async def close(self) -> None:
|
|
77
|
+
"""Close the cache connection and release resources."""
|
|
78
|
+
...
|
|
79
|
+
|
|
80
|
+
async def __aenter__(self) -> Self:
|
|
81
|
+
"""Enter the async context manager."""
|
|
82
|
+
return self
|
|
83
|
+
|
|
84
|
+
async def __aexit__(
|
|
85
|
+
self,
|
|
86
|
+
exc_type: type[BaseException] | None,
|
|
87
|
+
exc_val: BaseException | None,
|
|
88
|
+
exc_tb: TracebackType | None,
|
|
89
|
+
) -> None:
|
|
90
|
+
"""Exit the async context manager and close the cache."""
|
|
91
|
+
await self.close()
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class NullCache(CacheProtocol):
|
|
95
|
+
"""A no-op cache implementation that doesn't cache anything.
|
|
96
|
+
|
|
97
|
+
Useful for testing or when caching is not desired.
|
|
98
|
+
"""
|
|
99
|
+
|
|
100
|
+
async def get(self, key: str) -> Any | None:
|
|
101
|
+
"""Always returns None."""
|
|
102
|
+
return None
|
|
103
|
+
|
|
104
|
+
async def set(self, key: str, value: Any, ttl_seconds: int | None = None) -> None:
|
|
105
|
+
"""Does nothing."""
|
|
106
|
+
pass
|
|
107
|
+
|
|
108
|
+
async def delete(self, key: str) -> bool:
|
|
109
|
+
"""Always returns False."""
|
|
110
|
+
return False
|
|
111
|
+
|
|
112
|
+
async def clear(self) -> None:
|
|
113
|
+
"""Does nothing."""
|
|
114
|
+
pass
|
|
115
|
+
|
|
116
|
+
async def close(self) -> None:
|
|
117
|
+
"""Does nothing."""
|
|
118
|
+
pass
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class InMemoryCache(CacheProtocol):
|
|
122
|
+
"""In-memory cache implementation using a dictionary.
|
|
123
|
+
|
|
124
|
+
Suitable for single-process deployments or development.
|
|
125
|
+
|
|
126
|
+
Note: This cache is not shared across processes. For multi-process
|
|
127
|
+
deployments, use RedisCache instead.
|
|
128
|
+
"""
|
|
129
|
+
|
|
130
|
+
def __init__(self, cleanup_interval_seconds: int = 300) -> None:
|
|
131
|
+
"""Initialize the in-memory cache.
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
cleanup_interval_seconds: How often to run cleanup of expired entries.
|
|
135
|
+
"""
|
|
136
|
+
self._cache: dict[str, tuple[Any, float | None]] = {}
|
|
137
|
+
self._lock = asyncio.Lock()
|
|
138
|
+
self._cleanup_interval = cleanup_interval_seconds
|
|
139
|
+
self._cleanup_task: asyncio.Task[None] | None = None
|
|
140
|
+
self._closed = False
|
|
141
|
+
self._started = False
|
|
142
|
+
# Track this instance for cleanup on interpreter shutdown
|
|
143
|
+
_active_caches.add(self)
|
|
144
|
+
|
|
145
|
+
def __del__(self) -> None:
|
|
146
|
+
"""Cancel cleanup task on garbage collection if not properly closed."""
|
|
147
|
+
if self._cleanup_task is not None and not self._cleanup_task.done():
|
|
148
|
+
self._cleanup_task.cancel()
|
|
149
|
+
self._cleanup_task = None
|
|
150
|
+
self._closed = True
|
|
151
|
+
|
|
152
|
+
async def start(self) -> None:
|
|
153
|
+
"""Start the background cleanup task.
|
|
154
|
+
|
|
155
|
+
This method is idempotent - calling it multiple times has no effect.
|
|
156
|
+
It is also called automatically on first cache operation.
|
|
157
|
+
"""
|
|
158
|
+
if self._started or self._closed:
|
|
159
|
+
return
|
|
160
|
+
self._started = True
|
|
161
|
+
if self._cleanup_task is None:
|
|
162
|
+
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
|
|
163
|
+
|
|
164
|
+
async def _ensure_started(self) -> None:
|
|
165
|
+
"""Ensure the cache cleanup task is started."""
|
|
166
|
+
if not self._started and not self._closed:
|
|
167
|
+
await self.start()
|
|
168
|
+
|
|
169
|
+
async def _cleanup_loop(self) -> None:
|
|
170
|
+
"""Background task to periodically clean up expired entries."""
|
|
171
|
+
while not self._closed:
|
|
172
|
+
try:
|
|
173
|
+
await asyncio.sleep(self._cleanup_interval)
|
|
174
|
+
await self._cleanup_expired()
|
|
175
|
+
except asyncio.CancelledError:
|
|
176
|
+
break
|
|
177
|
+
except Exception as e:
|
|
178
|
+
# Log error but continue cleanup loop
|
|
179
|
+
logger.warning("Error during cache cleanup: %s", e)
|
|
180
|
+
|
|
181
|
+
async def _cleanup_expired(self) -> None:
|
|
182
|
+
"""Remove expired entries from the cache."""
|
|
183
|
+
now = time.time()
|
|
184
|
+
async with self._lock:
|
|
185
|
+
expired_keys = [
|
|
186
|
+
key
|
|
187
|
+
for key, (_, expiry) in self._cache.items()
|
|
188
|
+
if expiry is not None and expiry < now
|
|
189
|
+
]
|
|
190
|
+
for key in expired_keys:
|
|
191
|
+
del self._cache[key]
|
|
192
|
+
|
|
193
|
+
async def get(self, key: str) -> Any | None:
|
|
194
|
+
"""Get a value from the cache."""
|
|
195
|
+
await self._ensure_started()
|
|
196
|
+
async with self._lock:
|
|
197
|
+
if key not in self._cache:
|
|
198
|
+
return None
|
|
199
|
+
|
|
200
|
+
value, expiry = self._cache[key]
|
|
201
|
+
if expiry is not None and expiry < time.time():
|
|
202
|
+
del self._cache[key]
|
|
203
|
+
return None
|
|
204
|
+
|
|
205
|
+
return value
|
|
206
|
+
|
|
207
|
+
async def set(self, key: str, value: Any, ttl_seconds: int | None = None) -> None:
|
|
208
|
+
"""Set a value in the cache."""
|
|
209
|
+
await self._ensure_started()
|
|
210
|
+
expiry = time.time() + ttl_seconds if ttl_seconds is not None else None
|
|
211
|
+
async with self._lock:
|
|
212
|
+
self._cache[key] = (value, expiry)
|
|
213
|
+
|
|
214
|
+
async def delete(self, key: str) -> bool:
|
|
215
|
+
"""Delete a value from the cache."""
|
|
216
|
+
await self._ensure_started()
|
|
217
|
+
async with self._lock:
|
|
218
|
+
if key in self._cache:
|
|
219
|
+
del self._cache[key]
|
|
220
|
+
return True
|
|
221
|
+
return False
|
|
222
|
+
|
|
223
|
+
async def clear(self) -> None:
|
|
224
|
+
"""Clear all values from the cache."""
|
|
225
|
+
await self._ensure_started()
|
|
226
|
+
async with self._lock:
|
|
227
|
+
self._cache.clear()
|
|
228
|
+
|
|
229
|
+
async def close(self) -> None:
|
|
230
|
+
"""Stop the cleanup task and close the cache."""
|
|
231
|
+
self._closed = True
|
|
232
|
+
if self._cleanup_task is not None:
|
|
233
|
+
self._cleanup_task.cancel()
|
|
234
|
+
try:
|
|
235
|
+
await self._cleanup_task
|
|
236
|
+
except asyncio.CancelledError:
|
|
237
|
+
pass
|
|
238
|
+
self._cleanup_task = None
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
class RedisCache(CacheProtocol):
|
|
242
|
+
"""Redis-based cache implementation.
|
|
243
|
+
|
|
244
|
+
Suitable for multi-process and distributed deployments.
|
|
245
|
+
|
|
246
|
+
Requires the `redis` package to be installed:
|
|
247
|
+
pip install redis
|
|
248
|
+
"""
|
|
249
|
+
|
|
250
|
+
def __init__(
|
|
251
|
+
self,
|
|
252
|
+
url: str = "redis://localhost:6379/0",
|
|
253
|
+
key_prefix: str = "byoi:",
|
|
254
|
+
max_connections: int | None = None,
|
|
255
|
+
**redis_kwargs: Any,
|
|
256
|
+
) -> None:
|
|
257
|
+
"""Initialize the Redis cache.
|
|
258
|
+
|
|
259
|
+
Args:
|
|
260
|
+
url: Redis connection URL.
|
|
261
|
+
key_prefix: Prefix for all cache keys.
|
|
262
|
+
max_connections: Maximum number of connections in the pool.
|
|
263
|
+
If None, uses the redis library default.
|
|
264
|
+
**redis_kwargs: Additional arguments passed to Redis client.
|
|
265
|
+
"""
|
|
266
|
+
self._url = url
|
|
267
|
+
self._key_prefix = key_prefix
|
|
268
|
+
self._max_connections = max_connections
|
|
269
|
+
self._redis_kwargs = redis_kwargs
|
|
270
|
+
self._client: Any = None
|
|
271
|
+
self._closed = False
|
|
272
|
+
|
|
273
|
+
async def _get_client(self) -> Any:
|
|
274
|
+
"""Get or create the Redis client."""
|
|
275
|
+
if self._client is None:
|
|
276
|
+
try:
|
|
277
|
+
import redis.asyncio as redis
|
|
278
|
+
from redis.asyncio.connection import ConnectionPool
|
|
279
|
+
except ImportError as e:
|
|
280
|
+
raise ImportError(
|
|
281
|
+
"Redis support requires the 'redis' package. "
|
|
282
|
+
"Install it with: pip install redis"
|
|
283
|
+
) from e
|
|
284
|
+
|
|
285
|
+
# Create connection pool with max_connections if specified
|
|
286
|
+
pool_kwargs: dict[str, Any] = {}
|
|
287
|
+
if self._max_connections is not None:
|
|
288
|
+
pool_kwargs["max_connections"] = self._max_connections
|
|
289
|
+
|
|
290
|
+
pool = ConnectionPool.from_url(self._url, **pool_kwargs, **self._redis_kwargs)
|
|
291
|
+
self._client = redis.Redis(connection_pool=pool)
|
|
292
|
+
|
|
293
|
+
return self._client
|
|
294
|
+
|
|
295
|
+
def _make_key(self, key: str) -> str:
|
|
296
|
+
"""Create a prefixed cache key."""
|
|
297
|
+
return f"{self._key_prefix}{key}"
|
|
298
|
+
|
|
299
|
+
async def get(self, key: str) -> Any | None:
|
|
300
|
+
"""Get a value from Redis."""
|
|
301
|
+
client = await self._get_client()
|
|
302
|
+
value = await client.get(self._make_key(key))
|
|
303
|
+
if value is None:
|
|
304
|
+
return None
|
|
305
|
+
try:
|
|
306
|
+
return json.loads(value)
|
|
307
|
+
except (json.JSONDecodeError, TypeError):
|
|
308
|
+
return value
|
|
309
|
+
|
|
310
|
+
async def set(self, key: str, value: Any, ttl_seconds: int | None = None) -> None:
|
|
311
|
+
"""Set a value in Redis."""
|
|
312
|
+
client = await self._get_client()
|
|
313
|
+
serialized = json.dumps(value)
|
|
314
|
+
if ttl_seconds is not None:
|
|
315
|
+
await client.setex(self._make_key(key), ttl_seconds, serialized)
|
|
316
|
+
else:
|
|
317
|
+
await client.set(self._make_key(key), serialized)
|
|
318
|
+
|
|
319
|
+
async def delete(self, key: str) -> bool:
|
|
320
|
+
"""Delete a value from Redis."""
|
|
321
|
+
client = await self._get_client()
|
|
322
|
+
result = await client.delete(self._make_key(key))
|
|
323
|
+
return result > 0
|
|
324
|
+
|
|
325
|
+
async def clear(self) -> None:
|
|
326
|
+
"""Clear all BYOI keys from Redis.
|
|
327
|
+
|
|
328
|
+
Warning: This uses SCAN to find keys with the prefix and deletes them.
|
|
329
|
+
For large datasets, this may be slow.
|
|
330
|
+
"""
|
|
331
|
+
client = await self._get_client()
|
|
332
|
+
pattern = f"{self._key_prefix}*"
|
|
333
|
+
cursor = 0
|
|
334
|
+
while True:
|
|
335
|
+
cursor, keys = await client.scan(cursor, match=pattern, count=100)
|
|
336
|
+
if keys:
|
|
337
|
+
await client.delete(*keys)
|
|
338
|
+
if cursor == 0:
|
|
339
|
+
break
|
|
340
|
+
|
|
341
|
+
async def close(self) -> None:
|
|
342
|
+
"""Close the Redis connection."""
|
|
343
|
+
self._closed = True
|
|
344
|
+
if self._client is not None:
|
|
345
|
+
await self._client.close()
|
|
346
|
+
self._client = None
|
byoi/config.py
ADDED
|
@@ -0,0 +1,349 @@
|
|
|
1
|
+
"""BYOI - Bring Your Own Identity.
|
|
2
|
+
|
|
3
|
+
Main configuration and lifecycle management for the BYOI library.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from collections.abc import AsyncGenerator
|
|
7
|
+
from contextlib import asynccontextmanager
|
|
8
|
+
from typing import TYPE_CHECKING, Any
|
|
9
|
+
|
|
10
|
+
import httpx
|
|
11
|
+
from fastapi import FastAPI
|
|
12
|
+
from pydantic import BaseModel, ConfigDict, Field
|
|
13
|
+
|
|
14
|
+
from byoi.cache import CacheProtocol, InMemoryCache
|
|
15
|
+
from byoi.errors import ConfigurationError
|
|
16
|
+
from byoi.models import OIDCProviderConfig
|
|
17
|
+
from byoi.providers import ProviderManager
|
|
18
|
+
from byoi.service import AuthService
|
|
19
|
+
|
|
20
|
+
if TYPE_CHECKING:
|
|
21
|
+
from byoi.repositories import (
|
|
22
|
+
AuthStateRepositoryProtocol,
|
|
23
|
+
LinkedIdentityRepositoryProtocol,
|
|
24
|
+
UserRepositoryProtocol,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
__all__ = (
|
|
28
|
+
"BYOI",
|
|
29
|
+
"BYOIConfig",
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class BYOIConfig(BaseModel):
|
|
34
|
+
"""Configuration for the BYOI library.
|
|
35
|
+
|
|
36
|
+
This Pydantic model holds all configuration needed to initialize BYOI.
|
|
37
|
+
|
|
38
|
+
Attributes:
|
|
39
|
+
user_repository: Repository for user data persistence.
|
|
40
|
+
identity_repository: Repository for linked identity data persistence.
|
|
41
|
+
auth_state_repository: Repository for auth state (PKCE, nonce) persistence.
|
|
42
|
+
cache: Cache implementation for JWKS and discovery docs.
|
|
43
|
+
providers: List of OIDC provider configurations to register.
|
|
44
|
+
discovery_ttl_seconds: TTL for cached discovery documents.
|
|
45
|
+
jwks_ttl_seconds: TTL for cached JWKS.
|
|
46
|
+
state_expiration_minutes: How long auth states are valid.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
model_config = ConfigDict(
|
|
50
|
+
arbitrary_types_allowed=True,
|
|
51
|
+
frozen=False,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
user_repository: Any = Field(
|
|
55
|
+
...,
|
|
56
|
+
description="Repository for user data persistence (must implement UserRepositoryProtocol)",
|
|
57
|
+
)
|
|
58
|
+
identity_repository: Any = Field(
|
|
59
|
+
...,
|
|
60
|
+
description="Repository for linked identity data persistence (must implement LinkedIdentityRepositoryProtocol)",
|
|
61
|
+
)
|
|
62
|
+
auth_state_repository: Any = Field(
|
|
63
|
+
...,
|
|
64
|
+
description="Repository for auth state (PKCE, nonce) persistence (must implement AuthStateRepositoryProtocol)",
|
|
65
|
+
)
|
|
66
|
+
cache: CacheProtocol | None = Field(
|
|
67
|
+
default=None,
|
|
68
|
+
description="Cache implementation for JWKS and discovery docs. Defaults to InMemoryCache.",
|
|
69
|
+
)
|
|
70
|
+
providers: list[OIDCProviderConfig] = Field(
|
|
71
|
+
default_factory=list,
|
|
72
|
+
description="List of OIDC provider configurations to register.",
|
|
73
|
+
)
|
|
74
|
+
discovery_ttl_seconds: int = Field(
|
|
75
|
+
default=3600,
|
|
76
|
+
ge=0,
|
|
77
|
+
description="TTL for cached discovery documents in seconds.",
|
|
78
|
+
)
|
|
79
|
+
jwks_ttl_seconds: int = Field(
|
|
80
|
+
default=3600,
|
|
81
|
+
ge=0,
|
|
82
|
+
description="TTL for cached JWKS in seconds.",
|
|
83
|
+
)
|
|
84
|
+
state_expiration_minutes: int = Field(
|
|
85
|
+
default=10,
|
|
86
|
+
ge=1,
|
|
87
|
+
le=60,
|
|
88
|
+
description="How long auth states are valid in minutes.",
|
|
89
|
+
)
|
|
90
|
+
http_timeout_seconds: float = Field(
|
|
91
|
+
default=30.0,
|
|
92
|
+
ge=1.0,
|
|
93
|
+
le=120.0,
|
|
94
|
+
description="Timeout for HTTP requests to identity providers in seconds.",
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class BYOI:
|
|
99
|
+
"""Main BYOI class for configuration and lifecycle management.
|
|
100
|
+
|
|
101
|
+
This class provides a single point to initialize and configure BYOI,
|
|
102
|
+
and can be integrated with FastAPI's lifespan for proper resource
|
|
103
|
+
management.
|
|
104
|
+
|
|
105
|
+
Example:
|
|
106
|
+
```python
|
|
107
|
+
from contextlib import asynccontextmanager
|
|
108
|
+
from fastapi import FastAPI
|
|
109
|
+
from byoi import BYOI, BYOIConfig, OIDCProviderConfig
|
|
110
|
+
|
|
111
|
+
# Create your repository implementations
|
|
112
|
+
user_repo = MyUserRepository()
|
|
113
|
+
identity_repo = MyIdentityRepository()
|
|
114
|
+
auth_state_repo = MyAuthStateRepository()
|
|
115
|
+
|
|
116
|
+
# Configure BYOI
|
|
117
|
+
config = BYOIConfig(
|
|
118
|
+
user_repository=user_repo,
|
|
119
|
+
identity_repository=identity_repo,
|
|
120
|
+
auth_state_repository=auth_state_repo,
|
|
121
|
+
providers=[
|
|
122
|
+
OIDCProviderConfig(
|
|
123
|
+
name="google",
|
|
124
|
+
display_name="Google",
|
|
125
|
+
issuer="https://accounts.google.com",
|
|
126
|
+
client_id="your-client-id",
|
|
127
|
+
client_secret="your-client-secret",
|
|
128
|
+
),
|
|
129
|
+
],
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
byoi = BYOI(config)
|
|
133
|
+
|
|
134
|
+
@asynccontextmanager
|
|
135
|
+
async def lifespan(app: FastAPI):
|
|
136
|
+
await byoi.setup(app)
|
|
137
|
+
yield
|
|
138
|
+
await byoi.shutdown()
|
|
139
|
+
|
|
140
|
+
app = FastAPI(lifespan=lifespan)
|
|
141
|
+
```
|
|
142
|
+
"""
|
|
143
|
+
|
|
144
|
+
def __init__(self, config: BYOIConfig) -> None:
|
|
145
|
+
"""Initialize BYOI.
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
config: The BYOI configuration.
|
|
149
|
+
"""
|
|
150
|
+
self._config = config
|
|
151
|
+
self._provider_manager: ProviderManager | None = None
|
|
152
|
+
self._auth_service: AuthService | None = None
|
|
153
|
+
self._http_client: httpx.AsyncClient | None = None
|
|
154
|
+
self._cache: CacheProtocol | None = None
|
|
155
|
+
self._is_setup = False
|
|
156
|
+
|
|
157
|
+
@property
|
|
158
|
+
def provider_manager(self) -> ProviderManager:
|
|
159
|
+
"""Get the provider manager.
|
|
160
|
+
|
|
161
|
+
Raises:
|
|
162
|
+
RuntimeError: If BYOI is not set up.
|
|
163
|
+
"""
|
|
164
|
+
if self._provider_manager is None:
|
|
165
|
+
raise RuntimeError("BYOI is not set up. Call setup() first.")
|
|
166
|
+
return self._provider_manager
|
|
167
|
+
|
|
168
|
+
@property
|
|
169
|
+
def auth_service(self) -> AuthService:
|
|
170
|
+
"""Get the auth service.
|
|
171
|
+
|
|
172
|
+
Raises:
|
|
173
|
+
RuntimeError: If BYOI is not set up.
|
|
174
|
+
"""
|
|
175
|
+
if self._auth_service is None:
|
|
176
|
+
raise RuntimeError("BYOI is not set up. Call setup() first.")
|
|
177
|
+
return self._auth_service
|
|
178
|
+
|
|
179
|
+
async def setup(self, app: FastAPI | None = None) -> None:
|
|
180
|
+
"""Set up BYOI and register with FastAPI app.
|
|
181
|
+
|
|
182
|
+
This method initializes all BYOI components and stores them
|
|
183
|
+
in the FastAPI app state for dependency injection.
|
|
184
|
+
|
|
185
|
+
Args:
|
|
186
|
+
app: Optional FastAPI application to register with.
|
|
187
|
+
|
|
188
|
+
Raises:
|
|
189
|
+
ConfigurationError: If configuration is invalid.
|
|
190
|
+
"""
|
|
191
|
+
if self._is_setup:
|
|
192
|
+
return
|
|
193
|
+
|
|
194
|
+
# Create HTTP client
|
|
195
|
+
self._http_client = httpx.AsyncClient(
|
|
196
|
+
timeout=httpx.Timeout(self._config.http_timeout_seconds),
|
|
197
|
+
follow_redirects=True,
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
# Set up cache
|
|
201
|
+
if self._config.cache is not None:
|
|
202
|
+
self._cache = self._config.cache
|
|
203
|
+
else:
|
|
204
|
+
cache = InMemoryCache()
|
|
205
|
+
await cache.start()
|
|
206
|
+
self._cache = cache
|
|
207
|
+
|
|
208
|
+
# Create provider manager
|
|
209
|
+
self._provider_manager = ProviderManager(
|
|
210
|
+
cache=self._cache,
|
|
211
|
+
http_client=self._http_client,
|
|
212
|
+
discovery_ttl_seconds=self._config.discovery_ttl_seconds,
|
|
213
|
+
jwks_ttl_seconds=self._config.jwks_ttl_seconds,
|
|
214
|
+
http_timeout_seconds=self._config.http_timeout_seconds,
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
# Register providers
|
|
218
|
+
for provider_config in self._config.providers:
|
|
219
|
+
try:
|
|
220
|
+
self._provider_manager.register_provider(provider_config)
|
|
221
|
+
except ValueError as e:
|
|
222
|
+
raise ConfigurationError(str(e)) from e
|
|
223
|
+
|
|
224
|
+
# Create auth service
|
|
225
|
+
self._auth_service = AuthService(
|
|
226
|
+
provider_manager=self._provider_manager,
|
|
227
|
+
user_repository=self._config.user_repository,
|
|
228
|
+
identity_repository=self._config.identity_repository,
|
|
229
|
+
auth_state_repository=self._config.auth_state_repository,
|
|
230
|
+
state_expiration_minutes=self._config.state_expiration_minutes,
|
|
231
|
+
http_client=self._http_client,
|
|
232
|
+
http_timeout_seconds=self._config.http_timeout_seconds,
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
# Register with FastAPI app
|
|
236
|
+
if app is not None:
|
|
237
|
+
app.state.byoi_provider_manager = self._provider_manager
|
|
238
|
+
app.state.byoi_auth_service = self._auth_service
|
|
239
|
+
app.state.byoi = self
|
|
240
|
+
|
|
241
|
+
self._is_setup = True
|
|
242
|
+
|
|
243
|
+
async def shutdown(self) -> None:
|
|
244
|
+
"""Shut down BYOI and release resources.
|
|
245
|
+
|
|
246
|
+
This should be called when the application is shutting down.
|
|
247
|
+
"""
|
|
248
|
+
if not self._is_setup:
|
|
249
|
+
return
|
|
250
|
+
|
|
251
|
+
# Close auth service
|
|
252
|
+
if self._auth_service is not None:
|
|
253
|
+
await self._auth_service.close()
|
|
254
|
+
self._auth_service = None
|
|
255
|
+
|
|
256
|
+
# Close provider manager
|
|
257
|
+
if self._provider_manager is not None:
|
|
258
|
+
await self._provider_manager.close()
|
|
259
|
+
self._provider_manager = None
|
|
260
|
+
|
|
261
|
+
# Close HTTP client
|
|
262
|
+
if self._http_client is not None:
|
|
263
|
+
await self._http_client.aclose()
|
|
264
|
+
self._http_client = None
|
|
265
|
+
|
|
266
|
+
# Close cache
|
|
267
|
+
if self._cache is not None:
|
|
268
|
+
await self._cache.close()
|
|
269
|
+
self._cache = None
|
|
270
|
+
|
|
271
|
+
self._is_setup = False
|
|
272
|
+
|
|
273
|
+
def add_provider(self, config: OIDCProviderConfig) -> None:
|
|
274
|
+
"""Add a provider after setup.
|
|
275
|
+
|
|
276
|
+
Args:
|
|
277
|
+
config: The provider configuration.
|
|
278
|
+
|
|
279
|
+
Raises:
|
|
280
|
+
RuntimeError: If BYOI is not set up.
|
|
281
|
+
ValueError: If the provider is already registered.
|
|
282
|
+
"""
|
|
283
|
+
if not self._is_setup:
|
|
284
|
+
raise RuntimeError("BYOI is not set up. Call setup() first.")
|
|
285
|
+
self._provider_manager.register_provider(config)
|
|
286
|
+
|
|
287
|
+
def remove_provider(self, name: str) -> bool:
|
|
288
|
+
"""Remove a provider after setup.
|
|
289
|
+
|
|
290
|
+
Args:
|
|
291
|
+
name: The provider name.
|
|
292
|
+
|
|
293
|
+
Returns:
|
|
294
|
+
True if the provider was removed, False if not found.
|
|
295
|
+
|
|
296
|
+
Raises:
|
|
297
|
+
RuntimeError: If BYOI is not set up.
|
|
298
|
+
"""
|
|
299
|
+
if not self._is_setup:
|
|
300
|
+
raise RuntimeError("BYOI is not set up. Call setup() first.")
|
|
301
|
+
return self._provider_manager.unregister_provider(name)
|
|
302
|
+
|
|
303
|
+
@asynccontextmanager
|
|
304
|
+
async def lifespan(
|
|
305
|
+
self,
|
|
306
|
+
app: FastAPI,
|
|
307
|
+
) -> AsyncGenerator[None, None]:
|
|
308
|
+
"""FastAPI lifespan context manager for BYOI.
|
|
309
|
+
|
|
310
|
+
Use this as your FastAPI lifespan to automatically set up
|
|
311
|
+
and shut down BYOI.
|
|
312
|
+
|
|
313
|
+
Args:
|
|
314
|
+
app: The FastAPI application.
|
|
315
|
+
|
|
316
|
+
Example:
|
|
317
|
+
```python
|
|
318
|
+
byoi = BYOI(config)
|
|
319
|
+
app = FastAPI(lifespan=byoi.lifespan)
|
|
320
|
+
```
|
|
321
|
+
"""
|
|
322
|
+
await self.setup(app)
|
|
323
|
+
try:
|
|
324
|
+
yield
|
|
325
|
+
finally:
|
|
326
|
+
await self.shutdown()
|
|
327
|
+
|
|
328
|
+
@staticmethod
|
|
329
|
+
def create_lifespan(
|
|
330
|
+
config: BYOIConfig,
|
|
331
|
+
):
|
|
332
|
+
"""Create a lifespan context manager for a given config.
|
|
333
|
+
|
|
334
|
+
This is a convenience method that creates a BYOI instance
|
|
335
|
+
and returns its lifespan method.
|
|
336
|
+
|
|
337
|
+
Args:
|
|
338
|
+
config: The BYOI configuration.
|
|
339
|
+
|
|
340
|
+
Returns:
|
|
341
|
+
A lifespan context manager.
|
|
342
|
+
|
|
343
|
+
Example:
|
|
344
|
+
```python
|
|
345
|
+
app = FastAPI(lifespan=BYOI.create_lifespan(config))
|
|
346
|
+
```
|
|
347
|
+
"""
|
|
348
|
+
byoi = BYOI(config)
|
|
349
|
+
return byoi.lifespan
|