ragbits-chat 1.4.0.dev202512151244__py3-none-any.whl → 1.4.0.dev202601010248__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. ragbits/chat/api.py +364 -79
  2. ragbits/chat/auth/__init__.py +3 -1
  3. ragbits/chat/auth/backends.py +519 -81
  4. ragbits/chat/auth/base.py +8 -10
  5. ragbits/chat/auth/oauth2_providers.py +108 -0
  6. ragbits/chat/auth/provider_config.py +81 -0
  7. ragbits/chat/auth/session_store.py +178 -0
  8. ragbits/chat/auth/types.py +66 -29
  9. ragbits/chat/interface/types.py +15 -0
  10. ragbits/chat/providers/model_provider.py +10 -8
  11. ragbits/chat/ui-build/assets/AuthGuard-Bq7UOJ7y.js +1 -0
  12. ragbits/chat/ui-build/assets/{ChatHistory-Cp_DhrUx.js → ChatHistory-B2hLBYMJ.js} +2 -2
  13. ragbits/chat/ui-build/assets/{ChatOptionsForm-CNjzbIqN.js → ChatOptionsForm-bfNG8UIW.js} +1 -1
  14. ragbits/chat/ui-build/assets/CredentialsLogin-0g5-w2vR.js +1 -0
  15. ragbits/chat/ui-build/assets/{FeedbackForm-CmRSbYPS.js → FeedbackForm-oSbly5oN.js} +1 -1
  16. ragbits/chat/ui-build/assets/Login-DSW_CNFu.js +1 -0
  17. ragbits/chat/ui-build/assets/LogoutButton-BQE8NNsg.js +1 -0
  18. ragbits/chat/ui-build/assets/OAuth2Login-EmJ39PUe.js +2 -0
  19. ragbits/chat/ui-build/assets/ShareButton-B7DyIVH0.js +1 -0
  20. ragbits/chat/ui-build/assets/{UsageButton-B-N1J-sZ.js → UsageButton-BABA7a-w.js} +1 -1
  21. ragbits/chat/ui-build/assets/authStore-BfGlL8rp.js +1 -0
  22. ragbits/chat/ui-build/assets/{chunk-IGSAU2ZA-CsJAveMU.js → chunk-IGSAU2ZA-NXd1g0Qd.js} +1 -1
  23. ragbits/chat/ui-build/assets/{chunk-SSA7SXE4-BJI2Gxdq.js → chunk-SSA7SXE4-CJa0HuAU.js} +1 -1
  24. ragbits/chat/ui-build/assets/index-BZLU40Mk.js +83 -0
  25. ragbits/chat/ui-build/assets/index-Be0kkf3d.js +24 -0
  26. ragbits/chat/ui-build/assets/index-Bvn9K6h_.js +1 -0
  27. ragbits/chat/ui-build/assets/{index-v15bx9Do.js → index-Ceq7Rkzy.js} +1 -1
  28. ragbits/chat/ui-build/assets/index-ClAYkAiv.css +1 -0
  29. ragbits/chat/ui-build/assets/useInitializeUserStore-DyHP7g8x.js +1 -0
  30. ragbits/chat/ui-build/assets/{useMenuTriggerState-CTz3KfPq.js → useMenuTriggerState-SaFmATkk.js} +1 -1
  31. ragbits/chat/ui-build/assets/{useSelectableItem-DK6eABKK.js → useSelectableItem-DhuFnc0W.js} +1 -1
  32. ragbits/chat/ui-build/index.html +2 -2
  33. {ragbits_chat-1.4.0.dev202512151244.dist-info → ragbits_chat-1.4.0.dev202601010248.dist-info}/METADATA +2 -2
  34. ragbits_chat-1.4.0.dev202601010248.dist-info/RECORD +58 -0
  35. ragbits/chat/ui-build/assets/AuthGuard-B326tmZN.js +0 -1
  36. ragbits/chat/ui-build/assets/Login-Djq6QJ18.js +0 -1
  37. ragbits/chat/ui-build/assets/LogoutButton-Cn2L63Hk.js +0 -1
  38. ragbits/chat/ui-build/assets/ShareButton-lYj0v67r.js +0 -1
  39. ragbits/chat/ui-build/assets/authStore-DATNN-ps.js +0 -1
  40. ragbits/chat/ui-build/assets/index-B3hlerKe.js +0 -131
  41. ragbits/chat/ui-build/assets/index-B7bSwAmw.js +0 -32
  42. ragbits/chat/ui-build/assets/index-C_JcEI3R.js +0 -1
  43. ragbits/chat/ui-build/assets/index-CmsICuOz.css +0 -1
  44. ragbits_chat-1.4.0.dev202512151244.dist-info/RECORD +0 -52
  45. {ragbits_chat-1.4.0.dev202512151244.dist-info → ragbits_chat-1.4.0.dev202601010248.dist-info}/WHEEL +0 -0
ragbits/chat/auth/base.py CHANGED
@@ -7,14 +7,13 @@ from pydantic import BaseModel
7
7
  from ragbits.core.options import Options
8
8
  from ragbits.core.utils.config_handling import ConfigurableComponent
9
9
 
10
- from .types import JWTToken, OAuth2Credentials, User, UserCredentials
10
+ from .types import OAuth2Credentials, User, UserCredentials
11
11
 
12
12
 
13
13
  class AuthOptions(Options):
14
14
  """Options for authentication backends."""
15
15
 
16
- jwt_algorithm: str = "HS256"
17
- token_expiry_minutes: int = 24 * 60
16
+ pass
18
17
 
19
18
 
20
19
  class AuthenticationResponse(BaseModel):
@@ -22,7 +21,7 @@ class AuthenticationResponse(BaseModel):
22
21
 
23
22
  success: bool
24
23
  user: User | None = None
25
- jwt_token: JWTToken | None = None
24
+ session_id: str | None = None
26
25
  error_message: str | None = None
27
26
 
28
27
 
@@ -60,26 +59,25 @@ class AuthenticationBackend(ConfigurableComponent[AuthOptions], ABC):
60
59
  pass
61
60
 
62
61
  @abstractmethod
63
- async def validate_token(self, token: str) -> AuthenticationResponse:
62
+ async def validate_session(self, session_id: str) -> AuthenticationResponse:
64
63
  """
65
- Validate a JWT jwt_token.
64
+ Validate a session.
66
65
 
67
66
  Args:
68
- token: The JWT jwt_token to validate
67
+ session_id: The session ID to validate
69
68
 
70
69
  Returns:
71
70
  AuthenticationResult with user if valid
72
71
  """
73
- # Default implementation for backward compatibility
74
72
  pass
75
73
 
76
74
  @abstractmethod
77
- async def revoke_token(self, token: str) -> bool:
75
+ async def revoke_session(self, session_id: str) -> bool:
78
76
  """
79
77
  Revoke/logout a session.
80
78
 
81
79
  Args:
82
- token: The jwt_token to revoke
80
+ session_id: The session ID to revoke
83
81
 
84
82
  Returns:
85
83
  True if successfully revoked
@@ -0,0 +1,108 @@
1
+ """OAuth2 provider implementations for various authentication services."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import Any
5
+
6
+ from ragbits.chat.auth.types import User
7
+
8
+
9
+ class OAuth2Provider(ABC):
10
+ """Abstract base class for OAuth2 providers."""
11
+
12
+ @property
13
+ @abstractmethod
14
+ def name(self) -> str:
15
+ """Provider name (e.g., 'discord')."""
16
+ pass
17
+
18
+ @property
19
+ @abstractmethod
20
+ def display_name(self) -> str:
21
+ """Human-readable provider name (e.g., 'Discord')."""
22
+ pass
23
+
24
+ @property
25
+ @abstractmethod
26
+ def authorize_url(self) -> str:
27
+ """OAuth2 authorization endpoint."""
28
+ pass
29
+
30
+ @property
31
+ @abstractmethod
32
+ def token_url(self) -> str:
33
+ """OAuth2 token exchange endpoint."""
34
+ pass
35
+
36
+ @property
37
+ @abstractmethod
38
+ def user_info_url(self) -> str:
39
+ """User info endpoint."""
40
+ pass
41
+
42
+ @property
43
+ @abstractmethod
44
+ def scope(self) -> str:
45
+ """OAuth2 scopes to request."""
46
+ pass
47
+
48
+ @abstractmethod
49
+ def create_user_from_data(self, user_data: dict[str, Any]) -> User:
50
+ """
51
+ Create a User object from provider-specific user data.
52
+
53
+ Args:
54
+ user_data: Raw user data from the provider
55
+
56
+ Returns:
57
+ User object
58
+ """
59
+ pass
60
+
61
+
62
+ class DiscordOAuth2Provider(OAuth2Provider):
63
+ """Discord OAuth2 provider implementation."""
64
+
65
+ @property
66
+ def name(self) -> str:
67
+ """Return the provider name."""
68
+ return "discord"
69
+
70
+ @property
71
+ def display_name(self) -> str:
72
+ """Return the human-readable provider name."""
73
+ return "Discord"
74
+
75
+ @property
76
+ def authorize_url(self) -> str:
77
+ """Return the OAuth2 authorization URL."""
78
+ return "https://discord.com/api/oauth2/authorize"
79
+
80
+ @property
81
+ def token_url(self) -> str:
82
+ """Return the OAuth2 token exchange URL."""
83
+ return "https://discord.com/api/oauth2/token"
84
+
85
+ @property
86
+ def user_info_url(self) -> str:
87
+ """Return the user info API URL."""
88
+ return "https://discord.com/api/users/@me"
89
+
90
+ @property
91
+ def scope(self) -> str:
92
+ """Return the OAuth2 scope to request."""
93
+ return "identify email"
94
+
95
+ def create_user_from_data(self, user_data: dict[str, Any]) -> User: # noqa: PLR6301
96
+ """Create User object from Discord data."""
97
+ return User(
98
+ user_id=f"discord_{user_data['id']}",
99
+ username=user_data.get("username", ""),
100
+ email=user_data.get("email"),
101
+ full_name=user_data.get("global_name"),
102
+ roles=["user"],
103
+ metadata={
104
+ "provider": "discord",
105
+ "avatar": user_data.get("avatar"),
106
+ "discriminator": user_data.get("discriminator"),
107
+ },
108
+ )
@@ -0,0 +1,81 @@
1
+ """OAuth2 provider visual configuration (icons, colors, etc.)."""
2
+
3
+ from typing import Any
4
+
5
+
6
+ class OAuth2ProviderVisualConfig:
7
+ """Visual configuration for OAuth2 providers (brand colors, icons, etc.)."""
8
+
9
+ def __init__(
10
+ self,
11
+ name: str,
12
+ display_name: str,
13
+ color: str,
14
+ button_color: str | None = None,
15
+ text_color: str = "#FFFFFF",
16
+ icon_svg: str | None = None,
17
+ ):
18
+ """
19
+ Initialize provider visual configuration.
20
+
21
+ Args:
22
+ name: Provider identifier (e.g., 'google', 'discord')
23
+ display_name: Human-readable provider name
24
+ color: Brand color for the provider
25
+ button_color: Optional button background color (defaults to color)
26
+ text_color: Button text color (defaults to white)
27
+ icon_svg: Optional SVG icon as string
28
+ """
29
+ self.name = name
30
+ self.display_name = display_name
31
+ self.color = color
32
+ self.button_color = button_color or color
33
+ self.text_color = text_color
34
+ self.icon_svg = icon_svg
35
+
36
+ def to_dict(self) -> dict[str, Any]:
37
+ """Convert to dictionary for API serialization."""
38
+ return {
39
+ "name": self.name,
40
+ "display_name": self.display_name,
41
+ "color": self.color,
42
+ "button_color": self.button_color,
43
+ "text_color": self.text_color,
44
+ "icon_svg": self.icon_svg,
45
+ }
46
+
47
+
48
+ # Provider visual configurations
49
+ PROVIDER_CONFIGS: dict[str, OAuth2ProviderVisualConfig] = {
50
+ "discord": OAuth2ProviderVisualConfig(
51
+ name="discord",
52
+ display_name="Discord",
53
+ color="#5865F2",
54
+ icon_svg='<svg width="20" height="20" viewBox="0 0 71 55" fill="none" xmlns="http://www.w3.org/2000/svg"><g clipPath="url(#clip0)"><path d="M60.1045 4.8978C55.5792 2.8214 50.7265 1.2916 45.6527 0.41542C45.5603 0.39851 45.468 0.440769 45.4204 0.525289C44.7963 1.6353 44.105 3.0834 43.6209 4.2216C38.1637 3.4046 32.7345 3.4046 27.3892 4.2216C26.905 3.0581 26.1886 1.6353 25.5617 0.525289C25.5141 0.443589 25.4218 0.40133 25.3294 0.41542C20.2584 1.2888 15.4057 2.8186 10.8776 4.8978C10.8384 4.9147 10.8048 4.9429 10.7825 4.9795C1.57795 18.7309 -0.943561 32.1443 0.293408 45.3914C0.299005 45.4562 0.335386 45.5182 0.385761 45.5576C6.45866 50.0174 12.3413 52.7249 18.1147 54.5195C18.2071 54.5477 18.305 54.5139 18.3638 54.4378C19.7295 52.5728 20.9469 50.6063 21.9907 48.5383C22.0523 48.4172 21.9935 48.2735 21.8676 48.2256C19.9366 47.4931 18.0979 46.6 16.3292 45.5858C16.1893 45.5041 16.1781 45.304 16.3068 45.2082C16.679 44.9293 17.0513 44.6391 17.4067 44.3461C17.471 44.2926 17.5606 44.2813 17.6362 44.3151C29.2558 49.6202 41.8354 49.6202 53.3179 44.3151C53.3935 44.2785 53.4831 44.2898 53.5502 44.3433C53.9057 44.6363 54.2779 44.9293 54.6529 45.2082C54.7816 45.304 54.7732 45.5041 54.6333 45.5858C52.8646 46.6197 51.0259 47.4931 49.0921 48.2228C48.9662 48.2707 48.9102 48.4172 48.9718 48.5383C50.038 50.6034 51.2554 52.5699 52.5959 54.435C52.6519 54.5139 52.7526 54.5477 52.845 54.5195C58.6464 52.7249 64.529 50.0174 70.6019 45.5576C70.6551 45.5182 70.6887 45.459 70.6943 45.3942C72.1747 30.0791 68.2147 16.7757 60.1968 4.9823C60.1772 4.9429 60.1437 4.9147 60.1045 4.8978ZM23.7259 37.3253C20.2276 37.3253 17.3451 34.1136 17.3451 30.1693C17.3451 26.225 20.1717 23.0133 23.7259 23.0133C27.308 23.0133 30.1626 26.2532 30.1066 30.1693C30.1066 34.1136 27.28 37.3253 23.7259 37.3253ZM47.3178 37.3253C43.8196 37.3253 40.9371 34.1136 40.9371 30.1693C40.9371 26.225 43.7636 23.0133 47.3178 23.0133C50.9 23.0133 53.7545 26.2532 53.6986 30.1693C53.6986 34.1136 50.9 37.3253 47.3178 37.3253Z" fill="currentColor"/></g><defs><clipPath id="clip0"><rect width="71" height="55" fill="white"/></clipPath></defs></svg>', # noqa: E501
55
+ )
56
+ }
57
+
58
+
59
+ def get_provider_visual_config(provider_name: str) -> OAuth2ProviderVisualConfig:
60
+ """
61
+ Get visual configuration for a provider.
62
+
63
+ Args:
64
+ provider_name: Provider identifier
65
+
66
+ Returns:
67
+ Visual configuration for the provider, or default config if not found
68
+ """
69
+ config = PROVIDER_CONFIGS.get(provider_name.lower())
70
+ if config:
71
+ return config
72
+
73
+ # Return default configuration for unknown providers
74
+ capitalized_name = provider_name.capitalize()
75
+ return OAuth2ProviderVisualConfig(
76
+ name=provider_name,
77
+ display_name=capitalized_name,
78
+ color="#6B7280",
79
+ text_color="#FFFFFF",
80
+ icon_svg='<svg width="20" height="20" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg"><path d="M12 2C6.48 2 2 6.48 2 12s4.48 10 10 10 10-4.48 10-10S17.52 2 12 2zm0 18c-4.41 0-8-3.59-8-8s3.59-8 8-8 8 3.59 8 8-3.59 8-8 8zm-1-13h2v6h-2zm0 8h2v2h-2z" fill="currentColor"/></svg>', # noqa: E501
81
+ )
@@ -0,0 +1,178 @@
1
+ """Session storage implementations."""
2
+
3
+ import asyncio
4
+ import logging
5
+ import secrets
6
+ from datetime import datetime, timezone
7
+
8
+ from ragbits.chat.auth.types import Session, SessionStore
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ # Minimum length for session ID truncation in logs (for readability while maintaining some privacy)
13
+ SESSION_ID_LOG_LENGTH = 8
14
+
15
+
16
+ class InMemorySessionStore(SessionStore):
17
+ """In-memory session store implementation."""
18
+
19
+ def __init__(self) -> None:
20
+ """Initialize the in-memory session store."""
21
+ self.sessions: dict[str, Session] = {}
22
+ self.lock = asyncio.Lock()
23
+
24
+ async def create_session(self, session: Session) -> str:
25
+ """
26
+ Create a new session.
27
+
28
+ Args:
29
+ session: Session object to store
30
+
31
+ Returns:
32
+ Session ID
33
+ """
34
+ session_id = secrets.token_urlsafe(32)
35
+ session.session_id = session_id
36
+
37
+ async with self.lock:
38
+ self.sessions[session_id] = session
39
+
40
+ logger.debug("Created session for user: %s", session.user.username)
41
+ return session_id
42
+
43
+ async def get_session(self, session_id: str) -> Session | None:
44
+ """
45
+ Retrieve a session by ID.
46
+
47
+ Args:
48
+ session_id: Session ID to retrieve
49
+
50
+ Returns:
51
+ Session object if found and not expired, None otherwise
52
+ """
53
+ async with self.lock:
54
+ session = self.sessions.get(session_id)
55
+
56
+ if not session:
57
+ logger.debug(
58
+ "Session not found: %s...",
59
+ session_id[:SESSION_ID_LOG_LENGTH] if len(session_id) >= SESSION_ID_LOG_LENGTH else session_id,
60
+ )
61
+ return None
62
+
63
+ # Check if session is expired
64
+ if datetime.now(timezone.utc) > session.expires_at:
65
+ # Session expired, delete it
66
+ logger.debug(
67
+ "Session expired, removing: %s...",
68
+ session_id[:SESSION_ID_LOG_LENGTH] if len(session_id) >= SESSION_ID_LOG_LENGTH else session_id,
69
+ )
70
+ await self.delete_session(session_id)
71
+ return None
72
+
73
+ return session
74
+
75
+ async def delete_session(self, session_id: str) -> bool:
76
+ """
77
+ Delete a session.
78
+
79
+ Args:
80
+ session_id: Session ID to delete
81
+
82
+ Returns:
83
+ True if session was deleted, False if not found
84
+ """
85
+ async with self.lock:
86
+ if session_id in self.sessions:
87
+ del self.sessions[session_id]
88
+ logger.debug(
89
+ "Session deleted: %s...",
90
+ session_id[:SESSION_ID_LOG_LENGTH] if len(session_id) >= SESSION_ID_LOG_LENGTH else session_id,
91
+ )
92
+ return True
93
+
94
+ logger.debug(
95
+ "Session not found for deletion: %s...",
96
+ session_id[:SESSION_ID_LOG_LENGTH] if len(session_id) >= SESSION_ID_LOG_LENGTH else session_id,
97
+ )
98
+ return False
99
+
100
+ def cleanup_expired_sessions(self) -> int:
101
+ """
102
+ Remove expired sessions from storage.
103
+
104
+ This method is synchronous for compatibility with background task schedulers.
105
+ It acquires the lock internally and skips cleanup if the lock is held.
106
+
107
+ Returns:
108
+ Number of sessions removed
109
+
110
+ Example:
111
+ To schedule periodic cleanup, you can use APScheduler or a similar library:
112
+
113
+ ```python
114
+ from apscheduler.schedulers.asyncio import AsyncIOScheduler
115
+ from ragbits.chat.auth.session_store import InMemorySessionStore
116
+
117
+ session_store = InMemorySessionStore()
118
+ scheduler = AsyncIOScheduler()
119
+
120
+ # Run cleanup every hour
121
+ scheduler.add_job(
122
+ session_store.cleanup_expired_sessions,
123
+ "interval",
124
+ hours=1,
125
+ id="cleanup_sessions",
126
+ )
127
+ scheduler.start()
128
+ ```
129
+
130
+ Or using FastAPI's lifespan with asyncio:
131
+
132
+ ```python
133
+ import asyncio
134
+ from contextlib import asynccontextmanager
135
+ from fastapi import FastAPI
136
+
137
+
138
+ @asynccontextmanager
139
+ async def lifespan(app: FastAPI):
140
+ async def cleanup_task():
141
+ while True:
142
+ await asyncio.sleep(3600) # Run every hour
143
+ removed = session_store.cleanup_expired_sessions()
144
+ if removed > 0:
145
+ print(f"Cleaned up {removed} expired sessions")
146
+
147
+ task = asyncio.create_task(cleanup_task())
148
+
149
+ Yield:
150
+ task.cancel()
151
+ ```
152
+ """
153
+ now = datetime.now(timezone.utc)
154
+ sessions_to_remove = []
155
+
156
+ # Note: This is a synchronous method for background task compatibility
157
+ # In production, consider using an async background task instead
158
+ try:
159
+ # Try to acquire lock without blocking if we're in an async context
160
+ if self.lock.locked():
161
+ logger.debug("Session cleanup skipped: lock is held")
162
+ return 0 # Skip cleanup if lock is held
163
+
164
+ # Synchronous cleanup - safe for background threads
165
+ for session_id, session in list(self.sessions.items()):
166
+ if now > session.expires_at:
167
+ sessions_to_remove.append(session_id)
168
+
169
+ for session_id in sessions_to_remove:
170
+ self.sessions.pop(session_id, None)
171
+
172
+ if sessions_to_remove:
173
+ logger.info("Cleaned up %d expired sessions", len(sessions_to_remove))
174
+
175
+ return len(sessions_to_remove)
176
+ except Exception as e:
177
+ logger.exception("Error during session cleanup: %s", str(e))
178
+ return 0
@@ -1,3 +1,4 @@
1
+ from abc import ABC, abstractmethod
1
2
  from datetime import datetime
2
3
  from typing import Any
3
4
 
@@ -23,51 +24,87 @@ class UserCredentials(BaseModel):
23
24
 
24
25
 
25
26
  class OAuth2Credentials(BaseModel):
26
- """Represents OAuth2 authentication data."""
27
+ """Represents OAuth2 authentication data from Discord."""
27
28
 
28
29
  access_token: str
29
30
  token_type: str = "bearer"
30
- refresh_token: str | None = None
31
- expires_at: datetime | None = None
32
- scope: str | None = None
33
31
 
34
32
 
35
- class JWTToken(BaseModel):
36
- """Represents a JWT authentication jwt_token."""
37
-
38
- access_token: str
39
- token_type: str = "bearer"
40
- expires_in: int # seconds until expiration
41
- refresh_token: str | None = None
42
- user: User
43
-
44
-
45
- class CredentialsLoginRequest(BaseModel):
46
- """
47
- Request body for user login
48
- """
49
-
50
- username: str = Field(..., description="Username")
51
- password: str = Field(..., description="Password")
52
-
53
-
54
- LoginRequest = CredentialsLoginRequest
33
+ LoginRequest = UserCredentials
55
34
 
56
35
 
57
36
  class LoginResponse(BaseModel):
58
37
  """
59
- Response body for successful login
38
+ Response body for login with session-based authentication.
39
+
40
+ The session ID is set as an HTTP-only cookie by the backend.
41
+ Frontend only receives user information.
60
42
  """
61
43
 
62
44
  success: bool = Field(..., description="Whether login was successful")
63
45
  user: User | None = Field(None, description="User information")
64
46
  error_message: str | None = Field(None, description="Error message if login failed")
65
- jwt_token: JWTToken | None = Field(..., description="Access jwt_token")
66
47
 
67
48
 
68
- class LogoutRequest(BaseModel):
49
+ class OAuth2AuthorizeResponse(BaseModel):
69
50
  """
70
- Request body for user logout
51
+ Response for OAuth2 authorization URL request
71
52
  """
72
53
 
73
- token: str = Field(..., description="Session ID to logout")
54
+ authorize_url: str = Field(..., description="URL to redirect user to for OAuth2 authorization")
55
+ state: str = Field(..., description="State parameter for CSRF protection")
56
+
57
+
58
+ class Session(BaseModel):
59
+ """Represents a user session."""
60
+
61
+ session_id: str
62
+ user: User
63
+ provider: str
64
+ oauth_token: str # Provider's OAuth token
65
+ token_type: str
66
+ created_at: datetime
67
+ expires_at: datetime
68
+
69
+
70
+ class SessionStore(ABC):
71
+ """Abstract base class for session storage."""
72
+
73
+ @abstractmethod
74
+ async def create_session(self, session: Session) -> str:
75
+ """
76
+ Create a new session.
77
+
78
+ Args:
79
+ session: Session object to store
80
+
81
+ Returns:
82
+ Session ID
83
+ """
84
+ pass
85
+
86
+ @abstractmethod
87
+ async def get_session(self, session_id: str) -> Session | None:
88
+ """
89
+ Retrieve a session by ID.
90
+
91
+ Args:
92
+ session_id: Session ID to retrieve
93
+
94
+ Returns:
95
+ Session object if found, None otherwise
96
+ """
97
+ pass
98
+
99
+ @abstractmethod
100
+ async def delete_session(self, session_id: str) -> bool:
101
+ """
102
+ Delete a session.
103
+
104
+ Args:
105
+ session_id: Session ID to delete
106
+
107
+ Returns:
108
+ True if session was deleted, False if not found
109
+ """
110
+ pass
@@ -858,6 +858,18 @@ class AuthType(str, Enum):
858
858
  """Defines the available authentication types."""
859
859
 
860
860
  CREDENTIALS = "credentials"
861
+ OAUTH2 = "oauth2"
862
+
863
+
864
+ class OAuth2ProviderConfig(BaseModel):
865
+ """Configuration for an OAuth2 provider including visual configuration."""
866
+
867
+ name: str = Field(..., description="Provider name (e.g., 'discord')")
868
+ display_name: str | None = Field(None, description="Display name for the provider (e.g., 'Discord')")
869
+ color: str | None = Field(None, description="Brand color for the provider (e.g., '#5865F2')")
870
+ button_color: str | None = Field(None, description="Button background color (defaults to color)")
871
+ text_color: str | None = Field(None, description="Button text color (defaults to white)")
872
+ icon_svg: str | None = Field(None, description="SVG icon as string")
861
873
 
862
874
 
863
875
  class AuthenticationConfig(BaseModel):
@@ -865,6 +877,9 @@ class AuthenticationConfig(BaseModel):
865
877
 
866
878
  enabled: bool = Field(default=False, description="Enable/disable authentication")
867
879
  auth_types: list[AuthType] = Field(default=[], description="List of available authentication types")
880
+ oauth2_providers: list[OAuth2ProviderConfig] = Field(
881
+ default_factory=list, description="List of available OAuth2 providers"
882
+ )
868
883
 
869
884
 
870
885
  class ConfigResponse(BaseModel):
@@ -42,12 +42,11 @@ class RagbitsChatModelProvider:
42
42
 
43
43
  try:
44
44
  from ragbits.chat.auth.types import (
45
- CredentialsLoginRequest,
46
- JWTToken,
47
45
  LoginRequest,
48
46
  LoginResponse,
49
- LogoutRequest,
47
+ OAuth2AuthorizeResponse,
50
48
  User,
49
+ UserCredentials,
51
50
  )
52
51
  from ragbits.chat.interface.forms import UserSettings
53
52
  from ragbits.chat.interface.types import (
@@ -73,6 +72,7 @@ class RagbitsChatModelProvider:
73
72
  MessageIdContent,
74
73
  MessageRole,
75
74
  MessageUsage,
75
+ OAuth2ProviderConfig,
76
76
  Reference,
77
77
  StateUpdate,
78
78
  TextContent,
@@ -123,17 +123,17 @@ class RagbitsChatModelProvider:
123
123
  # API response models
124
124
  "ConfigResponse": ConfigResponse,
125
125
  "FeedbackResponse": FeedbackResponse,
126
+ "OAuth2AuthorizeResponse": OAuth2AuthorizeResponse,
127
+ "OAuth2ProviderConfig": OAuth2ProviderConfig,
126
128
  # API request models
127
129
  "ChatRequest": ChatMessageRequest,
128
130
  "FeedbackRequest": FeedbackRequest,
129
131
  # Auth
130
132
  "AuthType": AuthType,
131
133
  "AuthenticationConfig": AuthenticationConfig,
132
- "CredentialsLoginRequest": CredentialsLoginRequest,
133
- "JWTToken": JWTToken,
134
+ "UserCredentials": UserCredentials,
134
135
  "LoginRequest": LoginRequest,
135
136
  "LoginResponse": LoginResponse,
136
- "LogoutRequest": LogoutRequest,
137
137
  "User": User,
138
138
  }
139
139
 
@@ -169,7 +169,6 @@ class RagbitsChatModelProvider:
169
169
  "ServerState",
170
170
  "FeedbackItem",
171
171
  "Image",
172
- "JWTToken",
173
172
  "User",
174
173
  "MessageUsage",
175
174
  "Task",
@@ -192,16 +191,19 @@ class RagbitsChatModelProvider:
192
191
  "UserSettings",
193
192
  "FeedbackConfig",
194
193
  "AuthenticationConfig",
194
+ "OAuth2ProviderConfig",
195
195
  ],
196
196
  "responses": [
197
197
  "FeedbackResponse",
198
198
  "ConfigResponse",
199
199
  "LoginResponse",
200
+ "OAuth2AuthorizeResponse",
200
201
  ],
201
202
  "requests": [
202
203
  "ChatRequest",
203
204
  "FeedbackRequest",
204
- "CredentialsLoginRequest",
205
+ "UserCredentials",
206
+ "OAuth2Credentials",
205
207
  "LoginRequest",
206
208
  "LogoutRequest",
207
209
  ],
@@ -0,0 +1 @@
1
+ import{r as h,j as t,aW as i,aE as p,c as b,bo as f,bp as l,bq as d,br as j,bs as S}from"./index-BZLU40Mk.js";import{a as r}from"./authStore-BfGlL8rp.js";import{u as m}from"./useInitializeUserStore-DyHP7g8x.js";const v=h.createContext(null);function y({children:e}){const[a]=h.useState(()=>r);return t.jsx(v.Provider,{value:a,children:e})}function A(){const{logout:e,login:a,setHydrated:c}=i(r,x=>x),u=m(),n=p(),s=b("/api/user"),g=f();return h.useEffect(()=>{(async()=>{try{const o=await s.call();o?(a(o),u(o.user_id),g.pathname==="/login"&&n("/")):e()}catch(o){console.error("Failed to check session:",o),e(),n("/login")}finally{c()}})()},[]),null}function z({children:e}){const a=f(),c=i(r,s=>s.isAuthenticated),u=i(r,s=>s.hasHydrated),n=i(r,s=>s.logout);return u?a.pathname==="/login"?t.jsx(l,{baseUrl:d,auth:{credentials:"include"},children:e}):c?t.jsx(y,{children:t.jsx(l,{baseUrl:d,auth:{onUnauthorized:n,credentials:"include"},children:e})}):t.jsx(S,{to:"/login",replace:!0}):t.jsxs(l,{baseUrl:d,auth:{credentials:"include"},children:[t.jsx(A,{}),t.jsx(j,{})]})}export{z as default};