letta-nightly 0.12.1.dev20251024104217__py3-none-any.whl → 0.13.0.dev20251025104015__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of letta-nightly might be problematic. Click here for more details.
- letta/__init__.py +2 -3
- letta/adapters/letta_llm_adapter.py +1 -0
- letta/adapters/simple_llm_request_adapter.py +8 -5
- letta/adapters/simple_llm_stream_adapter.py +22 -6
- letta/agents/agent_loop.py +10 -3
- letta/agents/base_agent.py +4 -1
- letta/agents/helpers.py +41 -9
- letta/agents/letta_agent.py +11 -10
- letta/agents/letta_agent_v2.py +47 -37
- letta/agents/letta_agent_v3.py +395 -300
- letta/agents/voice_agent.py +8 -6
- letta/agents/voice_sleeptime_agent.py +3 -3
- letta/constants.py +30 -7
- letta/errors.py +20 -0
- letta/functions/function_sets/base.py +55 -3
- letta/functions/mcp_client/types.py +33 -57
- letta/functions/schema_generator.py +135 -23
- letta/groups/sleeptime_multi_agent_v3.py +6 -11
- letta/groups/sleeptime_multi_agent_v4.py +227 -0
- letta/helpers/converters.py +78 -4
- letta/helpers/crypto_utils.py +6 -2
- letta/interfaces/anthropic_parallel_tool_call_streaming_interface.py +9 -11
- letta/interfaces/anthropic_streaming_interface.py +3 -4
- letta/interfaces/gemini_streaming_interface.py +4 -6
- letta/interfaces/openai_streaming_interface.py +63 -28
- letta/llm_api/anthropic_client.py +7 -4
- letta/llm_api/deepseek_client.py +6 -4
- letta/llm_api/google_ai_client.py +3 -12
- letta/llm_api/google_vertex_client.py +1 -1
- letta/llm_api/helpers.py +90 -61
- letta/llm_api/llm_api_tools.py +4 -1
- letta/llm_api/openai.py +12 -12
- letta/llm_api/openai_client.py +53 -16
- letta/local_llm/constants.py +4 -3
- letta/local_llm/json_parser.py +5 -2
- letta/local_llm/utils.py +2 -3
- letta/log.py +171 -7
- letta/orm/agent.py +43 -9
- letta/orm/archive.py +4 -0
- letta/orm/custom_columns.py +15 -0
- letta/orm/identity.py +11 -11
- letta/orm/mcp_server.py +9 -0
- letta/orm/message.py +6 -1
- letta/orm/run_metrics.py +7 -2
- letta/orm/sqlalchemy_base.py +2 -2
- letta/orm/tool.py +3 -0
- letta/otel/tracing.py +2 -0
- letta/prompts/prompt_generator.py +7 -2
- letta/schemas/agent.py +41 -10
- letta/schemas/agent_file.py +3 -0
- letta/schemas/archive.py +4 -2
- letta/schemas/block.py +2 -1
- letta/schemas/enums.py +36 -3
- letta/schemas/file.py +3 -3
- letta/schemas/folder.py +2 -1
- letta/schemas/group.py +2 -1
- letta/schemas/identity.py +18 -9
- letta/schemas/job.py +3 -1
- letta/schemas/letta_message.py +71 -12
- letta/schemas/letta_request.py +7 -3
- letta/schemas/letta_stop_reason.py +0 -25
- letta/schemas/llm_config.py +8 -2
- letta/schemas/mcp.py +80 -83
- letta/schemas/mcp_server.py +349 -0
- letta/schemas/memory.py +20 -8
- letta/schemas/message.py +212 -67
- letta/schemas/providers/anthropic.py +13 -6
- letta/schemas/providers/azure.py +6 -4
- letta/schemas/providers/base.py +8 -4
- letta/schemas/providers/bedrock.py +6 -2
- letta/schemas/providers/cerebras.py +7 -3
- letta/schemas/providers/deepseek.py +2 -1
- letta/schemas/providers/google_gemini.py +15 -6
- letta/schemas/providers/groq.py +2 -1
- letta/schemas/providers/lmstudio.py +9 -6
- letta/schemas/providers/mistral.py +2 -1
- letta/schemas/providers/openai.py +7 -2
- letta/schemas/providers/together.py +9 -3
- letta/schemas/providers/xai.py +7 -3
- letta/schemas/run.py +7 -2
- letta/schemas/run_metrics.py +2 -1
- letta/schemas/sandbox_config.py +2 -2
- letta/schemas/secret.py +3 -158
- letta/schemas/source.py +2 -2
- letta/schemas/step.py +2 -2
- letta/schemas/tool.py +24 -1
- letta/schemas/usage.py +0 -1
- letta/server/rest_api/app.py +123 -7
- letta/server/rest_api/dependencies.py +3 -0
- letta/server/rest_api/interface.py +7 -4
- letta/server/rest_api/redis_stream_manager.py +16 -1
- letta/server/rest_api/routers/v1/__init__.py +7 -0
- letta/server/rest_api/routers/v1/agents.py +332 -322
- letta/server/rest_api/routers/v1/archives.py +127 -40
- letta/server/rest_api/routers/v1/blocks.py +54 -6
- letta/server/rest_api/routers/v1/chat_completions.py +146 -0
- letta/server/rest_api/routers/v1/folders.py +27 -35
- letta/server/rest_api/routers/v1/groups.py +23 -35
- letta/server/rest_api/routers/v1/identities.py +24 -10
- letta/server/rest_api/routers/v1/internal_runs.py +107 -0
- letta/server/rest_api/routers/v1/internal_templates.py +162 -179
- letta/server/rest_api/routers/v1/jobs.py +15 -27
- letta/server/rest_api/routers/v1/mcp_servers.py +309 -0
- letta/server/rest_api/routers/v1/messages.py +23 -34
- letta/server/rest_api/routers/v1/organizations.py +6 -27
- letta/server/rest_api/routers/v1/providers.py +35 -62
- letta/server/rest_api/routers/v1/runs.py +30 -43
- letta/server/rest_api/routers/v1/sandbox_configs.py +6 -4
- letta/server/rest_api/routers/v1/sources.py +26 -42
- letta/server/rest_api/routers/v1/steps.py +16 -29
- letta/server/rest_api/routers/v1/tools.py +17 -13
- letta/server/rest_api/routers/v1/users.py +5 -17
- letta/server/rest_api/routers/v1/voice.py +18 -27
- letta/server/rest_api/streaming_response.py +5 -2
- letta/server/rest_api/utils.py +187 -25
- letta/server/server.py +27 -22
- letta/server/ws_api/server.py +5 -4
- letta/services/agent_manager.py +148 -26
- letta/services/agent_serialization_manager.py +6 -1
- letta/services/archive_manager.py +168 -15
- letta/services/block_manager.py +14 -4
- letta/services/file_manager.py +33 -29
- letta/services/group_manager.py +10 -0
- letta/services/helpers/agent_manager_helper.py +65 -11
- letta/services/identity_manager.py +105 -4
- letta/services/job_manager.py +11 -1
- letta/services/mcp/base_client.py +2 -2
- letta/services/mcp/oauth_utils.py +33 -8
- letta/services/mcp_manager.py +174 -78
- letta/services/mcp_server_manager.py +1331 -0
- letta/services/message_manager.py +109 -4
- letta/services/organization_manager.py +4 -4
- letta/services/passage_manager.py +9 -25
- letta/services/provider_manager.py +91 -15
- letta/services/run_manager.py +72 -15
- letta/services/sandbox_config_manager.py +45 -3
- letta/services/source_manager.py +15 -8
- letta/services/step_manager.py +24 -1
- letta/services/streaming_service.py +581 -0
- letta/services/summarizer/summarizer.py +1 -1
- letta/services/tool_executor/core_tool_executor.py +111 -0
- letta/services/tool_executor/files_tool_executor.py +5 -3
- letta/services/tool_executor/sandbox_tool_executor.py +2 -2
- letta/services/tool_executor/tool_execution_manager.py +1 -1
- letta/services/tool_manager.py +10 -3
- letta/services/tool_sandbox/base.py +61 -1
- letta/services/tool_sandbox/local_sandbox.py +1 -3
- letta/services/user_manager.py +2 -2
- letta/settings.py +49 -5
- letta/system.py +14 -5
- letta/utils.py +73 -1
- letta/validators.py +105 -0
- {letta_nightly-0.12.1.dev20251024104217.dist-info → letta_nightly-0.13.0.dev20251025104015.dist-info}/METADATA +4 -2
- {letta_nightly-0.12.1.dev20251024104217.dist-info → letta_nightly-0.13.0.dev20251025104015.dist-info}/RECORD +157 -151
- letta/schemas/letta_ping.py +0 -28
- letta/server/rest_api/routers/openai/chat_completions/__init__.py +0 -0
- {letta_nightly-0.12.1.dev20251024104217.dist-info → letta_nightly-0.13.0.dev20251025104015.dist-info}/WHEEL +0 -0
- {letta_nightly-0.12.1.dev20251024104217.dist-info → letta_nightly-0.13.0.dev20251025104015.dist-info}/entry_points.txt +0 -0
- {letta_nightly-0.12.1.dev20251024104217.dist-info → letta_nightly-0.13.0.dev20251025104015.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,1331 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
import secrets
|
|
4
|
+
import uuid
|
|
5
|
+
from datetime import datetime, timedelta
|
|
6
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
7
|
+
|
|
8
|
+
from fastapi import HTTPException
|
|
9
|
+
from sqlalchemy import delete, desc, null, select
|
|
10
|
+
from starlette.requests import Request
|
|
11
|
+
|
|
12
|
+
import letta.constants as constants
|
|
13
|
+
from letta.functions.mcp_client.types import (
|
|
14
|
+
MCPServerType,
|
|
15
|
+
MCPTool,
|
|
16
|
+
MCPToolHealth,
|
|
17
|
+
SSEServerConfig,
|
|
18
|
+
StdioServerConfig,
|
|
19
|
+
StreamableHTTPServerConfig,
|
|
20
|
+
)
|
|
21
|
+
from letta.functions.schema_generator import normalize_mcp_schema
|
|
22
|
+
from letta.functions.schema_validator import validate_complete_json_schema
|
|
23
|
+
from letta.log import get_logger
|
|
24
|
+
from letta.orm.errors import NoResultFound
|
|
25
|
+
from letta.orm.mcp_oauth import MCPOAuth, OAuthSessionStatus
|
|
26
|
+
from letta.orm.mcp_server import MCPServer as MCPServerModel, MCPTools as MCPToolsModel
|
|
27
|
+
from letta.orm.tool import Tool as ToolModel
|
|
28
|
+
from letta.schemas.mcp import (
|
|
29
|
+
MCPOAuthSession,
|
|
30
|
+
MCPOAuthSessionCreate,
|
|
31
|
+
MCPOAuthSessionUpdate,
|
|
32
|
+
MCPServer,
|
|
33
|
+
MCPServerResyncResult,
|
|
34
|
+
UpdateMCPServer,
|
|
35
|
+
UpdateSSEMCPServer,
|
|
36
|
+
UpdateStdioMCPServer,
|
|
37
|
+
UpdateStreamableHTTPMCPServer,
|
|
38
|
+
)
|
|
39
|
+
from letta.schemas.secret import Secret
|
|
40
|
+
from letta.schemas.tool import Tool as PydanticTool, ToolCreate, ToolUpdate
|
|
41
|
+
from letta.schemas.user import User as PydanticUser
|
|
42
|
+
from letta.server.db import db_registry
|
|
43
|
+
from letta.services.mcp.sse_client import MCP_CONFIG_TOPLEVEL_KEY, AsyncSSEMCPClient
|
|
44
|
+
from letta.services.mcp.stdio_client import AsyncStdioMCPClient
|
|
45
|
+
from letta.services.mcp.streamable_http_client import AsyncStreamableHTTPMCPClient
|
|
46
|
+
from letta.services.tool_manager import ToolManager
|
|
47
|
+
from letta.settings import settings, tool_settings
|
|
48
|
+
from letta.utils import enforce_types, printd, safe_create_task
|
|
49
|
+
|
|
50
|
+
logger = get_logger(__name__)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class MCPServerManager:
|
|
54
|
+
"""Manager class to handle business logic related to MCP."""
|
|
55
|
+
|
|
56
|
+
def __init__(self):
|
|
57
|
+
# TODO: timeouts?
|
|
58
|
+
self.tool_manager = ToolManager()
|
|
59
|
+
self.cached_mcp_servers = {} # maps id -> async connection
|
|
60
|
+
|
|
61
|
+
# MCPTools mapping table management methods
|
|
62
|
+
@enforce_types
|
|
63
|
+
async def create_mcp_tool_mapping(self, mcp_server_id: str, tool_id: str, actor: PydanticUser) -> None:
|
|
64
|
+
"""Create a mapping between an MCP server and a tool."""
|
|
65
|
+
async with db_registry.async_session() as session:
|
|
66
|
+
mapping = MCPToolsModel(
|
|
67
|
+
id=f"mcp-tool-mapping-{uuid.uuid4()}",
|
|
68
|
+
mcp_server_id=mcp_server_id,
|
|
69
|
+
tool_id=tool_id,
|
|
70
|
+
organization_id=actor.organization_id,
|
|
71
|
+
)
|
|
72
|
+
await mapping.create_async(session, actor=actor)
|
|
73
|
+
|
|
74
|
+
@enforce_types
|
|
75
|
+
async def delete_mcp_tool_mappings_by_server(self, mcp_server_id: str, actor: PydanticUser) -> None:
|
|
76
|
+
"""Delete all tool mappings for a specific MCP server."""
|
|
77
|
+
async with db_registry.async_session() as session:
|
|
78
|
+
await session.execute(
|
|
79
|
+
delete(MCPToolsModel).where(
|
|
80
|
+
MCPToolsModel.mcp_server_id == mcp_server_id,
|
|
81
|
+
MCPToolsModel.organization_id == actor.organization_id,
|
|
82
|
+
)
|
|
83
|
+
)
|
|
84
|
+
await session.commit()
|
|
85
|
+
|
|
86
|
+
@enforce_types
|
|
87
|
+
async def get_tool_ids_by_mcp_server(self, mcp_server_id: str, actor: PydanticUser) -> List[str]:
|
|
88
|
+
"""Get all tool IDs associated with an MCP server."""
|
|
89
|
+
async with db_registry.async_session() as session:
|
|
90
|
+
result = await session.execute(
|
|
91
|
+
select(MCPToolsModel.tool_id).where(
|
|
92
|
+
MCPToolsModel.mcp_server_id == mcp_server_id,
|
|
93
|
+
MCPToolsModel.organization_id == actor.organization_id,
|
|
94
|
+
)
|
|
95
|
+
)
|
|
96
|
+
return [row[0] for row in result.fetchall()]
|
|
97
|
+
|
|
98
|
+
@enforce_types
|
|
99
|
+
async def get_mcp_server_id_by_tool(self, tool_id: str, actor: PydanticUser) -> Optional[str]:
|
|
100
|
+
"""Get the MCP server ID associated with a tool."""
|
|
101
|
+
async with db_registry.async_session() as session:
|
|
102
|
+
result = await session.execute(
|
|
103
|
+
select(MCPToolsModel.mcp_server_id).where(
|
|
104
|
+
MCPToolsModel.tool_id == tool_id,
|
|
105
|
+
MCPToolsModel.organization_id == actor.organization_id,
|
|
106
|
+
)
|
|
107
|
+
)
|
|
108
|
+
row = result.fetchone()
|
|
109
|
+
return row[0] if row else None
|
|
110
|
+
|
|
111
|
+
@enforce_types
|
|
112
|
+
async def list_tools_by_mcp_server_from_db(self, mcp_server_id: str, actor: PydanticUser) -> List[PydanticTool]:
|
|
113
|
+
"""
|
|
114
|
+
Get tools associated with an MCP server from the database using the MCPTools mapping.
|
|
115
|
+
This is more efficient than fetching from the MCP server directly.
|
|
116
|
+
"""
|
|
117
|
+
# First get all tool IDs associated with this MCP server
|
|
118
|
+
tool_ids = await self.get_tool_ids_by_mcp_server(mcp_server_id, actor)
|
|
119
|
+
|
|
120
|
+
if not tool_ids:
|
|
121
|
+
return []
|
|
122
|
+
|
|
123
|
+
# Fetch all tools in a single query
|
|
124
|
+
async with db_registry.async_session() as session:
|
|
125
|
+
result = await session.execute(
|
|
126
|
+
select(ToolModel).where(
|
|
127
|
+
ToolModel.id.in_(tool_ids),
|
|
128
|
+
ToolModel.organization_id == actor.organization_id,
|
|
129
|
+
)
|
|
130
|
+
)
|
|
131
|
+
tools = result.scalars().all()
|
|
132
|
+
return [tool.to_pydantic() for tool in tools]
|
|
133
|
+
|
|
134
|
+
@enforce_types
|
|
135
|
+
async def get_tool_by_mcp_server(self, mcp_server_id: str, tool_id: str, actor: PydanticUser) -> Optional[PydanticTool]:
|
|
136
|
+
"""
|
|
137
|
+
Get a specific tool that belongs to an MCP server.
|
|
138
|
+
Verifies the tool is associated with the MCP server via the mapping table.
|
|
139
|
+
"""
|
|
140
|
+
async with db_registry.async_session() as session:
|
|
141
|
+
# Check if the tool is associated with this MCP server
|
|
142
|
+
result = await session.execute(
|
|
143
|
+
select(MCPToolsModel).where(
|
|
144
|
+
MCPToolsModel.mcp_server_id == mcp_server_id,
|
|
145
|
+
MCPToolsModel.tool_id == tool_id,
|
|
146
|
+
MCPToolsModel.organization_id == actor.organization_id,
|
|
147
|
+
)
|
|
148
|
+
)
|
|
149
|
+
mapping = result.scalar_one_or_none()
|
|
150
|
+
|
|
151
|
+
if not mapping:
|
|
152
|
+
return None
|
|
153
|
+
|
|
154
|
+
# Fetch the tool
|
|
155
|
+
tool = await ToolModel.read_async(db_session=session, identifier=tool_id, actor=actor)
|
|
156
|
+
return tool.to_pydantic()
|
|
157
|
+
|
|
158
|
+
@enforce_types
|
|
159
|
+
async def list_mcp_server_tools(self, mcp_server_id: str, actor: PydanticUser, agent_id: Optional[str] = None) -> List[MCPTool]:
|
|
160
|
+
"""Get a list of all tools for a specific MCP server by server ID."""
|
|
161
|
+
mcp_client = None
|
|
162
|
+
try:
|
|
163
|
+
mcp_config = await self.get_mcp_server_by_id_async(mcp_server_id, actor=actor)
|
|
164
|
+
server_config = mcp_config.to_config()
|
|
165
|
+
mcp_client = await self.get_mcp_client(server_config, actor, agent_id=agent_id)
|
|
166
|
+
await mcp_client.connect_to_server()
|
|
167
|
+
|
|
168
|
+
# list tools
|
|
169
|
+
tools = await mcp_client.list_tools()
|
|
170
|
+
# Add health information to each tool
|
|
171
|
+
for tool in tools:
|
|
172
|
+
# Try to normalize the schema and re-validate
|
|
173
|
+
if tool.inputSchema:
|
|
174
|
+
tool.inputSchema = normalize_mcp_schema(tool.inputSchema)
|
|
175
|
+
health_status, reasons = validate_complete_json_schema(tool.inputSchema)
|
|
176
|
+
tool.health = MCPToolHealth(status=health_status.value, reasons=reasons)
|
|
177
|
+
|
|
178
|
+
return tools
|
|
179
|
+
except Exception as e:
|
|
180
|
+
# MCP tool listing errors are often due to connection/configuration issues, not system errors
|
|
181
|
+
# Log at info level to avoid triggering Sentry alerts for expected failures
|
|
182
|
+
logger.warning(f"Error listing tools for MCP server {mcp_server_id}: {e}")
|
|
183
|
+
raise e
|
|
184
|
+
finally:
|
|
185
|
+
if mcp_client:
|
|
186
|
+
try:
|
|
187
|
+
await mcp_client.cleanup()
|
|
188
|
+
except Exception as e:
|
|
189
|
+
logger.warning(f"Error listing tools for MCP server {mcp_server_id}: {e}")
|
|
190
|
+
raise e
|
|
191
|
+
|
|
192
|
+
@enforce_types
|
|
193
|
+
async def execute_mcp_server_tool(
|
|
194
|
+
self,
|
|
195
|
+
mcp_server_id: str,
|
|
196
|
+
tool_id: str,
|
|
197
|
+
tool_args: Optional[Dict[str, Any]],
|
|
198
|
+
environment_variables: Dict[str, str],
|
|
199
|
+
actor: PydanticUser,
|
|
200
|
+
agent_id: Optional[str] = None,
|
|
201
|
+
) -> Tuple[str, bool]:
|
|
202
|
+
"""Call a specific tool from a specific MCP server by IDs."""
|
|
203
|
+
mcp_client = None
|
|
204
|
+
try:
|
|
205
|
+
# Get the tool to find its actual name
|
|
206
|
+
async with db_registry.async_session() as session:
|
|
207
|
+
tool = await ToolModel.read_async(db_session=session, identifier=tool_id, actor=actor)
|
|
208
|
+
tool_name = tool.name
|
|
209
|
+
|
|
210
|
+
# Get the MCP server config
|
|
211
|
+
mcp_config = await self.get_mcp_server_by_id_async(mcp_server_id, actor=actor)
|
|
212
|
+
server_config = mcp_config.to_config(environment_variables)
|
|
213
|
+
|
|
214
|
+
mcp_client = await self.get_mcp_client(server_config, actor, agent_id=agent_id)
|
|
215
|
+
await mcp_client.connect_to_server()
|
|
216
|
+
|
|
217
|
+
# call tool
|
|
218
|
+
result, success = await mcp_client.execute_tool(tool_name, tool_args)
|
|
219
|
+
logger.info(f"MCP Result: {result}, Success: {success}")
|
|
220
|
+
return result, success
|
|
221
|
+
finally:
|
|
222
|
+
if mcp_client:
|
|
223
|
+
await mcp_client.cleanup()
|
|
224
|
+
|
|
225
|
+
@enforce_types
|
|
226
|
+
async def add_tool_from_mcp_server(self, mcp_server_id: str, mcp_tool_name: str, actor: PydanticUser) -> PydanticTool:
|
|
227
|
+
"""Add a tool from an MCP server to the Letta tool registry."""
|
|
228
|
+
# Get the MCP server to get its name
|
|
229
|
+
mcp_server = await self.get_mcp_server_by_id_async(mcp_server_id, actor=actor)
|
|
230
|
+
mcp_server_name = mcp_server.server_name
|
|
231
|
+
|
|
232
|
+
mcp_tools = await self.list_mcp_server_tools(mcp_server_id, actor=actor)
|
|
233
|
+
for mcp_tool in mcp_tools:
|
|
234
|
+
# TODO: @jnjpng move health check to tool class
|
|
235
|
+
if mcp_tool.name == mcp_tool_name:
|
|
236
|
+
# Check tool health - but try normalization first for INVALID schemas
|
|
237
|
+
if mcp_tool.health and mcp_tool.health.status == "INVALID":
|
|
238
|
+
logger.info(f"Attempting to normalize INVALID schema for tool {mcp_tool_name}")
|
|
239
|
+
logger.info(f"Original health reasons: {mcp_tool.health.reasons}")
|
|
240
|
+
|
|
241
|
+
# Try to normalize the schema and re-validate
|
|
242
|
+
try:
|
|
243
|
+
# Normalize the schema to fix common issues
|
|
244
|
+
logger.debug(f"Normalizing schema for {mcp_tool_name}")
|
|
245
|
+
normalized_schema = normalize_mcp_schema(mcp_tool.inputSchema)
|
|
246
|
+
|
|
247
|
+
# Re-validate after normalization
|
|
248
|
+
logger.debug(f"Re-validating schema for {mcp_tool_name}")
|
|
249
|
+
health_status, health_reasons = validate_complete_json_schema(normalized_schema)
|
|
250
|
+
logger.info(f"After normalization: status={health_status.value}, reasons={health_reasons}")
|
|
251
|
+
|
|
252
|
+
# Update the tool's schema and health (use inputSchema, not input_schema)
|
|
253
|
+
mcp_tool.inputSchema = normalized_schema
|
|
254
|
+
mcp_tool.health.status = health_status.value
|
|
255
|
+
mcp_tool.health.reasons = health_reasons
|
|
256
|
+
|
|
257
|
+
# Log the normalization result
|
|
258
|
+
if health_status.value != "INVALID":
|
|
259
|
+
logger.info(f"✓ MCP tool {mcp_tool_name} schema normalized successfully: {health_status.value}")
|
|
260
|
+
else:
|
|
261
|
+
logger.warning(f"MCP tool {mcp_tool_name} still INVALID after normalization. Reasons: {health_reasons}")
|
|
262
|
+
except Exception as e:
|
|
263
|
+
logger.error(f"Failed to normalize schema for tool {mcp_tool_name}: {e}", exc_info=True)
|
|
264
|
+
|
|
265
|
+
# After normalization attempt, check if still INVALID
|
|
266
|
+
if mcp_tool.health and mcp_tool.health.status == "INVALID":
|
|
267
|
+
logger.warning(f"Tool {mcp_tool_name} has potentially invalid schema. Reasons: {', '.join(mcp_tool.health.reasons)}")
|
|
268
|
+
|
|
269
|
+
tool_create = ToolCreate.from_mcp(mcp_server_name=mcp_server_name, mcp_tool=mcp_tool)
|
|
270
|
+
created_tool = await self.tool_manager.create_mcp_tool_async(
|
|
271
|
+
tool_create=tool_create, mcp_server_name=mcp_server_name, mcp_server_id=mcp_server_id, actor=actor
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
# Create mapping in MCPTools table
|
|
275
|
+
if created_tool:
|
|
276
|
+
await self.create_mcp_tool_mapping(mcp_server_id, created_tool.id, actor)
|
|
277
|
+
|
|
278
|
+
return created_tool
|
|
279
|
+
|
|
280
|
+
# failed to add - handle error?
|
|
281
|
+
return None
|
|
282
|
+
|
|
283
|
+
@enforce_types
|
|
284
|
+
async def resync_mcp_server_tools(
|
|
285
|
+
self, mcp_server_id: str, actor: PydanticUser, agent_id: Optional[str] = None
|
|
286
|
+
) -> MCPServerResyncResult:
|
|
287
|
+
"""
|
|
288
|
+
Resync tools for an MCP server by:
|
|
289
|
+
1. Fetching current tools from the MCP server
|
|
290
|
+
2. Deleting tools that no longer exist on the server
|
|
291
|
+
3. Updating schemas for existing tools
|
|
292
|
+
4. Adding new tools from the server
|
|
293
|
+
|
|
294
|
+
Returns a result with:
|
|
295
|
+
- deleted: List of deleted tool names
|
|
296
|
+
- updated: List of updated tool names
|
|
297
|
+
- added: List of added tool names
|
|
298
|
+
"""
|
|
299
|
+
# Get the MCP server to get its name
|
|
300
|
+
mcp_server = await self.get_mcp_server_by_id_async(mcp_server_id, actor=actor)
|
|
301
|
+
mcp_server_name = mcp_server.server_name
|
|
302
|
+
|
|
303
|
+
# Fetch current tools from MCP server
|
|
304
|
+
try:
|
|
305
|
+
current_mcp_tools = await self.list_mcp_server_tools(mcp_server_id, actor=actor, agent_id=agent_id)
|
|
306
|
+
except Exception as e:
|
|
307
|
+
logger.error(f"Failed to fetch tools from MCP server {mcp_server_name}: {e}")
|
|
308
|
+
raise HTTPException(
|
|
309
|
+
status_code=404,
|
|
310
|
+
detail={
|
|
311
|
+
"code": "MCPServerUnavailable",
|
|
312
|
+
"message": f"Could not connect to MCP server {mcp_server_name} to resync tools",
|
|
313
|
+
"error": str(e),
|
|
314
|
+
},
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
# Get all persisted tools for this MCP server
|
|
318
|
+
async with db_registry.async_session() as session:
|
|
319
|
+
# Query for tools with MCP metadata matching this server
|
|
320
|
+
# Using JSON path query to filter by metadata
|
|
321
|
+
persisted_tools = await ToolModel.list_async(
|
|
322
|
+
db_session=session,
|
|
323
|
+
organization_id=actor.organization_id,
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
# Filter tools that belong to this MCP server
|
|
327
|
+
mcp_tools = []
|
|
328
|
+
for tool in persisted_tools:
|
|
329
|
+
if tool.metadata_ and constants.MCP_TOOL_TAG_NAME_PREFIX in tool.metadata_:
|
|
330
|
+
if tool.metadata_[constants.MCP_TOOL_TAG_NAME_PREFIX].get("server_id") == mcp_server_id:
|
|
331
|
+
mcp_tools.append(tool)
|
|
332
|
+
|
|
333
|
+
# Create maps for easier comparison
|
|
334
|
+
current_tool_map = {tool.name: tool for tool in current_mcp_tools}
|
|
335
|
+
persisted_tool_map = {tool.name: tool for tool in mcp_tools}
|
|
336
|
+
|
|
337
|
+
deleted_tools = []
|
|
338
|
+
updated_tools = []
|
|
339
|
+
added_tools = []
|
|
340
|
+
|
|
341
|
+
# 1. Delete tools that no longer exist on the server
|
|
342
|
+
for tool_name, persisted_tool in persisted_tool_map.items():
|
|
343
|
+
if tool_name not in current_tool_map:
|
|
344
|
+
# Delete the tool (cascade will handle agent detachment)
|
|
345
|
+
await persisted_tool.hard_delete_async(db_session=session, actor=actor)
|
|
346
|
+
deleted_tools.append(tool_name)
|
|
347
|
+
logger.info(f"Deleted MCP tool {tool_name} as it no longer exists on server {mcp_server_name}")
|
|
348
|
+
|
|
349
|
+
# Commit deletions
|
|
350
|
+
await session.commit()
|
|
351
|
+
|
|
352
|
+
# 2. Update existing tools and add new tools
|
|
353
|
+
for tool_name, current_tool in current_tool_map.items():
|
|
354
|
+
if tool_name in persisted_tool_map:
|
|
355
|
+
# Update existing tool
|
|
356
|
+
persisted_tool = persisted_tool_map[tool_name]
|
|
357
|
+
tool_create = ToolCreate.from_mcp(mcp_server_name=mcp_server_name, mcp_tool=current_tool)
|
|
358
|
+
|
|
359
|
+
# Check if schema has changed
|
|
360
|
+
if persisted_tool.json_schema != tool_create.json_schema:
|
|
361
|
+
# Update the tool
|
|
362
|
+
update_data = ToolUpdate(
|
|
363
|
+
description=tool_create.description,
|
|
364
|
+
json_schema=tool_create.json_schema,
|
|
365
|
+
source_code=tool_create.source_code,
|
|
366
|
+
)
|
|
367
|
+
|
|
368
|
+
await self.tool_manager.update_tool_by_id_async(tool_id=persisted_tool.id, tool_update=update_data, actor=actor)
|
|
369
|
+
updated_tools.append(tool_name)
|
|
370
|
+
logger.info(f"Updated MCP tool {tool_name} with new schema from server {mcp_server_name}")
|
|
371
|
+
else:
|
|
372
|
+
# Add new tool
|
|
373
|
+
# Skip INVALID tools
|
|
374
|
+
if current_tool.health and current_tool.health.status == "INVALID":
|
|
375
|
+
logger.warning(
|
|
376
|
+
f"Skipping invalid tool {tool_name} from MCP server {mcp_server_name}: {', '.join(current_tool.health.reasons)}"
|
|
377
|
+
)
|
|
378
|
+
continue
|
|
379
|
+
|
|
380
|
+
tool_create = ToolCreate.from_mcp(mcp_server_name=mcp_server_name, mcp_tool=current_tool)
|
|
381
|
+
created_tool = await self.tool_manager.create_mcp_tool_async(
|
|
382
|
+
tool_create=tool_create, mcp_server_name=mcp_server_name, mcp_server_id=mcp_server_id, actor=actor
|
|
383
|
+
)
|
|
384
|
+
|
|
385
|
+
# Create mapping in MCPTools table
|
|
386
|
+
if created_tool:
|
|
387
|
+
await self.create_mcp_tool_mapping(mcp_server_id, created_tool.id, actor)
|
|
388
|
+
added_tools.append(tool_name)
|
|
389
|
+
logger.info(f"Added new MCP tool {tool_name} from server {mcp_server_name} with mapping")
|
|
390
|
+
|
|
391
|
+
return MCPServerResyncResult(
|
|
392
|
+
deleted=deleted_tools,
|
|
393
|
+
updated=updated_tools,
|
|
394
|
+
added=added_tools,
|
|
395
|
+
)
|
|
396
|
+
|
|
397
|
+
@enforce_types
|
|
398
|
+
async def list_mcp_servers(self, actor: PydanticUser) -> List[MCPServer]:
|
|
399
|
+
"""List all MCP servers available"""
|
|
400
|
+
async with db_registry.async_session() as session:
|
|
401
|
+
mcp_servers = await MCPServerModel.list_async(
|
|
402
|
+
db_session=session,
|
|
403
|
+
organization_id=actor.organization_id,
|
|
404
|
+
)
|
|
405
|
+
|
|
406
|
+
return [mcp_server.to_pydantic() for mcp_server in mcp_servers]
|
|
407
|
+
|
|
408
|
+
@enforce_types
|
|
409
|
+
async def create_or_update_mcp_server(self, pydantic_mcp_server: MCPServer, actor: PydanticUser) -> MCPServer:
|
|
410
|
+
"""Create a new tool based on the ToolCreate schema."""
|
|
411
|
+
mcp_server_id = await self.get_mcp_server_id_by_name(mcp_server_name=pydantic_mcp_server.server_name, actor=actor)
|
|
412
|
+
if mcp_server_id:
|
|
413
|
+
# Put to dict and remove fields that should not be reset
|
|
414
|
+
update_data = pydantic_mcp_server.model_dump(exclude_unset=True, exclude_none=True)
|
|
415
|
+
|
|
416
|
+
# If there's anything to update (can only update the configs, not the name)
|
|
417
|
+
# TODO: pass in custom headers for update as well?
|
|
418
|
+
if update_data:
|
|
419
|
+
if pydantic_mcp_server.server_type == MCPServerType.SSE:
|
|
420
|
+
update_request = UpdateSSEMCPServer(server_url=pydantic_mcp_server.server_url, token=pydantic_mcp_server.token)
|
|
421
|
+
elif pydantic_mcp_server.server_type == MCPServerType.STDIO:
|
|
422
|
+
update_request = UpdateStdioMCPServer(stdio_config=pydantic_mcp_server.stdio_config)
|
|
423
|
+
elif pydantic_mcp_server.server_type == MCPServerType.STREAMABLE_HTTP:
|
|
424
|
+
update_request = UpdateStreamableHTTPMCPServer(
|
|
425
|
+
server_url=pydantic_mcp_server.server_url, auth_token=pydantic_mcp_server.token
|
|
426
|
+
)
|
|
427
|
+
else:
|
|
428
|
+
raise ValueError(f"Unsupported server type: {pydantic_mcp_server.server_type}")
|
|
429
|
+
mcp_server = await self.update_mcp_server_by_id(mcp_server_id, update_request, actor)
|
|
430
|
+
else:
|
|
431
|
+
printd(
|
|
432
|
+
f"`create_or_update_mcp_server` was called with user_id={actor.id}, organization_id={actor.organization_id}, name={pydantic_mcp_server.server_name}, but found existing mcp server with nothing to update."
|
|
433
|
+
)
|
|
434
|
+
mcp_server = await self.get_mcp_server_by_id_async(mcp_server_id, actor=actor)
|
|
435
|
+
else:
|
|
436
|
+
mcp_server = await self.create_mcp_server(pydantic_mcp_server, actor=actor)
|
|
437
|
+
|
|
438
|
+
return mcp_server
|
|
439
|
+
|
|
440
|
+
@enforce_types
|
|
441
|
+
async def create_mcp_server(self, pydantic_mcp_server: MCPServer, actor: PydanticUser) -> MCPServer:
|
|
442
|
+
"""Create a new MCP server."""
|
|
443
|
+
async with db_registry.async_session() as session:
|
|
444
|
+
try:
|
|
445
|
+
# Set the organization id at the ORM layer
|
|
446
|
+
pydantic_mcp_server.organization_id = actor.organization_id
|
|
447
|
+
|
|
448
|
+
# Explicitly populate encrypted fields
|
|
449
|
+
if pydantic_mcp_server.token is not None:
|
|
450
|
+
pydantic_mcp_server.token_enc = Secret.from_plaintext(pydantic_mcp_server.token)
|
|
451
|
+
if pydantic_mcp_server.custom_headers is not None:
|
|
452
|
+
# custom_headers is a Dict[str, str], serialize to JSON then encrypt
|
|
453
|
+
import json
|
|
454
|
+
|
|
455
|
+
json_str = json.dumps(pydantic_mcp_server.custom_headers)
|
|
456
|
+
pydantic_mcp_server.custom_headers_enc = Secret.from_plaintext(json_str)
|
|
457
|
+
|
|
458
|
+
mcp_server_data = pydantic_mcp_server.model_dump(to_orm=True)
|
|
459
|
+
|
|
460
|
+
# Ensure custom_headers None is stored as SQL NULL, not JSON null
|
|
461
|
+
if mcp_server_data.get("custom_headers") is None:
|
|
462
|
+
mcp_server_data.pop("custom_headers", None)
|
|
463
|
+
|
|
464
|
+
mcp_server = MCPServerModel(**mcp_server_data)
|
|
465
|
+
mcp_server = await mcp_server.create_async(session, actor=actor, no_commit=True)
|
|
466
|
+
|
|
467
|
+
# Link existing OAuth sessions for the same user and server URL
|
|
468
|
+
# This ensures OAuth sessions created during testing get linked to the server
|
|
469
|
+
server_url = getattr(mcp_server, "server_url", None)
|
|
470
|
+
if server_url:
|
|
471
|
+
result = await session.execute(
|
|
472
|
+
select(MCPOAuth).where(
|
|
473
|
+
MCPOAuth.server_url == server_url,
|
|
474
|
+
MCPOAuth.organization_id == actor.organization_id,
|
|
475
|
+
MCPOAuth.user_id == actor.id, # Only link sessions for the same user
|
|
476
|
+
MCPOAuth.server_id.is_(None), # Only update sessions not already linked
|
|
477
|
+
)
|
|
478
|
+
)
|
|
479
|
+
oauth_sessions = result.scalars().all()
|
|
480
|
+
|
|
481
|
+
# TODO: @jnjpng we should upate sessions in bulk
|
|
482
|
+
for oauth_session in oauth_sessions:
|
|
483
|
+
oauth_session.server_id = mcp_server.id
|
|
484
|
+
await oauth_session.update_async(db_session=session, actor=actor, no_commit=True)
|
|
485
|
+
|
|
486
|
+
if oauth_sessions:
|
|
487
|
+
logger.info(
|
|
488
|
+
f"Linked {len(oauth_sessions)} OAuth sessions to MCP server {mcp_server.id} (URL: {server_url}) for user {actor.id}"
|
|
489
|
+
)
|
|
490
|
+
|
|
491
|
+
await session.commit()
|
|
492
|
+
return mcp_server.to_pydantic()
|
|
493
|
+
except Exception as e:
|
|
494
|
+
await session.rollback()
|
|
495
|
+
raise
|
|
496
|
+
|
|
497
|
+
@enforce_types
|
|
498
|
+
async def create_mcp_server_from_config(
|
|
499
|
+
self, server_config: Union[StdioServerConfig, SSEServerConfig, StreamableHTTPServerConfig], actor: PydanticUser
|
|
500
|
+
) -> MCPServer:
|
|
501
|
+
"""
|
|
502
|
+
Create an MCP server from a config object, handling encryption of sensitive fields.
|
|
503
|
+
|
|
504
|
+
This method converts the server config to an MCPServer model and encrypts
|
|
505
|
+
sensitive fields like tokens and custom headers.
|
|
506
|
+
"""
|
|
507
|
+
# Create base MCPServer object
|
|
508
|
+
if isinstance(server_config, StdioServerConfig):
|
|
509
|
+
mcp_server = MCPServer(server_name=server_config.server_name, server_type=server_config.type, stdio_config=server_config)
|
|
510
|
+
elif isinstance(server_config, SSEServerConfig):
|
|
511
|
+
mcp_server = MCPServer(
|
|
512
|
+
server_name=server_config.server_name,
|
|
513
|
+
server_type=server_config.type,
|
|
514
|
+
server_url=server_config.server_url,
|
|
515
|
+
)
|
|
516
|
+
# Encrypt sensitive fields
|
|
517
|
+
token = server_config.resolve_token()
|
|
518
|
+
if token:
|
|
519
|
+
token_secret = Secret.from_plaintext(token)
|
|
520
|
+
mcp_server.set_token_secret(token_secret)
|
|
521
|
+
if server_config.custom_headers:
|
|
522
|
+
# Convert dict to JSON string, then encrypt as Secret
|
|
523
|
+
headers_json = json.dumps(server_config.custom_headers)
|
|
524
|
+
headers_secret = Secret.from_plaintext(headers_json)
|
|
525
|
+
mcp_server.set_custom_headers_secret(headers_secret)
|
|
526
|
+
|
|
527
|
+
elif isinstance(server_config, StreamableHTTPServerConfig):
|
|
528
|
+
mcp_server = MCPServer(
|
|
529
|
+
server_name=server_config.server_name,
|
|
530
|
+
server_type=server_config.type,
|
|
531
|
+
server_url=server_config.server_url,
|
|
532
|
+
)
|
|
533
|
+
# Encrypt sensitive fields
|
|
534
|
+
token = server_config.resolve_token()
|
|
535
|
+
if token:
|
|
536
|
+
token_secret = Secret.from_plaintext(token)
|
|
537
|
+
mcp_server.set_token_secret(token_secret)
|
|
538
|
+
if server_config.custom_headers:
|
|
539
|
+
# Convert dict to JSON string, then encrypt as Secret
|
|
540
|
+
headers_json = json.dumps(server_config.custom_headers)
|
|
541
|
+
headers_secret = Secret.from_plaintext(headers_json)
|
|
542
|
+
mcp_server.set_custom_headers_secret(headers_secret)
|
|
543
|
+
else:
|
|
544
|
+
raise ValueError(f"Unsupported server config type: {type(server_config)}")
|
|
545
|
+
|
|
546
|
+
return mcp_server
|
|
547
|
+
|
|
548
|
+
@enforce_types
|
|
549
|
+
async def create_mcp_server_from_config_with_tools(
|
|
550
|
+
self, server_config: Union[StdioServerConfig, SSEServerConfig, StreamableHTTPServerConfig], actor: PydanticUser
|
|
551
|
+
) -> MCPServer:
|
|
552
|
+
"""
|
|
553
|
+
Create an MCP server from a config object and optimistically sync its tools.
|
|
554
|
+
|
|
555
|
+
This method handles encryption of sensitive fields and then creates the server
|
|
556
|
+
with automatic tool synchronization.
|
|
557
|
+
"""
|
|
558
|
+
# Convert config to MCPServer with encryption
|
|
559
|
+
mcp_server = await self.create_mcp_server_from_config(server_config, actor)
|
|
560
|
+
|
|
561
|
+
# Create the server with tools
|
|
562
|
+
return await self.create_mcp_server_with_tools(mcp_server, actor)
|
|
563
|
+
|
|
564
|
+
@enforce_types
|
|
565
|
+
async def create_mcp_server_with_tools(self, pydantic_mcp_server: MCPServer, actor: PydanticUser) -> MCPServer:
|
|
566
|
+
"""
|
|
567
|
+
Create a new MCP server and optimistically sync its tools.
|
|
568
|
+
|
|
569
|
+
This method:
|
|
570
|
+
1. Creates the MCP server record
|
|
571
|
+
2. Attempts to connect and fetch tools
|
|
572
|
+
3. Persists valid tools in parallel (best-effort)
|
|
573
|
+
"""
|
|
574
|
+
import asyncio
|
|
575
|
+
|
|
576
|
+
# First, create the MCP server
|
|
577
|
+
created_server = await self.create_mcp_server(pydantic_mcp_server, actor)
|
|
578
|
+
|
|
579
|
+
# Optimistically try to sync tools
|
|
580
|
+
try:
|
|
581
|
+
logger.info(f"Attempting to auto-sync tools from MCP server: {created_server.server_name}")
|
|
582
|
+
|
|
583
|
+
# List all tools from the MCP server
|
|
584
|
+
mcp_tools = await self.list_mcp_server_tools(created_server.id, actor=actor)
|
|
585
|
+
|
|
586
|
+
# Filter out invalid tools
|
|
587
|
+
valid_tools = [tool for tool in mcp_tools if not (tool.health and tool.health.status == "INVALID")]
|
|
588
|
+
|
|
589
|
+
# Register in parallel
|
|
590
|
+
if valid_tools:
|
|
591
|
+
tool_tasks = []
|
|
592
|
+
for mcp_tool in valid_tools:
|
|
593
|
+
tool_create = ToolCreate.from_mcp(mcp_server_name=created_server.server_name, mcp_tool=mcp_tool)
|
|
594
|
+
task = self.tool_manager.create_mcp_tool_async(
|
|
595
|
+
tool_create=tool_create, mcp_server_name=created_server.server_name, mcp_server_id=created_server.id, actor=actor
|
|
596
|
+
)
|
|
597
|
+
tool_tasks.append(task)
|
|
598
|
+
|
|
599
|
+
results = await asyncio.gather(*tool_tasks, return_exceptions=True)
|
|
600
|
+
|
|
601
|
+
# Create mappings in MCPTools table for successful tools
|
|
602
|
+
mapping_tasks = []
|
|
603
|
+
successful_count = 0
|
|
604
|
+
for result in results:
|
|
605
|
+
if not isinstance(result, Exception) and result:
|
|
606
|
+
# result should be a PydanticTool
|
|
607
|
+
mapping_task = self.create_mcp_tool_mapping(created_server.id, result.id, actor)
|
|
608
|
+
mapping_tasks.append(mapping_task)
|
|
609
|
+
successful_count += 1
|
|
610
|
+
|
|
611
|
+
# Execute mapping creation in parallel
|
|
612
|
+
if mapping_tasks:
|
|
613
|
+
await asyncio.gather(*mapping_tasks, return_exceptions=True)
|
|
614
|
+
|
|
615
|
+
failed = len(results) - successful_count
|
|
616
|
+
logger.info(
|
|
617
|
+
f"Auto-sync completed for MCP server {created_server.server_name}: "
|
|
618
|
+
f"{successful_count} tools persisted with mappings, {failed} failed, "
|
|
619
|
+
f"{len(mcp_tools) - len(valid_tools)} invalid tools skipped"
|
|
620
|
+
)
|
|
621
|
+
else:
|
|
622
|
+
logger.info(f"No valid tools found to sync from MCP server {created_server.server_name}")
|
|
623
|
+
|
|
624
|
+
except Exception as e:
|
|
625
|
+
# Log the error but don't fail the server creation
|
|
626
|
+
logger.warning(
|
|
627
|
+
f"Failed to auto-sync tools from MCP server {created_server.server_name}: {e}. "
|
|
628
|
+
f"Server was created successfully but tools were not persisted."
|
|
629
|
+
)
|
|
630
|
+
|
|
631
|
+
return created_server
|
|
632
|
+
|
|
633
|
+
@enforce_types
|
|
634
|
+
async def update_mcp_server_by_id(self, mcp_server_id: str, mcp_server_update: UpdateMCPServer, actor: PydanticUser) -> MCPServer:
|
|
635
|
+
"""Update a tool by its ID with the given ToolUpdate object."""
|
|
636
|
+
async with db_registry.async_session() as session:
|
|
637
|
+
# Fetch the tool by ID
|
|
638
|
+
mcp_server = await MCPServerModel.read_async(db_session=session, identifier=mcp_server_id, actor=actor)
|
|
639
|
+
|
|
640
|
+
# Update tool attributes with only the fields that were explicitly set
|
|
641
|
+
update_data = mcp_server_update.model_dump(to_orm=True, exclude_unset=True)
|
|
642
|
+
|
|
643
|
+
# If renaming, proactively resolve name collisions within the same organization
|
|
644
|
+
new_name = update_data.get("server_name")
|
|
645
|
+
if new_name and new_name != getattr(mcp_server, "server_name", None):
|
|
646
|
+
# Look for another server with the same name in this org
|
|
647
|
+
existing = await MCPServerModel.list_async(
|
|
648
|
+
db_session=session,
|
|
649
|
+
organization_id=actor.organization_id,
|
|
650
|
+
server_name=new_name,
|
|
651
|
+
)
|
|
652
|
+
# Delete conflicting entries that are not the current server
|
|
653
|
+
for other in existing:
|
|
654
|
+
if other.id != mcp_server.id:
|
|
655
|
+
await session.execute(
|
|
656
|
+
delete(MCPServerModel).where(
|
|
657
|
+
MCPServerModel.id == other.id,
|
|
658
|
+
MCPServerModel.organization_id == actor.organization_id,
|
|
659
|
+
)
|
|
660
|
+
)
|
|
661
|
+
|
|
662
|
+
# Handle encryption for token if provided
|
|
663
|
+
# Only re-encrypt if the value has actually changed
|
|
664
|
+
if "token" in update_data and update_data["token"] is not None:
|
|
665
|
+
# Check if value changed
|
|
666
|
+
existing_token = None
|
|
667
|
+
if mcp_server.token_enc:
|
|
668
|
+
existing_secret = Secret.from_encrypted(mcp_server.token_enc)
|
|
669
|
+
existing_token = existing_secret.get_plaintext()
|
|
670
|
+
elif mcp_server.token:
|
|
671
|
+
existing_token = mcp_server.token
|
|
672
|
+
|
|
673
|
+
# Only re-encrypt if different
|
|
674
|
+
if existing_token != update_data["token"]:
|
|
675
|
+
mcp_server.token_enc = Secret.from_plaintext(update_data["token"]).get_encrypted()
|
|
676
|
+
# Keep plaintext for dual-write during migration
|
|
677
|
+
mcp_server.token = update_data["token"]
|
|
678
|
+
|
|
679
|
+
# Remove from update_data since we set directly on mcp_server
|
|
680
|
+
update_data.pop("token", None)
|
|
681
|
+
update_data.pop("token_enc", None)
|
|
682
|
+
|
|
683
|
+
# Handle encryption for custom_headers if provided
|
|
684
|
+
# Only re-encrypt if the value has actually changed
|
|
685
|
+
if "custom_headers" in update_data:
|
|
686
|
+
if update_data["custom_headers"] is not None:
|
|
687
|
+
# custom_headers is a Dict[str, str], serialize to JSON then encrypt
|
|
688
|
+
import json
|
|
689
|
+
|
|
690
|
+
json_str = json.dumps(update_data["custom_headers"])
|
|
691
|
+
|
|
692
|
+
# Check if value changed
|
|
693
|
+
existing_headers_json = None
|
|
694
|
+
if mcp_server.custom_headers_enc:
|
|
695
|
+
existing_secret = Secret.from_encrypted(mcp_server.custom_headers_enc)
|
|
696
|
+
existing_headers_json = existing_secret.get_plaintext()
|
|
697
|
+
elif mcp_server.custom_headers:
|
|
698
|
+
existing_headers_json = json.dumps(mcp_server.custom_headers)
|
|
699
|
+
|
|
700
|
+
# Only re-encrypt if different
|
|
701
|
+
if existing_headers_json != json_str:
|
|
702
|
+
mcp_server.custom_headers_enc = Secret.from_plaintext(json_str).get_encrypted()
|
|
703
|
+
# Keep plaintext for dual-write during migration
|
|
704
|
+
mcp_server.custom_headers = update_data["custom_headers"]
|
|
705
|
+
|
|
706
|
+
# Remove from update_data since we set directly on mcp_server
|
|
707
|
+
update_data.pop("custom_headers", None)
|
|
708
|
+
update_data.pop("custom_headers_enc", None)
|
|
709
|
+
else:
|
|
710
|
+
# Ensure custom_headers None is stored as SQL NULL, not JSON null
|
|
711
|
+
update_data.pop("custom_headers", None)
|
|
712
|
+
setattr(mcp_server, "custom_headers", null())
|
|
713
|
+
setattr(mcp_server, "custom_headers_enc", None)
|
|
714
|
+
|
|
715
|
+
for key, value in update_data.items():
|
|
716
|
+
setattr(mcp_server, key, value)
|
|
717
|
+
|
|
718
|
+
mcp_server = await mcp_server.update_async(db_session=session, actor=actor)
|
|
719
|
+
|
|
720
|
+
# Save the updated tool to the database mcp_server = await mcp_server.update_async(db_session=session, actor=actor)
|
|
721
|
+
return mcp_server.to_pydantic()
|
|
722
|
+
|
|
723
|
+
@enforce_types
|
|
724
|
+
async def update_mcp_server_by_name(self, mcp_server_name: str, mcp_server_update: UpdateMCPServer, actor: PydanticUser) -> MCPServer:
|
|
725
|
+
"""Update an MCP server by its name."""
|
|
726
|
+
mcp_server_id = await self.get_mcp_server_id_by_name(mcp_server_name, actor)
|
|
727
|
+
if not mcp_server_id:
|
|
728
|
+
raise HTTPException(
|
|
729
|
+
status_code=404,
|
|
730
|
+
detail={
|
|
731
|
+
"code": "MCPServerNotFoundError",
|
|
732
|
+
"message": f"MCP server {mcp_server_name} not found",
|
|
733
|
+
"mcp_server_name": mcp_server_name,
|
|
734
|
+
},
|
|
735
|
+
)
|
|
736
|
+
return await self.update_mcp_server_by_id(mcp_server_id, mcp_server_update, actor)
|
|
737
|
+
|
|
738
|
+
@enforce_types
|
|
739
|
+
async def get_mcp_server_id_by_name(self, mcp_server_name: str, actor: PydanticUser) -> Optional[str]:
|
|
740
|
+
"""Retrieve a MCP server by its name and a user"""
|
|
741
|
+
try:
|
|
742
|
+
async with db_registry.async_session() as session:
|
|
743
|
+
mcp_server = await MCPServerModel.read_async(db_session=session, server_name=mcp_server_name, actor=actor)
|
|
744
|
+
return mcp_server.id
|
|
745
|
+
except NoResultFound:
|
|
746
|
+
return None
|
|
747
|
+
|
|
748
|
+
@enforce_types
|
|
749
|
+
async def get_mcp_server_by_id_async(self, mcp_server_id: str, actor: PydanticUser) -> MCPServer:
|
|
750
|
+
"""Fetch a tool by its ID."""
|
|
751
|
+
async with db_registry.async_session() as session:
|
|
752
|
+
# Retrieve tool by id using the Tool model's read method
|
|
753
|
+
mcp_server = await MCPServerModel.read_async(db_session=session, identifier=mcp_server_id, actor=actor)
|
|
754
|
+
# Convert the SQLAlchemy Tool object to PydanticTool
|
|
755
|
+
return mcp_server.to_pydantic()
|
|
756
|
+
|
|
757
|
+
@enforce_types
|
|
758
|
+
async def get_mcp_servers_by_ids(self, mcp_server_ids: List[str], actor: PydanticUser) -> List[MCPServer]:
|
|
759
|
+
"""Fetch multiple MCP servers by their IDs in a single query."""
|
|
760
|
+
if not mcp_server_ids:
|
|
761
|
+
return []
|
|
762
|
+
|
|
763
|
+
async with db_registry.async_session() as session:
|
|
764
|
+
mcp_servers = await MCPServerModel.list_async(
|
|
765
|
+
db_session=session,
|
|
766
|
+
organization_id=actor.organization_id,
|
|
767
|
+
id=mcp_server_ids, # This will use the IN operator
|
|
768
|
+
)
|
|
769
|
+
return [mcp_server.to_pydantic() for mcp_server in mcp_servers]
|
|
770
|
+
|
|
771
|
+
@enforce_types
|
|
772
|
+
async def get_mcp_server(self, mcp_server_name: str, actor: PydanticUser) -> PydanticTool:
|
|
773
|
+
"""Get a MCP server by name."""
|
|
774
|
+
async with db_registry.async_session() as session:
|
|
775
|
+
mcp_server_id = await self.get_mcp_server_id_by_name(mcp_server_name, actor)
|
|
776
|
+
mcp_server = await MCPServerModel.read_async(db_session=session, identifier=mcp_server_id, actor=actor)
|
|
777
|
+
if not mcp_server:
|
|
778
|
+
raise HTTPException(
|
|
779
|
+
status_code=404, # Not Found
|
|
780
|
+
detail={
|
|
781
|
+
"code": "MCPServerNotFoundError",
|
|
782
|
+
"message": f"MCP server {mcp_server_name} not found",
|
|
783
|
+
"mcp_server_name": mcp_server_name,
|
|
784
|
+
},
|
|
785
|
+
)
|
|
786
|
+
return mcp_server.to_pydantic()
|
|
787
|
+
|
|
788
|
+
@enforce_types
|
|
789
|
+
async def delete_mcp_server_by_id(self, mcp_server_id: str, actor: PydanticUser) -> None:
|
|
790
|
+
"""Delete a MCP server by its ID and associated tools and OAuth sessions."""
|
|
791
|
+
async with db_registry.async_session() as session:
|
|
792
|
+
try:
|
|
793
|
+
mcp_server = await MCPServerModel.read_async(db_session=session, identifier=mcp_server_id, actor=actor)
|
|
794
|
+
if not mcp_server:
|
|
795
|
+
raise NoResultFound(f"MCP server with id {mcp_server_id} not found.")
|
|
796
|
+
|
|
797
|
+
server_url = getattr(mcp_server, "server_url", None)
|
|
798
|
+
# Get all tools with matching metadata
|
|
799
|
+
stmt = select(ToolModel).where(ToolModel.organization_id == actor.organization_id)
|
|
800
|
+
result = await session.execute(stmt)
|
|
801
|
+
all_tools = result.scalars().all()
|
|
802
|
+
|
|
803
|
+
# Filter and delete tools that belong to this MCP server
|
|
804
|
+
tools_deleted = 0
|
|
805
|
+
for tool in all_tools:
|
|
806
|
+
if tool.metadata_ and constants.MCP_TOOL_TAG_NAME_PREFIX in tool.metadata_:
|
|
807
|
+
if tool.metadata_[constants.MCP_TOOL_TAG_NAME_PREFIX].get("server_id") == mcp_server_id:
|
|
808
|
+
await tool.hard_delete_async(db_session=session, actor=actor)
|
|
809
|
+
tools_deleted = 1
|
|
810
|
+
logger.info(f"Deleted MCP tool {tool.name} associated with MCP server {mcp_server_id}")
|
|
811
|
+
|
|
812
|
+
if tools_deleted > 0:
|
|
813
|
+
logger.info(f"Deleted {tools_deleted} MCP tools associated with MCP server {mcp_server_id}")
|
|
814
|
+
|
|
815
|
+
# Delete all MCPTools mappings for this server
|
|
816
|
+
await session.execute(
|
|
817
|
+
delete(MCPToolsModel).where(
|
|
818
|
+
MCPToolsModel.mcp_server_id == mcp_server_id,
|
|
819
|
+
MCPToolsModel.organization_id == actor.organization_id,
|
|
820
|
+
)
|
|
821
|
+
)
|
|
822
|
+
logger.info(f"Deleted MCPTools mappings for MCP server {mcp_server_id}")
|
|
823
|
+
|
|
824
|
+
# Delete OAuth sessions for the same user and server URL in the same transaction
|
|
825
|
+
# This handles orphaned sessions that were created during testing/connection
|
|
826
|
+
oauth_count = 0
|
|
827
|
+
if server_url:
|
|
828
|
+
result = await session.execute(
|
|
829
|
+
delete(MCPOAuth).where(
|
|
830
|
+
MCPOAuth.server_url == server_url,
|
|
831
|
+
MCPOAuth.organization_id == actor.organization_id,
|
|
832
|
+
MCPOAuth.user_id == actor.id, # Only delete sessions for the same user
|
|
833
|
+
)
|
|
834
|
+
)
|
|
835
|
+
oauth_count = result.rowcount
|
|
836
|
+
if oauth_count > 0:
|
|
837
|
+
logger.info(
|
|
838
|
+
f"Deleting {oauth_count} OAuth sessions for MCP server {mcp_server_id} (URL: {server_url}) for user {actor.id}"
|
|
839
|
+
)
|
|
840
|
+
|
|
841
|
+
# Delete the MCP server, will cascade delete to linked OAuth sessions
|
|
842
|
+
await session.execute(
|
|
843
|
+
delete(MCPServerModel).where(
|
|
844
|
+
MCPServerModel.id == mcp_server_id,
|
|
845
|
+
MCPServerModel.organization_id == actor.organization_id,
|
|
846
|
+
)
|
|
847
|
+
)
|
|
848
|
+
|
|
849
|
+
await session.commit()
|
|
850
|
+
except NoResultFound:
|
|
851
|
+
await session.rollback()
|
|
852
|
+
raise ValueError(f"MCP server with id {mcp_server_id} not found.")
|
|
853
|
+
except Exception as e:
|
|
854
|
+
await session.rollback()
|
|
855
|
+
logger.error(f"Failed to delete MCP server {mcp_server_id}: {e}")
|
|
856
|
+
raise
|
|
857
|
+
|
|
858
|
+
def read_mcp_config(self) -> dict[str, Union[SSEServerConfig, StdioServerConfig, StreamableHTTPServerConfig]]:
|
|
859
|
+
mcp_server_list = {}
|
|
860
|
+
|
|
861
|
+
# Attempt to read from ~/.letta/mcp_config.json
|
|
862
|
+
mcp_config_path = os.path.join(constants.LETTA_DIR, constants.MCP_CONFIG_NAME)
|
|
863
|
+
if os.path.exists(mcp_config_path):
|
|
864
|
+
with open(mcp_config_path, "r") as f:
|
|
865
|
+
try:
|
|
866
|
+
mcp_config = json.load(f)
|
|
867
|
+
except Exception as e:
|
|
868
|
+
# Config parsing errors are user configuration issues, not system errors
|
|
869
|
+
logger.warning(f"Failed to parse MCP config file ({mcp_config_path}) as json: {e}")
|
|
870
|
+
return mcp_server_list
|
|
871
|
+
|
|
872
|
+
# Proper formatting is "mcpServers" key at the top level,
|
|
873
|
+
# then a dict with the MCP server name as the key,
|
|
874
|
+
# with the value being the schema from StdioServerParameters
|
|
875
|
+
if MCP_CONFIG_TOPLEVEL_KEY in mcp_config:
|
|
876
|
+
for server_name, server_params_raw in mcp_config[MCP_CONFIG_TOPLEVEL_KEY].items():
|
|
877
|
+
# No support for duplicate server names
|
|
878
|
+
if server_name in mcp_server_list:
|
|
879
|
+
# Duplicate server names are configuration issues, not system errors
|
|
880
|
+
logger.warning(f"Duplicate MCP server name found (skipping): {server_name}")
|
|
881
|
+
continue
|
|
882
|
+
|
|
883
|
+
if "url" in server_params_raw:
|
|
884
|
+
# Attempt to parse the server params as an SSE server
|
|
885
|
+
try:
|
|
886
|
+
server_params = SSEServerConfig(
|
|
887
|
+
server_name=server_name,
|
|
888
|
+
server_url=server_params_raw["url"],
|
|
889
|
+
auth_header=server_params_raw.get("auth_header", None),
|
|
890
|
+
auth_token=server_params_raw.get("auth_token", None),
|
|
891
|
+
headers=server_params_raw.get("headers", None),
|
|
892
|
+
)
|
|
893
|
+
mcp_server_list[server_name] = server_params
|
|
894
|
+
except Exception as e:
|
|
895
|
+
# Config parsing errors are user configuration issues, not system errors
|
|
896
|
+
logger.warning(f"Failed to parse server params for MCP server {server_name} (skipping): {e}")
|
|
897
|
+
continue
|
|
898
|
+
else:
|
|
899
|
+
# Attempt to parse the server params as a StdioServerParameters
|
|
900
|
+
try:
|
|
901
|
+
server_params = StdioServerConfig(
|
|
902
|
+
server_name=server_name,
|
|
903
|
+
command=server_params_raw["command"],
|
|
904
|
+
args=server_params_raw.get("args", []),
|
|
905
|
+
env=server_params_raw.get("env", {}),
|
|
906
|
+
)
|
|
907
|
+
mcp_server_list[server_name] = server_params
|
|
908
|
+
except Exception as e:
|
|
909
|
+
# Config parsing errors are user configuration issues, not system errors
|
|
910
|
+
logger.warning(f"Failed to parse server params for MCP server {server_name} (skipping): {e}")
|
|
911
|
+
continue
|
|
912
|
+
return mcp_server_list
|
|
913
|
+
|
|
914
|
+
async def get_mcp_client(
|
|
915
|
+
self,
|
|
916
|
+
server_config: Union[SSEServerConfig, StdioServerConfig, StreamableHTTPServerConfig],
|
|
917
|
+
actor: PydanticUser,
|
|
918
|
+
oauth_provider: Optional[Any] = None,
|
|
919
|
+
agent_id: Optional[str] = None,
|
|
920
|
+
) -> Union[AsyncSSEMCPClient, AsyncStdioMCPClient, AsyncStreamableHTTPMCPClient]:
|
|
921
|
+
"""
|
|
922
|
+
Helper function to create the appropriate MCP client based on server configuration.
|
|
923
|
+
|
|
924
|
+
Args:
|
|
925
|
+
server_config: The server configuration object
|
|
926
|
+
actor: The user making the request
|
|
927
|
+
oauth_provider: Optional OAuth provider for authentication
|
|
928
|
+
|
|
929
|
+
Returns:
|
|
930
|
+
The appropriate MCP client instance
|
|
931
|
+
|
|
932
|
+
Raises:
|
|
933
|
+
ValueError: If server config type is not supported
|
|
934
|
+
"""
|
|
935
|
+
# If no OAuth provider is provided, check if we have stored OAuth credentials
|
|
936
|
+
if oauth_provider is None and hasattr(server_config, "server_url"):
|
|
937
|
+
oauth_session = await self.get_oauth_session_by_server(server_config.server_url, actor)
|
|
938
|
+
# Check if access token exists by attempting to decrypt it
|
|
939
|
+
if oauth_session and oauth_session.get_access_token_secret().get_plaintext():
|
|
940
|
+
# Create OAuth provider from stored credentials
|
|
941
|
+
from letta.services.mcp.oauth_utils import create_oauth_provider
|
|
942
|
+
|
|
943
|
+
oauth_provider = await create_oauth_provider(
|
|
944
|
+
session_id=oauth_session.id,
|
|
945
|
+
server_url=oauth_session.server_url,
|
|
946
|
+
redirect_uri=oauth_session.redirect_uri,
|
|
947
|
+
mcp_manager=self,
|
|
948
|
+
actor=actor,
|
|
949
|
+
)
|
|
950
|
+
|
|
951
|
+
if server_config.type == MCPServerType.SSE:
|
|
952
|
+
server_config = SSEServerConfig(**server_config.model_dump())
|
|
953
|
+
return AsyncSSEMCPClient(server_config=server_config, oauth_provider=oauth_provider, agent_id=agent_id)
|
|
954
|
+
elif server_config.type == MCPServerType.STDIO:
|
|
955
|
+
server_config = StdioServerConfig(**server_config.model_dump())
|
|
956
|
+
return AsyncStdioMCPClient(server_config=server_config, oauth_provider=oauth_provider, agent_id=agent_id)
|
|
957
|
+
elif server_config.type == MCPServerType.STREAMABLE_HTTP:
|
|
958
|
+
server_config = StreamableHTTPServerConfig(**server_config.model_dump())
|
|
959
|
+
return AsyncStreamableHTTPMCPClient(server_config=server_config, oauth_provider=oauth_provider, agent_id=agent_id)
|
|
960
|
+
else:
|
|
961
|
+
raise ValueError(f"Unsupported server config type: {type(server_config)}")
|
|
962
|
+
|
|
963
|
+
# OAuth-related methods
|
|
964
|
+
def _oauth_orm_to_pydantic(self, oauth_session: MCPOAuth) -> MCPOAuthSession:
|
|
965
|
+
"""
|
|
966
|
+
Convert OAuth ORM model to Pydantic model, handling decryption of sensitive fields.
|
|
967
|
+
"""
|
|
968
|
+
# Get decrypted values using the dual-read approach
|
|
969
|
+
# Secret.from_db() will automatically use settings.encryption_key if available
|
|
970
|
+
access_token = None
|
|
971
|
+
if oauth_session.access_token_enc or oauth_session.access_token:
|
|
972
|
+
if settings.encryption_key:
|
|
973
|
+
secret = Secret.from_db(oauth_session.access_token_enc, oauth_session.access_token)
|
|
974
|
+
access_token = secret.get_plaintext()
|
|
975
|
+
else:
|
|
976
|
+
# No encryption key, use plaintext if available
|
|
977
|
+
access_token = oauth_session.access_token
|
|
978
|
+
|
|
979
|
+
refresh_token = None
|
|
980
|
+
if oauth_session.refresh_token_enc or oauth_session.refresh_token:
|
|
981
|
+
if settings.encryption_key:
|
|
982
|
+
secret = Secret.from_db(oauth_session.refresh_token_enc, oauth_session.refresh_token)
|
|
983
|
+
refresh_token = secret.get_plaintext()
|
|
984
|
+
else:
|
|
985
|
+
# No encryption key, use plaintext if available
|
|
986
|
+
refresh_token = oauth_session.refresh_token
|
|
987
|
+
|
|
988
|
+
client_secret = None
|
|
989
|
+
if oauth_session.client_secret_enc or oauth_session.client_secret:
|
|
990
|
+
if settings.encryption_key:
|
|
991
|
+
secret = Secret.from_db(oauth_session.client_secret_enc, oauth_session.client_secret)
|
|
992
|
+
client_secret = secret.get_plaintext()
|
|
993
|
+
else:
|
|
994
|
+
# No encryption key, use plaintext if available
|
|
995
|
+
client_secret = oauth_session.client_secret
|
|
996
|
+
|
|
997
|
+
authorization_code = None
|
|
998
|
+
if oauth_session.authorization_code_enc or oauth_session.authorization_code:
|
|
999
|
+
if settings.encryption_key:
|
|
1000
|
+
secret = Secret.from_db(oauth_session.authorization_code_enc, oauth_session.authorization_code)
|
|
1001
|
+
authorization_code = secret.get_plaintext()
|
|
1002
|
+
else:
|
|
1003
|
+
# No encryption key, use plaintext if available
|
|
1004
|
+
authorization_code = oauth_session.authorization_code
|
|
1005
|
+
|
|
1006
|
+
# Create the Pydantic object with encrypted fields as Secret objects
|
|
1007
|
+
pydantic_session = MCPOAuthSession(
|
|
1008
|
+
id=oauth_session.id,
|
|
1009
|
+
state=oauth_session.state,
|
|
1010
|
+
server_id=oauth_session.server_id,
|
|
1011
|
+
server_url=oauth_session.server_url,
|
|
1012
|
+
server_name=oauth_session.server_name,
|
|
1013
|
+
user_id=oauth_session.user_id,
|
|
1014
|
+
organization_id=oauth_session.organization_id,
|
|
1015
|
+
authorization_url=oauth_session.authorization_url,
|
|
1016
|
+
authorization_code=authorization_code,
|
|
1017
|
+
access_token=access_token,
|
|
1018
|
+
refresh_token=refresh_token,
|
|
1019
|
+
token_type=oauth_session.token_type,
|
|
1020
|
+
expires_at=oauth_session.expires_at,
|
|
1021
|
+
scope=oauth_session.scope,
|
|
1022
|
+
client_id=oauth_session.client_id,
|
|
1023
|
+
client_secret=client_secret,
|
|
1024
|
+
redirect_uri=oauth_session.redirect_uri,
|
|
1025
|
+
status=oauth_session.status,
|
|
1026
|
+
created_at=oauth_session.created_at,
|
|
1027
|
+
updated_at=oauth_session.updated_at,
|
|
1028
|
+
# Encrypted fields as Secret objects (converted from encrypted strings in DB)
|
|
1029
|
+
authorization_code_enc=Secret.from_encrypted(oauth_session.authorization_code_enc)
|
|
1030
|
+
if oauth_session.authorization_code_enc
|
|
1031
|
+
else None,
|
|
1032
|
+
access_token_enc=Secret.from_encrypted(oauth_session.access_token_enc) if oauth_session.access_token_enc else None,
|
|
1033
|
+
refresh_token_enc=Secret.from_encrypted(oauth_session.refresh_token_enc) if oauth_session.refresh_token_enc else None,
|
|
1034
|
+
client_secret_enc=Secret.from_encrypted(oauth_session.client_secret_enc) if oauth_session.client_secret_enc else None,
|
|
1035
|
+
)
|
|
1036
|
+
return pydantic_session
|
|
1037
|
+
|
|
1038
|
+
@enforce_types
|
|
1039
|
+
async def create_oauth_session(self, session_create: MCPOAuthSessionCreate, actor: PydanticUser) -> MCPOAuthSession:
|
|
1040
|
+
"""Create a new OAuth session for MCP server authentication."""
|
|
1041
|
+
async with db_registry.async_session() as session:
|
|
1042
|
+
# Create the OAuth session with a unique state
|
|
1043
|
+
oauth_session = MCPOAuth(
|
|
1044
|
+
id="mcp-oauth-" + str(uuid.uuid4())[:8],
|
|
1045
|
+
state=secrets.token_urlsafe(32),
|
|
1046
|
+
server_url=session_create.server_url,
|
|
1047
|
+
server_name=session_create.server_name,
|
|
1048
|
+
user_id=session_create.user_id,
|
|
1049
|
+
organization_id=session_create.organization_id,
|
|
1050
|
+
status=OAuthSessionStatus.PENDING,
|
|
1051
|
+
created_at=datetime.now(),
|
|
1052
|
+
updated_at=datetime.now(),
|
|
1053
|
+
)
|
|
1054
|
+
oauth_session = await oauth_session.create_async(session, actor=actor)
|
|
1055
|
+
|
|
1056
|
+
# Convert to Pydantic model - note: new sessions won't have tokens yet
|
|
1057
|
+
return self._oauth_orm_to_pydantic(oauth_session)
|
|
1058
|
+
|
|
1059
|
+
@enforce_types
|
|
1060
|
+
async def get_oauth_session_by_id(self, session_id: str, actor: PydanticUser) -> Optional[MCPOAuthSession]:
|
|
1061
|
+
"""Get an OAuth session by its ID."""
|
|
1062
|
+
async with db_registry.async_session() as session:
|
|
1063
|
+
try:
|
|
1064
|
+
oauth_session = await MCPOAuth.read_async(db_session=session, identifier=session_id, actor=actor)
|
|
1065
|
+
return self._oauth_orm_to_pydantic(oauth_session)
|
|
1066
|
+
except NoResultFound:
|
|
1067
|
+
return None
|
|
1068
|
+
|
|
1069
|
+
@enforce_types
|
|
1070
|
+
async def get_oauth_session_by_server(self, server_url: str, actor: PydanticUser) -> Optional[MCPOAuthSession]:
|
|
1071
|
+
"""Get the latest OAuth session by server URL, organization, and user."""
|
|
1072
|
+
async with db_registry.async_session() as session:
|
|
1073
|
+
# Query for OAuth session matching organization, user, server URL, and status
|
|
1074
|
+
# Order by updated_at desc to get the most recent record
|
|
1075
|
+
result = await session.execute(
|
|
1076
|
+
select(MCPOAuth)
|
|
1077
|
+
.where(
|
|
1078
|
+
MCPOAuth.organization_id == actor.organization_id,
|
|
1079
|
+
MCPOAuth.user_id == actor.id,
|
|
1080
|
+
MCPOAuth.server_url == server_url,
|
|
1081
|
+
MCPOAuth.status == OAuthSessionStatus.AUTHORIZED,
|
|
1082
|
+
)
|
|
1083
|
+
.order_by(desc(MCPOAuth.updated_at))
|
|
1084
|
+
.limit(1)
|
|
1085
|
+
)
|
|
1086
|
+
oauth_session = result.scalar_one_or_none()
|
|
1087
|
+
|
|
1088
|
+
if not oauth_session:
|
|
1089
|
+
return None
|
|
1090
|
+
|
|
1091
|
+
return self._oauth_orm_to_pydantic(oauth_session)
|
|
1092
|
+
|
|
1093
|
+
@enforce_types
|
|
1094
|
+
async def update_oauth_session(self, session_id: str, session_update: MCPOAuthSessionUpdate, actor: PydanticUser) -> MCPOAuthSession:
|
|
1095
|
+
"""Update an existing OAuth session."""
|
|
1096
|
+
async with db_registry.async_session() as session:
|
|
1097
|
+
oauth_session = await MCPOAuth.read_async(db_session=session, identifier=session_id, actor=actor)
|
|
1098
|
+
|
|
1099
|
+
# Update fields that are provided
|
|
1100
|
+
if session_update.authorization_url is not None:
|
|
1101
|
+
oauth_session.authorization_url = session_update.authorization_url
|
|
1102
|
+
|
|
1103
|
+
# Handle encryption for authorization_code
|
|
1104
|
+
# Only re-encrypt if the value has actually changed
|
|
1105
|
+
if session_update.authorization_code is not None:
|
|
1106
|
+
# Check if value changed
|
|
1107
|
+
existing_code = None
|
|
1108
|
+
if oauth_session.authorization_code_enc:
|
|
1109
|
+
existing_secret = Secret.from_encrypted(oauth_session.authorization_code_enc)
|
|
1110
|
+
existing_code = existing_secret.get_plaintext()
|
|
1111
|
+
elif oauth_session.authorization_code:
|
|
1112
|
+
existing_code = oauth_session.authorization_code
|
|
1113
|
+
|
|
1114
|
+
# Only re-encrypt if different
|
|
1115
|
+
if existing_code != session_update.authorization_code:
|
|
1116
|
+
oauth_session.authorization_code_enc = Secret.from_plaintext(session_update.authorization_code).get_encrypted()
|
|
1117
|
+
# Keep plaintext for dual-write during migration
|
|
1118
|
+
oauth_session.authorization_code = session_update.authorization_code
|
|
1119
|
+
|
|
1120
|
+
# Handle encryption for access_token
|
|
1121
|
+
# Only re-encrypt if the value has actually changed
|
|
1122
|
+
if session_update.access_token is not None:
|
|
1123
|
+
# Check if value changed
|
|
1124
|
+
existing_token = None
|
|
1125
|
+
if oauth_session.access_token_enc:
|
|
1126
|
+
existing_secret = Secret.from_encrypted(oauth_session.access_token_enc)
|
|
1127
|
+
existing_token = existing_secret.get_plaintext()
|
|
1128
|
+
elif oauth_session.access_token:
|
|
1129
|
+
existing_token = oauth_session.access_token
|
|
1130
|
+
|
|
1131
|
+
# Only re-encrypt if different
|
|
1132
|
+
if existing_token != session_update.access_token:
|
|
1133
|
+
oauth_session.access_token_enc = Secret.from_plaintext(session_update.access_token).get_encrypted()
|
|
1134
|
+
# Keep plaintext for dual-write during migration
|
|
1135
|
+
oauth_session.access_token = session_update.access_token
|
|
1136
|
+
|
|
1137
|
+
# Handle encryption for refresh_token
|
|
1138
|
+
# Only re-encrypt if the value has actually changed
|
|
1139
|
+
if session_update.refresh_token is not None:
|
|
1140
|
+
# Check if value changed
|
|
1141
|
+
existing_refresh = None
|
|
1142
|
+
if oauth_session.refresh_token_enc:
|
|
1143
|
+
existing_secret = Secret.from_encrypted(oauth_session.refresh_token_enc)
|
|
1144
|
+
existing_refresh = existing_secret.get_plaintext()
|
|
1145
|
+
elif oauth_session.refresh_token:
|
|
1146
|
+
existing_refresh = oauth_session.refresh_token
|
|
1147
|
+
|
|
1148
|
+
# Only re-encrypt if different
|
|
1149
|
+
if existing_refresh != session_update.refresh_token:
|
|
1150
|
+
oauth_session.refresh_token_enc = Secret.from_plaintext(session_update.refresh_token).get_encrypted()
|
|
1151
|
+
# Keep plaintext for dual-write during migration
|
|
1152
|
+
oauth_session.refresh_token = session_update.refresh_token
|
|
1153
|
+
|
|
1154
|
+
if session_update.token_type is not None:
|
|
1155
|
+
oauth_session.token_type = session_update.token_type
|
|
1156
|
+
if session_update.expires_at is not None:
|
|
1157
|
+
oauth_session.expires_at = session_update.expires_at
|
|
1158
|
+
if session_update.scope is not None:
|
|
1159
|
+
oauth_session.scope = session_update.scope
|
|
1160
|
+
if session_update.client_id is not None:
|
|
1161
|
+
oauth_session.client_id = session_update.client_id
|
|
1162
|
+
|
|
1163
|
+
# Handle encryption for client_secret
|
|
1164
|
+
# Only re-encrypt if the value has actually changed
|
|
1165
|
+
if session_update.client_secret is not None:
|
|
1166
|
+
# Check if value changed
|
|
1167
|
+
existing_secret_val = None
|
|
1168
|
+
if oauth_session.client_secret_enc:
|
|
1169
|
+
existing_secret = Secret.from_encrypted(oauth_session.client_secret_enc)
|
|
1170
|
+
existing_secret_val = existing_secret.get_plaintext()
|
|
1171
|
+
elif oauth_session.client_secret:
|
|
1172
|
+
existing_secret_val = oauth_session.client_secret
|
|
1173
|
+
|
|
1174
|
+
# Only re-encrypt if different
|
|
1175
|
+
if existing_secret_val != session_update.client_secret:
|
|
1176
|
+
oauth_session.client_secret_enc = Secret.from_plaintext(session_update.client_secret).get_encrypted()
|
|
1177
|
+
# Keep plaintext for dual-write during migration
|
|
1178
|
+
oauth_session.client_secret = session_update.client_secret
|
|
1179
|
+
|
|
1180
|
+
if session_update.redirect_uri is not None:
|
|
1181
|
+
oauth_session.redirect_uri = session_update.redirect_uri
|
|
1182
|
+
if session_update.status is not None:
|
|
1183
|
+
oauth_session.status = session_update.status
|
|
1184
|
+
|
|
1185
|
+
# Always update the updated_at timestamp
|
|
1186
|
+
oauth_session.updated_at = datetime.now()
|
|
1187
|
+
|
|
1188
|
+
oauth_session = await oauth_session.update_async(db_session=session, actor=actor)
|
|
1189
|
+
|
|
1190
|
+
return self._oauth_orm_to_pydantic(oauth_session)
|
|
1191
|
+
|
|
1192
|
+
@enforce_types
|
|
1193
|
+
async def delete_oauth_session(self, session_id: str, actor: PydanticUser) -> None:
|
|
1194
|
+
"""Delete an OAuth session."""
|
|
1195
|
+
async with db_registry.async_session() as session:
|
|
1196
|
+
try:
|
|
1197
|
+
oauth_session = await MCPOAuth.read_async(db_session=session, identifier=session_id, actor=actor)
|
|
1198
|
+
await oauth_session.hard_delete_async(db_session=session, actor=actor)
|
|
1199
|
+
except NoResultFound:
|
|
1200
|
+
raise ValueError(f"OAuth session with id {session_id} not found.")
|
|
1201
|
+
|
|
1202
|
+
@enforce_types
|
|
1203
|
+
async def cleanup_expired_oauth_sessions(self, max_age_hours: int = 24) -> int:
|
|
1204
|
+
"""Clean up expired OAuth sessions and return the count of deleted sessions."""
|
|
1205
|
+
cutoff_time = datetime.now() - timedelta(hours=max_age_hours)
|
|
1206
|
+
|
|
1207
|
+
async with db_registry.async_session() as session:
|
|
1208
|
+
# Find expired sessions
|
|
1209
|
+
result = await session.execute(select(MCPOAuth).where(MCPOAuth.created_at < cutoff_time))
|
|
1210
|
+
expired_sessions = result.scalars().all()
|
|
1211
|
+
|
|
1212
|
+
# Delete expired sessions using async ORM method
|
|
1213
|
+
for oauth_session in expired_sessions:
|
|
1214
|
+
await oauth_session.hard_delete_async(db_session=session, actor=None)
|
|
1215
|
+
|
|
1216
|
+
if expired_sessions:
|
|
1217
|
+
logger.info(f"Cleaned up {len(expired_sessions)} expired OAuth sessions")
|
|
1218
|
+
|
|
1219
|
+
return len(expired_sessions)
|
|
1220
|
+
|
|
1221
|
+
@enforce_types
|
|
1222
|
+
async def handle_oauth_flow(
|
|
1223
|
+
self,
|
|
1224
|
+
request: Union[SSEServerConfig, StdioServerConfig, StreamableHTTPServerConfig],
|
|
1225
|
+
actor: PydanticUser,
|
|
1226
|
+
http_request: Optional[Request] = None,
|
|
1227
|
+
):
|
|
1228
|
+
"""
|
|
1229
|
+
Handle OAuth flow for MCP server connection and yield SSE events.
|
|
1230
|
+
|
|
1231
|
+
Args:
|
|
1232
|
+
request: The server configuration
|
|
1233
|
+
actor: The user making the request
|
|
1234
|
+
http_request: The HTTP request object
|
|
1235
|
+
|
|
1236
|
+
Yields:
|
|
1237
|
+
SSE events during OAuth flow
|
|
1238
|
+
|
|
1239
|
+
Returns:
|
|
1240
|
+
Tuple of (temp_client, connect_task) after yielding events
|
|
1241
|
+
"""
|
|
1242
|
+
import asyncio
|
|
1243
|
+
|
|
1244
|
+
from letta.services.mcp.oauth_utils import create_oauth_provider, oauth_stream_event
|
|
1245
|
+
from letta.services.mcp.types import OauthStreamEvent
|
|
1246
|
+
|
|
1247
|
+
# OAuth required, yield state to client to prepare to handle authorization URL
|
|
1248
|
+
yield oauth_stream_event(OauthStreamEvent.OAUTH_REQUIRED, message="OAuth authentication required")
|
|
1249
|
+
|
|
1250
|
+
# Create OAuth session to persist the state of the OAuth flow
|
|
1251
|
+
session_create = MCPOAuthSessionCreate(
|
|
1252
|
+
server_url=request.server_url,
|
|
1253
|
+
server_name=request.server_name,
|
|
1254
|
+
user_id=actor.id,
|
|
1255
|
+
organization_id=actor.organization_id,
|
|
1256
|
+
)
|
|
1257
|
+
oauth_session = await self.create_oauth_session(session_create, actor)
|
|
1258
|
+
session_id = oauth_session.id
|
|
1259
|
+
|
|
1260
|
+
# TODO: @jnjpng make this check more robust and remove direct os.getenv
|
|
1261
|
+
# Check if request is from web frontend to determine redirect URI
|
|
1262
|
+
is_web_request = (
|
|
1263
|
+
http_request
|
|
1264
|
+
and http_request.headers
|
|
1265
|
+
and http_request.headers.get("user-agent", "") == "Next.js Middleware"
|
|
1266
|
+
and http_request.headers.__contains__("x-organization-id")
|
|
1267
|
+
)
|
|
1268
|
+
|
|
1269
|
+
logo_uri = None
|
|
1270
|
+
NEXT_PUBLIC_CURRENT_HOST = os.getenv("NEXT_PUBLIC_CURRENT_HOST")
|
|
1271
|
+
LETTA_AGENTS_ENDPOINT = os.getenv("LETTA_AGENTS_ENDPOINT")
|
|
1272
|
+
|
|
1273
|
+
if is_web_request and NEXT_PUBLIC_CURRENT_HOST:
|
|
1274
|
+
redirect_uri = f"{NEXT_PUBLIC_CURRENT_HOST}/oauth/callback/{session_id}"
|
|
1275
|
+
logo_uri = f"{NEXT_PUBLIC_CURRENT_HOST}/seo/favicon.svg"
|
|
1276
|
+
elif LETTA_AGENTS_ENDPOINT:
|
|
1277
|
+
# API and SDK usage should call core server directly
|
|
1278
|
+
redirect_uri = f"{LETTA_AGENTS_ENDPOINT}/v1/tools/mcp/oauth/callback/{session_id}"
|
|
1279
|
+
else:
|
|
1280
|
+
logger.error(
|
|
1281
|
+
f"No redirect URI found for request and base urls: {http_request.headers if http_request else 'No headers'} {NEXT_PUBLIC_CURRENT_HOST} {LETTA_AGENTS_ENDPOINT}"
|
|
1282
|
+
)
|
|
1283
|
+
raise HTTPException(status_code=400, detail="No redirect URI found")
|
|
1284
|
+
|
|
1285
|
+
# Create OAuth provider for the instance of the stream connection
|
|
1286
|
+
oauth_provider = await create_oauth_provider(session_id, request.server_url, redirect_uri, self, actor, logo_uri=logo_uri)
|
|
1287
|
+
|
|
1288
|
+
# Get authorization URL by triggering OAuth flow
|
|
1289
|
+
temp_client = None
|
|
1290
|
+
connect_task = None
|
|
1291
|
+
try:
|
|
1292
|
+
temp_client = await self.get_mcp_client(request, actor, oauth_provider)
|
|
1293
|
+
|
|
1294
|
+
# Run connect_to_server in background to avoid blocking
|
|
1295
|
+
# This will trigger the OAuth flow and the redirect_handler will save the authorization URL to database
|
|
1296
|
+
connect_task = safe_create_task(temp_client.connect_to_server(), label="mcp_oauth_connect")
|
|
1297
|
+
|
|
1298
|
+
# Give the OAuth flow time to trigger and save the URL
|
|
1299
|
+
await asyncio.sleep(1.0)
|
|
1300
|
+
|
|
1301
|
+
# Fetch the authorization URL from database and yield state to client to proceed with handling authorization URL
|
|
1302
|
+
auth_session = await self.get_oauth_session_by_id(session_id, actor)
|
|
1303
|
+
if auth_session and auth_session.authorization_url:
|
|
1304
|
+
yield oauth_stream_event(OauthStreamEvent.AUTHORIZATION_URL, url=auth_session.authorization_url, session_id=session_id)
|
|
1305
|
+
|
|
1306
|
+
# Wait for user authorization (with timeout), client should render loading state until user completes the flow and /mcp/oauth/callback/{session_id} is hit
|
|
1307
|
+
yield oauth_stream_event(OauthStreamEvent.WAITING_FOR_AUTH, message="Waiting for user authorization...")
|
|
1308
|
+
|
|
1309
|
+
# Callback handler will poll for authorization code and state and update the OAuth session
|
|
1310
|
+
await connect_task
|
|
1311
|
+
|
|
1312
|
+
tools = await temp_client.list_tools(serialize=True)
|
|
1313
|
+
yield oauth_stream_event(OauthStreamEvent.SUCCESS, tools=tools)
|
|
1314
|
+
|
|
1315
|
+
except Exception as e:
|
|
1316
|
+
logger.error(f"Error triggering OAuth flow: {e}")
|
|
1317
|
+
yield oauth_stream_event(OauthStreamEvent.ERROR, message=f"Failed to trigger OAuth: {str(e)}")
|
|
1318
|
+
raise e
|
|
1319
|
+
finally:
|
|
1320
|
+
# Clean up resources
|
|
1321
|
+
if connect_task and not connect_task.done():
|
|
1322
|
+
connect_task.cancel()
|
|
1323
|
+
try:
|
|
1324
|
+
await connect_task
|
|
1325
|
+
except asyncio.CancelledError:
|
|
1326
|
+
pass
|
|
1327
|
+
if temp_client:
|
|
1328
|
+
try:
|
|
1329
|
+
await temp_client.cleanup()
|
|
1330
|
+
except Exception as cleanup_error:
|
|
1331
|
+
logger.warning(f"Error during temp MCP client cleanup: {cleanup_error}")
|