byoi 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
byoi/cache.py ADDED
@@ -0,0 +1,346 @@
1
+ """Cache implementations for the BYOI library.
2
+
3
+ Provides caching for JWKS and OIDC discovery documents.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ import asyncio
9
+ import json
10
+ import logging
11
+ import time
12
+ import weakref
13
+ from abc import ABC, abstractmethod
14
+ from typing import TYPE_CHECKING, Any, Self
15
+
16
+ if TYPE_CHECKING:
17
+ from types import TracebackType
18
+
19
+ logger = logging.getLogger("byoi.cache")
20
+
21
+ # Track all InMemoryCache instances for cleanup on interpreter shutdown
22
+ _active_caches: weakref.WeakSet[InMemoryCache] = weakref.WeakSet()
23
+
24
+ __all__ = (
25
+ "CacheProtocol",
26
+ "InMemoryCache",
27
+ "RedisCache",
28
+ "NullCache",
29
+ )
30
+
31
+
32
+ class CacheProtocol(ABC):
33
+ """Abstract base class for cache implementations."""
34
+
35
+ @abstractmethod
36
+ async def get(self, key: str) -> Any | None:
37
+ """Get a value from the cache.
38
+
39
+ Args:
40
+ key: The cache key.
41
+
42
+ Returns:
43
+ The cached value, or None if not found or expired.
44
+ """
45
+ ...
46
+
47
+ @abstractmethod
48
+ async def set(self, key: str, value: Any, ttl_seconds: int | None = None) -> None:
49
+ """Set a value in the cache.
50
+
51
+ Args:
52
+ key: The cache key.
53
+ value: The value to cache (must be JSON-serializable).
54
+ ttl_seconds: Time-to-live in seconds. None means no expiration.
55
+ """
56
+ ...
57
+
58
+ @abstractmethod
59
+ async def delete(self, key: str) -> bool:
60
+ """Delete a value from the cache.
61
+
62
+ Args:
63
+ key: The cache key.
64
+
65
+ Returns:
66
+ True if the key was deleted, False if not found.
67
+ """
68
+ ...
69
+
70
+ @abstractmethod
71
+ async def clear(self) -> None:
72
+ """Clear all values from the cache."""
73
+ ...
74
+
75
+ @abstractmethod
76
+ async def close(self) -> None:
77
+ """Close the cache connection and release resources."""
78
+ ...
79
+
80
+ async def __aenter__(self) -> Self:
81
+ """Enter the async context manager."""
82
+ return self
83
+
84
+ async def __aexit__(
85
+ self,
86
+ exc_type: type[BaseException] | None,
87
+ exc_val: BaseException | None,
88
+ exc_tb: TracebackType | None,
89
+ ) -> None:
90
+ """Exit the async context manager and close the cache."""
91
+ await self.close()
92
+
93
+
94
+ class NullCache(CacheProtocol):
95
+ """A no-op cache implementation that doesn't cache anything.
96
+
97
+ Useful for testing or when caching is not desired.
98
+ """
99
+
100
+ async def get(self, key: str) -> Any | None:
101
+ """Always returns None."""
102
+ return None
103
+
104
+ async def set(self, key: str, value: Any, ttl_seconds: int | None = None) -> None:
105
+ """Does nothing."""
106
+ pass
107
+
108
+ async def delete(self, key: str) -> bool:
109
+ """Always returns False."""
110
+ return False
111
+
112
+ async def clear(self) -> None:
113
+ """Does nothing."""
114
+ pass
115
+
116
+ async def close(self) -> None:
117
+ """Does nothing."""
118
+ pass
119
+
120
+
121
+ class InMemoryCache(CacheProtocol):
122
+ """In-memory cache implementation using a dictionary.
123
+
124
+ Suitable for single-process deployments or development.
125
+
126
+ Note: This cache is not shared across processes. For multi-process
127
+ deployments, use RedisCache instead.
128
+ """
129
+
130
+ def __init__(self, cleanup_interval_seconds: int = 300) -> None:
131
+ """Initialize the in-memory cache.
132
+
133
+ Args:
134
+ cleanup_interval_seconds: How often to run cleanup of expired entries.
135
+ """
136
+ self._cache: dict[str, tuple[Any, float | None]] = {}
137
+ self._lock = asyncio.Lock()
138
+ self._cleanup_interval = cleanup_interval_seconds
139
+ self._cleanup_task: asyncio.Task[None] | None = None
140
+ self._closed = False
141
+ self._started = False
142
+ # Track this instance for cleanup on interpreter shutdown
143
+ _active_caches.add(self)
144
+
145
+ def __del__(self) -> None:
146
+ """Cancel cleanup task on garbage collection if not properly closed."""
147
+ if self._cleanup_task is not None and not self._cleanup_task.done():
148
+ self._cleanup_task.cancel()
149
+ self._cleanup_task = None
150
+ self._closed = True
151
+
152
+ async def start(self) -> None:
153
+ """Start the background cleanup task.
154
+
155
+ This method is idempotent - calling it multiple times has no effect.
156
+ It is also called automatically on first cache operation.
157
+ """
158
+ if self._started or self._closed:
159
+ return
160
+ self._started = True
161
+ if self._cleanup_task is None:
162
+ self._cleanup_task = asyncio.create_task(self._cleanup_loop())
163
+
164
+ async def _ensure_started(self) -> None:
165
+ """Ensure the cache cleanup task is started."""
166
+ if not self._started and not self._closed:
167
+ await self.start()
168
+
169
+ async def _cleanup_loop(self) -> None:
170
+ """Background task to periodically clean up expired entries."""
171
+ while not self._closed:
172
+ try:
173
+ await asyncio.sleep(self._cleanup_interval)
174
+ await self._cleanup_expired()
175
+ except asyncio.CancelledError:
176
+ break
177
+ except Exception as e:
178
+ # Log error but continue cleanup loop
179
+ logger.warning("Error during cache cleanup: %s", e)
180
+
181
+ async def _cleanup_expired(self) -> None:
182
+ """Remove expired entries from the cache."""
183
+ now = time.time()
184
+ async with self._lock:
185
+ expired_keys = [
186
+ key
187
+ for key, (_, expiry) in self._cache.items()
188
+ if expiry is not None and expiry < now
189
+ ]
190
+ for key in expired_keys:
191
+ del self._cache[key]
192
+
193
+ async def get(self, key: str) -> Any | None:
194
+ """Get a value from the cache."""
195
+ await self._ensure_started()
196
+ async with self._lock:
197
+ if key not in self._cache:
198
+ return None
199
+
200
+ value, expiry = self._cache[key]
201
+ if expiry is not None and expiry < time.time():
202
+ del self._cache[key]
203
+ return None
204
+
205
+ return value
206
+
207
+ async def set(self, key: str, value: Any, ttl_seconds: int | None = None) -> None:
208
+ """Set a value in the cache."""
209
+ await self._ensure_started()
210
+ expiry = time.time() + ttl_seconds if ttl_seconds is not None else None
211
+ async with self._lock:
212
+ self._cache[key] = (value, expiry)
213
+
214
+ async def delete(self, key: str) -> bool:
215
+ """Delete a value from the cache."""
216
+ await self._ensure_started()
217
+ async with self._lock:
218
+ if key in self._cache:
219
+ del self._cache[key]
220
+ return True
221
+ return False
222
+
223
+ async def clear(self) -> None:
224
+ """Clear all values from the cache."""
225
+ await self._ensure_started()
226
+ async with self._lock:
227
+ self._cache.clear()
228
+
229
+ async def close(self) -> None:
230
+ """Stop the cleanup task and close the cache."""
231
+ self._closed = True
232
+ if self._cleanup_task is not None:
233
+ self._cleanup_task.cancel()
234
+ try:
235
+ await self._cleanup_task
236
+ except asyncio.CancelledError:
237
+ pass
238
+ self._cleanup_task = None
239
+
240
+
241
+ class RedisCache(CacheProtocol):
242
+ """Redis-based cache implementation.
243
+
244
+ Suitable for multi-process and distributed deployments.
245
+
246
+ Requires the `redis` package to be installed:
247
+ pip install redis
248
+ """
249
+
250
+ def __init__(
251
+ self,
252
+ url: str = "redis://localhost:6379/0",
253
+ key_prefix: str = "byoi:",
254
+ max_connections: int | None = None,
255
+ **redis_kwargs: Any,
256
+ ) -> None:
257
+ """Initialize the Redis cache.
258
+
259
+ Args:
260
+ url: Redis connection URL.
261
+ key_prefix: Prefix for all cache keys.
262
+ max_connections: Maximum number of connections in the pool.
263
+ If None, uses the redis library default.
264
+ **redis_kwargs: Additional arguments passed to Redis client.
265
+ """
266
+ self._url = url
267
+ self._key_prefix = key_prefix
268
+ self._max_connections = max_connections
269
+ self._redis_kwargs = redis_kwargs
270
+ self._client: Any = None
271
+ self._closed = False
272
+
273
+ async def _get_client(self) -> Any:
274
+ """Get or create the Redis client."""
275
+ if self._client is None:
276
+ try:
277
+ import redis.asyncio as redis
278
+ from redis.asyncio.connection import ConnectionPool
279
+ except ImportError as e:
280
+ raise ImportError(
281
+ "Redis support requires the 'redis' package. "
282
+ "Install it with: pip install redis"
283
+ ) from e
284
+
285
+ # Create connection pool with max_connections if specified
286
+ pool_kwargs: dict[str, Any] = {}
287
+ if self._max_connections is not None:
288
+ pool_kwargs["max_connections"] = self._max_connections
289
+
290
+ pool = ConnectionPool.from_url(self._url, **pool_kwargs, **self._redis_kwargs)
291
+ self._client = redis.Redis(connection_pool=pool)
292
+
293
+ return self._client
294
+
295
+ def _make_key(self, key: str) -> str:
296
+ """Create a prefixed cache key."""
297
+ return f"{self._key_prefix}{key}"
298
+
299
+ async def get(self, key: str) -> Any | None:
300
+ """Get a value from Redis."""
301
+ client = await self._get_client()
302
+ value = await client.get(self._make_key(key))
303
+ if value is None:
304
+ return None
305
+ try:
306
+ return json.loads(value)
307
+ except (json.JSONDecodeError, TypeError):
308
+ return value
309
+
310
+ async def set(self, key: str, value: Any, ttl_seconds: int | None = None) -> None:
311
+ """Set a value in Redis."""
312
+ client = await self._get_client()
313
+ serialized = json.dumps(value)
314
+ if ttl_seconds is not None:
315
+ await client.setex(self._make_key(key), ttl_seconds, serialized)
316
+ else:
317
+ await client.set(self._make_key(key), serialized)
318
+
319
+ async def delete(self, key: str) -> bool:
320
+ """Delete a value from Redis."""
321
+ client = await self._get_client()
322
+ result = await client.delete(self._make_key(key))
323
+ return result > 0
324
+
325
+ async def clear(self) -> None:
326
+ """Clear all BYOI keys from Redis.
327
+
328
+ Warning: This uses SCAN to find keys with the prefix and deletes them.
329
+ For large datasets, this may be slow.
330
+ """
331
+ client = await self._get_client()
332
+ pattern = f"{self._key_prefix}*"
333
+ cursor = 0
334
+ while True:
335
+ cursor, keys = await client.scan(cursor, match=pattern, count=100)
336
+ if keys:
337
+ await client.delete(*keys)
338
+ if cursor == 0:
339
+ break
340
+
341
+ async def close(self) -> None:
342
+ """Close the Redis connection."""
343
+ self._closed = True
344
+ if self._client is not None:
345
+ await self._client.close()
346
+ self._client = None
byoi/config.py ADDED
@@ -0,0 +1,349 @@
1
+ """BYOI - Bring Your Own Identity.
2
+
3
+ Main configuration and lifecycle management for the BYOI library.
4
+ """
5
+
6
+ from collections.abc import AsyncGenerator
7
+ from contextlib import asynccontextmanager
8
+ from typing import TYPE_CHECKING, Any
9
+
10
+ import httpx
11
+ from fastapi import FastAPI
12
+ from pydantic import BaseModel, ConfigDict, Field
13
+
14
+ from byoi.cache import CacheProtocol, InMemoryCache
15
+ from byoi.errors import ConfigurationError
16
+ from byoi.models import OIDCProviderConfig
17
+ from byoi.providers import ProviderManager
18
+ from byoi.service import AuthService
19
+
20
+ if TYPE_CHECKING:
21
+ from byoi.repositories import (
22
+ AuthStateRepositoryProtocol,
23
+ LinkedIdentityRepositoryProtocol,
24
+ UserRepositoryProtocol,
25
+ )
26
+
27
+ __all__ = (
28
+ "BYOI",
29
+ "BYOIConfig",
30
+ )
31
+
32
+
33
+ class BYOIConfig(BaseModel):
34
+ """Configuration for the BYOI library.
35
+
36
+ This Pydantic model holds all configuration needed to initialize BYOI.
37
+
38
+ Attributes:
39
+ user_repository: Repository for user data persistence.
40
+ identity_repository: Repository for linked identity data persistence.
41
+ auth_state_repository: Repository for auth state (PKCE, nonce) persistence.
42
+ cache: Cache implementation for JWKS and discovery docs.
43
+ providers: List of OIDC provider configurations to register.
44
+ discovery_ttl_seconds: TTL for cached discovery documents.
45
+ jwks_ttl_seconds: TTL for cached JWKS.
46
+ state_expiration_minutes: How long auth states are valid.
47
+ """
48
+
49
+ model_config = ConfigDict(
50
+ arbitrary_types_allowed=True,
51
+ frozen=False,
52
+ )
53
+
54
+ user_repository: Any = Field(
55
+ ...,
56
+ description="Repository for user data persistence (must implement UserRepositoryProtocol)",
57
+ )
58
+ identity_repository: Any = Field(
59
+ ...,
60
+ description="Repository for linked identity data persistence (must implement LinkedIdentityRepositoryProtocol)",
61
+ )
62
+ auth_state_repository: Any = Field(
63
+ ...,
64
+ description="Repository for auth state (PKCE, nonce) persistence (must implement AuthStateRepositoryProtocol)",
65
+ )
66
+ cache: CacheProtocol | None = Field(
67
+ default=None,
68
+ description="Cache implementation for JWKS and discovery docs. Defaults to InMemoryCache.",
69
+ )
70
+ providers: list[OIDCProviderConfig] = Field(
71
+ default_factory=list,
72
+ description="List of OIDC provider configurations to register.",
73
+ )
74
+ discovery_ttl_seconds: int = Field(
75
+ default=3600,
76
+ ge=0,
77
+ description="TTL for cached discovery documents in seconds.",
78
+ )
79
+ jwks_ttl_seconds: int = Field(
80
+ default=3600,
81
+ ge=0,
82
+ description="TTL for cached JWKS in seconds.",
83
+ )
84
+ state_expiration_minutes: int = Field(
85
+ default=10,
86
+ ge=1,
87
+ le=60,
88
+ description="How long auth states are valid in minutes.",
89
+ )
90
+ http_timeout_seconds: float = Field(
91
+ default=30.0,
92
+ ge=1.0,
93
+ le=120.0,
94
+ description="Timeout for HTTP requests to identity providers in seconds.",
95
+ )
96
+
97
+
98
+ class BYOI:
99
+ """Main BYOI class for configuration and lifecycle management.
100
+
101
+ This class provides a single point to initialize and configure BYOI,
102
+ and can be integrated with FastAPI's lifespan for proper resource
103
+ management.
104
+
105
+ Example:
106
+ ```python
107
+ from contextlib import asynccontextmanager
108
+ from fastapi import FastAPI
109
+ from byoi import BYOI, BYOIConfig, OIDCProviderConfig
110
+
111
+ # Create your repository implementations
112
+ user_repo = MyUserRepository()
113
+ identity_repo = MyIdentityRepository()
114
+ auth_state_repo = MyAuthStateRepository()
115
+
116
+ # Configure BYOI
117
+ config = BYOIConfig(
118
+ user_repository=user_repo,
119
+ identity_repository=identity_repo,
120
+ auth_state_repository=auth_state_repo,
121
+ providers=[
122
+ OIDCProviderConfig(
123
+ name="google",
124
+ display_name="Google",
125
+ issuer="https://accounts.google.com",
126
+ client_id="your-client-id",
127
+ client_secret="your-client-secret",
128
+ ),
129
+ ],
130
+ )
131
+
132
+ byoi = BYOI(config)
133
+
134
+ @asynccontextmanager
135
+ async def lifespan(app: FastAPI):
136
+ await byoi.setup(app)
137
+ yield
138
+ await byoi.shutdown()
139
+
140
+ app = FastAPI(lifespan=lifespan)
141
+ ```
142
+ """
143
+
144
+ def __init__(self, config: BYOIConfig) -> None:
145
+ """Initialize BYOI.
146
+
147
+ Args:
148
+ config: The BYOI configuration.
149
+ """
150
+ self._config = config
151
+ self._provider_manager: ProviderManager | None = None
152
+ self._auth_service: AuthService | None = None
153
+ self._http_client: httpx.AsyncClient | None = None
154
+ self._cache: CacheProtocol | None = None
155
+ self._is_setup = False
156
+
157
+ @property
158
+ def provider_manager(self) -> ProviderManager:
159
+ """Get the provider manager.
160
+
161
+ Raises:
162
+ RuntimeError: If BYOI is not set up.
163
+ """
164
+ if self._provider_manager is None:
165
+ raise RuntimeError("BYOI is not set up. Call setup() first.")
166
+ return self._provider_manager
167
+
168
+ @property
169
+ def auth_service(self) -> AuthService:
170
+ """Get the auth service.
171
+
172
+ Raises:
173
+ RuntimeError: If BYOI is not set up.
174
+ """
175
+ if self._auth_service is None:
176
+ raise RuntimeError("BYOI is not set up. Call setup() first.")
177
+ return self._auth_service
178
+
179
+ async def setup(self, app: FastAPI | None = None) -> None:
180
+ """Set up BYOI and register with FastAPI app.
181
+
182
+ This method initializes all BYOI components and stores them
183
+ in the FastAPI app state for dependency injection.
184
+
185
+ Args:
186
+ app: Optional FastAPI application to register with.
187
+
188
+ Raises:
189
+ ConfigurationError: If configuration is invalid.
190
+ """
191
+ if self._is_setup:
192
+ return
193
+
194
+ # Create HTTP client
195
+ self._http_client = httpx.AsyncClient(
196
+ timeout=httpx.Timeout(self._config.http_timeout_seconds),
197
+ follow_redirects=True,
198
+ )
199
+
200
+ # Set up cache
201
+ if self._config.cache is not None:
202
+ self._cache = self._config.cache
203
+ else:
204
+ cache = InMemoryCache()
205
+ await cache.start()
206
+ self._cache = cache
207
+
208
+ # Create provider manager
209
+ self._provider_manager = ProviderManager(
210
+ cache=self._cache,
211
+ http_client=self._http_client,
212
+ discovery_ttl_seconds=self._config.discovery_ttl_seconds,
213
+ jwks_ttl_seconds=self._config.jwks_ttl_seconds,
214
+ http_timeout_seconds=self._config.http_timeout_seconds,
215
+ )
216
+
217
+ # Register providers
218
+ for provider_config in self._config.providers:
219
+ try:
220
+ self._provider_manager.register_provider(provider_config)
221
+ except ValueError as e:
222
+ raise ConfigurationError(str(e)) from e
223
+
224
+ # Create auth service
225
+ self._auth_service = AuthService(
226
+ provider_manager=self._provider_manager,
227
+ user_repository=self._config.user_repository,
228
+ identity_repository=self._config.identity_repository,
229
+ auth_state_repository=self._config.auth_state_repository,
230
+ state_expiration_minutes=self._config.state_expiration_minutes,
231
+ http_client=self._http_client,
232
+ http_timeout_seconds=self._config.http_timeout_seconds,
233
+ )
234
+
235
+ # Register with FastAPI app
236
+ if app is not None:
237
+ app.state.byoi_provider_manager = self._provider_manager
238
+ app.state.byoi_auth_service = self._auth_service
239
+ app.state.byoi = self
240
+
241
+ self._is_setup = True
242
+
243
+ async def shutdown(self) -> None:
244
+ """Shut down BYOI and release resources.
245
+
246
+ This should be called when the application is shutting down.
247
+ """
248
+ if not self._is_setup:
249
+ return
250
+
251
+ # Close auth service
252
+ if self._auth_service is not None:
253
+ await self._auth_service.close()
254
+ self._auth_service = None
255
+
256
+ # Close provider manager
257
+ if self._provider_manager is not None:
258
+ await self._provider_manager.close()
259
+ self._provider_manager = None
260
+
261
+ # Close HTTP client
262
+ if self._http_client is not None:
263
+ await self._http_client.aclose()
264
+ self._http_client = None
265
+
266
+ # Close cache
267
+ if self._cache is not None:
268
+ await self._cache.close()
269
+ self._cache = None
270
+
271
+ self._is_setup = False
272
+
273
+ def add_provider(self, config: OIDCProviderConfig) -> None:
274
+ """Add a provider after setup.
275
+
276
+ Args:
277
+ config: The provider configuration.
278
+
279
+ Raises:
280
+ RuntimeError: If BYOI is not set up.
281
+ ValueError: If the provider is already registered.
282
+ """
283
+ if not self._is_setup:
284
+ raise RuntimeError("BYOI is not set up. Call setup() first.")
285
+ self._provider_manager.register_provider(config)
286
+
287
+ def remove_provider(self, name: str) -> bool:
288
+ """Remove a provider after setup.
289
+
290
+ Args:
291
+ name: The provider name.
292
+
293
+ Returns:
294
+ True if the provider was removed, False if not found.
295
+
296
+ Raises:
297
+ RuntimeError: If BYOI is not set up.
298
+ """
299
+ if not self._is_setup:
300
+ raise RuntimeError("BYOI is not set up. Call setup() first.")
301
+ return self._provider_manager.unregister_provider(name)
302
+
303
+ @asynccontextmanager
304
+ async def lifespan(
305
+ self,
306
+ app: FastAPI,
307
+ ) -> AsyncGenerator[None, None]:
308
+ """FastAPI lifespan context manager for BYOI.
309
+
310
+ Use this as your FastAPI lifespan to automatically set up
311
+ and shut down BYOI.
312
+
313
+ Args:
314
+ app: The FastAPI application.
315
+
316
+ Example:
317
+ ```python
318
+ byoi = BYOI(config)
319
+ app = FastAPI(lifespan=byoi.lifespan)
320
+ ```
321
+ """
322
+ await self.setup(app)
323
+ try:
324
+ yield
325
+ finally:
326
+ await self.shutdown()
327
+
328
+ @staticmethod
329
+ def create_lifespan(
330
+ config: BYOIConfig,
331
+ ):
332
+ """Create a lifespan context manager for a given config.
333
+
334
+ This is a convenience method that creates a BYOI instance
335
+ and returns its lifespan method.
336
+
337
+ Args:
338
+ config: The BYOI configuration.
339
+
340
+ Returns:
341
+ A lifespan context manager.
342
+
343
+ Example:
344
+ ```python
345
+ app = FastAPI(lifespan=BYOI.create_lifespan(config))
346
+ ```
347
+ """
348
+ byoi = BYOI(config)
349
+ return byoi.lifespan