letta-nightly 0.12.1.dev20251023104211__py3-none-any.whl → 0.13.0.dev20251024223017__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of letta-nightly might be problematic. Click here for more details.
- letta/__init__.py +2 -3
- letta/adapters/letta_llm_adapter.py +1 -0
- letta/adapters/simple_llm_request_adapter.py +8 -5
- letta/adapters/simple_llm_stream_adapter.py +22 -6
- letta/agents/agent_loop.py +10 -3
- letta/agents/base_agent.py +4 -1
- letta/agents/helpers.py +41 -9
- letta/agents/letta_agent.py +11 -10
- letta/agents/letta_agent_v2.py +47 -37
- letta/agents/letta_agent_v3.py +395 -300
- letta/agents/voice_agent.py +8 -6
- letta/agents/voice_sleeptime_agent.py +3 -3
- letta/constants.py +30 -7
- letta/errors.py +20 -0
- letta/functions/function_sets/base.py +55 -3
- letta/functions/mcp_client/types.py +33 -57
- letta/functions/schema_generator.py +135 -23
- letta/groups/sleeptime_multi_agent_v3.py +6 -11
- letta/groups/sleeptime_multi_agent_v4.py +227 -0
- letta/helpers/converters.py +78 -4
- letta/helpers/crypto_utils.py +6 -2
- letta/interfaces/anthropic_parallel_tool_call_streaming_interface.py +9 -11
- letta/interfaces/anthropic_streaming_interface.py +3 -4
- letta/interfaces/gemini_streaming_interface.py +4 -6
- letta/interfaces/openai_streaming_interface.py +63 -28
- letta/llm_api/anthropic_client.py +7 -4
- letta/llm_api/deepseek_client.py +6 -4
- letta/llm_api/google_ai_client.py +3 -12
- letta/llm_api/google_vertex_client.py +1 -1
- letta/llm_api/helpers.py +90 -61
- letta/llm_api/llm_api_tools.py +4 -1
- letta/llm_api/openai.py +12 -12
- letta/llm_api/openai_client.py +53 -16
- letta/local_llm/constants.py +4 -3
- letta/local_llm/json_parser.py +5 -2
- letta/local_llm/utils.py +2 -3
- letta/log.py +171 -7
- letta/orm/agent.py +43 -9
- letta/orm/archive.py +4 -0
- letta/orm/custom_columns.py +15 -0
- letta/orm/identity.py +11 -11
- letta/orm/mcp_server.py +9 -0
- letta/orm/message.py +6 -1
- letta/orm/run_metrics.py +7 -2
- letta/orm/sqlalchemy_base.py +2 -2
- letta/orm/tool.py +3 -0
- letta/otel/tracing.py +2 -0
- letta/prompts/prompt_generator.py +7 -2
- letta/schemas/agent.py +41 -10
- letta/schemas/agent_file.py +3 -0
- letta/schemas/archive.py +4 -2
- letta/schemas/block.py +2 -1
- letta/schemas/enums.py +36 -3
- letta/schemas/file.py +3 -3
- letta/schemas/folder.py +2 -1
- letta/schemas/group.py +2 -1
- letta/schemas/identity.py +18 -9
- letta/schemas/job.py +3 -1
- letta/schemas/letta_message.py +71 -12
- letta/schemas/letta_request.py +7 -3
- letta/schemas/letta_stop_reason.py +0 -25
- letta/schemas/llm_config.py +8 -2
- letta/schemas/mcp.py +80 -83
- letta/schemas/mcp_server.py +349 -0
- letta/schemas/memory.py +20 -8
- letta/schemas/message.py +212 -67
- letta/schemas/providers/anthropic.py +13 -6
- letta/schemas/providers/azure.py +6 -4
- letta/schemas/providers/base.py +8 -4
- letta/schemas/providers/bedrock.py +6 -2
- letta/schemas/providers/cerebras.py +7 -3
- letta/schemas/providers/deepseek.py +2 -1
- letta/schemas/providers/google_gemini.py +15 -6
- letta/schemas/providers/groq.py +2 -1
- letta/schemas/providers/lmstudio.py +9 -6
- letta/schemas/providers/mistral.py +2 -1
- letta/schemas/providers/openai.py +7 -2
- letta/schemas/providers/together.py +9 -3
- letta/schemas/providers/xai.py +7 -3
- letta/schemas/run.py +7 -2
- letta/schemas/run_metrics.py +2 -1
- letta/schemas/sandbox_config.py +2 -2
- letta/schemas/secret.py +3 -158
- letta/schemas/source.py +2 -2
- letta/schemas/step.py +2 -2
- letta/schemas/tool.py +24 -1
- letta/schemas/usage.py +0 -1
- letta/server/rest_api/app.py +123 -7
- letta/server/rest_api/dependencies.py +3 -0
- letta/server/rest_api/interface.py +7 -4
- letta/server/rest_api/redis_stream_manager.py +16 -1
- letta/server/rest_api/routers/v1/__init__.py +7 -0
- letta/server/rest_api/routers/v1/agents.py +332 -322
- letta/server/rest_api/routers/v1/archives.py +127 -40
- letta/server/rest_api/routers/v1/blocks.py +54 -6
- letta/server/rest_api/routers/v1/chat_completions.py +146 -0
- letta/server/rest_api/routers/v1/folders.py +27 -35
- letta/server/rest_api/routers/v1/groups.py +23 -35
- letta/server/rest_api/routers/v1/identities.py +24 -10
- letta/server/rest_api/routers/v1/internal_runs.py +107 -0
- letta/server/rest_api/routers/v1/internal_templates.py +162 -179
- letta/server/rest_api/routers/v1/jobs.py +15 -27
- letta/server/rest_api/routers/v1/mcp_servers.py +309 -0
- letta/server/rest_api/routers/v1/messages.py +23 -34
- letta/server/rest_api/routers/v1/organizations.py +6 -27
- letta/server/rest_api/routers/v1/providers.py +35 -62
- letta/server/rest_api/routers/v1/runs.py +30 -43
- letta/server/rest_api/routers/v1/sandbox_configs.py +6 -4
- letta/server/rest_api/routers/v1/sources.py +26 -42
- letta/server/rest_api/routers/v1/steps.py +16 -29
- letta/server/rest_api/routers/v1/tools.py +17 -13
- letta/server/rest_api/routers/v1/users.py +5 -17
- letta/server/rest_api/routers/v1/voice.py +18 -27
- letta/server/rest_api/streaming_response.py +5 -2
- letta/server/rest_api/utils.py +187 -25
- letta/server/server.py +27 -22
- letta/server/ws_api/server.py +5 -4
- letta/services/agent_manager.py +148 -26
- letta/services/agent_serialization_manager.py +6 -1
- letta/services/archive_manager.py +168 -15
- letta/services/block_manager.py +14 -4
- letta/services/file_manager.py +33 -29
- letta/services/group_manager.py +10 -0
- letta/services/helpers/agent_manager_helper.py +65 -11
- letta/services/identity_manager.py +105 -4
- letta/services/job_manager.py +11 -1
- letta/services/mcp/base_client.py +2 -2
- letta/services/mcp/oauth_utils.py +33 -8
- letta/services/mcp_manager.py +174 -78
- letta/services/mcp_server_manager.py +1331 -0
- letta/services/message_manager.py +109 -4
- letta/services/organization_manager.py +4 -4
- letta/services/passage_manager.py +9 -25
- letta/services/provider_manager.py +91 -15
- letta/services/run_manager.py +72 -15
- letta/services/sandbox_config_manager.py +45 -3
- letta/services/source_manager.py +15 -8
- letta/services/step_manager.py +24 -1
- letta/services/streaming_service.py +581 -0
- letta/services/summarizer/summarizer.py +1 -1
- letta/services/tool_executor/core_tool_executor.py +111 -0
- letta/services/tool_executor/files_tool_executor.py +5 -3
- letta/services/tool_executor/sandbox_tool_executor.py +2 -2
- letta/services/tool_executor/tool_execution_manager.py +1 -1
- letta/services/tool_manager.py +10 -3
- letta/services/tool_sandbox/base.py +61 -1
- letta/services/tool_sandbox/local_sandbox.py +1 -3
- letta/services/user_manager.py +2 -2
- letta/settings.py +49 -5
- letta/system.py +14 -5
- letta/utils.py +73 -1
- letta/validators.py +105 -0
- {letta_nightly-0.12.1.dev20251023104211.dist-info → letta_nightly-0.13.0.dev20251024223017.dist-info}/METADATA +4 -2
- {letta_nightly-0.12.1.dev20251023104211.dist-info → letta_nightly-0.13.0.dev20251024223017.dist-info}/RECORD +157 -151
- letta/schemas/letta_ping.py +0 -28
- letta/server/rest_api/routers/openai/chat_completions/__init__.py +0 -0
- {letta_nightly-0.12.1.dev20251023104211.dist-info → letta_nightly-0.13.0.dev20251024223017.dist-info}/WHEEL +0 -0
- {letta_nightly-0.12.1.dev20251023104211.dist-info → letta_nightly-0.13.0.dev20251024223017.dist-info}/entry_points.txt +0 -0
- {letta_nightly-0.12.1.dev20251023104211.dist-info → letta_nightly-0.13.0.dev20251024223017.dist-info}/licenses/LICENSE +0 -0
letta/services/mcp_manager.py
CHANGED
|
@@ -25,6 +25,7 @@ from letta.orm.errors import NoResultFound
|
|
|
25
25
|
from letta.orm.mcp_oauth import MCPOAuth, OAuthSessionStatus
|
|
26
26
|
from letta.orm.mcp_server import MCPServer as MCPServerModel
|
|
27
27
|
from letta.orm.tool import Tool as ToolModel
|
|
28
|
+
from letta.schemas.enums import PrimitiveType
|
|
28
29
|
from letta.schemas.mcp import (
|
|
29
30
|
MCPOAuthSession,
|
|
30
31
|
MCPOAuthSessionCreate,
|
|
@@ -36,16 +37,18 @@ from letta.schemas.mcp import (
|
|
|
36
37
|
UpdateStdioMCPServer,
|
|
37
38
|
UpdateStreamableHTTPMCPServer,
|
|
38
39
|
)
|
|
39
|
-
from letta.schemas.secret import Secret
|
|
40
|
+
from letta.schemas.secret import Secret
|
|
40
41
|
from letta.schemas.tool import Tool as PydanticTool, ToolCreate, ToolUpdate
|
|
41
42
|
from letta.schemas.user import User as PydanticUser
|
|
42
43
|
from letta.server.db import db_registry
|
|
44
|
+
from letta.services.mcp.base_client import AsyncBaseMCPClient
|
|
43
45
|
from letta.services.mcp.sse_client import MCP_CONFIG_TOPLEVEL_KEY, AsyncSSEMCPClient
|
|
44
46
|
from letta.services.mcp.stdio_client import AsyncStdioMCPClient
|
|
45
47
|
from letta.services.mcp.streamable_http_client import AsyncStreamableHTTPMCPClient
|
|
46
48
|
from letta.services.tool_manager import ToolManager
|
|
47
|
-
from letta.settings import tool_settings
|
|
48
|
-
from letta.utils import enforce_types, printd,
|
|
49
|
+
from letta.settings import settings, tool_settings
|
|
50
|
+
from letta.utils import enforce_types, printd, safe_create_task_with_return
|
|
51
|
+
from letta.validators import raise_on_invalid_id
|
|
49
52
|
|
|
50
53
|
logger = get_logger(__name__)
|
|
51
54
|
|
|
@@ -59,6 +62,7 @@ class MCPManager:
|
|
|
59
62
|
self.cached_mcp_servers = {} # maps id -> async connection
|
|
60
63
|
|
|
61
64
|
@enforce_types
|
|
65
|
+
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
|
62
66
|
async def list_mcp_server_tools(self, mcp_server_name: str, actor: PydanticUser, agent_id: Optional[str] = None) -> List[MCPTool]:
|
|
63
67
|
"""Get a list of all tools for a specific MCP server."""
|
|
64
68
|
mcp_client = None
|
|
@@ -73,7 +77,7 @@ class MCPManager:
|
|
|
73
77
|
tools = await mcp_client.list_tools()
|
|
74
78
|
# Add health information to each tool
|
|
75
79
|
for tool in tools:
|
|
76
|
-
# Try to normalize the schema and re-validate
|
|
80
|
+
# Try to normalize the schema and re-validate
|
|
77
81
|
if tool.inputSchema:
|
|
78
82
|
tool.inputSchema = normalize_mcp_schema(tool.inputSchema)
|
|
79
83
|
health_status, reasons = validate_complete_json_schema(tool.inputSchema)
|
|
@@ -174,10 +178,7 @@ class MCPManager:
|
|
|
174
178
|
|
|
175
179
|
# After normalization attempt, check if still INVALID
|
|
176
180
|
if mcp_tool.health and mcp_tool.health.status == "INVALID":
|
|
177
|
-
|
|
178
|
-
f"Tool {mcp_tool_name} cannot be attached, JSON schema is invalid even after normalization. "
|
|
179
|
-
f"Reasons: {', '.join(mcp_tool.health.reasons)}"
|
|
180
|
-
)
|
|
181
|
+
logger.warning(f"Tool {mcp_tool_name} has potentially invalid schema. Reasons: {', '.join(mcp_tool.health.reasons)}")
|
|
181
182
|
|
|
182
183
|
tool_create = ToolCreate.from_mcp(mcp_server_name=mcp_server_name, mcp_tool=mcp_tool)
|
|
183
184
|
return await self.tool_manager.create_mcp_tool_async(
|
|
@@ -318,6 +319,7 @@ class MCPManager:
|
|
|
318
319
|
update_data = pydantic_mcp_server.model_dump(exclude_unset=True, exclude_none=True)
|
|
319
320
|
|
|
320
321
|
# If there's anything to update (can only update the configs, not the name)
|
|
322
|
+
# TODO: pass in custom headers for update as well?
|
|
321
323
|
if update_data:
|
|
322
324
|
if pydantic_mcp_server.server_type == MCPServerType.SSE:
|
|
323
325
|
update_request = UpdateSSEMCPServer(server_url=pydantic_mcp_server.server_url, token=pydantic_mcp_server.token)
|
|
@@ -325,7 +327,7 @@ class MCPManager:
|
|
|
325
327
|
update_request = UpdateStdioMCPServer(stdio_config=pydantic_mcp_server.stdio_config)
|
|
326
328
|
elif pydantic_mcp_server.server_type == MCPServerType.STREAMABLE_HTTP:
|
|
327
329
|
update_request = UpdateStreamableHTTPMCPServer(
|
|
328
|
-
server_url=pydantic_mcp_server.server_url,
|
|
330
|
+
server_url=pydantic_mcp_server.server_url, auth_token=pydantic_mcp_server.token
|
|
329
331
|
)
|
|
330
332
|
else:
|
|
331
333
|
raise ValueError(f"Unsupported server type: {pydantic_mcp_server.server_type}")
|
|
@@ -347,6 +349,17 @@ class MCPManager:
|
|
|
347
349
|
try:
|
|
348
350
|
# Set the organization id at the ORM layer
|
|
349
351
|
pydantic_mcp_server.organization_id = actor.organization_id
|
|
352
|
+
|
|
353
|
+
# Explicitly populate encrypted fields
|
|
354
|
+
if pydantic_mcp_server.token is not None:
|
|
355
|
+
pydantic_mcp_server.token_enc = Secret.from_plaintext(pydantic_mcp_server.token)
|
|
356
|
+
if pydantic_mcp_server.custom_headers is not None:
|
|
357
|
+
# custom_headers is a Dict[str, str], serialize to JSON then encrypt
|
|
358
|
+
import json
|
|
359
|
+
|
|
360
|
+
json_str = json.dumps(pydantic_mcp_server.custom_headers)
|
|
361
|
+
pydantic_mcp_server.custom_headers_enc = Secret.from_plaintext(json_str)
|
|
362
|
+
|
|
350
363
|
mcp_server_data = pydantic_mcp_server.model_dump(to_orm=True)
|
|
351
364
|
|
|
352
365
|
# Ensure custom_headers None is stored as SQL NULL, not JSON null
|
|
@@ -384,7 +397,6 @@ class MCPManager:
|
|
|
384
397
|
return mcp_server.to_pydantic()
|
|
385
398
|
except Exception as e:
|
|
386
399
|
await session.rollback()
|
|
387
|
-
logger.error(f"Failed to create MCP server: {e}")
|
|
388
400
|
raise
|
|
389
401
|
|
|
390
402
|
@enforce_types
|
|
@@ -412,7 +424,9 @@ class MCPManager:
|
|
|
412
424
|
token_secret = Secret.from_plaintext(token)
|
|
413
425
|
mcp_server.set_token_secret(token_secret)
|
|
414
426
|
if server_config.custom_headers:
|
|
415
|
-
|
|
427
|
+
# Convert dict to JSON string, then encrypt as Secret
|
|
428
|
+
headers_json = json.dumps(server_config.custom_headers)
|
|
429
|
+
headers_secret = Secret.from_plaintext(headers_json)
|
|
416
430
|
mcp_server.set_custom_headers_secret(headers_secret)
|
|
417
431
|
|
|
418
432
|
elif isinstance(server_config, StreamableHTTPServerConfig):
|
|
@@ -427,7 +441,9 @@ class MCPManager:
|
|
|
427
441
|
token_secret = Secret.from_plaintext(token)
|
|
428
442
|
mcp_server.set_token_secret(token_secret)
|
|
429
443
|
if server_config.custom_headers:
|
|
430
|
-
|
|
444
|
+
# Convert dict to JSON string, then encrypt as Secret
|
|
445
|
+
headers_json = json.dumps(server_config.custom_headers)
|
|
446
|
+
headers_secret = Secret.from_plaintext(headers_json)
|
|
431
447
|
mcp_server.set_custom_headers_secret(headers_secret)
|
|
432
448
|
else:
|
|
433
449
|
raise ValueError(f"Unsupported server config type: {type(server_config)}")
|
|
@@ -517,27 +533,52 @@ class MCPManager:
|
|
|
517
533
|
update_data = mcp_server_update.model_dump(to_orm=True, exclude_unset=True)
|
|
518
534
|
|
|
519
535
|
# Handle encryption for token if provided
|
|
536
|
+
# Only re-encrypt if the value has actually changed
|
|
520
537
|
if "token" in update_data and update_data["token"] is not None:
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
538
|
+
# Check if value changed
|
|
539
|
+
existing_token = None
|
|
540
|
+
if mcp_server.token_enc:
|
|
541
|
+
existing_secret = Secret.from_encrypted(mcp_server.token_enc)
|
|
542
|
+
existing_token = existing_secret.get_plaintext()
|
|
543
|
+
elif mcp_server.token:
|
|
544
|
+
existing_token = mcp_server.token
|
|
545
|
+
|
|
546
|
+
# Only re-encrypt if different
|
|
547
|
+
if existing_token != update_data["token"]:
|
|
548
|
+
mcp_server.token_enc = Secret.from_plaintext(update_data["token"]).get_encrypted()
|
|
549
|
+
# Keep plaintext for dual-write during migration
|
|
550
|
+
mcp_server.token = update_data["token"]
|
|
551
|
+
|
|
552
|
+
# Remove from update_data since we set directly on mcp_server
|
|
553
|
+
update_data.pop("token", None)
|
|
554
|
+
update_data.pop("token_enc", None)
|
|
529
555
|
|
|
530
556
|
# Handle encryption for custom_headers if provided
|
|
557
|
+
# Only re-encrypt if the value has actually changed
|
|
531
558
|
if "custom_headers" in update_data:
|
|
532
559
|
if update_data["custom_headers"] is not None:
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
560
|
+
# custom_headers is a Dict[str, str], serialize to JSON then encrypt
|
|
561
|
+
import json
|
|
562
|
+
|
|
563
|
+
json_str = json.dumps(update_data["custom_headers"])
|
|
564
|
+
|
|
565
|
+
# Check if value changed
|
|
566
|
+
existing_headers_json = None
|
|
567
|
+
if mcp_server.custom_headers_enc:
|
|
568
|
+
existing_secret = Secret.from_encrypted(mcp_server.custom_headers_enc)
|
|
569
|
+
existing_headers_json = existing_secret.get_plaintext()
|
|
570
|
+
elif mcp_server.custom_headers:
|
|
571
|
+
existing_headers_json = json.dumps(mcp_server.custom_headers)
|
|
572
|
+
|
|
573
|
+
# Only re-encrypt if different
|
|
574
|
+
if existing_headers_json != json_str:
|
|
575
|
+
mcp_server.custom_headers_enc = Secret.from_plaintext(json_str).get_encrypted()
|
|
576
|
+
# Keep plaintext for dual-write during migration
|
|
577
|
+
mcp_server.custom_headers = update_data["custom_headers"]
|
|
578
|
+
|
|
579
|
+
# Remove from update_data since we set directly on mcp_server
|
|
580
|
+
update_data.pop("custom_headers", None)
|
|
581
|
+
update_data.pop("custom_headers_enc", None)
|
|
541
582
|
else:
|
|
542
583
|
# Ensure custom_headers None is stored as SQL NULL, not JSON null
|
|
543
584
|
update_data.pop("custom_headers", None)
|
|
@@ -758,7 +799,8 @@ class MCPManager:
|
|
|
758
799
|
# If no OAuth provider is provided, check if we have stored OAuth credentials
|
|
759
800
|
if oauth_provider is None and hasattr(server_config, "server_url"):
|
|
760
801
|
oauth_session = await self.get_oauth_session_by_server(server_config.server_url, actor)
|
|
761
|
-
if
|
|
802
|
+
# Check if access token exists by attempting to decrypt it
|
|
803
|
+
if oauth_session and oauth_session.get_access_token_secret().get_plaintext():
|
|
762
804
|
# Create OAuth provider from stored credentials
|
|
763
805
|
from letta.services.mcp.oauth_utils import create_oauth_provider
|
|
764
806
|
|
|
@@ -787,8 +829,6 @@ class MCPManager:
|
|
|
787
829
|
"""
|
|
788
830
|
Convert OAuth ORM model to Pydantic model, handling decryption of sensitive fields.
|
|
789
831
|
"""
|
|
790
|
-
from letta.settings import settings
|
|
791
|
-
|
|
792
832
|
# Get decrypted values using the dual-read approach
|
|
793
833
|
# Secret.from_db() will automatically use settings.encryption_key if available
|
|
794
834
|
access_token = None
|
|
@@ -818,7 +858,17 @@ class MCPManager:
|
|
|
818
858
|
# No encryption key, use plaintext if available
|
|
819
859
|
client_secret = oauth_session.client_secret
|
|
820
860
|
|
|
821
|
-
|
|
861
|
+
authorization_code = None
|
|
862
|
+
if oauth_session.authorization_code_enc or oauth_session.authorization_code:
|
|
863
|
+
if settings.encryption_key:
|
|
864
|
+
secret = Secret.from_db(oauth_session.authorization_code_enc, oauth_session.authorization_code)
|
|
865
|
+
authorization_code = secret.get_plaintext()
|
|
866
|
+
else:
|
|
867
|
+
# No encryption key, use plaintext if available
|
|
868
|
+
authorization_code = oauth_session.authorization_code
|
|
869
|
+
|
|
870
|
+
# Create the Pydantic object with encrypted fields as Secret objects
|
|
871
|
+
pydantic_session = MCPOAuthSession(
|
|
822
872
|
id=oauth_session.id,
|
|
823
873
|
state=oauth_session.state,
|
|
824
874
|
server_id=oauth_session.server_id,
|
|
@@ -827,7 +877,7 @@ class MCPManager:
|
|
|
827
877
|
user_id=oauth_session.user_id,
|
|
828
878
|
organization_id=oauth_session.organization_id,
|
|
829
879
|
authorization_url=oauth_session.authorization_url,
|
|
830
|
-
authorization_code=
|
|
880
|
+
authorization_code=authorization_code,
|
|
831
881
|
access_token=access_token,
|
|
832
882
|
refresh_token=refresh_token,
|
|
833
883
|
token_type=oauth_session.token_type,
|
|
@@ -839,7 +889,15 @@ class MCPManager:
|
|
|
839
889
|
status=oauth_session.status,
|
|
840
890
|
created_at=oauth_session.created_at,
|
|
841
891
|
updated_at=oauth_session.updated_at,
|
|
892
|
+
# Encrypted fields as Secret objects (converted from encrypted strings in DB)
|
|
893
|
+
authorization_code_enc=Secret.from_encrypted(oauth_session.authorization_code_enc)
|
|
894
|
+
if oauth_session.authorization_code_enc
|
|
895
|
+
else None,
|
|
896
|
+
access_token_enc=Secret.from_encrypted(oauth_session.access_token_enc) if oauth_session.access_token_enc else None,
|
|
897
|
+
refresh_token_enc=Secret.from_encrypted(oauth_session.refresh_token_enc) if oauth_session.refresh_token_enc else None,
|
|
898
|
+
client_secret_enc=Secret.from_encrypted(oauth_session.client_secret_enc) if oauth_session.client_secret_enc else None,
|
|
842
899
|
)
|
|
900
|
+
return pydantic_session
|
|
843
901
|
|
|
844
902
|
@enforce_types
|
|
845
903
|
async def create_oauth_session(self, session_create: MCPOAuthSessionCreate, actor: PydanticUser) -> MCPOAuthSession:
|
|
@@ -905,38 +963,57 @@ class MCPManager:
|
|
|
905
963
|
# Update fields that are provided
|
|
906
964
|
if session_update.authorization_url is not None:
|
|
907
965
|
oauth_session.authorization_url = session_update.authorization_url
|
|
966
|
+
|
|
967
|
+
# Handle encryption for authorization_code
|
|
968
|
+
# Only re-encrypt if the value has actually changed
|
|
908
969
|
if session_update.authorization_code is not None:
|
|
909
|
-
|
|
970
|
+
# Check if value changed
|
|
971
|
+
existing_code = None
|
|
972
|
+
if oauth_session.authorization_code_enc:
|
|
973
|
+
existing_secret = Secret.from_encrypted(oauth_session.authorization_code_enc)
|
|
974
|
+
existing_code = existing_secret.get_plaintext()
|
|
975
|
+
elif oauth_session.authorization_code:
|
|
976
|
+
existing_code = oauth_session.authorization_code
|
|
977
|
+
|
|
978
|
+
# Only re-encrypt if different
|
|
979
|
+
if existing_code != session_update.authorization_code:
|
|
980
|
+
oauth_session.authorization_code_enc = Secret.from_plaintext(session_update.authorization_code).get_encrypted()
|
|
981
|
+
# Keep plaintext for dual-write during migration
|
|
982
|
+
oauth_session.authorization_code = session_update.authorization_code
|
|
910
983
|
|
|
911
984
|
# Handle encryption for access_token
|
|
985
|
+
# Only re-encrypt if the value has actually changed
|
|
912
986
|
if session_update.access_token is not None:
|
|
913
|
-
|
|
914
|
-
|
|
915
|
-
if
|
|
916
|
-
|
|
917
|
-
|
|
918
|
-
|
|
919
|
-
|
|
920
|
-
|
|
921
|
-
|
|
922
|
-
|
|
987
|
+
# Check if value changed
|
|
988
|
+
existing_token = None
|
|
989
|
+
if oauth_session.access_token_enc:
|
|
990
|
+
existing_secret = Secret.from_encrypted(oauth_session.access_token_enc)
|
|
991
|
+
existing_token = existing_secret.get_plaintext()
|
|
992
|
+
elif oauth_session.access_token:
|
|
993
|
+
existing_token = oauth_session.access_token
|
|
994
|
+
|
|
995
|
+
# Only re-encrypt if different
|
|
996
|
+
if existing_token != session_update.access_token:
|
|
997
|
+
oauth_session.access_token_enc = Secret.from_plaintext(session_update.access_token).get_encrypted()
|
|
998
|
+
# Keep plaintext for dual-write during migration
|
|
923
999
|
oauth_session.access_token = session_update.access_token
|
|
924
|
-
oauth_session.access_token_enc = None
|
|
925
1000
|
|
|
926
1001
|
# Handle encryption for refresh_token
|
|
1002
|
+
# Only re-encrypt if the value has actually changed
|
|
927
1003
|
if session_update.refresh_token is not None:
|
|
928
|
-
|
|
929
|
-
|
|
930
|
-
if
|
|
931
|
-
|
|
932
|
-
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
|
|
936
|
-
|
|
937
|
-
|
|
1004
|
+
# Check if value changed
|
|
1005
|
+
existing_refresh = None
|
|
1006
|
+
if oauth_session.refresh_token_enc:
|
|
1007
|
+
existing_secret = Secret.from_encrypted(oauth_session.refresh_token_enc)
|
|
1008
|
+
existing_refresh = existing_secret.get_plaintext()
|
|
1009
|
+
elif oauth_session.refresh_token:
|
|
1010
|
+
existing_refresh = oauth_session.refresh_token
|
|
1011
|
+
|
|
1012
|
+
# Only re-encrypt if different
|
|
1013
|
+
if existing_refresh != session_update.refresh_token:
|
|
1014
|
+
oauth_session.refresh_token_enc = Secret.from_plaintext(session_update.refresh_token).get_encrypted()
|
|
1015
|
+
# Keep plaintext for dual-write during migration
|
|
938
1016
|
oauth_session.refresh_token = session_update.refresh_token
|
|
939
|
-
oauth_session.refresh_token_enc = None
|
|
940
1017
|
|
|
941
1018
|
if session_update.token_type is not None:
|
|
942
1019
|
oauth_session.token_type = session_update.token_type
|
|
@@ -948,19 +1025,21 @@ class MCPManager:
|
|
|
948
1025
|
oauth_session.client_id = session_update.client_id
|
|
949
1026
|
|
|
950
1027
|
# Handle encryption for client_secret
|
|
1028
|
+
# Only re-encrypt if the value has actually changed
|
|
951
1029
|
if session_update.client_secret is not None:
|
|
952
|
-
|
|
953
|
-
|
|
954
|
-
if
|
|
955
|
-
|
|
956
|
-
|
|
957
|
-
|
|
958
|
-
|
|
959
|
-
|
|
960
|
-
|
|
961
|
-
|
|
1030
|
+
# Check if value changed
|
|
1031
|
+
existing_secret_val = None
|
|
1032
|
+
if oauth_session.client_secret_enc:
|
|
1033
|
+
existing_secret = Secret.from_encrypted(oauth_session.client_secret_enc)
|
|
1034
|
+
existing_secret_val = existing_secret.get_plaintext()
|
|
1035
|
+
elif oauth_session.client_secret:
|
|
1036
|
+
existing_secret_val = oauth_session.client_secret
|
|
1037
|
+
|
|
1038
|
+
# Only re-encrypt if different
|
|
1039
|
+
if existing_secret_val != session_update.client_secret:
|
|
1040
|
+
oauth_session.client_secret_enc = Secret.from_plaintext(session_update.client_secret).get_encrypted()
|
|
1041
|
+
# Keep plaintext for dual-write during migration
|
|
962
1042
|
oauth_session.client_secret = session_update.client_secret
|
|
963
|
-
oauth_session.client_secret_enc = None
|
|
964
1043
|
|
|
965
1044
|
if session_update.redirect_uri is not None:
|
|
966
1045
|
oauth_session.redirect_uri = session_update.redirect_uri
|
|
@@ -1073,18 +1152,37 @@ class MCPManager:
|
|
|
1073
1152
|
# Get authorization URL by triggering OAuth flow
|
|
1074
1153
|
temp_client = None
|
|
1075
1154
|
connect_task = None
|
|
1155
|
+
|
|
1156
|
+
async def connect_and_cleanup(client: AsyncBaseMCPClient, ready_queue: asyncio.Queue):
|
|
1157
|
+
"""Wrap connection and cleanup in the same task to share cancel scope"""
|
|
1158
|
+
try:
|
|
1159
|
+
await client.connect_to_server()
|
|
1160
|
+
# Send client to main task without finishing the task
|
|
1161
|
+
await ready_queue.put(client)
|
|
1162
|
+
# Now wait for signal to cleanup
|
|
1163
|
+
await client._cleanup_event.wait()
|
|
1164
|
+
finally:
|
|
1165
|
+
await client.cleanup()
|
|
1166
|
+
|
|
1076
1167
|
try:
|
|
1168
|
+
ready_queue = asyncio.Queue()
|
|
1077
1169
|
temp_client = await self.get_mcp_client(request, actor, oauth_provider)
|
|
1170
|
+
temp_client._cleanup_event = asyncio.Event()
|
|
1078
1171
|
|
|
1079
1172
|
# Run connect_to_server in background to avoid blocking
|
|
1080
1173
|
# This will trigger the OAuth flow and the redirect_handler will save the authorization URL to database
|
|
1081
|
-
connect_task =
|
|
1082
|
-
|
|
1083
|
-
# Give the OAuth flow time to trigger and save the URL
|
|
1084
|
-
await asyncio.sleep(1.0)
|
|
1174
|
+
connect_task = safe_create_task_with_return(connect_and_cleanup(temp_client, ready_queue), label="mcp_oauth_connect")
|
|
1085
1175
|
|
|
1086
1176
|
# Fetch the authorization URL from database and yield state to client to proceed with handling authorization URL
|
|
1087
1177
|
auth_session = await self.get_oauth_session_by_id(session_id, actor)
|
|
1178
|
+
|
|
1179
|
+
# Give the OAuth flow time to connect to the MCP server and store the authorization URL
|
|
1180
|
+
timeout = 0
|
|
1181
|
+
while not auth_session or not auth_session.authorization_url and not connect_task.done() and timeout < 10:
|
|
1182
|
+
timeout += 1
|
|
1183
|
+
auth_session = await self.get_oauth_session_by_id(session_id, actor)
|
|
1184
|
+
await asyncio.sleep(1.0)
|
|
1185
|
+
|
|
1088
1186
|
if auth_session and auth_session.authorization_url:
|
|
1089
1187
|
yield oauth_stream_event(OauthStreamEvent.AUTHORIZATION_URL, url=auth_session.authorization_url, session_id=session_id)
|
|
1090
1188
|
|
|
@@ -1092,11 +1190,14 @@ class MCPManager:
|
|
|
1092
1190
|
yield oauth_stream_event(OauthStreamEvent.WAITING_FOR_AUTH, message="Waiting for user authorization...")
|
|
1093
1191
|
|
|
1094
1192
|
# Callback handler will poll for authorization code and state and update the OAuth session
|
|
1095
|
-
|
|
1096
|
-
|
|
1193
|
+
# Get the client from the queue
|
|
1194
|
+
temp_client = await ready_queue.get()
|
|
1097
1195
|
tools = await temp_client.list_tools(serialize=True)
|
|
1098
1196
|
yield oauth_stream_event(OauthStreamEvent.SUCCESS, tools=tools)
|
|
1099
1197
|
|
|
1198
|
+
# Signal the background task to cleanup in its own task
|
|
1199
|
+
temp_client._cleanup_event.set()
|
|
1200
|
+
await connect_task # now it finishes safely
|
|
1100
1201
|
except Exception as e:
|
|
1101
1202
|
logger.error(f"Error triggering OAuth flow: {e}")
|
|
1102
1203
|
yield oauth_stream_event(OauthStreamEvent.ERROR, message=f"Failed to trigger OAuth: {str(e)}")
|
|
@@ -1109,8 +1210,3 @@ class MCPManager:
|
|
|
1109
1210
|
await connect_task
|
|
1110
1211
|
except asyncio.CancelledError:
|
|
1111
1212
|
pass
|
|
1112
|
-
if temp_client:
|
|
1113
|
-
try:
|
|
1114
|
-
await temp_client.cleanup()
|
|
1115
|
-
except Exception as cleanup_error:
|
|
1116
|
-
logger.warning(f"Error during temp MCP client cleanup: {cleanup_error}")
|