ragbits-chat 1.4.0.dev202512151244__py3-none-any.whl → 1.4.0.dev202601010248__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ragbits/chat/api.py +364 -79
- ragbits/chat/auth/__init__.py +3 -1
- ragbits/chat/auth/backends.py +519 -81
- ragbits/chat/auth/base.py +8 -10
- ragbits/chat/auth/oauth2_providers.py +108 -0
- ragbits/chat/auth/provider_config.py +81 -0
- ragbits/chat/auth/session_store.py +178 -0
- ragbits/chat/auth/types.py +66 -29
- ragbits/chat/interface/types.py +15 -0
- ragbits/chat/providers/model_provider.py +10 -8
- ragbits/chat/ui-build/assets/AuthGuard-Bq7UOJ7y.js +1 -0
- ragbits/chat/ui-build/assets/{ChatHistory-Cp_DhrUx.js → ChatHistory-B2hLBYMJ.js} +2 -2
- ragbits/chat/ui-build/assets/{ChatOptionsForm-CNjzbIqN.js → ChatOptionsForm-bfNG8UIW.js} +1 -1
- ragbits/chat/ui-build/assets/CredentialsLogin-0g5-w2vR.js +1 -0
- ragbits/chat/ui-build/assets/{FeedbackForm-CmRSbYPS.js → FeedbackForm-oSbly5oN.js} +1 -1
- ragbits/chat/ui-build/assets/Login-DSW_CNFu.js +1 -0
- ragbits/chat/ui-build/assets/LogoutButton-BQE8NNsg.js +1 -0
- ragbits/chat/ui-build/assets/OAuth2Login-EmJ39PUe.js +2 -0
- ragbits/chat/ui-build/assets/ShareButton-B7DyIVH0.js +1 -0
- ragbits/chat/ui-build/assets/{UsageButton-B-N1J-sZ.js → UsageButton-BABA7a-w.js} +1 -1
- ragbits/chat/ui-build/assets/authStore-BfGlL8rp.js +1 -0
- ragbits/chat/ui-build/assets/{chunk-IGSAU2ZA-CsJAveMU.js → chunk-IGSAU2ZA-NXd1g0Qd.js} +1 -1
- ragbits/chat/ui-build/assets/{chunk-SSA7SXE4-BJI2Gxdq.js → chunk-SSA7SXE4-CJa0HuAU.js} +1 -1
- ragbits/chat/ui-build/assets/index-BZLU40Mk.js +83 -0
- ragbits/chat/ui-build/assets/index-Be0kkf3d.js +24 -0
- ragbits/chat/ui-build/assets/index-Bvn9K6h_.js +1 -0
- ragbits/chat/ui-build/assets/{index-v15bx9Do.js → index-Ceq7Rkzy.js} +1 -1
- ragbits/chat/ui-build/assets/index-ClAYkAiv.css +1 -0
- ragbits/chat/ui-build/assets/useInitializeUserStore-DyHP7g8x.js +1 -0
- ragbits/chat/ui-build/assets/{useMenuTriggerState-CTz3KfPq.js → useMenuTriggerState-SaFmATkk.js} +1 -1
- ragbits/chat/ui-build/assets/{useSelectableItem-DK6eABKK.js → useSelectableItem-DhuFnc0W.js} +1 -1
- ragbits/chat/ui-build/index.html +2 -2
- {ragbits_chat-1.4.0.dev202512151244.dist-info → ragbits_chat-1.4.0.dev202601010248.dist-info}/METADATA +2 -2
- ragbits_chat-1.4.0.dev202601010248.dist-info/RECORD +58 -0
- ragbits/chat/ui-build/assets/AuthGuard-B326tmZN.js +0 -1
- ragbits/chat/ui-build/assets/Login-Djq6QJ18.js +0 -1
- ragbits/chat/ui-build/assets/LogoutButton-Cn2L63Hk.js +0 -1
- ragbits/chat/ui-build/assets/ShareButton-lYj0v67r.js +0 -1
- ragbits/chat/ui-build/assets/authStore-DATNN-ps.js +0 -1
- ragbits/chat/ui-build/assets/index-B3hlerKe.js +0 -131
- ragbits/chat/ui-build/assets/index-B7bSwAmw.js +0 -32
- ragbits/chat/ui-build/assets/index-C_JcEI3R.js +0 -1
- ragbits/chat/ui-build/assets/index-CmsICuOz.css +0 -1
- ragbits_chat-1.4.0.dev202512151244.dist-info/RECORD +0 -52
- {ragbits_chat-1.4.0.dev202512151244.dist-info → ragbits_chat-1.4.0.dev202601010248.dist-info}/WHEEL +0 -0
ragbits/chat/auth/base.py
CHANGED
|
@@ -7,14 +7,13 @@ from pydantic import BaseModel
|
|
|
7
7
|
from ragbits.core.options import Options
|
|
8
8
|
from ragbits.core.utils.config_handling import ConfigurableComponent
|
|
9
9
|
|
|
10
|
-
from .types import
|
|
10
|
+
from .types import OAuth2Credentials, User, UserCredentials
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
class AuthOptions(Options):
|
|
14
14
|
"""Options for authentication backends."""
|
|
15
15
|
|
|
16
|
-
|
|
17
|
-
token_expiry_minutes: int = 24 * 60
|
|
16
|
+
pass
|
|
18
17
|
|
|
19
18
|
|
|
20
19
|
class AuthenticationResponse(BaseModel):
|
|
@@ -22,7 +21,7 @@ class AuthenticationResponse(BaseModel):
|
|
|
22
21
|
|
|
23
22
|
success: bool
|
|
24
23
|
user: User | None = None
|
|
25
|
-
|
|
24
|
+
session_id: str | None = None
|
|
26
25
|
error_message: str | None = None
|
|
27
26
|
|
|
28
27
|
|
|
@@ -60,26 +59,25 @@ class AuthenticationBackend(ConfigurableComponent[AuthOptions], ABC):
|
|
|
60
59
|
pass
|
|
61
60
|
|
|
62
61
|
@abstractmethod
|
|
63
|
-
async def
|
|
62
|
+
async def validate_session(self, session_id: str) -> AuthenticationResponse:
|
|
64
63
|
"""
|
|
65
|
-
Validate a
|
|
64
|
+
Validate a session.
|
|
66
65
|
|
|
67
66
|
Args:
|
|
68
|
-
|
|
67
|
+
session_id: The session ID to validate
|
|
69
68
|
|
|
70
69
|
Returns:
|
|
71
70
|
AuthenticationResult with user if valid
|
|
72
71
|
"""
|
|
73
|
-
# Default implementation for backward compatibility
|
|
74
72
|
pass
|
|
75
73
|
|
|
76
74
|
@abstractmethod
|
|
77
|
-
async def
|
|
75
|
+
async def revoke_session(self, session_id: str) -> bool:
|
|
78
76
|
"""
|
|
79
77
|
Revoke/logout a session.
|
|
80
78
|
|
|
81
79
|
Args:
|
|
82
|
-
|
|
80
|
+
session_id: The session ID to revoke
|
|
83
81
|
|
|
84
82
|
Returns:
|
|
85
83
|
True if successfully revoked
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
"""OAuth2 provider implementations for various authentication services."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from ragbits.chat.auth.types import User
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class OAuth2Provider(ABC):
|
|
10
|
+
"""Abstract base class for OAuth2 providers."""
|
|
11
|
+
|
|
12
|
+
@property
|
|
13
|
+
@abstractmethod
|
|
14
|
+
def name(self) -> str:
|
|
15
|
+
"""Provider name (e.g., 'discord')."""
|
|
16
|
+
pass
|
|
17
|
+
|
|
18
|
+
@property
|
|
19
|
+
@abstractmethod
|
|
20
|
+
def display_name(self) -> str:
|
|
21
|
+
"""Human-readable provider name (e.g., 'Discord')."""
|
|
22
|
+
pass
|
|
23
|
+
|
|
24
|
+
@property
|
|
25
|
+
@abstractmethod
|
|
26
|
+
def authorize_url(self) -> str:
|
|
27
|
+
"""OAuth2 authorization endpoint."""
|
|
28
|
+
pass
|
|
29
|
+
|
|
30
|
+
@property
|
|
31
|
+
@abstractmethod
|
|
32
|
+
def token_url(self) -> str:
|
|
33
|
+
"""OAuth2 token exchange endpoint."""
|
|
34
|
+
pass
|
|
35
|
+
|
|
36
|
+
@property
|
|
37
|
+
@abstractmethod
|
|
38
|
+
def user_info_url(self) -> str:
|
|
39
|
+
"""User info endpoint."""
|
|
40
|
+
pass
|
|
41
|
+
|
|
42
|
+
@property
|
|
43
|
+
@abstractmethod
|
|
44
|
+
def scope(self) -> str:
|
|
45
|
+
"""OAuth2 scopes to request."""
|
|
46
|
+
pass
|
|
47
|
+
|
|
48
|
+
@abstractmethod
|
|
49
|
+
def create_user_from_data(self, user_data: dict[str, Any]) -> User:
|
|
50
|
+
"""
|
|
51
|
+
Create a User object from provider-specific user data.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
user_data: Raw user data from the provider
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
User object
|
|
58
|
+
"""
|
|
59
|
+
pass
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class DiscordOAuth2Provider(OAuth2Provider):
|
|
63
|
+
"""Discord OAuth2 provider implementation."""
|
|
64
|
+
|
|
65
|
+
@property
|
|
66
|
+
def name(self) -> str:
|
|
67
|
+
"""Return the provider name."""
|
|
68
|
+
return "discord"
|
|
69
|
+
|
|
70
|
+
@property
|
|
71
|
+
def display_name(self) -> str:
|
|
72
|
+
"""Return the human-readable provider name."""
|
|
73
|
+
return "Discord"
|
|
74
|
+
|
|
75
|
+
@property
|
|
76
|
+
def authorize_url(self) -> str:
|
|
77
|
+
"""Return the OAuth2 authorization URL."""
|
|
78
|
+
return "https://discord.com/api/oauth2/authorize"
|
|
79
|
+
|
|
80
|
+
@property
|
|
81
|
+
def token_url(self) -> str:
|
|
82
|
+
"""Return the OAuth2 token exchange URL."""
|
|
83
|
+
return "https://discord.com/api/oauth2/token"
|
|
84
|
+
|
|
85
|
+
@property
|
|
86
|
+
def user_info_url(self) -> str:
|
|
87
|
+
"""Return the user info API URL."""
|
|
88
|
+
return "https://discord.com/api/users/@me"
|
|
89
|
+
|
|
90
|
+
@property
|
|
91
|
+
def scope(self) -> str:
|
|
92
|
+
"""Return the OAuth2 scope to request."""
|
|
93
|
+
return "identify email"
|
|
94
|
+
|
|
95
|
+
def create_user_from_data(self, user_data: dict[str, Any]) -> User: # noqa: PLR6301
|
|
96
|
+
"""Create User object from Discord data."""
|
|
97
|
+
return User(
|
|
98
|
+
user_id=f"discord_{user_data['id']}",
|
|
99
|
+
username=user_data.get("username", ""),
|
|
100
|
+
email=user_data.get("email"),
|
|
101
|
+
full_name=user_data.get("global_name"),
|
|
102
|
+
roles=["user"],
|
|
103
|
+
metadata={
|
|
104
|
+
"provider": "discord",
|
|
105
|
+
"avatar": user_data.get("avatar"),
|
|
106
|
+
"discriminator": user_data.get("discriminator"),
|
|
107
|
+
},
|
|
108
|
+
)
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
"""OAuth2 provider visual configuration (icons, colors, etc.)."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class OAuth2ProviderVisualConfig:
|
|
7
|
+
"""Visual configuration for OAuth2 providers (brand colors, icons, etc.)."""
|
|
8
|
+
|
|
9
|
+
def __init__(
|
|
10
|
+
self,
|
|
11
|
+
name: str,
|
|
12
|
+
display_name: str,
|
|
13
|
+
color: str,
|
|
14
|
+
button_color: str | None = None,
|
|
15
|
+
text_color: str = "#FFFFFF",
|
|
16
|
+
icon_svg: str | None = None,
|
|
17
|
+
):
|
|
18
|
+
"""
|
|
19
|
+
Initialize provider visual configuration.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
name: Provider identifier (e.g., 'google', 'discord')
|
|
23
|
+
display_name: Human-readable provider name
|
|
24
|
+
color: Brand color for the provider
|
|
25
|
+
button_color: Optional button background color (defaults to color)
|
|
26
|
+
text_color: Button text color (defaults to white)
|
|
27
|
+
icon_svg: Optional SVG icon as string
|
|
28
|
+
"""
|
|
29
|
+
self.name = name
|
|
30
|
+
self.display_name = display_name
|
|
31
|
+
self.color = color
|
|
32
|
+
self.button_color = button_color or color
|
|
33
|
+
self.text_color = text_color
|
|
34
|
+
self.icon_svg = icon_svg
|
|
35
|
+
|
|
36
|
+
def to_dict(self) -> dict[str, Any]:
|
|
37
|
+
"""Convert to dictionary for API serialization."""
|
|
38
|
+
return {
|
|
39
|
+
"name": self.name,
|
|
40
|
+
"display_name": self.display_name,
|
|
41
|
+
"color": self.color,
|
|
42
|
+
"button_color": self.button_color,
|
|
43
|
+
"text_color": self.text_color,
|
|
44
|
+
"icon_svg": self.icon_svg,
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
# Provider visual configurations
|
|
49
|
+
PROVIDER_CONFIGS: dict[str, OAuth2ProviderVisualConfig] = {
|
|
50
|
+
"discord": OAuth2ProviderVisualConfig(
|
|
51
|
+
name="discord",
|
|
52
|
+
display_name="Discord",
|
|
53
|
+
color="#5865F2",
|
|
54
|
+
icon_svg='<svg width="20" height="20" viewBox="0 0 71 55" fill="none" xmlns="http://www.w3.org/2000/svg"><g clipPath="url(#clip0)"><path d="M60.1045 4.8978C55.5792 2.8214 50.7265 1.2916 45.6527 0.41542C45.5603 0.39851 45.468 0.440769 45.4204 0.525289C44.7963 1.6353 44.105 3.0834 43.6209 4.2216C38.1637 3.4046 32.7345 3.4046 27.3892 4.2216C26.905 3.0581 26.1886 1.6353 25.5617 0.525289C25.5141 0.443589 25.4218 0.40133 25.3294 0.41542C20.2584 1.2888 15.4057 2.8186 10.8776 4.8978C10.8384 4.9147 10.8048 4.9429 10.7825 4.9795C1.57795 18.7309 -0.943561 32.1443 0.293408 45.3914C0.299005 45.4562 0.335386 45.5182 0.385761 45.5576C6.45866 50.0174 12.3413 52.7249 18.1147 54.5195C18.2071 54.5477 18.305 54.5139 18.3638 54.4378C19.7295 52.5728 20.9469 50.6063 21.9907 48.5383C22.0523 48.4172 21.9935 48.2735 21.8676 48.2256C19.9366 47.4931 18.0979 46.6 16.3292 45.5858C16.1893 45.5041 16.1781 45.304 16.3068 45.2082C16.679 44.9293 17.0513 44.6391 17.4067 44.3461C17.471 44.2926 17.5606 44.2813 17.6362 44.3151C29.2558 49.6202 41.8354 49.6202 53.3179 44.3151C53.3935 44.2785 53.4831 44.2898 53.5502 44.3433C53.9057 44.6363 54.2779 44.9293 54.6529 45.2082C54.7816 45.304 54.7732 45.5041 54.6333 45.5858C52.8646 46.6197 51.0259 47.4931 49.0921 48.2228C48.9662 48.2707 48.9102 48.4172 48.9718 48.5383C50.038 50.6034 51.2554 52.5699 52.5959 54.435C52.6519 54.5139 52.7526 54.5477 52.845 54.5195C58.6464 52.7249 64.529 50.0174 70.6019 45.5576C70.6551 45.5182 70.6887 45.459 70.6943 45.3942C72.1747 30.0791 68.2147 16.7757 60.1968 4.9823C60.1772 4.9429 60.1437 4.9147 60.1045 4.8978ZM23.7259 37.3253C20.2276 37.3253 17.3451 34.1136 17.3451 30.1693C17.3451 26.225 20.1717 23.0133 23.7259 23.0133C27.308 23.0133 30.1626 26.2532 30.1066 30.1693C30.1066 34.1136 27.28 37.3253 23.7259 37.3253ZM47.3178 37.3253C43.8196 37.3253 40.9371 34.1136 40.9371 30.1693C40.9371 26.225 43.7636 23.0133 47.3178 23.0133C50.9 23.0133 53.7545 26.2532 53.6986 30.1693C53.6986 34.1136 50.9 37.3253 47.3178 37.3253Z" fill="currentColor"/></g><defs><clipPath id="clip0"><rect width="71" height="55" fill="white"/></clipPath></defs></svg>', # noqa: E501
|
|
55
|
+
)
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def get_provider_visual_config(provider_name: str) -> OAuth2ProviderVisualConfig:
|
|
60
|
+
"""
|
|
61
|
+
Get visual configuration for a provider.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
provider_name: Provider identifier
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
Visual configuration for the provider, or default config if not found
|
|
68
|
+
"""
|
|
69
|
+
config = PROVIDER_CONFIGS.get(provider_name.lower())
|
|
70
|
+
if config:
|
|
71
|
+
return config
|
|
72
|
+
|
|
73
|
+
# Return default configuration for unknown providers
|
|
74
|
+
capitalized_name = provider_name.capitalize()
|
|
75
|
+
return OAuth2ProviderVisualConfig(
|
|
76
|
+
name=provider_name,
|
|
77
|
+
display_name=capitalized_name,
|
|
78
|
+
color="#6B7280",
|
|
79
|
+
text_color="#FFFFFF",
|
|
80
|
+
icon_svg='<svg width="20" height="20" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg"><path d="M12 2C6.48 2 2 6.48 2 12s4.48 10 10 10 10-4.48 10-10S17.52 2 12 2zm0 18c-4.41 0-8-3.59-8-8s3.59-8 8-8 8 3.59 8 8-3.59 8-8 8zm-1-13h2v6h-2zm0 8h2v2h-2z" fill="currentColor"/></svg>', # noqa: E501
|
|
81
|
+
)
|
|
@@ -0,0 +1,178 @@
|
|
|
1
|
+
"""Session storage implementations."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import logging
|
|
5
|
+
import secrets
|
|
6
|
+
from datetime import datetime, timezone
|
|
7
|
+
|
|
8
|
+
from ragbits.chat.auth.types import Session, SessionStore
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
# Minimum length for session ID truncation in logs (for readability while maintaining some privacy)
|
|
13
|
+
SESSION_ID_LOG_LENGTH = 8
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class InMemorySessionStore(SessionStore):
|
|
17
|
+
"""In-memory session store implementation."""
|
|
18
|
+
|
|
19
|
+
def __init__(self) -> None:
|
|
20
|
+
"""Initialize the in-memory session store."""
|
|
21
|
+
self.sessions: dict[str, Session] = {}
|
|
22
|
+
self.lock = asyncio.Lock()
|
|
23
|
+
|
|
24
|
+
async def create_session(self, session: Session) -> str:
|
|
25
|
+
"""
|
|
26
|
+
Create a new session.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
session: Session object to store
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
Session ID
|
|
33
|
+
"""
|
|
34
|
+
session_id = secrets.token_urlsafe(32)
|
|
35
|
+
session.session_id = session_id
|
|
36
|
+
|
|
37
|
+
async with self.lock:
|
|
38
|
+
self.sessions[session_id] = session
|
|
39
|
+
|
|
40
|
+
logger.debug("Created session for user: %s", session.user.username)
|
|
41
|
+
return session_id
|
|
42
|
+
|
|
43
|
+
async def get_session(self, session_id: str) -> Session | None:
|
|
44
|
+
"""
|
|
45
|
+
Retrieve a session by ID.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
session_id: Session ID to retrieve
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
Session object if found and not expired, None otherwise
|
|
52
|
+
"""
|
|
53
|
+
async with self.lock:
|
|
54
|
+
session = self.sessions.get(session_id)
|
|
55
|
+
|
|
56
|
+
if not session:
|
|
57
|
+
logger.debug(
|
|
58
|
+
"Session not found: %s...",
|
|
59
|
+
session_id[:SESSION_ID_LOG_LENGTH] if len(session_id) >= SESSION_ID_LOG_LENGTH else session_id,
|
|
60
|
+
)
|
|
61
|
+
return None
|
|
62
|
+
|
|
63
|
+
# Check if session is expired
|
|
64
|
+
if datetime.now(timezone.utc) > session.expires_at:
|
|
65
|
+
# Session expired, delete it
|
|
66
|
+
logger.debug(
|
|
67
|
+
"Session expired, removing: %s...",
|
|
68
|
+
session_id[:SESSION_ID_LOG_LENGTH] if len(session_id) >= SESSION_ID_LOG_LENGTH else session_id,
|
|
69
|
+
)
|
|
70
|
+
await self.delete_session(session_id)
|
|
71
|
+
return None
|
|
72
|
+
|
|
73
|
+
return session
|
|
74
|
+
|
|
75
|
+
async def delete_session(self, session_id: str) -> bool:
|
|
76
|
+
"""
|
|
77
|
+
Delete a session.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
session_id: Session ID to delete
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
True if session was deleted, False if not found
|
|
84
|
+
"""
|
|
85
|
+
async with self.lock:
|
|
86
|
+
if session_id in self.sessions:
|
|
87
|
+
del self.sessions[session_id]
|
|
88
|
+
logger.debug(
|
|
89
|
+
"Session deleted: %s...",
|
|
90
|
+
session_id[:SESSION_ID_LOG_LENGTH] if len(session_id) >= SESSION_ID_LOG_LENGTH else session_id,
|
|
91
|
+
)
|
|
92
|
+
return True
|
|
93
|
+
|
|
94
|
+
logger.debug(
|
|
95
|
+
"Session not found for deletion: %s...",
|
|
96
|
+
session_id[:SESSION_ID_LOG_LENGTH] if len(session_id) >= SESSION_ID_LOG_LENGTH else session_id,
|
|
97
|
+
)
|
|
98
|
+
return False
|
|
99
|
+
|
|
100
|
+
def cleanup_expired_sessions(self) -> int:
|
|
101
|
+
"""
|
|
102
|
+
Remove expired sessions from storage.
|
|
103
|
+
|
|
104
|
+
This method is synchronous for compatibility with background task schedulers.
|
|
105
|
+
It acquires the lock internally and skips cleanup if the lock is held.
|
|
106
|
+
|
|
107
|
+
Returns:
|
|
108
|
+
Number of sessions removed
|
|
109
|
+
|
|
110
|
+
Example:
|
|
111
|
+
To schedule periodic cleanup, you can use APScheduler or a similar library:
|
|
112
|
+
|
|
113
|
+
```python
|
|
114
|
+
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
|
115
|
+
from ragbits.chat.auth.session_store import InMemorySessionStore
|
|
116
|
+
|
|
117
|
+
session_store = InMemorySessionStore()
|
|
118
|
+
scheduler = AsyncIOScheduler()
|
|
119
|
+
|
|
120
|
+
# Run cleanup every hour
|
|
121
|
+
scheduler.add_job(
|
|
122
|
+
session_store.cleanup_expired_sessions,
|
|
123
|
+
"interval",
|
|
124
|
+
hours=1,
|
|
125
|
+
id="cleanup_sessions",
|
|
126
|
+
)
|
|
127
|
+
scheduler.start()
|
|
128
|
+
```
|
|
129
|
+
|
|
130
|
+
Or using FastAPI's lifespan with asyncio:
|
|
131
|
+
|
|
132
|
+
```python
|
|
133
|
+
import asyncio
|
|
134
|
+
from contextlib import asynccontextmanager
|
|
135
|
+
from fastapi import FastAPI
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
@asynccontextmanager
|
|
139
|
+
async def lifespan(app: FastAPI):
|
|
140
|
+
async def cleanup_task():
|
|
141
|
+
while True:
|
|
142
|
+
await asyncio.sleep(3600) # Run every hour
|
|
143
|
+
removed = session_store.cleanup_expired_sessions()
|
|
144
|
+
if removed > 0:
|
|
145
|
+
print(f"Cleaned up {removed} expired sessions")
|
|
146
|
+
|
|
147
|
+
task = asyncio.create_task(cleanup_task())
|
|
148
|
+
|
|
149
|
+
Yield:
|
|
150
|
+
task.cancel()
|
|
151
|
+
```
|
|
152
|
+
"""
|
|
153
|
+
now = datetime.now(timezone.utc)
|
|
154
|
+
sessions_to_remove = []
|
|
155
|
+
|
|
156
|
+
# Note: This is a synchronous method for background task compatibility
|
|
157
|
+
# In production, consider using an async background task instead
|
|
158
|
+
try:
|
|
159
|
+
# Try to acquire lock without blocking if we're in an async context
|
|
160
|
+
if self.lock.locked():
|
|
161
|
+
logger.debug("Session cleanup skipped: lock is held")
|
|
162
|
+
return 0 # Skip cleanup if lock is held
|
|
163
|
+
|
|
164
|
+
# Synchronous cleanup - safe for background threads
|
|
165
|
+
for session_id, session in list(self.sessions.items()):
|
|
166
|
+
if now > session.expires_at:
|
|
167
|
+
sessions_to_remove.append(session_id)
|
|
168
|
+
|
|
169
|
+
for session_id in sessions_to_remove:
|
|
170
|
+
self.sessions.pop(session_id, None)
|
|
171
|
+
|
|
172
|
+
if sessions_to_remove:
|
|
173
|
+
logger.info("Cleaned up %d expired sessions", len(sessions_to_remove))
|
|
174
|
+
|
|
175
|
+
return len(sessions_to_remove)
|
|
176
|
+
except Exception as e:
|
|
177
|
+
logger.exception("Error during session cleanup: %s", str(e))
|
|
178
|
+
return 0
|
ragbits/chat/auth/types.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
1
2
|
from datetime import datetime
|
|
2
3
|
from typing import Any
|
|
3
4
|
|
|
@@ -23,51 +24,87 @@ class UserCredentials(BaseModel):
|
|
|
23
24
|
|
|
24
25
|
|
|
25
26
|
class OAuth2Credentials(BaseModel):
|
|
26
|
-
"""Represents OAuth2 authentication data."""
|
|
27
|
+
"""Represents OAuth2 authentication data from Discord."""
|
|
27
28
|
|
|
28
29
|
access_token: str
|
|
29
30
|
token_type: str = "bearer"
|
|
30
|
-
refresh_token: str | None = None
|
|
31
|
-
expires_at: datetime | None = None
|
|
32
|
-
scope: str | None = None
|
|
33
31
|
|
|
34
32
|
|
|
35
|
-
|
|
36
|
-
"""Represents a JWT authentication jwt_token."""
|
|
37
|
-
|
|
38
|
-
access_token: str
|
|
39
|
-
token_type: str = "bearer"
|
|
40
|
-
expires_in: int # seconds until expiration
|
|
41
|
-
refresh_token: str | None = None
|
|
42
|
-
user: User
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
class CredentialsLoginRequest(BaseModel):
|
|
46
|
-
"""
|
|
47
|
-
Request body for user login
|
|
48
|
-
"""
|
|
49
|
-
|
|
50
|
-
username: str = Field(..., description="Username")
|
|
51
|
-
password: str = Field(..., description="Password")
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
LoginRequest = CredentialsLoginRequest
|
|
33
|
+
LoginRequest = UserCredentials
|
|
55
34
|
|
|
56
35
|
|
|
57
36
|
class LoginResponse(BaseModel):
|
|
58
37
|
"""
|
|
59
|
-
Response body for
|
|
38
|
+
Response body for login with session-based authentication.
|
|
39
|
+
|
|
40
|
+
The session ID is set as an HTTP-only cookie by the backend.
|
|
41
|
+
Frontend only receives user information.
|
|
60
42
|
"""
|
|
61
43
|
|
|
62
44
|
success: bool = Field(..., description="Whether login was successful")
|
|
63
45
|
user: User | None = Field(None, description="User information")
|
|
64
46
|
error_message: str | None = Field(None, description="Error message if login failed")
|
|
65
|
-
jwt_token: JWTToken | None = Field(..., description="Access jwt_token")
|
|
66
47
|
|
|
67
48
|
|
|
68
|
-
class
|
|
49
|
+
class OAuth2AuthorizeResponse(BaseModel):
|
|
69
50
|
"""
|
|
70
|
-
|
|
51
|
+
Response for OAuth2 authorization URL request
|
|
71
52
|
"""
|
|
72
53
|
|
|
73
|
-
|
|
54
|
+
authorize_url: str = Field(..., description="URL to redirect user to for OAuth2 authorization")
|
|
55
|
+
state: str = Field(..., description="State parameter for CSRF protection")
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class Session(BaseModel):
|
|
59
|
+
"""Represents a user session."""
|
|
60
|
+
|
|
61
|
+
session_id: str
|
|
62
|
+
user: User
|
|
63
|
+
provider: str
|
|
64
|
+
oauth_token: str # Provider's OAuth token
|
|
65
|
+
token_type: str
|
|
66
|
+
created_at: datetime
|
|
67
|
+
expires_at: datetime
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class SessionStore(ABC):
|
|
71
|
+
"""Abstract base class for session storage."""
|
|
72
|
+
|
|
73
|
+
@abstractmethod
|
|
74
|
+
async def create_session(self, session: Session) -> str:
|
|
75
|
+
"""
|
|
76
|
+
Create a new session.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
session: Session object to store
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
Session ID
|
|
83
|
+
"""
|
|
84
|
+
pass
|
|
85
|
+
|
|
86
|
+
@abstractmethod
|
|
87
|
+
async def get_session(self, session_id: str) -> Session | None:
|
|
88
|
+
"""
|
|
89
|
+
Retrieve a session by ID.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
session_id: Session ID to retrieve
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
Session object if found, None otherwise
|
|
96
|
+
"""
|
|
97
|
+
pass
|
|
98
|
+
|
|
99
|
+
@abstractmethod
|
|
100
|
+
async def delete_session(self, session_id: str) -> bool:
|
|
101
|
+
"""
|
|
102
|
+
Delete a session.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
session_id: Session ID to delete
|
|
106
|
+
|
|
107
|
+
Returns:
|
|
108
|
+
True if session was deleted, False if not found
|
|
109
|
+
"""
|
|
110
|
+
pass
|
ragbits/chat/interface/types.py
CHANGED
|
@@ -858,6 +858,18 @@ class AuthType(str, Enum):
|
|
|
858
858
|
"""Defines the available authentication types."""
|
|
859
859
|
|
|
860
860
|
CREDENTIALS = "credentials"
|
|
861
|
+
OAUTH2 = "oauth2"
|
|
862
|
+
|
|
863
|
+
|
|
864
|
+
class OAuth2ProviderConfig(BaseModel):
|
|
865
|
+
"""Configuration for an OAuth2 provider including visual configuration."""
|
|
866
|
+
|
|
867
|
+
name: str = Field(..., description="Provider name (e.g., 'discord')")
|
|
868
|
+
display_name: str | None = Field(None, description="Display name for the provider (e.g., 'Discord')")
|
|
869
|
+
color: str | None = Field(None, description="Brand color for the provider (e.g., '#5865F2')")
|
|
870
|
+
button_color: str | None = Field(None, description="Button background color (defaults to color)")
|
|
871
|
+
text_color: str | None = Field(None, description="Button text color (defaults to white)")
|
|
872
|
+
icon_svg: str | None = Field(None, description="SVG icon as string")
|
|
861
873
|
|
|
862
874
|
|
|
863
875
|
class AuthenticationConfig(BaseModel):
|
|
@@ -865,6 +877,9 @@ class AuthenticationConfig(BaseModel):
|
|
|
865
877
|
|
|
866
878
|
enabled: bool = Field(default=False, description="Enable/disable authentication")
|
|
867
879
|
auth_types: list[AuthType] = Field(default=[], description="List of available authentication types")
|
|
880
|
+
oauth2_providers: list[OAuth2ProviderConfig] = Field(
|
|
881
|
+
default_factory=list, description="List of available OAuth2 providers"
|
|
882
|
+
)
|
|
868
883
|
|
|
869
884
|
|
|
870
885
|
class ConfigResponse(BaseModel):
|
|
@@ -42,12 +42,11 @@ class RagbitsChatModelProvider:
|
|
|
42
42
|
|
|
43
43
|
try:
|
|
44
44
|
from ragbits.chat.auth.types import (
|
|
45
|
-
CredentialsLoginRequest,
|
|
46
|
-
JWTToken,
|
|
47
45
|
LoginRequest,
|
|
48
46
|
LoginResponse,
|
|
49
|
-
|
|
47
|
+
OAuth2AuthorizeResponse,
|
|
50
48
|
User,
|
|
49
|
+
UserCredentials,
|
|
51
50
|
)
|
|
52
51
|
from ragbits.chat.interface.forms import UserSettings
|
|
53
52
|
from ragbits.chat.interface.types import (
|
|
@@ -73,6 +72,7 @@ class RagbitsChatModelProvider:
|
|
|
73
72
|
MessageIdContent,
|
|
74
73
|
MessageRole,
|
|
75
74
|
MessageUsage,
|
|
75
|
+
OAuth2ProviderConfig,
|
|
76
76
|
Reference,
|
|
77
77
|
StateUpdate,
|
|
78
78
|
TextContent,
|
|
@@ -123,17 +123,17 @@ class RagbitsChatModelProvider:
|
|
|
123
123
|
# API response models
|
|
124
124
|
"ConfigResponse": ConfigResponse,
|
|
125
125
|
"FeedbackResponse": FeedbackResponse,
|
|
126
|
+
"OAuth2AuthorizeResponse": OAuth2AuthorizeResponse,
|
|
127
|
+
"OAuth2ProviderConfig": OAuth2ProviderConfig,
|
|
126
128
|
# API request models
|
|
127
129
|
"ChatRequest": ChatMessageRequest,
|
|
128
130
|
"FeedbackRequest": FeedbackRequest,
|
|
129
131
|
# Auth
|
|
130
132
|
"AuthType": AuthType,
|
|
131
133
|
"AuthenticationConfig": AuthenticationConfig,
|
|
132
|
-
"
|
|
133
|
-
"JWTToken": JWTToken,
|
|
134
|
+
"UserCredentials": UserCredentials,
|
|
134
135
|
"LoginRequest": LoginRequest,
|
|
135
136
|
"LoginResponse": LoginResponse,
|
|
136
|
-
"LogoutRequest": LogoutRequest,
|
|
137
137
|
"User": User,
|
|
138
138
|
}
|
|
139
139
|
|
|
@@ -169,7 +169,6 @@ class RagbitsChatModelProvider:
|
|
|
169
169
|
"ServerState",
|
|
170
170
|
"FeedbackItem",
|
|
171
171
|
"Image",
|
|
172
|
-
"JWTToken",
|
|
173
172
|
"User",
|
|
174
173
|
"MessageUsage",
|
|
175
174
|
"Task",
|
|
@@ -192,16 +191,19 @@ class RagbitsChatModelProvider:
|
|
|
192
191
|
"UserSettings",
|
|
193
192
|
"FeedbackConfig",
|
|
194
193
|
"AuthenticationConfig",
|
|
194
|
+
"OAuth2ProviderConfig",
|
|
195
195
|
],
|
|
196
196
|
"responses": [
|
|
197
197
|
"FeedbackResponse",
|
|
198
198
|
"ConfigResponse",
|
|
199
199
|
"LoginResponse",
|
|
200
|
+
"OAuth2AuthorizeResponse",
|
|
200
201
|
],
|
|
201
202
|
"requests": [
|
|
202
203
|
"ChatRequest",
|
|
203
204
|
"FeedbackRequest",
|
|
204
|
-
"
|
|
205
|
+
"UserCredentials",
|
|
206
|
+
"OAuth2Credentials",
|
|
205
207
|
"LoginRequest",
|
|
206
208
|
"LogoutRequest",
|
|
207
209
|
],
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
import{r as h,j as t,aW as i,aE as p,c as b,bo as f,bp as l,bq as d,br as j,bs as S}from"./index-BZLU40Mk.js";import{a as r}from"./authStore-BfGlL8rp.js";import{u as m}from"./useInitializeUserStore-DyHP7g8x.js";const v=h.createContext(null);function y({children:e}){const[a]=h.useState(()=>r);return t.jsx(v.Provider,{value:a,children:e})}function A(){const{logout:e,login:a,setHydrated:c}=i(r,x=>x),u=m(),n=p(),s=b("/api/user"),g=f();return h.useEffect(()=>{(async()=>{try{const o=await s.call();o?(a(o),u(o.user_id),g.pathname==="/login"&&n("/")):e()}catch(o){console.error("Failed to check session:",o),e(),n("/login")}finally{c()}})()},[]),null}function z({children:e}){const a=f(),c=i(r,s=>s.isAuthenticated),u=i(r,s=>s.hasHydrated),n=i(r,s=>s.logout);return u?a.pathname==="/login"?t.jsx(l,{baseUrl:d,auth:{credentials:"include"},children:e}):c?t.jsx(y,{children:t.jsx(l,{baseUrl:d,auth:{onUnauthorized:n,credentials:"include"},children:e})}):t.jsx(S,{to:"/login",replace:!0}):t.jsxs(l,{baseUrl:d,auth:{credentials:"include"},children:[t.jsx(A,{}),t.jsx(j,{})]})}export{z as default};
|