ragbits-chat 1.4.0.dev202512160238__py3-none-any.whl → 1.4.0.dev202601010248__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ragbits/chat/api.py +364 -79
- ragbits/chat/auth/__init__.py +3 -1
- ragbits/chat/auth/backends.py +519 -81
- ragbits/chat/auth/base.py +8 -10
- ragbits/chat/auth/oauth2_providers.py +108 -0
- ragbits/chat/auth/provider_config.py +81 -0
- ragbits/chat/auth/session_store.py +178 -0
- ragbits/chat/auth/types.py +66 -29
- ragbits/chat/interface/types.py +15 -0
- ragbits/chat/providers/model_provider.py +10 -8
- ragbits/chat/ui-build/assets/AuthGuard-Bq7UOJ7y.js +1 -0
- ragbits/chat/ui-build/assets/{ChatHistory--EyHdeYk.js → ChatHistory-B2hLBYMJ.js} +2 -2
- ragbits/chat/ui-build/assets/{ChatOptionsForm-BQJ-bYMu.js → ChatOptionsForm-bfNG8UIW.js} +1 -1
- ragbits/chat/ui-build/assets/CredentialsLogin-0g5-w2vR.js +1 -0
- ragbits/chat/ui-build/assets/{FeedbackForm-BYee4uF-.js → FeedbackForm-oSbly5oN.js} +1 -1
- ragbits/chat/ui-build/assets/Login-DSW_CNFu.js +1 -0
- ragbits/chat/ui-build/assets/LogoutButton-BQE8NNsg.js +1 -0
- ragbits/chat/ui-build/assets/OAuth2Login-EmJ39PUe.js +2 -0
- ragbits/chat/ui-build/assets/ShareButton-B7DyIVH0.js +1 -0
- ragbits/chat/ui-build/assets/{UsageButton-C0lzhbc6.js → UsageButton-BABA7a-w.js} +1 -1
- ragbits/chat/ui-build/assets/authStore-BfGlL8rp.js +1 -0
- ragbits/chat/ui-build/assets/{chunk-IGSAU2ZA-ZqFHUQCB.js → chunk-IGSAU2ZA-NXd1g0Qd.js} +1 -1
- ragbits/chat/ui-build/assets/{chunk-SSA7SXE4-Dj8WqXIN.js → chunk-SSA7SXE4-CJa0HuAU.js} +1 -1
- ragbits/chat/ui-build/assets/index-BZLU40Mk.js +83 -0
- ragbits/chat/ui-build/assets/index-Be0kkf3d.js +24 -0
- ragbits/chat/ui-build/assets/index-Bvn9K6h_.js +1 -0
- ragbits/chat/ui-build/assets/{index-Bpba6d6u.js → index-Ceq7Rkzy.js} +1 -1
- ragbits/chat/ui-build/assets/index-ClAYkAiv.css +1 -0
- ragbits/chat/ui-build/assets/useInitializeUserStore-DyHP7g8x.js +1 -0
- ragbits/chat/ui-build/assets/{useMenuTriggerState-DaMXoDzf.js → useMenuTriggerState-SaFmATkk.js} +1 -1
- ragbits/chat/ui-build/assets/{useSelectableItem-D8MlMsUd.js → useSelectableItem-DhuFnc0W.js} +1 -1
- ragbits/chat/ui-build/index.html +2 -2
- {ragbits_chat-1.4.0.dev202512160238.dist-info → ragbits_chat-1.4.0.dev202601010248.dist-info}/METADATA +2 -2
- ragbits_chat-1.4.0.dev202601010248.dist-info/RECORD +58 -0
- ragbits/chat/ui-build/assets/AuthGuard-Da8Duw5e.js +0 -1
- ragbits/chat/ui-build/assets/Login-DkcluI7q.js +0 -1
- ragbits/chat/ui-build/assets/LogoutButton-DSdCqsdC.js +0 -1
- ragbits/chat/ui-build/assets/ShareButton-YVgRY6bN.js +0 -1
- ragbits/chat/ui-build/assets/authStore-BFeUV-Bg.js +0 -1
- ragbits/chat/ui-build/assets/index-8hpVK8cj.js +0 -131
- ragbits/chat/ui-build/assets/index-BUbs7vFP.js +0 -1
- ragbits/chat/ui-build/assets/index-BZGp6GjF.js +0 -32
- ragbits/chat/ui-build/assets/index-CmsICuOz.css +0 -1
- ragbits_chat-1.4.0.dev202512160238.dist-info/RECORD +0 -52
- {ragbits_chat-1.4.0.dev202512160238.dist-info → ragbits_chat-1.4.0.dev202601010248.dist-info}/WHEEL +0 -0
ragbits/chat/auth/backends.py
CHANGED
|
@@ -1,14 +1,21 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import secrets
|
|
1
3
|
import uuid
|
|
2
4
|
from datetime import datetime, timedelta, timezone
|
|
3
5
|
from typing import Any, cast
|
|
6
|
+
from urllib.parse import urlencode
|
|
4
7
|
|
|
5
8
|
import bcrypt
|
|
6
|
-
|
|
7
|
-
from jose.exceptions import ExpiredSignatureError, JWTError
|
|
9
|
+
import httpx
|
|
8
10
|
|
|
9
11
|
from ragbits.chat.auth.base import AuthenticationBackend, AuthenticationResponse, AuthOptions
|
|
10
|
-
from ragbits.chat.auth.
|
|
11
|
-
from ragbits.
|
|
12
|
+
from ragbits.chat.auth.oauth2_providers import OAuth2Provider
|
|
13
|
+
from ragbits.chat.auth.types import OAuth2Credentials, Session, SessionStore, User, UserCredentials
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
# Minimum length for session ID truncation in logs (for readability while maintaining some privacy)
|
|
18
|
+
SESSION_ID_LOG_LENGTH = 8
|
|
12
19
|
|
|
13
20
|
|
|
14
21
|
class ListAuthenticationBackend(AuthenticationBackend):
|
|
@@ -17,6 +24,8 @@ class ListAuthenticationBackend(AuthenticationBackend):
|
|
|
17
24
|
def __init__(
|
|
18
25
|
self,
|
|
19
26
|
users: list[dict[str, Any]],
|
|
27
|
+
session_store: SessionStore,
|
|
28
|
+
session_expiry_hours: int = 24,
|
|
20
29
|
default_options: AuthOptions | None = None,
|
|
21
30
|
):
|
|
22
31
|
"""
|
|
@@ -24,17 +33,16 @@ class ListAuthenticationBackend(AuthenticationBackend):
|
|
|
24
33
|
|
|
25
34
|
Args:
|
|
26
35
|
users: List of user dicts with 'username', 'password', and optional fields
|
|
27
|
-
|
|
36
|
+
session_store: Session storage backend
|
|
37
|
+
session_expiry_hours: Hours until session expires (default: 24)
|
|
28
38
|
default_options: Default options for the component
|
|
29
39
|
"""
|
|
30
40
|
if default_options is None:
|
|
31
41
|
default_options = AuthOptions()
|
|
32
42
|
super().__init__(default_options)
|
|
33
43
|
self.users = {}
|
|
34
|
-
self.
|
|
35
|
-
self.
|
|
36
|
-
self.token_expiry_minutes = default_options.token_expiry_minutes
|
|
37
|
-
self.revoked_tokens: set[str] = set() # Blacklist for revoked tokens
|
|
44
|
+
self.session_store = session_store
|
|
45
|
+
self.session_expiry_hours = session_expiry_hours
|
|
38
46
|
|
|
39
47
|
for user_data in users:
|
|
40
48
|
# Hash passwords with bcrypt for security
|
|
@@ -51,7 +59,7 @@ class ListAuthenticationBackend(AuthenticationBackend):
|
|
|
51
59
|
),
|
|
52
60
|
}
|
|
53
61
|
|
|
54
|
-
async def authenticate_with_credentials(self, credentials: UserCredentials) -> AuthenticationResponse:
|
|
62
|
+
async def authenticate_with_credentials(self, credentials: UserCredentials) -> AuthenticationResponse:
|
|
55
63
|
"""
|
|
56
64
|
Authenticate into backend using provided credentials
|
|
57
65
|
|
|
@@ -60,118 +68,548 @@ class ListAuthenticationBackend(AuthenticationBackend):
|
|
|
60
68
|
Returns:
|
|
61
69
|
AuthenticationResponse: Result of authentication
|
|
62
70
|
"""
|
|
71
|
+
logger.debug("Attempting credential authentication for user: %s", credentials.username)
|
|
72
|
+
|
|
63
73
|
user_data = self.users.get(credentials.username)
|
|
64
74
|
if not user_data:
|
|
75
|
+
logger.warning("Authentication failed: user '%s' not found", credentials.username)
|
|
65
76
|
return AuthenticationResponse(success=False, error_message="User not found")
|
|
66
77
|
|
|
67
78
|
# Verify password with bcrypt
|
|
68
79
|
password_hash = str(user_data["password_hash"])
|
|
69
80
|
if not bcrypt.checkpw(credentials.password.encode("utf-8"), password_hash.encode("utf-8")):
|
|
81
|
+
logger.warning("Authentication failed: invalid password for user '%s'", credentials.username)
|
|
70
82
|
return AuthenticationResponse(success=False, error_message="Invalid password")
|
|
71
83
|
|
|
72
84
|
user = cast(User, user_data["user"])
|
|
73
85
|
|
|
74
|
-
# Create
|
|
75
|
-
|
|
86
|
+
# Create session
|
|
87
|
+
now = datetime.now(timezone.utc)
|
|
88
|
+
session = Session(
|
|
89
|
+
session_id="", # Will be generated by session store
|
|
90
|
+
user=user,
|
|
91
|
+
provider="credentials",
|
|
92
|
+
oauth_token="", # Not applicable for credentials
|
|
93
|
+
token_type="",
|
|
94
|
+
created_at=now,
|
|
95
|
+
expires_at=now + timedelta(hours=self.session_expiry_hours),
|
|
96
|
+
)
|
|
97
|
+
session_id = await self.session_store.create_session(session)
|
|
98
|
+
logger.info("User '%s' authenticated successfully, session created", credentials.username)
|
|
99
|
+
return AuthenticationResponse(success=True, user=user, session_id=session_id)
|
|
100
|
+
|
|
101
|
+
async def validate_session(self, session_id: str) -> AuthenticationResponse:
|
|
102
|
+
"""
|
|
103
|
+
Validate a session.
|
|
76
104
|
|
|
77
|
-
|
|
105
|
+
Args:
|
|
106
|
+
session_id: The session ID to validate
|
|
78
107
|
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
"
|
|
90
|
-
"
|
|
91
|
-
|
|
92
|
-
|
|
108
|
+
Returns:
|
|
109
|
+
AuthenticationResponse with user if valid
|
|
110
|
+
"""
|
|
111
|
+
logger.debug(
|
|
112
|
+
"Validating session: %s...",
|
|
113
|
+
session_id[:SESSION_ID_LOG_LENGTH] if len(session_id) >= SESSION_ID_LOG_LENGTH else session_id,
|
|
114
|
+
)
|
|
115
|
+
session = await self.session_store.get_session(session_id)
|
|
116
|
+
|
|
117
|
+
if not session:
|
|
118
|
+
logger.debug("Session validation failed: session not found or expired")
|
|
119
|
+
return AuthenticationResponse(success=False, error_message="Invalid or expired session")
|
|
120
|
+
|
|
121
|
+
logger.debug("Session validated successfully for user: %s", session.user.username)
|
|
122
|
+
return AuthenticationResponse(success=True, user=session.user)
|
|
123
|
+
|
|
124
|
+
async def authenticate_with_oauth2( # noqa: PLR6301
|
|
125
|
+
self, oauth_credentials: OAuth2Credentials
|
|
126
|
+
) -> AuthenticationResponse:
|
|
127
|
+
"""
|
|
128
|
+
Authenticate user with OAuth2 credentials.
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
oauth_credentials: OAuth2 credentials
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
AuthenticationResponse: Authentication failure as OAuth2 is not supported
|
|
135
|
+
"""
|
|
136
|
+
return AuthenticationResponse(success=False, error_message="OAuth2 not supported by ListAuthentication")
|
|
137
|
+
|
|
138
|
+
async def revoke_session(self, session_id: str) -> bool:
|
|
139
|
+
"""
|
|
140
|
+
Revoke a session.
|
|
141
|
+
|
|
142
|
+
Args:
|
|
143
|
+
session_id: The session ID to revoke
|
|
144
|
+
|
|
145
|
+
Returns:
|
|
146
|
+
True if session was revoked
|
|
147
|
+
"""
|
|
148
|
+
logger.debug(
|
|
149
|
+
"Revoking session: %s...",
|
|
150
|
+
session_id[:SESSION_ID_LOG_LENGTH] if len(session_id) >= SESSION_ID_LOG_LENGTH else session_id,
|
|
151
|
+
)
|
|
152
|
+
success = await self.session_store.delete_session(session_id)
|
|
153
|
+
if success:
|
|
154
|
+
logger.info("Session revoked successfully")
|
|
155
|
+
else:
|
|
156
|
+
logger.debug("Session revocation failed: session not found")
|
|
157
|
+
return success
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
class OAuth2AuthenticationBackend(AuthenticationBackend):
|
|
161
|
+
"""Generic OAuth2 authentication backend supporting multiple providers."""
|
|
162
|
+
|
|
163
|
+
def __init__(
|
|
164
|
+
self,
|
|
165
|
+
session_store: SessionStore,
|
|
166
|
+
provider: OAuth2Provider,
|
|
167
|
+
client_id: str | None = None,
|
|
168
|
+
client_secret: str | None = None,
|
|
169
|
+
redirect_uri: str | None = None,
|
|
170
|
+
session_expiry_hours: int = 24,
|
|
171
|
+
default_options: AuthOptions | None = None,
|
|
172
|
+
):
|
|
173
|
+
"""
|
|
174
|
+
Initialize OAuth2 authentication backend.
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
session_store: Session storage backend
|
|
178
|
+
provider: OAuth2 provider implementation (e.g., DiscordOAuth2Provider, GoogleOAuth2Provider)
|
|
179
|
+
client_id: OAuth2 client ID (or set {PROVIDER}_CLIENT_ID env var)
|
|
180
|
+
client_secret: OAuth2 client secret (or set {PROVIDER}_CLIENT_SECRET env var)
|
|
181
|
+
redirect_uri: Callback URL for OAuth2 flow (or set OAUTH2_REDIRECT_URI env var,
|
|
182
|
+
defaults to 'http://localhost:8000/api/auth/callback/{provider_name}')
|
|
183
|
+
session_expiry_hours: Hours until session expires (default: 24)
|
|
184
|
+
default_options: Default options for the component
|
|
185
|
+
|
|
186
|
+
Note:
|
|
187
|
+
The default redirect_uri uses a provider-specific path for better isolation and debugging.
|
|
188
|
+
For Discord, it defaults to: http://localhost:8000/api/auth/callback/discord
|
|
189
|
+
"""
|
|
190
|
+
import os
|
|
191
|
+
|
|
192
|
+
if default_options is None:
|
|
193
|
+
default_options = AuthOptions()
|
|
194
|
+
super().__init__(default_options)
|
|
195
|
+
|
|
196
|
+
self.session_store = session_store
|
|
197
|
+
self.session_expiry_hours = session_expiry_hours
|
|
198
|
+
self.provider = provider
|
|
199
|
+
|
|
200
|
+
# Get credentials from args or environment variables
|
|
201
|
+
self.client_id = client_id or os.getenv(f"{self.provider.name.upper()}_CLIENT_ID")
|
|
202
|
+
self.client_secret = client_secret or os.getenv(f"{self.provider.name.upper()}_CLIENT_SECRET")
|
|
203
|
+
|
|
204
|
+
# Use provider-specific callback URL for better isolation and debugging
|
|
205
|
+
if not redirect_uri:
|
|
206
|
+
# Get base URL from environment variable or use default
|
|
207
|
+
base_url = os.getenv("OAUTH2_CALLBACK_BASE_URL", "http://localhost:8000")
|
|
208
|
+
# Remove trailing slash from base URL
|
|
209
|
+
base_url = base_url.rstrip("/")
|
|
210
|
+
# Construct redirect URI with base URL and provider name
|
|
211
|
+
redirect_uri = f"{base_url}/api/auth/callback/{self.provider.name}"
|
|
212
|
+
|
|
213
|
+
# remove trailing slash from redirect URI
|
|
214
|
+
redirect_uri = redirect_uri.rstrip("/")
|
|
215
|
+
# set redirect URI on the backend
|
|
216
|
+
self.redirect_uri = redirect_uri
|
|
217
|
+
|
|
218
|
+
if not self.client_id or not self.client_secret:
|
|
219
|
+
raise ValueError(
|
|
220
|
+
f"OAuth2 credentials not provided. Either pass client_id and client_secret to the constructor, "
|
|
221
|
+
f"or set {self.provider.name.upper()}_CLIENT_ID and {self.provider.name.upper()}_CLIENT_SECRET "
|
|
222
|
+
f"environment variables."
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
# State storage for CSRF protection (in production, use Redis or similar)
|
|
226
|
+
self.pending_states: dict[str, datetime] = {}
|
|
227
|
+
|
|
228
|
+
def generate_authorize_url(self) -> tuple[str, str]:
|
|
229
|
+
"""
|
|
230
|
+
Generate OAuth2 authorization URL with state parameter.
|
|
231
|
+
|
|
232
|
+
Returns:
|
|
233
|
+
Tuple of (authorize_url, state)
|
|
234
|
+
"""
|
|
235
|
+
state = secrets.token_urlsafe(32)
|
|
236
|
+
self.pending_states[state] = datetime.now(timezone.utc)
|
|
237
|
+
|
|
238
|
+
params = {
|
|
239
|
+
"client_id": self.client_id,
|
|
240
|
+
"redirect_uri": self.redirect_uri,
|
|
241
|
+
"response_type": "code",
|
|
242
|
+
"scope": self.provider.scope,
|
|
243
|
+
"state": state,
|
|
93
244
|
}
|
|
94
245
|
|
|
95
|
-
|
|
246
|
+
authorize_url = f"{self.provider.authorize_url}?{urlencode(params)}"
|
|
247
|
+
logger.debug("Generated OAuth2 authorization URL for provider '%s'", self.provider.name)
|
|
248
|
+
return authorize_url, state
|
|
249
|
+
|
|
250
|
+
def verify_state(self, state: str) -> bool:
|
|
251
|
+
"""
|
|
252
|
+
Verify OAuth2 state parameter for CSRF protection.
|
|
253
|
+
|
|
254
|
+
Args:
|
|
255
|
+
state: State parameter to verify
|
|
256
|
+
|
|
257
|
+
Returns:
|
|
258
|
+
True if state is valid, False otherwise
|
|
259
|
+
"""
|
|
260
|
+
if state not in self.pending_states:
|
|
261
|
+
logger.warning("OAuth2 state verification failed: state not found (possible CSRF attempt)")
|
|
262
|
+
return False
|
|
263
|
+
|
|
264
|
+
# Check if state is not expired (valid for 10 minutes)
|
|
265
|
+
created_at = self.pending_states[state]
|
|
266
|
+
if datetime.now(timezone.utc) - created_at > timedelta(minutes=10):
|
|
267
|
+
del self.pending_states[state]
|
|
268
|
+
logger.warning("OAuth2 state verification failed: state expired")
|
|
269
|
+
return False
|
|
270
|
+
|
|
271
|
+
# Remove state after verification (one-time use)
|
|
272
|
+
del self.pending_states[state]
|
|
273
|
+
logger.debug("OAuth2 state verified successfully")
|
|
274
|
+
return True
|
|
275
|
+
|
|
276
|
+
def cleanup_expired_states(self) -> int:
|
|
277
|
+
"""
|
|
278
|
+
Remove expired state tokens from storage.
|
|
279
|
+
|
|
280
|
+
This method removes state tokens that are older than 10 minutes to prevent
|
|
281
|
+
memory leaks from abandoned OAuth2 flows.
|
|
282
|
+
|
|
283
|
+
Returns:
|
|
284
|
+
Number of state tokens removed
|
|
285
|
+
|
|
286
|
+
Example:
|
|
287
|
+
To schedule periodic cleanup:
|
|
288
|
+
|
|
289
|
+
```python
|
|
290
|
+
import asyncio
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
async def cleanup_loop():
|
|
294
|
+
while True:
|
|
295
|
+
await asyncio.sleep(600) # Run every 10 minutes
|
|
296
|
+
removed = oauth2_backend.cleanup_expired_states()
|
|
297
|
+
if removed > 0:
|
|
298
|
+
logger.info(f"Cleaned up {removed} expired OAuth2 states")
|
|
299
|
+
```
|
|
300
|
+
"""
|
|
301
|
+
now = datetime.now(timezone.utc)
|
|
302
|
+
states_to_remove = []
|
|
303
|
+
|
|
304
|
+
# Find expired states (older than 10 minutes)
|
|
305
|
+
for state, created_at in list(self.pending_states.items()):
|
|
306
|
+
if now - created_at > timedelta(minutes=10):
|
|
307
|
+
states_to_remove.append(state)
|
|
96
308
|
|
|
97
|
-
|
|
98
|
-
|
|
309
|
+
# Remove expired states
|
|
310
|
+
for state in states_to_remove:
|
|
311
|
+
self.pending_states.pop(state, None)
|
|
99
312
|
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
# Check if token is blacklisted (revoked)
|
|
103
|
-
if token in self.revoked_tokens:
|
|
104
|
-
return AuthenticationResponse(success=False, error_message="Token has been revoked")
|
|
313
|
+
if states_to_remove:
|
|
314
|
+
logger.info("Cleaned up %d expired OAuth2 state tokens", len(states_to_remove))
|
|
105
315
|
|
|
316
|
+
return len(states_to_remove)
|
|
317
|
+
|
|
318
|
+
async def exchange_code_for_token(self, code: str) -> str | None:
|
|
319
|
+
"""
|
|
320
|
+
Exchange authorization code for access token.
|
|
321
|
+
|
|
322
|
+
Args:
|
|
323
|
+
code: Authorization code from OAuth2 provider
|
|
324
|
+
|
|
325
|
+
Returns:
|
|
326
|
+
Access token if successful, None otherwise
|
|
327
|
+
"""
|
|
328
|
+
logger.debug("Exchanging authorization code for token with provider '%s'", self.provider.name)
|
|
329
|
+
try:
|
|
330
|
+
async with httpx.AsyncClient() as client:
|
|
331
|
+
response = await client.post(
|
|
332
|
+
self.provider.token_url,
|
|
333
|
+
data={
|
|
334
|
+
"client_id": self.client_id,
|
|
335
|
+
"client_secret": self.client_secret,
|
|
336
|
+
"grant_type": "authorization_code",
|
|
337
|
+
"code": code,
|
|
338
|
+
"redirect_uri": self.redirect_uri,
|
|
339
|
+
},
|
|
340
|
+
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
if response.status_code != 200: # noqa: PLR2004
|
|
344
|
+
logger.error(
|
|
345
|
+
"Token exchange failed with provider '%s': HTTP %d",
|
|
346
|
+
self.provider.name,
|
|
347
|
+
response.status_code,
|
|
348
|
+
)
|
|
349
|
+
return None
|
|
350
|
+
|
|
351
|
+
token_data = response.json()
|
|
352
|
+
logger.debug("Token exchange successful with provider '%s'", self.provider.name)
|
|
353
|
+
return token_data.get("access_token")
|
|
354
|
+
|
|
355
|
+
except Exception as e:
|
|
356
|
+
logger.exception("Token exchange error with provider '%s': %s", self.provider.name, str(e))
|
|
357
|
+
return None
|
|
358
|
+
|
|
359
|
+
async def authenticate_with_oauth2(self, oauth_credentials: OAuth2Credentials) -> AuthenticationResponse:
|
|
360
|
+
"""
|
|
361
|
+
Authenticate user with OAuth2 access token.
|
|
362
|
+
|
|
363
|
+
Args:
|
|
364
|
+
oauth_credentials: OAuth2 credentials with access token
|
|
365
|
+
|
|
366
|
+
Returns:
|
|
367
|
+
AuthenticationResponse with user and session ID
|
|
368
|
+
"""
|
|
369
|
+
logger.debug("Authenticating user with OAuth2 provider '%s'", self.provider.name)
|
|
106
370
|
try:
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
371
|
+
# Fetch user info from provider
|
|
372
|
+
async with httpx.AsyncClient() as client:
|
|
373
|
+
response = await client.get(
|
|
374
|
+
self.provider.user_info_url,
|
|
375
|
+
headers={"Authorization": f"{oauth_credentials.token_type} {oauth_credentials.access_token}"},
|
|
376
|
+
)
|
|
377
|
+
|
|
378
|
+
if response.status_code != 200: # noqa: PLR2004
|
|
379
|
+
logger.error(
|
|
380
|
+
"Failed to fetch user info from '%s': HTTP %d",
|
|
381
|
+
self.provider.name,
|
|
382
|
+
response.status_code,
|
|
383
|
+
)
|
|
384
|
+
return AuthenticationResponse(
|
|
385
|
+
success=False,
|
|
386
|
+
error_message=f"Failed to fetch user info from {self.provider.name}: {response.status_code}",
|
|
387
|
+
)
|
|
388
|
+
|
|
389
|
+
user_data = response.json()
|
|
390
|
+
|
|
391
|
+
# Create User object from provider data
|
|
392
|
+
user = self.provider.create_user_from_data(user_data)
|
|
393
|
+
logger.debug("User info fetched successfully from '%s' for user: %s", self.provider.name, user.username)
|
|
394
|
+
|
|
395
|
+
# Create session
|
|
396
|
+
now = datetime.now(timezone.utc)
|
|
397
|
+
session = Session(
|
|
398
|
+
session_id="", # Will be generated by session store
|
|
399
|
+
user=user,
|
|
400
|
+
provider=self.provider.name,
|
|
401
|
+
oauth_token=oauth_credentials.access_token,
|
|
402
|
+
token_type=oauth_credentials.token_type,
|
|
403
|
+
created_at=now,
|
|
404
|
+
expires_at=now + timedelta(hours=self.session_expiry_hours),
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
session_id = await self.session_store.create_session(session)
|
|
408
|
+
logger.info(
|
|
409
|
+
"User '%s' authenticated successfully via OAuth2 provider '%s'",
|
|
410
|
+
user.username,
|
|
411
|
+
self.provider.name,
|
|
117
412
|
)
|
|
118
413
|
|
|
119
|
-
return AuthenticationResponse(success=True, user=user)
|
|
414
|
+
return AuthenticationResponse(success=True, user=user, session_id=session_id)
|
|
120
415
|
|
|
121
|
-
except
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
416
|
+
except Exception as e:
|
|
417
|
+
logger.exception("OAuth2 authentication failed with provider '%s': %s", self.provider.name, str(e))
|
|
418
|
+
return AuthenticationResponse(
|
|
419
|
+
success=False,
|
|
420
|
+
error_message=f"OAuth2 authentication failed: {str(e)}",
|
|
421
|
+
)
|
|
125
422
|
|
|
126
|
-
async def
|
|
423
|
+
async def authenticate_with_credentials(self, credentials: UserCredentials) -> AuthenticationResponse: # noqa: PLR6301
|
|
127
424
|
"""
|
|
128
|
-
|
|
425
|
+
OAuth2 backend does not support credential authentication.
|
|
129
426
|
|
|
130
427
|
Args:
|
|
131
|
-
|
|
428
|
+
credentials: User credentials
|
|
132
429
|
|
|
133
430
|
Returns:
|
|
134
|
-
AuthenticationResponse
|
|
431
|
+
AuthenticationResponse with error
|
|
135
432
|
"""
|
|
136
|
-
return AuthenticationResponse(
|
|
433
|
+
return AuthenticationResponse(
|
|
434
|
+
success=False,
|
|
435
|
+
error_message="Credential authentication not supported by OAuth2 backend",
|
|
436
|
+
)
|
|
137
437
|
|
|
138
|
-
async def
|
|
438
|
+
async def validate_session(self, session_id: str) -> AuthenticationResponse:
|
|
139
439
|
"""
|
|
140
|
-
|
|
440
|
+
Validate a session.
|
|
141
441
|
|
|
142
442
|
Args:
|
|
143
|
-
|
|
443
|
+
session_id: The session ID to validate
|
|
144
444
|
|
|
145
|
-
|
|
146
|
-
|
|
445
|
+
Returns:
|
|
446
|
+
AuthenticationResponse with user if valid
|
|
147
447
|
"""
|
|
148
|
-
|
|
149
|
-
"
|
|
150
|
-
|
|
151
|
-
"if you need to revoke tokens please consider using different backend or implementing your own."
|
|
448
|
+
logger.debug(
|
|
449
|
+
"Validating OAuth2 session: %s...",
|
|
450
|
+
session_id[:SESSION_ID_LOG_LENGTH] if len(session_id) >= SESSION_ID_LOG_LENGTH else session_id,
|
|
152
451
|
)
|
|
452
|
+
session = await self.session_store.get_session(session_id)
|
|
153
453
|
|
|
154
|
-
|
|
454
|
+
if not session:
|
|
455
|
+
logger.debug("OAuth2 session validation failed: session not found or expired")
|
|
456
|
+
return AuthenticationResponse(success=False, error_message="Invalid or expired session")
|
|
457
|
+
|
|
458
|
+
logger.debug("OAuth2 session validated successfully for user: %s", session.user.username)
|
|
459
|
+
return AuthenticationResponse(success=True, user=session.user)
|
|
460
|
+
|
|
461
|
+
async def revoke_session(self, session_id: str) -> bool:
|
|
155
462
|
"""
|
|
156
|
-
|
|
463
|
+
Revoke a session.
|
|
464
|
+
|
|
465
|
+
Args:
|
|
466
|
+
session_id: The session ID to revoke
|
|
157
467
|
|
|
158
468
|
Returns:
|
|
159
|
-
|
|
469
|
+
True if session was revoked
|
|
160
470
|
"""
|
|
161
|
-
|
|
471
|
+
logger.debug(
|
|
472
|
+
"Revoking OAuth2 session: %s...",
|
|
473
|
+
session_id[:SESSION_ID_LOG_LENGTH] if len(session_id) >= SESSION_ID_LOG_LENGTH else session_id,
|
|
474
|
+
)
|
|
475
|
+
success = await self.session_store.delete_session(session_id)
|
|
476
|
+
if success:
|
|
477
|
+
logger.info("OAuth2 session revoked successfully")
|
|
478
|
+
else:
|
|
479
|
+
logger.debug("OAuth2 session revocation failed: session not found")
|
|
480
|
+
return success
|
|
162
481
|
|
|
163
|
-
for token in self.revoked_tokens:
|
|
164
|
-
try:
|
|
165
|
-
# Try to decode the token - if it raises ExpiredSignatureError, it's expired
|
|
166
|
-
jwt.decode(token, self.jwt_secret, algorithms=[self.jwt_algorithm])
|
|
167
|
-
except ExpiredSignatureError:
|
|
168
|
-
# Token is expired, safe to remove from blacklist
|
|
169
|
-
tokens_to_remove.append(token)
|
|
170
|
-
except JWTError:
|
|
171
|
-
# Token is invalid, remove it too
|
|
172
|
-
tokens_to_remove.append(token)
|
|
173
482
|
|
|
174
|
-
|
|
175
|
-
|
|
483
|
+
class MultiAuthenticationBackend(AuthenticationBackend):
|
|
484
|
+
"""
|
|
485
|
+
Authentication backend that supports multiple authentication methods.
|
|
176
486
|
|
|
177
|
-
|
|
487
|
+
This backend allows combining credentials-based and OAuth2 authentication,
|
|
488
|
+
enabling users to choose their preferred login method.
|
|
489
|
+
"""
|
|
490
|
+
|
|
491
|
+
def __init__(
|
|
492
|
+
self,
|
|
493
|
+
backends: list[AuthenticationBackend],
|
|
494
|
+
default_options: AuthOptions | None = None,
|
|
495
|
+
):
|
|
496
|
+
"""
|
|
497
|
+
Initialize multi-authentication backend.
|
|
498
|
+
|
|
499
|
+
Args:
|
|
500
|
+
backends: List of authentication backends to support
|
|
501
|
+
default_options: Default options for the component
|
|
502
|
+
"""
|
|
503
|
+
if not backends:
|
|
504
|
+
raise ValueError("At least one authentication backend must be provided")
|
|
505
|
+
|
|
506
|
+
if default_options is None:
|
|
507
|
+
default_options = AuthOptions()
|
|
508
|
+
super().__init__(default_options)
|
|
509
|
+
|
|
510
|
+
self.backends = backends
|
|
511
|
+
|
|
512
|
+
def get_oauth2_backends(self) -> list[OAuth2AuthenticationBackend]:
|
|
513
|
+
"""Get all OAuth2 backends."""
|
|
514
|
+
return [b for b in self.backends if isinstance(b, OAuth2AuthenticationBackend)]
|
|
515
|
+
|
|
516
|
+
def get_credentials_backends(self) -> list[AuthenticationBackend]:
|
|
517
|
+
"""Get all credentials-based backends."""
|
|
518
|
+
return [b for b in self.backends if not isinstance(b, OAuth2AuthenticationBackend)]
|
|
519
|
+
|
|
520
|
+
async def authenticate_with_credentials(self, credentials: UserCredentials) -> AuthenticationResponse:
|
|
521
|
+
"""
|
|
522
|
+
Try to authenticate with credentials using all credentials-based backends.
|
|
523
|
+
|
|
524
|
+
Args:
|
|
525
|
+
credentials: User credentials
|
|
526
|
+
|
|
527
|
+
Returns:
|
|
528
|
+
AuthenticationResponse from the first successful backend
|
|
529
|
+
"""
|
|
530
|
+
logger.debug(
|
|
531
|
+
"Multi-backend credential authentication for user '%s' with %d backends",
|
|
532
|
+
credentials.username,
|
|
533
|
+
len(self.get_credentials_backends()),
|
|
534
|
+
)
|
|
535
|
+
errors = []
|
|
536
|
+
|
|
537
|
+
for backend in self.get_credentials_backends():
|
|
538
|
+
result = await backend.authenticate_with_credentials(credentials)
|
|
539
|
+
if result.success:
|
|
540
|
+
logger.debug("Credential authentication succeeded with backend: %s", type(backend).__name__)
|
|
541
|
+
return result
|
|
542
|
+
if result.error_message:
|
|
543
|
+
errors.append(result.error_message)
|
|
544
|
+
|
|
545
|
+
# All backends failed
|
|
546
|
+
error_msg = "; ".join(errors) if errors else "Authentication failed"
|
|
547
|
+
logger.warning("All credential backends failed for user '%s': %s", credentials.username, error_msg)
|
|
548
|
+
return AuthenticationResponse(success=False, error_message=error_msg)
|
|
549
|
+
|
|
550
|
+
async def authenticate_with_oauth2(self, oauth_credentials: OAuth2Credentials) -> AuthenticationResponse:
|
|
551
|
+
"""
|
|
552
|
+
Try to authenticate with OAuth2 using all OAuth2 backends.
|
|
553
|
+
|
|
554
|
+
Args:
|
|
555
|
+
oauth_credentials: OAuth2 credentials
|
|
556
|
+
|
|
557
|
+
Returns:
|
|
558
|
+
AuthenticationResponse from the first successful backend
|
|
559
|
+
"""
|
|
560
|
+
logger.debug("Multi-backend OAuth2 authentication with %d backends", len(self.get_oauth2_backends()))
|
|
561
|
+
errors = []
|
|
562
|
+
|
|
563
|
+
for backend in self.get_oauth2_backends():
|
|
564
|
+
result = await backend.authenticate_with_oauth2(oauth_credentials)
|
|
565
|
+
if result.success:
|
|
566
|
+
logger.debug("OAuth2 authentication succeeded with backend: %s", type(backend).__name__)
|
|
567
|
+
return result
|
|
568
|
+
if result.error_message:
|
|
569
|
+
errors.append(result.error_message)
|
|
570
|
+
|
|
571
|
+
# All backends failed
|
|
572
|
+
error_msg = "; ".join(errors) if errors else "OAuth2 authentication failed"
|
|
573
|
+
logger.warning("All OAuth2 backends failed: %s", error_msg)
|
|
574
|
+
return AuthenticationResponse(success=False, error_message=error_msg)
|
|
575
|
+
|
|
576
|
+
async def validate_session(self, session_id: str) -> AuthenticationResponse:
|
|
577
|
+
"""
|
|
578
|
+
Validate a session.
|
|
579
|
+
|
|
580
|
+
Args:
|
|
581
|
+
session_id: The session ID to validate
|
|
582
|
+
|
|
583
|
+
Returns:
|
|
584
|
+
AuthenticationResponse with user if valid
|
|
585
|
+
"""
|
|
586
|
+
# Try to get session from backends that have session stores
|
|
587
|
+
for backend in self.backends:
|
|
588
|
+
if hasattr(backend, "session_store") and backend.session_store:
|
|
589
|
+
result = await backend.validate_session(session_id)
|
|
590
|
+
if result.success:
|
|
591
|
+
return result
|
|
592
|
+
|
|
593
|
+
return AuthenticationResponse(success=False, error_message="Invalid or expired session")
|
|
594
|
+
|
|
595
|
+
async def revoke_session(self, session_id: str) -> bool:
|
|
596
|
+
"""
|
|
597
|
+
Revoke a session.
|
|
598
|
+
|
|
599
|
+
Args:
|
|
600
|
+
session_id: The session ID to revoke
|
|
601
|
+
|
|
602
|
+
Returns:
|
|
603
|
+
True if any backend successfully revoked the session
|
|
604
|
+
"""
|
|
605
|
+
for backend in self.backends:
|
|
606
|
+
if hasattr(backend, "session_store") and backend.session_store:
|
|
607
|
+
try:
|
|
608
|
+
success = await backend.revoke_session(session_id)
|
|
609
|
+
if success:
|
|
610
|
+
return True
|
|
611
|
+
except Exception: # noqa: S112
|
|
612
|
+
# Silently continue to next backend if one fails
|
|
613
|
+
continue
|
|
614
|
+
|
|
615
|
+
return False
|