ragbits-chat 1.4.0.dev202512160238__py3-none-any.whl → 1.4.0.dev202601130240__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. ragbits/chat/api.py +364 -79
  2. ragbits/chat/auth/__init__.py +3 -1
  3. ragbits/chat/auth/backends.py +519 -81
  4. ragbits/chat/auth/base.py +8 -10
  5. ragbits/chat/auth/oauth2_providers.py +108 -0
  6. ragbits/chat/auth/provider_config.py +81 -0
  7. ragbits/chat/auth/session_store.py +178 -0
  8. ragbits/chat/auth/types.py +66 -29
  9. ragbits/chat/interface/types.py +15 -0
  10. ragbits/chat/persistence/sql.py +1 -0
  11. ragbits/chat/providers/model_provider.py +10 -8
  12. ragbits/chat/ui-build/assets/AuthGuard-Bq7UOJ7y.js +1 -0
  13. ragbits/chat/ui-build/assets/{ChatHistory--EyHdeYk.js → ChatHistory-B2hLBYMJ.js} +2 -2
  14. ragbits/chat/ui-build/assets/{ChatOptionsForm-BQJ-bYMu.js → ChatOptionsForm-bfNG8UIW.js} +1 -1
  15. ragbits/chat/ui-build/assets/CredentialsLogin-0g5-w2vR.js +1 -0
  16. ragbits/chat/ui-build/assets/{FeedbackForm-BYee4uF-.js → FeedbackForm-oSbly5oN.js} +1 -1
  17. ragbits/chat/ui-build/assets/Login-DSW_CNFu.js +1 -0
  18. ragbits/chat/ui-build/assets/LogoutButton-BQE8NNsg.js +1 -0
  19. ragbits/chat/ui-build/assets/OAuth2Login-EmJ39PUe.js +2 -0
  20. ragbits/chat/ui-build/assets/ShareButton-B7DyIVH0.js +1 -0
  21. ragbits/chat/ui-build/assets/{UsageButton-C0lzhbc6.js → UsageButton-BABA7a-w.js} +1 -1
  22. ragbits/chat/ui-build/assets/authStore-BfGlL8rp.js +1 -0
  23. ragbits/chat/ui-build/assets/{chunk-IGSAU2ZA-ZqFHUQCB.js → chunk-IGSAU2ZA-NXd1g0Qd.js} +1 -1
  24. ragbits/chat/ui-build/assets/{chunk-SSA7SXE4-Dj8WqXIN.js → chunk-SSA7SXE4-CJa0HuAU.js} +1 -1
  25. ragbits/chat/ui-build/assets/index-BZLU40Mk.js +83 -0
  26. ragbits/chat/ui-build/assets/index-Be0kkf3d.js +24 -0
  27. ragbits/chat/ui-build/assets/index-Bvn9K6h_.js +1 -0
  28. ragbits/chat/ui-build/assets/{index-Bpba6d6u.js → index-Ceq7Rkzy.js} +1 -1
  29. ragbits/chat/ui-build/assets/index-ClAYkAiv.css +1 -0
  30. ragbits/chat/ui-build/assets/useInitializeUserStore-DyHP7g8x.js +1 -0
  31. ragbits/chat/ui-build/assets/{useMenuTriggerState-DaMXoDzf.js → useMenuTriggerState-SaFmATkk.js} +1 -1
  32. ragbits/chat/ui-build/assets/{useSelectableItem-D8MlMsUd.js → useSelectableItem-DhuFnc0W.js} +1 -1
  33. ragbits/chat/ui-build/index.html +2 -2
  34. {ragbits_chat-1.4.0.dev202512160238.dist-info → ragbits_chat-1.4.0.dev202601130240.dist-info}/METADATA +2 -2
  35. ragbits_chat-1.4.0.dev202601130240.dist-info/RECORD +58 -0
  36. ragbits/chat/ui-build/assets/AuthGuard-Da8Duw5e.js +0 -1
  37. ragbits/chat/ui-build/assets/Login-DkcluI7q.js +0 -1
  38. ragbits/chat/ui-build/assets/LogoutButton-DSdCqsdC.js +0 -1
  39. ragbits/chat/ui-build/assets/ShareButton-YVgRY6bN.js +0 -1
  40. ragbits/chat/ui-build/assets/authStore-BFeUV-Bg.js +0 -1
  41. ragbits/chat/ui-build/assets/index-8hpVK8cj.js +0 -131
  42. ragbits/chat/ui-build/assets/index-BUbs7vFP.js +0 -1
  43. ragbits/chat/ui-build/assets/index-BZGp6GjF.js +0 -32
  44. ragbits/chat/ui-build/assets/index-CmsICuOz.css +0 -1
  45. ragbits_chat-1.4.0.dev202512160238.dist-info/RECORD +0 -52
  46. {ragbits_chat-1.4.0.dev202512160238.dist-info → ragbits_chat-1.4.0.dev202601130240.dist-info}/WHEEL +0 -0
@@ -1,14 +1,21 @@
1
+ import logging
2
+ import secrets
1
3
  import uuid
2
4
  from datetime import datetime, timedelta, timezone
3
5
  from typing import Any, cast
6
+ from urllib.parse import urlencode
4
7
 
5
8
  import bcrypt
6
- from jose import jwt
7
- from jose.exceptions import ExpiredSignatureError, JWTError
9
+ import httpx
8
10
 
9
11
  from ragbits.chat.auth.base import AuthenticationBackend, AuthenticationResponse, AuthOptions
10
- from ragbits.chat.auth.types import JWTToken, OAuth2Credentials, User, UserCredentials
11
- from ragbits.core.utils import get_secret_key
12
+ from ragbits.chat.auth.oauth2_providers import OAuth2Provider
13
+ from ragbits.chat.auth.types import OAuth2Credentials, Session, SessionStore, User, UserCredentials
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ # Minimum length for session ID truncation in logs (for readability while maintaining some privacy)
18
+ SESSION_ID_LOG_LENGTH = 8
12
19
 
13
20
 
14
21
  class ListAuthenticationBackend(AuthenticationBackend):
@@ -17,6 +24,8 @@ class ListAuthenticationBackend(AuthenticationBackend):
17
24
  def __init__(
18
25
  self,
19
26
  users: list[dict[str, Any]],
27
+ session_store: SessionStore,
28
+ session_expiry_hours: int = 24,
20
29
  default_options: AuthOptions | None = None,
21
30
  ):
22
31
  """
@@ -24,17 +33,16 @@ class ListAuthenticationBackend(AuthenticationBackend):
24
33
 
25
34
  Args:
26
35
  users: List of user dicts with 'username', 'password', and optional fields
27
- jwt_secret: Secret key for JWT jwt_token signing (generates random if not provided)
36
+ session_store: Session storage backend
37
+ session_expiry_hours: Hours until session expires (default: 24)
28
38
  default_options: Default options for the component
29
39
  """
30
40
  if default_options is None:
31
41
  default_options = AuthOptions()
32
42
  super().__init__(default_options)
33
43
  self.users = {}
34
- self.jwt_secret = get_secret_key()
35
- self.jwt_algorithm = default_options.jwt_algorithm
36
- self.token_expiry_minutes = default_options.token_expiry_minutes
37
- self.revoked_tokens: set[str] = set() # Blacklist for revoked tokens
44
+ self.session_store = session_store
45
+ self.session_expiry_hours = session_expiry_hours
38
46
 
39
47
  for user_data in users:
40
48
  # Hash passwords with bcrypt for security
@@ -51,7 +59,7 @@ class ListAuthenticationBackend(AuthenticationBackend):
51
59
  ),
52
60
  }
53
61
 
54
- async def authenticate_with_credentials(self, credentials: UserCredentials) -> AuthenticationResponse: # noqa: PLR6301
62
+ async def authenticate_with_credentials(self, credentials: UserCredentials) -> AuthenticationResponse:
55
63
  """
56
64
  Authenticate into backend using provided credentials
57
65
 
@@ -60,118 +68,548 @@ class ListAuthenticationBackend(AuthenticationBackend):
60
68
  Returns:
61
69
  AuthenticationResponse: Result of authentication
62
70
  """
71
+ logger.debug("Attempting credential authentication for user: %s", credentials.username)
72
+
63
73
  user_data = self.users.get(credentials.username)
64
74
  if not user_data:
75
+ logger.warning("Authentication failed: user '%s' not found", credentials.username)
65
76
  return AuthenticationResponse(success=False, error_message="User not found")
66
77
 
67
78
  # Verify password with bcrypt
68
79
  password_hash = str(user_data["password_hash"])
69
80
  if not bcrypt.checkpw(credentials.password.encode("utf-8"), password_hash.encode("utf-8")):
81
+ logger.warning("Authentication failed: invalid password for user '%s'", credentials.username)
70
82
  return AuthenticationResponse(success=False, error_message="Invalid password")
71
83
 
72
84
  user = cast(User, user_data["user"])
73
85
 
74
- # Create JWT jwt_token
75
- jwt_token = self._create_jwt_token(user)
86
+ # Create session
87
+ now = datetime.now(timezone.utc)
88
+ session = Session(
89
+ session_id="", # Will be generated by session store
90
+ user=user,
91
+ provider="credentials",
92
+ oauth_token="", # Not applicable for credentials
93
+ token_type="",
94
+ created_at=now,
95
+ expires_at=now + timedelta(hours=self.session_expiry_hours),
96
+ )
97
+ session_id = await self.session_store.create_session(session)
98
+ logger.info("User '%s' authenticated successfully, session created", credentials.username)
99
+ return AuthenticationResponse(success=True, user=user, session_id=session_id)
100
+
101
+ async def validate_session(self, session_id: str) -> AuthenticationResponse:
102
+ """
103
+ Validate a session.
76
104
 
77
- return AuthenticationResponse(success=True, user=user, jwt_token=jwt_token)
105
+ Args:
106
+ session_id: The session ID to validate
78
107
 
79
- def _create_jwt_token(self, user: User) -> JWTToken:
80
- """Create a JWT jwt_token for the user."""
81
- now = datetime.now(timezone.utc)
82
- expires_in = self.token_expiry_minutes * 60 # Convert to seconds
83
-
84
- payload = {
85
- "user_id": user.user_id,
86
- "username": user.username,
87
- "email": user.email,
88
- "full_name": user.full_name,
89
- "roles": user.roles,
90
- "metadata": user.metadata,
91
- "iat": now,
92
- "exp": now + timedelta(minutes=self.token_expiry_minutes),
108
+ Returns:
109
+ AuthenticationResponse with user if valid
110
+ """
111
+ logger.debug(
112
+ "Validating session: %s...",
113
+ session_id[:SESSION_ID_LOG_LENGTH] if len(session_id) >= SESSION_ID_LOG_LENGTH else session_id,
114
+ )
115
+ session = await self.session_store.get_session(session_id)
116
+
117
+ if not session:
118
+ logger.debug("Session validation failed: session not found or expired")
119
+ return AuthenticationResponse(success=False, error_message="Invalid or expired session")
120
+
121
+ logger.debug("Session validated successfully for user: %s", session.user.username)
122
+ return AuthenticationResponse(success=True, user=session.user)
123
+
124
+ async def authenticate_with_oauth2( # noqa: PLR6301
125
+ self, oauth_credentials: OAuth2Credentials
126
+ ) -> AuthenticationResponse:
127
+ """
128
+ Authenticate user with OAuth2 credentials.
129
+
130
+ Args:
131
+ oauth_credentials: OAuth2 credentials
132
+
133
+ Returns:
134
+ AuthenticationResponse: Authentication failure as OAuth2 is not supported
135
+ """
136
+ return AuthenticationResponse(success=False, error_message="OAuth2 not supported by ListAuthentication")
137
+
138
+ async def revoke_session(self, session_id: str) -> bool:
139
+ """
140
+ Revoke a session.
141
+
142
+ Args:
143
+ session_id: The session ID to revoke
144
+
145
+ Returns:
146
+ True if session was revoked
147
+ """
148
+ logger.debug(
149
+ "Revoking session: %s...",
150
+ session_id[:SESSION_ID_LOG_LENGTH] if len(session_id) >= SESSION_ID_LOG_LENGTH else session_id,
151
+ )
152
+ success = await self.session_store.delete_session(session_id)
153
+ if success:
154
+ logger.info("Session revoked successfully")
155
+ else:
156
+ logger.debug("Session revocation failed: session not found")
157
+ return success
158
+
159
+
160
+ class OAuth2AuthenticationBackend(AuthenticationBackend):
161
+ """Generic OAuth2 authentication backend supporting multiple providers."""
162
+
163
+ def __init__(
164
+ self,
165
+ session_store: SessionStore,
166
+ provider: OAuth2Provider,
167
+ client_id: str | None = None,
168
+ client_secret: str | None = None,
169
+ redirect_uri: str | None = None,
170
+ session_expiry_hours: int = 24,
171
+ default_options: AuthOptions | None = None,
172
+ ):
173
+ """
174
+ Initialize OAuth2 authentication backend.
175
+
176
+ Args:
177
+ session_store: Session storage backend
178
+ provider: OAuth2 provider implementation (e.g., DiscordOAuth2Provider, GoogleOAuth2Provider)
179
+ client_id: OAuth2 client ID (or set {PROVIDER}_CLIENT_ID env var)
180
+ client_secret: OAuth2 client secret (or set {PROVIDER}_CLIENT_SECRET env var)
181
+ redirect_uri: Callback URL for OAuth2 flow (or set OAUTH2_REDIRECT_URI env var,
182
+ defaults to 'http://localhost:8000/api/auth/callback/{provider_name}')
183
+ session_expiry_hours: Hours until session expires (default: 24)
184
+ default_options: Default options for the component
185
+
186
+ Note:
187
+ The default redirect_uri uses a provider-specific path for better isolation and debugging.
188
+ For Discord, it defaults to: http://localhost:8000/api/auth/callback/discord
189
+ """
190
+ import os
191
+
192
+ if default_options is None:
193
+ default_options = AuthOptions()
194
+ super().__init__(default_options)
195
+
196
+ self.session_store = session_store
197
+ self.session_expiry_hours = session_expiry_hours
198
+ self.provider = provider
199
+
200
+ # Get credentials from args or environment variables
201
+ self.client_id = client_id or os.getenv(f"{self.provider.name.upper()}_CLIENT_ID")
202
+ self.client_secret = client_secret or os.getenv(f"{self.provider.name.upper()}_CLIENT_SECRET")
203
+
204
+ # Use provider-specific callback URL for better isolation and debugging
205
+ if not redirect_uri:
206
+ # Get base URL from environment variable or use default
207
+ base_url = os.getenv("OAUTH2_CALLBACK_BASE_URL", "http://localhost:8000")
208
+ # Remove trailing slash from base URL
209
+ base_url = base_url.rstrip("/")
210
+ # Construct redirect URI with base URL and provider name
211
+ redirect_uri = f"{base_url}/api/auth/callback/{self.provider.name}"
212
+
213
+ # remove trailing slash from redirect URI
214
+ redirect_uri = redirect_uri.rstrip("/")
215
+ # set redirect URI on the backend
216
+ self.redirect_uri = redirect_uri
217
+
218
+ if not self.client_id or not self.client_secret:
219
+ raise ValueError(
220
+ f"OAuth2 credentials not provided. Either pass client_id and client_secret to the constructor, "
221
+ f"or set {self.provider.name.upper()}_CLIENT_ID and {self.provider.name.upper()}_CLIENT_SECRET "
222
+ f"environment variables."
223
+ )
224
+
225
+ # State storage for CSRF protection (in production, use Redis or similar)
226
+ self.pending_states: dict[str, datetime] = {}
227
+
228
+ def generate_authorize_url(self) -> tuple[str, str]:
229
+ """
230
+ Generate OAuth2 authorization URL with state parameter.
231
+
232
+ Returns:
233
+ Tuple of (authorize_url, state)
234
+ """
235
+ state = secrets.token_urlsafe(32)
236
+ self.pending_states[state] = datetime.now(timezone.utc)
237
+
238
+ params = {
239
+ "client_id": self.client_id,
240
+ "redirect_uri": self.redirect_uri,
241
+ "response_type": "code",
242
+ "scope": self.provider.scope,
243
+ "state": state,
93
244
  }
94
245
 
95
- access_token = jwt.encode(payload, self.jwt_secret, algorithm=self.jwt_algorithm)
246
+ authorize_url = f"{self.provider.authorize_url}?{urlencode(params)}"
247
+ logger.debug("Generated OAuth2 authorization URL for provider '%s'", self.provider.name)
248
+ return authorize_url, state
249
+
250
+ def verify_state(self, state: str) -> bool:
251
+ """
252
+ Verify OAuth2 state parameter for CSRF protection.
253
+
254
+ Args:
255
+ state: State parameter to verify
256
+
257
+ Returns:
258
+ True if state is valid, False otherwise
259
+ """
260
+ if state not in self.pending_states:
261
+ logger.warning("OAuth2 state verification failed: state not found (possible CSRF attempt)")
262
+ return False
263
+
264
+ # Check if state is not expired (valid for 10 minutes)
265
+ created_at = self.pending_states[state]
266
+ if datetime.now(timezone.utc) - created_at > timedelta(minutes=10):
267
+ del self.pending_states[state]
268
+ logger.warning("OAuth2 state verification failed: state expired")
269
+ return False
270
+
271
+ # Remove state after verification (one-time use)
272
+ del self.pending_states[state]
273
+ logger.debug("OAuth2 state verified successfully")
274
+ return True
275
+
276
+ def cleanup_expired_states(self) -> int:
277
+ """
278
+ Remove expired state tokens from storage.
279
+
280
+ This method removes state tokens that are older than 10 minutes to prevent
281
+ memory leaks from abandoned OAuth2 flows.
282
+
283
+ Returns:
284
+ Number of state tokens removed
285
+
286
+ Example:
287
+ To schedule periodic cleanup:
288
+
289
+ ```python
290
+ import asyncio
291
+
292
+
293
+ async def cleanup_loop():
294
+ while True:
295
+ await asyncio.sleep(600) # Run every 10 minutes
296
+ removed = oauth2_backend.cleanup_expired_states()
297
+ if removed > 0:
298
+ logger.info(f"Cleaned up {removed} expired OAuth2 states")
299
+ ```
300
+ """
301
+ now = datetime.now(timezone.utc)
302
+ states_to_remove = []
303
+
304
+ # Find expired states (older than 10 minutes)
305
+ for state, created_at in list(self.pending_states.items()):
306
+ if now - created_at > timedelta(minutes=10):
307
+ states_to_remove.append(state)
96
308
 
97
- token_type = "bearer" # noqa: S105
98
- return JWTToken(access_token=access_token, token_type=token_type, expires_in=expires_in, user=user)
309
+ # Remove expired states
310
+ for state in states_to_remove:
311
+ self.pending_states.pop(state, None)
99
312
 
100
- async def validate_token(self, token: str) -> AuthenticationResponse:
101
- """Validate a JWT jwt_token."""
102
- # Check if token is blacklisted (revoked)
103
- if token in self.revoked_tokens:
104
- return AuthenticationResponse(success=False, error_message="Token has been revoked")
313
+ if states_to_remove:
314
+ logger.info("Cleaned up %d expired OAuth2 state tokens", len(states_to_remove))
105
315
 
316
+ return len(states_to_remove)
317
+
318
+ async def exchange_code_for_token(self, code: str) -> str | None:
319
+ """
320
+ Exchange authorization code for access token.
321
+
322
+ Args:
323
+ code: Authorization code from OAuth2 provider
324
+
325
+ Returns:
326
+ Access token if successful, None otherwise
327
+ """
328
+ logger.debug("Exchanging authorization code for token with provider '%s'", self.provider.name)
329
+ try:
330
+ async with httpx.AsyncClient() as client:
331
+ response = await client.post(
332
+ self.provider.token_url,
333
+ data={
334
+ "client_id": self.client_id,
335
+ "client_secret": self.client_secret,
336
+ "grant_type": "authorization_code",
337
+ "code": code,
338
+ "redirect_uri": self.redirect_uri,
339
+ },
340
+ headers={"Content-Type": "application/x-www-form-urlencoded"},
341
+ )
342
+
343
+ if response.status_code != 200: # noqa: PLR2004
344
+ logger.error(
345
+ "Token exchange failed with provider '%s': HTTP %d",
346
+ self.provider.name,
347
+ response.status_code,
348
+ )
349
+ return None
350
+
351
+ token_data = response.json()
352
+ logger.debug("Token exchange successful with provider '%s'", self.provider.name)
353
+ return token_data.get("access_token")
354
+
355
+ except Exception as e:
356
+ logger.exception("Token exchange error with provider '%s': %s", self.provider.name, str(e))
357
+ return None
358
+
359
+ async def authenticate_with_oauth2(self, oauth_credentials: OAuth2Credentials) -> AuthenticationResponse:
360
+ """
361
+ Authenticate user with OAuth2 access token.
362
+
363
+ Args:
364
+ oauth_credentials: OAuth2 credentials with access token
365
+
366
+ Returns:
367
+ AuthenticationResponse with user and session ID
368
+ """
369
+ logger.debug("Authenticating user with OAuth2 provider '%s'", self.provider.name)
106
370
  try:
107
- payload = jwt.decode(token, self.jwt_secret, algorithms=[self.jwt_algorithm])
108
-
109
- # Reconstruct user from jwt_token payload
110
- user = User(
111
- user_id=payload["user_id"],
112
- username=payload["username"],
113
- email=payload.get("email"),
114
- full_name=payload.get("full_name"),
115
- roles=payload.get("roles", []),
116
- metadata=payload.get("metadata", {}),
371
+ # Fetch user info from provider
372
+ async with httpx.AsyncClient() as client:
373
+ response = await client.get(
374
+ self.provider.user_info_url,
375
+ headers={"Authorization": f"{oauth_credentials.token_type} {oauth_credentials.access_token}"},
376
+ )
377
+
378
+ if response.status_code != 200: # noqa: PLR2004
379
+ logger.error(
380
+ "Failed to fetch user info from '%s': HTTP %d",
381
+ self.provider.name,
382
+ response.status_code,
383
+ )
384
+ return AuthenticationResponse(
385
+ success=False,
386
+ error_message=f"Failed to fetch user info from {self.provider.name}: {response.status_code}",
387
+ )
388
+
389
+ user_data = response.json()
390
+
391
+ # Create User object from provider data
392
+ user = self.provider.create_user_from_data(user_data)
393
+ logger.debug("User info fetched successfully from '%s' for user: %s", self.provider.name, user.username)
394
+
395
+ # Create session
396
+ now = datetime.now(timezone.utc)
397
+ session = Session(
398
+ session_id="", # Will be generated by session store
399
+ user=user,
400
+ provider=self.provider.name,
401
+ oauth_token=oauth_credentials.access_token,
402
+ token_type=oauth_credentials.token_type,
403
+ created_at=now,
404
+ expires_at=now + timedelta(hours=self.session_expiry_hours),
405
+ )
406
+
407
+ session_id = await self.session_store.create_session(session)
408
+ logger.info(
409
+ "User '%s' authenticated successfully via OAuth2 provider '%s'",
410
+ user.username,
411
+ self.provider.name,
117
412
  )
118
413
 
119
- return AuthenticationResponse(success=True, user=user)
414
+ return AuthenticationResponse(success=True, user=user, session_id=session_id)
120
415
 
121
- except ExpiredSignatureError:
122
- return AuthenticationResponse(success=False, error_message="Token expired")
123
- except JWTError:
124
- return AuthenticationResponse(success=False, error_message="Invalid jwt_token")
416
+ except Exception as e:
417
+ logger.exception("OAuth2 authentication failed with provider '%s': %s", self.provider.name, str(e))
418
+ return AuthenticationResponse(
419
+ success=False,
420
+ error_message=f"OAuth2 authentication failed: {str(e)}",
421
+ )
125
422
 
126
- async def authenticate_with_oauth2(self, oauth_credentials: OAuth2Credentials) -> AuthenticationResponse: # noqa: PLR6301
423
+ async def authenticate_with_credentials(self, credentials: UserCredentials) -> AuthenticationResponse: # noqa: PLR6301
127
424
  """
128
- Authenticate user with OAuth2 credentials.
425
+ OAuth2 backend does not support credential authentication.
129
426
 
130
427
  Args:
131
- oauth_credentials: OAuth2 credentials
428
+ credentials: User credentials
132
429
 
133
430
  Returns:
134
- AuthenticationResponse: Authentication failure as OAuth2 is not supported
431
+ AuthenticationResponse with error
135
432
  """
136
- return AuthenticationResponse(success=False, error_message="OAuth2 not supported by ListAuthentication")
433
+ return AuthenticationResponse(
434
+ success=False,
435
+ error_message="Credential authentication not supported by OAuth2 backend",
436
+ )
137
437
 
138
- async def revoke_token(self, token: str) -> bool: # noqa: PLR6301
438
+ async def validate_session(self, session_id: str) -> AuthenticationResponse:
139
439
  """
140
- Revoke a JWT token.
440
+ Validate a session.
141
441
 
142
442
  Args:
143
- token: The JWT token to revoke
443
+ session_id: The session ID to validate
144
444
 
145
- Raises:
146
- NotImplementedError: This method is not implemented
445
+ Returns:
446
+ AuthenticationResponse with user if valid
147
447
  """
148
- raise NotImplementedError(
149
- "ListAuthenticationBackend is designed to run in development / testing scenarios. "
150
- "Revoking tokens is not implemented in this backend, "
151
- "if you need to revoke tokens please consider using different backend or implementing your own."
448
+ logger.debug(
449
+ "Validating OAuth2 session: %s...",
450
+ session_id[:SESSION_ID_LOG_LENGTH] if len(session_id) >= SESSION_ID_LOG_LENGTH else session_id,
152
451
  )
452
+ session = await self.session_store.get_session(session_id)
153
453
 
154
- def cleanup_expired_tokens(self) -> int:
454
+ if not session:
455
+ logger.debug("OAuth2 session validation failed: session not found or expired")
456
+ return AuthenticationResponse(success=False, error_message="Invalid or expired session")
457
+
458
+ logger.debug("OAuth2 session validated successfully for user: %s", session.user.username)
459
+ return AuthenticationResponse(success=True, user=session.user)
460
+
461
+ async def revoke_session(self, session_id: str) -> bool:
155
462
  """
156
- Remove expired tokens from the blacklist to prevent memory bloat.
463
+ Revoke a session.
464
+
465
+ Args:
466
+ session_id: The session ID to revoke
157
467
 
158
468
  Returns:
159
- Number of tokens removed
469
+ True if session was revoked
160
470
  """
161
- tokens_to_remove = []
471
+ logger.debug(
472
+ "Revoking OAuth2 session: %s...",
473
+ session_id[:SESSION_ID_LOG_LENGTH] if len(session_id) >= SESSION_ID_LOG_LENGTH else session_id,
474
+ )
475
+ success = await self.session_store.delete_session(session_id)
476
+ if success:
477
+ logger.info("OAuth2 session revoked successfully")
478
+ else:
479
+ logger.debug("OAuth2 session revocation failed: session not found")
480
+ return success
162
481
 
163
- for token in self.revoked_tokens:
164
- try:
165
- # Try to decode the token - if it raises ExpiredSignatureError, it's expired
166
- jwt.decode(token, self.jwt_secret, algorithms=[self.jwt_algorithm])
167
- except ExpiredSignatureError:
168
- # Token is expired, safe to remove from blacklist
169
- tokens_to_remove.append(token)
170
- except JWTError:
171
- # Token is invalid, remove it too
172
- tokens_to_remove.append(token)
173
482
 
174
- for token in tokens_to_remove:
175
- self.revoked_tokens.remove(token)
483
+ class MultiAuthenticationBackend(AuthenticationBackend):
484
+ """
485
+ Authentication backend that supports multiple authentication methods.
176
486
 
177
- return len(tokens_to_remove)
487
+ This backend allows combining credentials-based and OAuth2 authentication,
488
+ enabling users to choose their preferred login method.
489
+ """
490
+
491
+ def __init__(
492
+ self,
493
+ backends: list[AuthenticationBackend],
494
+ default_options: AuthOptions | None = None,
495
+ ):
496
+ """
497
+ Initialize multi-authentication backend.
498
+
499
+ Args:
500
+ backends: List of authentication backends to support
501
+ default_options: Default options for the component
502
+ """
503
+ if not backends:
504
+ raise ValueError("At least one authentication backend must be provided")
505
+
506
+ if default_options is None:
507
+ default_options = AuthOptions()
508
+ super().__init__(default_options)
509
+
510
+ self.backends = backends
511
+
512
+ def get_oauth2_backends(self) -> list[OAuth2AuthenticationBackend]:
513
+ """Get all OAuth2 backends."""
514
+ return [b for b in self.backends if isinstance(b, OAuth2AuthenticationBackend)]
515
+
516
+ def get_credentials_backends(self) -> list[AuthenticationBackend]:
517
+ """Get all credentials-based backends."""
518
+ return [b for b in self.backends if not isinstance(b, OAuth2AuthenticationBackend)]
519
+
520
+ async def authenticate_with_credentials(self, credentials: UserCredentials) -> AuthenticationResponse:
521
+ """
522
+ Try to authenticate with credentials using all credentials-based backends.
523
+
524
+ Args:
525
+ credentials: User credentials
526
+
527
+ Returns:
528
+ AuthenticationResponse from the first successful backend
529
+ """
530
+ logger.debug(
531
+ "Multi-backend credential authentication for user '%s' with %d backends",
532
+ credentials.username,
533
+ len(self.get_credentials_backends()),
534
+ )
535
+ errors = []
536
+
537
+ for backend in self.get_credentials_backends():
538
+ result = await backend.authenticate_with_credentials(credentials)
539
+ if result.success:
540
+ logger.debug("Credential authentication succeeded with backend: %s", type(backend).__name__)
541
+ return result
542
+ if result.error_message:
543
+ errors.append(result.error_message)
544
+
545
+ # All backends failed
546
+ error_msg = "; ".join(errors) if errors else "Authentication failed"
547
+ logger.warning("All credential backends failed for user '%s': %s", credentials.username, error_msg)
548
+ return AuthenticationResponse(success=False, error_message=error_msg)
549
+
550
+ async def authenticate_with_oauth2(self, oauth_credentials: OAuth2Credentials) -> AuthenticationResponse:
551
+ """
552
+ Try to authenticate with OAuth2 using all OAuth2 backends.
553
+
554
+ Args:
555
+ oauth_credentials: OAuth2 credentials
556
+
557
+ Returns:
558
+ AuthenticationResponse from the first successful backend
559
+ """
560
+ logger.debug("Multi-backend OAuth2 authentication with %d backends", len(self.get_oauth2_backends()))
561
+ errors = []
562
+
563
+ for backend in self.get_oauth2_backends():
564
+ result = await backend.authenticate_with_oauth2(oauth_credentials)
565
+ if result.success:
566
+ logger.debug("OAuth2 authentication succeeded with backend: %s", type(backend).__name__)
567
+ return result
568
+ if result.error_message:
569
+ errors.append(result.error_message)
570
+
571
+ # All backends failed
572
+ error_msg = "; ".join(errors) if errors else "OAuth2 authentication failed"
573
+ logger.warning("All OAuth2 backends failed: %s", error_msg)
574
+ return AuthenticationResponse(success=False, error_message=error_msg)
575
+
576
+ async def validate_session(self, session_id: str) -> AuthenticationResponse:
577
+ """
578
+ Validate a session.
579
+
580
+ Args:
581
+ session_id: The session ID to validate
582
+
583
+ Returns:
584
+ AuthenticationResponse with user if valid
585
+ """
586
+ # Try to get session from backends that have session stores
587
+ for backend in self.backends:
588
+ if hasattr(backend, "session_store") and backend.session_store:
589
+ result = await backend.validate_session(session_id)
590
+ if result.success:
591
+ return result
592
+
593
+ return AuthenticationResponse(success=False, error_message="Invalid or expired session")
594
+
595
+ async def revoke_session(self, session_id: str) -> bool:
596
+ """
597
+ Revoke a session.
598
+
599
+ Args:
600
+ session_id: The session ID to revoke
601
+
602
+ Returns:
603
+ True if any backend successfully revoked the session
604
+ """
605
+ for backend in self.backends:
606
+ if hasattr(backend, "session_store") and backend.session_store:
607
+ try:
608
+ success = await backend.revoke_session(session_id)
609
+ if success:
610
+ return True
611
+ except Exception: # noqa: S112
612
+ # Silently continue to next backend if one fails
613
+ continue
614
+
615
+ return False