letta-nightly 0.9.0.dev20250725104508__py3-none-any.whl → 0.9.1.dev20250727063635__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. letta/__init__.py +1 -1
  2. letta/agents/base_agent.py +1 -1
  3. letta/agents/letta_agent.py +6 -0
  4. letta/helpers/datetime_helpers.py +1 -1
  5. letta/helpers/json_helpers.py +1 -1
  6. letta/orm/agent.py +2 -3
  7. letta/orm/agents_tags.py +1 -0
  8. letta/orm/block.py +2 -2
  9. letta/orm/group.py +2 -2
  10. letta/orm/identity.py +3 -4
  11. letta/orm/mcp_oauth.py +62 -0
  12. letta/orm/step.py +2 -4
  13. letta/schemas/agent_file.py +31 -5
  14. letta/schemas/block.py +3 -0
  15. letta/schemas/enums.py +4 -0
  16. letta/schemas/group.py +3 -0
  17. letta/schemas/mcp.py +70 -0
  18. letta/schemas/memory.py +35 -0
  19. letta/schemas/message.py +98 -91
  20. letta/schemas/providers/openai.py +1 -1
  21. letta/server/rest_api/app.py +19 -21
  22. letta/server/rest_api/middleware/__init__.py +4 -0
  23. letta/server/rest_api/middleware/check_password.py +24 -0
  24. letta/server/rest_api/middleware/profiler_context.py +25 -0
  25. letta/server/rest_api/routers/v1/blocks.py +2 -0
  26. letta/server/rest_api/routers/v1/groups.py +1 -1
  27. letta/server/rest_api/routers/v1/sources.py +26 -0
  28. letta/server/rest_api/routers/v1/tools.py +224 -23
  29. letta/services/agent_manager.py +15 -9
  30. letta/services/agent_serialization_manager.py +84 -3
  31. letta/services/block_manager.py +4 -0
  32. letta/services/file_manager.py +23 -13
  33. letta/services/file_processor/file_processor.py +12 -10
  34. letta/services/mcp/base_client.py +20 -28
  35. letta/services/mcp/oauth_utils.py +433 -0
  36. letta/services/mcp/sse_client.py +12 -1
  37. letta/services/mcp/streamable_http_client.py +17 -5
  38. letta/services/mcp/types.py +9 -0
  39. letta/services/mcp_manager.py +304 -42
  40. letta/services/provider_manager.py +2 -2
  41. letta/services/tool_executor/tool_executor.py +6 -2
  42. letta/services/tool_manager.py +8 -4
  43. letta/services/tool_sandbox/base.py +3 -3
  44. letta/services/tool_sandbox/e2b_sandbox.py +1 -1
  45. letta/services/tool_sandbox/local_sandbox.py +16 -9
  46. letta/settings.py +11 -1
  47. letta/system.py +1 -1
  48. letta/templates/template_helper.py +25 -1
  49. letta/utils.py +19 -35
  50. {letta_nightly-0.9.0.dev20250725104508.dist-info → letta_nightly-0.9.1.dev20250727063635.dist-info}/METADATA +3 -2
  51. {letta_nightly-0.9.0.dev20250725104508.dist-info → letta_nightly-0.9.1.dev20250727063635.dist-info}/RECORD +54 -49
  52. {letta_nightly-0.9.0.dev20250725104508.dist-info → letta_nightly-0.9.1.dev20250727063635.dist-info}/LICENSE +0 -0
  53. {letta_nightly-0.9.0.dev20250725104508.dist-info → letta_nightly-0.9.1.dev20250727063635.dist-info}/WHEEL +0 -0
  54. {letta_nightly-0.9.0.dev20250725104508.dist-info → letta_nightly-0.9.1.dev20250727063635.dist-info}/entry_points.txt +0 -0
@@ -1,9 +1,9 @@
1
- import asyncio
2
1
  from contextlib import AsyncExitStack
3
2
  from typing import Optional, Tuple
4
3
 
5
4
  from mcp import ClientSession
6
5
  from mcp import Tool as MCPTool
6
+ from mcp.client.auth import OAuthClientProvider
7
7
  from mcp.types import TextContent
8
8
 
9
9
  from letta.functions.mcp_client.types import BaseServerConfig
@@ -14,14 +14,12 @@ logger = get_logger(__name__)
14
14
 
15
15
  # TODO: Get rid of Async prefix on this class name once we deprecate old sync code
16
16
  class AsyncBaseMCPClient:
17
- def __init__(self, server_config: BaseServerConfig):
17
+ def __init__(self, server_config: BaseServerConfig, oauth_provider: Optional[OAuthClientProvider] = None):
18
18
  self.server_config = server_config
19
+ self.oauth_provider = oauth_provider
19
20
  self.exit_stack = AsyncExitStack()
20
21
  self.session: Optional[ClientSession] = None
21
22
  self.initialized = False
22
- # Track the task that created this client
23
- self._creation_task = asyncio.current_task()
24
- self._cleanup_queue = asyncio.Queue(maxsize=1)
25
23
 
26
24
  async def connect_to_server(self):
27
25
  try:
@@ -48,9 +46,25 @@ class AsyncBaseMCPClient:
48
46
  async def _initialize_connection(self, server_config: BaseServerConfig) -> None:
49
47
  raise NotImplementedError("Subclasses must implement _initialize_connection")
50
48
 
51
- async def list_tools(self) -> list[MCPTool]:
49
+ async def list_tools(self, serialize: bool = False) -> list[MCPTool]:
52
50
  self._check_initialized()
53
51
  response = await self.session.list_tools()
52
+ if serialize:
53
+ serializable_tools = []
54
+ for tool in response.tools:
55
+ if hasattr(tool, "model_dump"):
56
+ # Pydantic model - use model_dump
57
+ serializable_tools.append(tool.model_dump())
58
+ elif hasattr(tool, "dict"):
59
+ # Older Pydantic model - use dict()
60
+ serializable_tools.append(tool.dict())
61
+ elif hasattr(tool, "__dict__"):
62
+ # Regular object - use __dict__
63
+ serializable_tools.append(tool.__dict__)
64
+ else:
65
+ # Fallback - convert to string
66
+ serializable_tools.append(str(tool))
67
+ return serializable_tools
54
68
  return response.tools
55
69
 
56
70
  async def execute_tool(self, tool_name: str, tool_args: dict) -> Tuple[str, bool]:
@@ -79,29 +93,7 @@ class AsyncBaseMCPClient:
79
93
 
80
94
  # TODO: still hitting some async errors for voice agents, need to fix
81
95
  async def cleanup(self):
82
- """Clean up resources - ensure this runs in the same task"""
83
- if hasattr(self, "_cleanup_task"):
84
- # If we're in a different task, schedule cleanup in original task
85
- current_task = asyncio.current_task()
86
- if current_task != self._creation_task:
87
- # Create a future to signal completion
88
- cleanup_done = asyncio.Future()
89
- self._cleanup_queue.put_nowait((self.exit_stack, cleanup_done))
90
- await cleanup_done
91
- return
92
-
93
- # Normal cleanup
94
96
  await self.exit_stack.aclose()
95
97
 
96
98
  def to_sync_client(self):
97
99
  raise NotImplementedError("Subclasses must implement to_sync_client")
98
-
99
- async def __aenter__(self):
100
- """Enter the async context manager."""
101
- await self.connect_to_server()
102
- return self
103
-
104
- async def __aexit__(self, exc_type, exc_val, exc_tb):
105
- """Exit the async context manager."""
106
- await self.cleanup()
107
- return False # Don't suppress exceptions
@@ -0,0 +1,433 @@
1
+ """OAuth utilities for MCP server authentication."""
2
+
3
+ import asyncio
4
+ import json
5
+ import secrets
6
+ import time
7
+ import uuid
8
+ from datetime import datetime, timedelta
9
+ from typing import Callable, Optional, Tuple
10
+
11
+ from mcp.client.auth import OAuthClientProvider, TokenStorage
12
+ from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
13
+ from sqlalchemy import select
14
+
15
+ from letta.log import get_logger
16
+ from letta.orm.mcp_oauth import MCPOAuth, OAuthSessionStatus
17
+ from letta.schemas.mcp import MCPOAuthSessionUpdate
18
+ from letta.schemas.user import User as PydanticUser
19
+ from letta.server.db import db_registry
20
+ from letta.services.mcp.types import OauthStreamEvent
21
+ from letta.services.mcp_manager import MCPManager
22
+
23
+ logger = get_logger(__name__)
24
+
25
+
26
+ class DatabaseTokenStorage(TokenStorage):
27
+ """Database-backed token storage using MCPOAuth table via mcp_manager."""
28
+
29
+ def __init__(self, session_id: str, mcp_manager: MCPManager, actor: PydanticUser):
30
+ self.session_id = session_id
31
+ self.mcp_manager = mcp_manager
32
+ self.actor = actor
33
+
34
+ async def get_tokens(self) -> Optional[OAuthToken]:
35
+ """Retrieve tokens from database."""
36
+ oauth_session = await self.mcp_manager.get_oauth_session_by_id(self.session_id, self.actor)
37
+ if not oauth_session or not oauth_session.access_token:
38
+ return None
39
+
40
+ return OAuthToken(
41
+ access_token=oauth_session.access_token,
42
+ refresh_token=oauth_session.refresh_token,
43
+ token_type=oauth_session.token_type,
44
+ expires_in=int(oauth_session.expires_at.timestamp() - time.time()),
45
+ scope=oauth_session.scope,
46
+ )
47
+
48
+ async def set_tokens(self, tokens: OAuthToken) -> None:
49
+ """Store tokens in database."""
50
+ session_update = MCPOAuthSessionUpdate(
51
+ access_token=tokens.access_token,
52
+ refresh_token=tokens.refresh_token,
53
+ token_type=tokens.token_type,
54
+ expires_at=datetime.fromtimestamp(tokens.expires_in + time.time()),
55
+ scope=tokens.scope,
56
+ status=OAuthSessionStatus.AUTHORIZED,
57
+ )
58
+ await self.mcp_manager.update_oauth_session(self.session_id, session_update, self.actor)
59
+
60
+ async def get_client_info(self) -> Optional[OAuthClientInformationFull]:
61
+ """Retrieve client information from database."""
62
+ oauth_session = await self.mcp_manager.get_oauth_session_by_id(self.session_id, self.actor)
63
+ if not oauth_session or not oauth_session.client_id:
64
+ return None
65
+
66
+ return OAuthClientInformationFull(
67
+ client_id=oauth_session.client_id,
68
+ client_secret=oauth_session.client_secret,
69
+ redirect_uris=[oauth_session.redirect_uri] if oauth_session.redirect_uri else [],
70
+ )
71
+
72
+ async def set_client_info(self, client_info: OAuthClientInformationFull) -> None:
73
+ """Store client information in database."""
74
+ session_update = MCPOAuthSessionUpdate(
75
+ client_id=client_info.client_id,
76
+ client_secret=client_info.client_secret,
77
+ redirect_uri=str(client_info.redirect_uris[0]) if client_info.redirect_uris else None,
78
+ )
79
+ await self.mcp_manager.update_oauth_session(self.session_id, session_update, self.actor)
80
+
81
+
82
+ class MCPOAuthSession:
83
+ """Legacy OAuth session class - deprecated, use mcp_manager directly."""
84
+
85
+ def __init__(self, server_url: str, server_name: str, user_id: Optional[str], organization_id: str):
86
+ self.server_url = server_url
87
+ self.server_name = server_name
88
+ self.user_id = user_id
89
+ self.organization_id = organization_id
90
+ self.session_id = str(uuid.uuid4())
91
+ self.state = secrets.token_urlsafe(32)
92
+
93
+ def __init__(self, session_id: str):
94
+ self.session_id = session_id
95
+
96
+ # TODO: consolidate / deprecate this in favor of mcp_manager access
97
+ async def create_session(self) -> str:
98
+ """Create a new OAuth session in the database."""
99
+ async with db_registry.async_session() as session:
100
+ oauth_record = MCPOAuth(
101
+ id=self.session_id,
102
+ state=self.state,
103
+ server_url=self.server_url,
104
+ server_name=self.server_name,
105
+ user_id=self.user_id,
106
+ organization_id=self.organization_id,
107
+ status=OAuthSessionStatus.PENDING,
108
+ created_at=datetime.now(),
109
+ updated_at=datetime.now(),
110
+ )
111
+ oauth_record = await oauth_record.create_async(session, actor=None)
112
+
113
+ return self.session_id
114
+
115
+ async def get_session_status(self) -> OAuthSessionStatus:
116
+ """Get the current status of the OAuth session."""
117
+ async with db_registry.async_session() as session:
118
+ try:
119
+ oauth_record = await MCPOAuth.read_async(db_session=session, identifier=self.session_id, actor=None)
120
+ return oauth_record.status
121
+ except Exception:
122
+ return OAuthSessionStatus.ERROR
123
+
124
+ async def update_session_status(self, status: OAuthSessionStatus) -> None:
125
+ """Update the session status."""
126
+ async with db_registry.async_session() as session:
127
+ try:
128
+ oauth_record = await MCPOAuth.read_async(db_session=session, identifier=self.session_id, actor=None)
129
+ oauth_record.status = status
130
+ oauth_record.updated_at = datetime.now()
131
+ await oauth_record.update_async(db_session=session, actor=None)
132
+ except Exception:
133
+ pass
134
+
135
+ async def store_authorization_code(self, code: str, state: str) -> bool:
136
+ """Store the authorization code from OAuth callback."""
137
+ async with db_registry.async_session() as session:
138
+ try:
139
+ oauth_record = await MCPOAuth.read_async(db_session=session, identifier=self.session_id, actor=None)
140
+
141
+ # if oauth_record.state != state:
142
+ # return False
143
+
144
+ oauth_record.authorization_code = code
145
+ oauth_record.state = state
146
+ oauth_record.status = OAuthSessionStatus.AUTHORIZED
147
+ oauth_record.updated_at = datetime.now()
148
+ await oauth_record.update_async(db_session=session, actor=None)
149
+ return True
150
+ except Exception:
151
+ return False
152
+
153
+ async def get_authorization_url(self) -> Optional[str]:
154
+ """Get the authorization URL for this session."""
155
+ async with db_registry.async_session() as session:
156
+ try:
157
+ oauth_record = await MCPOAuth.read_async(db_session=session, identifier=self.session_id, actor=None)
158
+ return oauth_record.authorization_url
159
+ except Exception:
160
+ return None
161
+
162
+ async def set_authorization_url(self, url: str) -> None:
163
+ """Set the authorization URL for this session."""
164
+ async with db_registry.async_session() as session:
165
+ try:
166
+ oauth_record = await MCPOAuth.read_async(db_session=session, identifier=self.session_id, actor=None)
167
+ oauth_record.authorization_url = url
168
+ oauth_record.updated_at = datetime.now()
169
+ await oauth_record.update_async(db_session=session, actor=None)
170
+ except Exception:
171
+ pass
172
+
173
+
174
+ async def create_oauth_provider(
175
+ session_id: str,
176
+ server_url: str,
177
+ redirect_uri: str,
178
+ mcp_manager: MCPManager,
179
+ actor: PydanticUser,
180
+ url_callback: Optional[Callable[[str], None]] = None,
181
+ ) -> OAuthClientProvider:
182
+ """Create an OAuth provider for MCP server authentication."""
183
+
184
+ client_metadata_dict = {
185
+ "client_name": "Letta MCP Client",
186
+ "redirect_uris": [redirect_uri],
187
+ "grant_types": ["authorization_code", "refresh_token"],
188
+ "response_types": ["code"],
189
+ "token_endpoint_auth_method": "client_secret_post",
190
+ }
191
+
192
+ # Use manager-based storage
193
+ storage = DatabaseTokenStorage(session_id, mcp_manager, actor)
194
+
195
+ # Extract base URL (remove /mcp endpoint if present)
196
+ oauth_server_url = server_url.rstrip("/").removesuffix("/sse").removesuffix("/mcp")
197
+
198
+ async def redirect_handler(authorization_url: str) -> None:
199
+ """Handle OAuth redirect by storing the authorization URL."""
200
+ logger.info(f"OAuth redirect handler called with URL: {authorization_url}")
201
+ session_update = MCPOAuthSessionUpdate(authorization_url=authorization_url)
202
+ await mcp_manager.update_oauth_session(session_id, session_update, actor)
203
+ logger.info(f"OAuth authorization URL stored: {authorization_url}")
204
+
205
+ # Call the callback if provided (e.g., to yield URL to SSE stream)
206
+ if url_callback:
207
+ url_callback(authorization_url)
208
+
209
+ async def callback_handler() -> Tuple[str, Optional[str]]:
210
+ """Handle OAuth callback by waiting for authorization code."""
211
+ timeout = 300 # 5 minutes
212
+ start_time = time.time()
213
+
214
+ logger.info(f"Waiting for authorization code for session {session_id}")
215
+ while time.time() - start_time < timeout:
216
+ oauth_session = await mcp_manager.get_oauth_session_by_id(session_id, actor)
217
+ if oauth_session and oauth_session.authorization_code:
218
+ return oauth_session.authorization_code, oauth_session.state
219
+ elif oauth_session and oauth_session.status == OAuthSessionStatus.ERROR:
220
+ raise Exception("OAuth authorization failed")
221
+ await asyncio.sleep(1)
222
+
223
+ raise Exception(f"Timeout waiting for OAuth callback after {timeout} seconds")
224
+
225
+ return OAuthClientProvider(
226
+ server_url=oauth_server_url,
227
+ client_metadata=OAuthClientMetadata.model_validate(client_metadata_dict),
228
+ storage=storage,
229
+ redirect_handler=redirect_handler,
230
+ callback_handler=callback_handler,
231
+ )
232
+
233
+
234
+ async def cleanup_expired_oauth_sessions(max_age_hours: int = 24) -> None:
235
+ """Clean up expired OAuth sessions."""
236
+ cutoff_time = datetime.now() - timedelta(hours=max_age_hours)
237
+
238
+ async with db_registry.async_session() as session:
239
+ result = await session.execute(select(MCPOAuth).where(MCPOAuth.created_at < cutoff_time))
240
+ expired_sessions = result.scalars().all()
241
+
242
+ for oauth_session in expired_sessions:
243
+ await oauth_session.hard_delete_async(db_session=session, actor=None)
244
+
245
+ if expired_sessions:
246
+ logger.info(f"Cleaned up {len(expired_sessions)} expired OAuth sessions")
247
+
248
+
249
+ def oauth_stream_event(event: OauthStreamEvent, **kwargs) -> str:
250
+ data = {"event": event.value}
251
+ data.update(kwargs)
252
+ return f"data: {json.dumps(data)}\n\n"
253
+
254
+
255
+ def drill_down_exception(exception, depth=0, max_depth=5):
256
+ """Recursively drill down into nested exceptions to find the root cause"""
257
+ indent = " " * depth
258
+ error_details = []
259
+
260
+ error_details.append(f"{indent}Exception at depth {depth}:")
261
+ error_details.append(f"{indent} Type: {type(exception).__name__}")
262
+ error_details.append(f"{indent} Message: {str(exception)}")
263
+ error_details.append(f"{indent} Module: {getattr(type(exception), '__module__', 'unknown')}")
264
+
265
+ # Check for exception groups (TaskGroup errors)
266
+ if hasattr(exception, "exceptions") and exception.exceptions:
267
+ error_details.append(f"{indent} ExceptionGroup with {len(exception.exceptions)} sub-exceptions:")
268
+ for i, sub_exc in enumerate(exception.exceptions):
269
+ error_details.append(f"{indent} Sub-exception {i}:")
270
+ if depth < max_depth:
271
+ error_details.extend(drill_down_exception(sub_exc, depth + 1, max_depth))
272
+
273
+ # Check for chained exceptions (__cause__ and __context__)
274
+ if hasattr(exception, "__cause__") and exception.__cause__ and depth < max_depth:
275
+ error_details.append(f"{indent} Caused by:")
276
+ error_details.extend(drill_down_exception(exception.__cause__, depth + 1, max_depth))
277
+
278
+ if hasattr(exception, "__context__") and exception.__context__ and depth < max_depth:
279
+ error_details.append(f"{indent} Context:")
280
+ error_details.extend(drill_down_exception(exception.__context__, depth + 1, max_depth))
281
+
282
+ # Add traceback info
283
+ import traceback
284
+
285
+ if hasattr(exception, "__traceback__") and exception.__traceback__:
286
+ tb_lines = traceback.format_tb(exception.__traceback__)
287
+ error_details.append(f"{indent} Traceback:")
288
+ for line in tb_lines[-3:]: # Show last 3 traceback lines
289
+ error_details.append(f"{indent} {line.strip()}")
290
+
291
+ error_info = "".join(error_details)
292
+ return error_info
293
+
294
+
295
+ def get_oauth_success_html() -> str:
296
+ """Generate HTML for successful OAuth authorization."""
297
+ return """
298
+ <!DOCTYPE html>
299
+ <html>
300
+ <head>
301
+ <title>Authorization Successful - Letta</title>
302
+ <style>
303
+ * {
304
+ margin: 0;
305
+ padding: 0;
306
+ box-sizing: border-box;
307
+ }
308
+
309
+ body {
310
+ font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
311
+ display: flex;
312
+ justify-content: center;
313
+ align-items: center;
314
+ min-height: 100vh;
315
+ margin: 0;
316
+ background-color: #f5f5f5;
317
+ background-image: url("data:image/svg+xml,%3Csvg width='1440' height='860' viewBox='0 0 1440 860' fill='none' xmlns='http://www.w3.org/2000/svg'%3E%3Cg clip-path='url(%23clip0_14823_146864)'%3E%3Cpath d='M720.001 1003.14C1080.62 1003.14 1372.96 824.028 1372.96 603.083C1372.96 382.138 1080.62 203.026 720.001 203.026C359.384 203.026 67.046 382.138 67.046 603.083C67.046 824.028 359.384 1003.14 720.001 1003.14Z' stroke='%23E1E2E3' stroke-width='1.5' stroke-miterlimit='10'/%3E%3Cpath d='M719.999 978.04C910.334 978.04 1064.63 883.505 1064.63 766.891C1064.63 650.276 910.334 555.741 719.999 555.741C529.665 555.741 375.368 650.276 375.368 766.891C375.368 883.505 529.665 978.04 719.999 978.04Z' stroke='%23E1E2E3' stroke-width='1.5' stroke-miterlimit='10'/%3E%3Cpath d='M720 1020.95C1262.17 1020.95 1701.68 756.371 1701.68 430C1701.68 103.629 1262.17 -160.946 720 -160.946C177.834 -160.946 -261.678 103.629 -261.678 430C-261.678 756.371 177.834 1020.95 720 1020.95Z' stroke='%23E1E2E3' stroke-width='1.5' stroke-miterlimit='10'/%3E%3Cpath d='M719.999 323.658C910.334 323.658 1064.63 223.814 1064.63 100.649C1064.63 -22.5157 910.334 -122.36 719.999 -122.36C529.665 -122.36 375.368 -22.5157 375.368 100.649C375.368 223.814 529.665 323.658 719.999 323.658Z' stroke='%23E1E2E3' stroke-width='1.5' stroke-miterlimit='10'/%3E%3Cpath d='M720.001 706.676C1080.62 706.676 1372.96 517.507 1372.96 284.155C1372.96 50.8029 1080.62 -138.366 720.001 -138.366C359.384 -138.366 67.046 50.8029 67.046 284.155C67.046 517.507 359.384 706.676 720.001 706.676Z' stroke='%23E1E2E3' stroke-width='1.5' stroke-miterlimit='10'/%3E%3Cpath d='M719.999 874.604C1180.69 874.604 1554.15 645.789 1554.15 363.531C1554.15 81.2725 1180.69 -147.543 719.999 -147.543C259.311 -147.543 -114.15 81.2725 -114.15 363.531C-114.15 645.789 259.311 874.604 719.999 874.604Z' stroke='%23E1E2E3' stroke-width='1.5' stroke-miterlimit='10'/%3E%3C/g%3E%3Cdefs%3E%3CclipPath id='clip0_14823_146864'%3E%3Crect width='1440' height='860' fill='white'/%3E%3C/clipPath%3E%3C/defs%3E%3C/svg%3E");
318
+ background-size: cover;
319
+ background-position: center;
320
+ background-repeat: no-repeat;
321
+ }
322
+
323
+ .card {
324
+ text-align: center;
325
+ padding: 48px;
326
+ background: white;
327
+ border-radius: 8px;
328
+ border: 1px solid #E1E2E3;
329
+ max-width: 400px;
330
+ width: 90%;
331
+ position: relative;
332
+ z-index: 1;
333
+ }
334
+
335
+ .logo {
336
+ width: 48px;
337
+ height: 48px;
338
+ margin: 0 auto 24px;
339
+ display: block;
340
+ }
341
+
342
+ .logo svg {
343
+ width: 100%;
344
+ height: 100%;
345
+ }
346
+
347
+ h1 {
348
+ font-size: 20px;
349
+ font-weight: 600;
350
+ color: #101010;
351
+ margin-bottom: 12px;
352
+ line-height: 1.2;
353
+ }
354
+
355
+ .subtitle {
356
+ color: #666;
357
+ font-size: 12px;
358
+ margin-top: 10px;
359
+ margin-bottom: 24px;
360
+ line-height: 1.5;
361
+ }
362
+
363
+ .close-info {
364
+ font-size: 12px;
365
+ color: #999;
366
+ display: flex;
367
+ align-items: center;
368
+ justify-content: center;
369
+ gap: 8px;
370
+ }
371
+
372
+ .spinner {
373
+ width: 16px;
374
+ height: 16px;
375
+ border: 2px solid #E1E2E3;
376
+ border-top: 2px solid #333;
377
+ border-radius: 50%;
378
+ animation: spin 1s linear infinite;
379
+ }
380
+
381
+ @keyframes spin {
382
+ 0% { transform: rotate(0deg); }
383
+ 100% { transform: rotate(360deg); }
384
+ }
385
+
386
+ /* Dark mode styles */
387
+ @media (prefers-color-scheme: dark) {
388
+ body {
389
+ background-color: #101010;
390
+ background-image: url("data:image/svg+xml,%3Csvg width='1440' height='860' viewBox='0 0 1440 860' fill='none' xmlns='http://www.w3.org/2000/svg'%3E%3Cg clip-path='url(%23clip0_14833_149362)'%3E%3Cpath d='M720.001 1003.14C1080.62 1003.14 1372.96 824.028 1372.96 603.083C1372.96 382.138 1080.62 203.026 720.001 203.026C359.384 203.026 67.046 382.138 67.046 603.083C67.046 824.028 359.384 1003.14 720.001 1003.14Z' stroke='%2346484A' stroke-width='1.5' stroke-miterlimit='10'/%3E%3Cpath d='M719.999 978.04C910.334 978.04 1064.63 883.505 1064.63 766.891C1064.63 650.276 910.334 555.741 719.999 555.741C529.665 555.741 375.368 650.276 375.368 766.891C375.368 883.505 529.665 978.04 719.999 978.04Z' stroke='%2346484A' stroke-width='1.5' stroke-miterlimit='10'/%3E%3Cpath d='M720 1020.95C1262.17 1020.95 1701.68 756.371 1701.68 430C1701.68 103.629 1262.17 -160.946 720 -160.946C177.834 -160.946 -261.678 103.629 -261.678 430C-261.678 756.371 177.834 1020.95 720 1020.95Z' stroke='%2346484A' stroke-width='1.5' stroke-miterlimit='10'/%3E%3Cpath d='M719.999 323.658C910.334 323.658 1064.63 223.814 1064.63 100.649C1064.63 -22.5157 910.334 -122.36 719.999 -122.36C529.665 -122.36 375.368 -22.5157 375.368 100.649C375.368 223.814 529.665 323.658 719.999 323.658Z' stroke='%2346484A' stroke-width='1.5' stroke-miterlimit='10'/%3E%3Cpath d='M720.001 706.676C1080.62 706.676 1372.96 517.507 1372.96 284.155C1372.96 50.8029 1080.62 -138.366 720.001 -138.366C359.384 -138.366 67.046 50.8029 67.046 284.155C67.046 517.507 359.384 706.676 720.001 706.676Z' stroke='%2346484A' stroke-width='1.5' stroke-miterlimit='10'/%3E%3Cpath d='M719.999 874.604C1180.69 874.604 1554.15 645.789 1554.15 363.531C1554.15 81.2725 1180.69 -147.543 719.999 -147.543C259.311 -147.543 -114.15 81.2725 -114.15 363.531C-114.15 645.789 259.311 874.604 719.999 874.604Z' stroke='%2346484A' stroke-width='1.5' stroke-miterlimit='10'/%3E%3C/g%3E%3Cdefs%3E%3CclipPath id='clip0_14833_149362'%3E%3Crect width='1440' height='860' fill='white'/%3E%3C/clipPath%3E%3C/defs%3E%3C/svg%3E");
391
+ }
392
+
393
+ .card {
394
+ background-color: #141414;
395
+ border-color: #202020;
396
+ }
397
+
398
+ h1 {
399
+ color: #E1E2E3;
400
+ }
401
+
402
+ .subtitle {
403
+ color: #999;
404
+ }
405
+
406
+ .logo svg path {
407
+ fill: #E1E2E3;
408
+ }
409
+
410
+ .spinner {
411
+ border-color: #46484A;
412
+ border-top-color: #E1E2E3;
413
+ }
414
+ }
415
+ </style>
416
+ </head>
417
+ <body>
418
+ <div class="card">
419
+ <div class="logo">
420
+ <svg width="48" height="48" viewBox="0 0 18 18" fill="none" xmlns="http://www.w3.org/2000/svg">
421
+ <path d="M10.7134 7.30028H7.28759V10.7002H10.7134V7.30028Z" fill="#333"/>
422
+ <path d="M14.1391 2.81618V0.5H3.86131V2.81618C3.86131 3.41495 3.37266 3.89991 2.76935 3.89991H0.435547V14.1001H2.76935C3.37266 14.1001 3.86131 14.5851 3.86131 15.1838V17.5H14.1391V15.1838C14.1391 14.5851 14.6277 14.1001 15.231 14.1001H17.5648V3.89991H15.231C14.6277 3.89991 14.1391 3.41495 14.1391 2.81618ZM14.1391 13.0159C14.1391 13.6147 13.6504 14.0996 13.0471 14.0996H4.95375C4.35043 14.0996 3.86179 13.6147 3.86179 13.0159V4.98363C3.86179 4.38486 4.35043 3.89991 4.95375 3.89991H13.0471C13.6504 3.89991 14.1391 4.38486 14.1391 4.98363V13.0159Z" fill="#333"/>
423
+ </svg>
424
+ </div>
425
+ <h3>Authorization Successful</h3>
426
+ <p class="subtitle">You have successfully connected your MCP server.</p>
427
+ <div class="close-info">
428
+ <span>You can now close this window.</span>
429
+ </div>
430
+ </div>
431
+ </body>
432
+ </html>
433
+ """
@@ -1,4 +1,7 @@
1
+ from typing import Optional
2
+
1
3
  from mcp import ClientSession
4
+ from mcp.client.auth import OAuthClientProvider
2
5
  from mcp.client.sse import sse_client
3
6
 
4
7
  from letta.functions.mcp_client.types import SSEServerConfig
@@ -13,6 +16,9 @@ logger = get_logger(__name__)
13
16
 
14
17
  # TODO: Get rid of Async prefix on this class name once we deprecate old sync code
15
18
  class AsyncSSEMCPClient(AsyncBaseMCPClient):
19
+ def __init__(self, server_config: SSEServerConfig, oauth_provider: Optional[OAuthClientProvider] = None):
20
+ super().__init__(server_config, oauth_provider)
21
+
16
22
  async def _initialize_connection(self, server_config: SSEServerConfig) -> None:
17
23
  headers = {}
18
24
  if server_config.custom_headers:
@@ -21,7 +27,12 @@ class AsyncSSEMCPClient(AsyncBaseMCPClient):
21
27
  if server_config.auth_header and server_config.auth_token:
22
28
  headers[server_config.auth_header] = server_config.auth_token
23
29
 
24
- sse_cm = sse_client(url=server_config.server_url, headers=headers if headers else None)
30
+ # Use OAuth provider if available, otherwise use regular headers
31
+ if self.oauth_provider:
32
+ sse_cm = sse_client(url=server_config.server_url, headers=headers if headers else None, auth=self.oauth_provider)
33
+ else:
34
+ sse_cm = sse_client(url=server_config.server_url, headers=headers if headers else None)
35
+
25
36
  sse_transport = await self.exit_stack.enter_async_context(sse_cm)
26
37
  self.stdio, self.write = sse_transport
27
38
 
@@ -1,4 +1,7 @@
1
+ from typing import Optional
2
+
1
3
  from mcp import ClientSession
4
+ from mcp.client.auth import OAuthClientProvider
2
5
  from mcp.client.streamable_http import streamablehttp_client
3
6
 
4
7
  from letta.functions.mcp_client.types import BaseServerConfig, StreamableHTTPServerConfig
@@ -9,10 +12,12 @@ logger = get_logger(__name__)
9
12
 
10
13
 
11
14
  class AsyncStreamableHTTPMCPClient(AsyncBaseMCPClient):
15
+ def __init__(self, server_config: StreamableHTTPServerConfig, oauth_provider: Optional[OAuthClientProvider] = None):
16
+ super().__init__(server_config, oauth_provider)
17
+
12
18
  async def _initialize_connection(self, server_config: BaseServerConfig) -> None:
13
19
  if not isinstance(server_config, StreamableHTTPServerConfig):
14
20
  raise ValueError("Expected StreamableHTTPServerConfig")
15
-
16
21
  try:
17
22
  # Prepare headers for authentication
18
23
  headers = {}
@@ -23,11 +28,18 @@ class AsyncStreamableHTTPMCPClient(AsyncBaseMCPClient):
23
28
  if server_config.auth_header and server_config.auth_token:
24
29
  headers[server_config.auth_header] = server_config.auth_token
25
30
 
26
- # Use streamablehttp_client context manager with headers if provided
27
- if headers:
28
- streamable_http_cm = streamablehttp_client(server_config.server_url, headers=headers)
31
+ # Use OAuth provider if available, otherwise use regular headers
32
+ if self.oauth_provider:
33
+ streamable_http_cm = streamablehttp_client(
34
+ server_config.server_url, headers=headers if headers else None, auth=self.oauth_provider
35
+ )
29
36
  else:
30
- streamable_http_cm = streamablehttp_client(server_config.server_url)
37
+ # Use streamablehttp_client context manager with headers if provided
38
+ if headers:
39
+ streamable_http_cm = streamablehttp_client(server_config.server_url, headers=headers)
40
+ else:
41
+ streamable_http_cm = streamablehttp_client(server_config.server_url)
42
+
31
43
  read_stream, write_stream, _ = await self.exit_stack.enter_async_context(streamable_http_cm)
32
44
 
33
45
  # Create and enter the ClientSession context manager
@@ -46,3 +46,12 @@ class StdioServerConfig(BaseServerConfig):
46
46
  if self.env is not None:
47
47
  values["env"] = self.env
48
48
  return values
49
+
50
+
51
+ class OauthStreamEvent(str, Enum):
52
+ CONNECTION_ATTEMPT = "connection_attempt"
53
+ SUCCESS = "success"
54
+ ERROR = "error"
55
+ OAUTH_REQUIRED = "oauth_required"
56
+ AUTHORIZATION_URL = "authorization_url"
57
+ WAITING_FOR_AUTH = "waiting_for_auth"