letta-nightly 0.9.0.dev20250726104256__py3-none-any.whl → 0.9.1.dev20250727104258__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- letta/__init__.py +1 -1
- letta/agents/base_agent.py +1 -1
- letta/agents/letta_agent.py +6 -0
- letta/helpers/datetime_helpers.py +1 -1
- letta/helpers/json_helpers.py +1 -1
- letta/orm/agent.py +2 -3
- letta/orm/agents_tags.py +1 -0
- letta/orm/block.py +2 -2
- letta/orm/group.py +2 -2
- letta/orm/identity.py +3 -4
- letta/orm/mcp_oauth.py +62 -0
- letta/orm/step.py +2 -4
- letta/schemas/agent_file.py +31 -5
- letta/schemas/block.py +3 -0
- letta/schemas/enums.py +4 -0
- letta/schemas/group.py +3 -0
- letta/schemas/mcp.py +70 -0
- letta/schemas/memory.py +35 -0
- letta/schemas/message.py +98 -91
- letta/schemas/providers/openai.py +1 -1
- letta/server/rest_api/app.py +19 -21
- letta/server/rest_api/middleware/__init__.py +4 -0
- letta/server/rest_api/middleware/check_password.py +24 -0
- letta/server/rest_api/middleware/profiler_context.py +25 -0
- letta/server/rest_api/routers/v1/blocks.py +2 -0
- letta/server/rest_api/routers/v1/groups.py +1 -1
- letta/server/rest_api/routers/v1/sources.py +26 -0
- letta/server/rest_api/routers/v1/tools.py +224 -23
- letta/services/agent_manager.py +15 -9
- letta/services/agent_serialization_manager.py +84 -3
- letta/services/block_manager.py +4 -0
- letta/services/file_manager.py +23 -13
- letta/services/file_processor/file_processor.py +12 -10
- letta/services/mcp/base_client.py +20 -28
- letta/services/mcp/oauth_utils.py +433 -0
- letta/services/mcp/sse_client.py +12 -1
- letta/services/mcp/streamable_http_client.py +17 -5
- letta/services/mcp/types.py +9 -0
- letta/services/mcp_manager.py +304 -42
- letta/services/provider_manager.py +2 -2
- letta/services/tool_executor/tool_executor.py +6 -2
- letta/services/tool_manager.py +8 -4
- letta/services/tool_sandbox/base.py +3 -3
- letta/services/tool_sandbox/e2b_sandbox.py +1 -1
- letta/services/tool_sandbox/local_sandbox.py +16 -9
- letta/settings.py +11 -1
- letta/system.py +1 -1
- letta/templates/template_helper.py +25 -1
- letta/utils.py +19 -35
- {letta_nightly-0.9.0.dev20250726104256.dist-info → letta_nightly-0.9.1.dev20250727104258.dist-info}/METADATA +3 -2
- {letta_nightly-0.9.0.dev20250726104256.dist-info → letta_nightly-0.9.1.dev20250727104258.dist-info}/RECORD +54 -49
- {letta_nightly-0.9.0.dev20250726104256.dist-info → letta_nightly-0.9.1.dev20250727104258.dist-info}/LICENSE +0 -0
- {letta_nightly-0.9.0.dev20250726104256.dist-info → letta_nightly-0.9.1.dev20250727104258.dist-info}/WHEEL +0 -0
- {letta_nightly-0.9.0.dev20250726104256.dist-info → letta_nightly-0.9.1.dev20250727104258.dist-info}/entry_points.txt +0 -0
@@ -1,4 +1,6 @@
|
|
1
|
+
import asyncio
|
1
2
|
import json
|
3
|
+
from collections.abc import AsyncGenerator
|
2
4
|
from typing import Any, Dict, List, Optional, Union
|
3
5
|
|
4
6
|
from composio.client import ComposioClientError, HTTPError, NoItemsFound
|
@@ -11,27 +13,38 @@ from composio.exceptions import (
|
|
11
13
|
EnumStringNotFound,
|
12
14
|
)
|
13
15
|
from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query
|
16
|
+
from fastapi.responses import HTMLResponse
|
14
17
|
from pydantic import BaseModel, Field
|
18
|
+
from starlette.responses import StreamingResponse
|
15
19
|
|
16
20
|
from letta.errors import LettaToolCreateError
|
17
21
|
from letta.functions.functions import derive_openai_json_schema
|
18
22
|
from letta.functions.mcp_client.exceptions import MCPTimeoutError
|
19
|
-
from letta.functions.mcp_client.types import
|
23
|
+
from letta.functions.mcp_client.types import MCPTool, SSEServerConfig, StdioServerConfig, StreamableHTTPServerConfig
|
20
24
|
from letta.helpers.composio_helpers import get_composio_api_key
|
25
|
+
from letta.helpers.decorators import deprecated
|
21
26
|
from letta.llm_api.llm_client import LLMClient
|
22
27
|
from letta.log import get_logger
|
23
28
|
from letta.orm.errors import UniqueConstraintViolationError
|
29
|
+
from letta.orm.mcp_oauth import OAuthSessionStatus
|
24
30
|
from letta.schemas.enums import MessageRole
|
25
31
|
from letta.schemas.letta_message import ToolReturnMessage
|
26
32
|
from letta.schemas.letta_message_content import TextContent
|
27
|
-
from letta.schemas.mcp import UpdateSSEMCPServer, UpdateStdioMCPServer, UpdateStreamableHTTPMCPServer
|
33
|
+
from letta.schemas.mcp import MCPOAuthSessionCreate, UpdateSSEMCPServer, UpdateStdioMCPServer, UpdateStreamableHTTPMCPServer
|
28
34
|
from letta.schemas.message import Message
|
29
35
|
from letta.schemas.tool import Tool, ToolCreate, ToolRunFromSource, ToolUpdate
|
36
|
+
from letta.server.rest_api.streaming_response import StreamingResponseWithStatusCode
|
30
37
|
from letta.server.rest_api.utils import get_letta_server
|
31
38
|
from letta.server.server import SyncServer
|
32
|
-
from letta.services.mcp.
|
39
|
+
from letta.services.mcp.oauth_utils import (
|
40
|
+
MCPOAuthSession,
|
41
|
+
create_oauth_provider,
|
42
|
+
drill_down_exception,
|
43
|
+
get_oauth_success_html,
|
44
|
+
oauth_stream_event,
|
45
|
+
)
|
33
46
|
from letta.services.mcp.stdio_client import AsyncStdioMCPClient
|
34
|
-
from letta.services.mcp.
|
47
|
+
from letta.services.mcp.types import OauthStreamEvent
|
35
48
|
from letta.settings import tool_settings
|
36
49
|
|
37
50
|
router = APIRouter(prefix="/tools", tags=["tools"])
|
@@ -475,7 +488,11 @@ async def add_mcp_tool(
|
|
475
488
|
)
|
476
489
|
|
477
490
|
tool_create = ToolCreate.from_mcp(mcp_server_name=mcp_server_name, mcp_tool=mcp_tool)
|
478
|
-
|
491
|
+
# For config-based servers, use the server name as ID since they don't have database IDs
|
492
|
+
mcp_server_id = mcp_server_name
|
493
|
+
return await server.tool_manager.create_mcp_tool_async(
|
494
|
+
tool_create=tool_create, mcp_server_name=mcp_server_name, mcp_server_id=mcp_server_id, actor=actor
|
495
|
+
)
|
479
496
|
|
480
497
|
else:
|
481
498
|
return await server.mcp_manager.add_tool_from_mcp_server(mcp_server_name=mcp_server_name, mcp_tool_name=mcp_tool_name, actor=actor)
|
@@ -608,35 +625,26 @@ async def delete_mcp_server_from_config(
|
|
608
625
|
return [server.to_config() for server in all_servers]
|
609
626
|
|
610
627
|
|
611
|
-
@
|
628
|
+
@deprecated("Deprecated in favor of /mcp/servers/connect which handles OAuth flow via SSE stream")
|
629
|
+
@router.post("/mcp/servers/test", operation_id="test_mcp_server")
|
612
630
|
async def test_mcp_server(
|
613
631
|
request: Union[StdioServerConfig, SSEServerConfig, StreamableHTTPServerConfig] = Body(...),
|
632
|
+
server: SyncServer = Depends(get_letta_server),
|
633
|
+
actor_id: Optional[str] = Header(None, alias="user_id"),
|
614
634
|
):
|
615
635
|
"""
|
616
636
|
Test connection to an MCP server without adding it.
|
617
|
-
Returns the list of available tools if successful.
|
637
|
+
Returns the list of available tools if successful, or OAuth information if OAuth is required.
|
618
638
|
"""
|
619
639
|
client = None
|
620
640
|
try:
|
621
|
-
|
622
|
-
|
623
|
-
if not isinstance(request, SSEServerConfig):
|
624
|
-
request = SSEServerConfig(**request.model_dump())
|
625
|
-
client = AsyncSSEMCPClient(request)
|
626
|
-
elif request.type == MCPServerType.STREAMABLE_HTTP:
|
627
|
-
if not isinstance(request, StreamableHTTPServerConfig):
|
628
|
-
request = StreamableHTTPServerConfig(**request.model_dump())
|
629
|
-
client = AsyncStreamableHTTPMCPClient(request)
|
630
|
-
elif request.type == MCPServerType.STDIO:
|
631
|
-
if not isinstance(request, StdioServerConfig):
|
632
|
-
request = StdioServerConfig(**request.model_dump())
|
633
|
-
client = AsyncStdioMCPClient(request)
|
634
|
-
else:
|
635
|
-
raise ValueError(f"Invalid MCP server type: {request.type}")
|
641
|
+
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
642
|
+
client = await server.mcp_manager.get_mcp_client(request, actor)
|
636
643
|
|
637
644
|
await client.connect_to_server()
|
638
645
|
tools = await client.list_tools()
|
639
|
-
|
646
|
+
|
647
|
+
return {"status": "success", "tools": tools}
|
640
648
|
except ConnectionError as e:
|
641
649
|
raise HTTPException(
|
642
650
|
status_code=400,
|
@@ -672,6 +680,160 @@ async def test_mcp_server(
|
|
672
680
|
logger.warning(f"Error during MCP client cleanup: {cleanup_error}")
|
673
681
|
|
674
682
|
|
683
|
+
@router.post(
|
684
|
+
"/mcp/servers/connect",
|
685
|
+
response_model=None,
|
686
|
+
responses={
|
687
|
+
200: {
|
688
|
+
"description": "Successful response",
|
689
|
+
"content": {
|
690
|
+
"text/event-stream": {"description": "Server-Sent Events stream"},
|
691
|
+
},
|
692
|
+
}
|
693
|
+
},
|
694
|
+
operation_id="connect_mcp_server",
|
695
|
+
)
|
696
|
+
async def connect_mcp_server(
|
697
|
+
request: Union[StdioServerConfig, SSEServerConfig, StreamableHTTPServerConfig] = Body(...),
|
698
|
+
server: SyncServer = Depends(get_letta_server),
|
699
|
+
actor_id: Optional[str] = Header(None, alias="user_id"),
|
700
|
+
) -> StreamingResponse:
|
701
|
+
"""
|
702
|
+
Connect to an MCP server with support for OAuth via SSE.
|
703
|
+
Returns a stream of events handling authorization state and exchange if OAuth is required.
|
704
|
+
"""
|
705
|
+
|
706
|
+
async def oauth_stream_generator(
|
707
|
+
request: Union[StdioServerConfig, SSEServerConfig, StreamableHTTPServerConfig],
|
708
|
+
) -> AsyncGenerator[str, None]:
|
709
|
+
client = None
|
710
|
+
oauth_provider = None
|
711
|
+
temp_client = None
|
712
|
+
connect_task = None
|
713
|
+
|
714
|
+
try:
|
715
|
+
# Acknolwedge connection attempt
|
716
|
+
yield oauth_stream_event(OauthStreamEvent.CONNECTION_ATTEMPT, server_name=request.server_name)
|
717
|
+
|
718
|
+
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
719
|
+
|
720
|
+
# Create MCP client with respective transport type
|
721
|
+
try:
|
722
|
+
client = await server.mcp_manager.get_mcp_client(request, actor)
|
723
|
+
except ValueError as e:
|
724
|
+
yield oauth_stream_event(OauthStreamEvent.ERROR, message=str(e))
|
725
|
+
return
|
726
|
+
|
727
|
+
# Try normal connection first for flows that don't require OAuth
|
728
|
+
try:
|
729
|
+
await client.connect_to_server()
|
730
|
+
tools = await client.list_tools(serialize=True)
|
731
|
+
yield oauth_stream_event(OauthStreamEvent.SUCCESS, tools=tools)
|
732
|
+
return
|
733
|
+
except ConnectionError:
|
734
|
+
# TODO: jnjpng make this connection error check more specific to the 401 unauthorized error
|
735
|
+
if isinstance(client, AsyncStdioMCPClient):
|
736
|
+
logger.warning(f"OAuth not supported for stdio")
|
737
|
+
yield oauth_stream_event(OauthStreamEvent.ERROR, message=f"OAuth not supported for stdio")
|
738
|
+
return
|
739
|
+
# Continue to OAuth flow
|
740
|
+
logger.info(f"Attempting OAuth flow for {request}...")
|
741
|
+
except Exception as e:
|
742
|
+
yield oauth_stream_event(OauthStreamEvent.ERROR, message=f"Connection failed: {str(e)}")
|
743
|
+
return
|
744
|
+
|
745
|
+
# OAuth required, yield state to client to prepare to handle authorization URL
|
746
|
+
yield oauth_stream_event(OauthStreamEvent.OAUTH_REQUIRED, message="OAuth authentication required")
|
747
|
+
|
748
|
+
# Create OAuth session to persist the state of the OAuth flow
|
749
|
+
session_create = MCPOAuthSessionCreate(
|
750
|
+
server_url=request.server_url,
|
751
|
+
server_name=request.server_name,
|
752
|
+
user_id=actor.id,
|
753
|
+
organization_id=actor.organization_id,
|
754
|
+
)
|
755
|
+
oauth_session = await server.mcp_manager.create_oauth_session(session_create, actor)
|
756
|
+
session_id = oauth_session.id
|
757
|
+
|
758
|
+
# Create OAuth provider for the instance of the stream connection
|
759
|
+
# Note: Using the correct API path for the callback
|
760
|
+
# do not edit this this is the correct url
|
761
|
+
redirect_uri = f"http://localhost:8283/v1/tools/mcp/oauth/callback/{session_id}"
|
762
|
+
oauth_provider = await create_oauth_provider(session_id, request.server_url, redirect_uri, server.mcp_manager, actor)
|
763
|
+
|
764
|
+
# Get authorization URL by triggering OAuth flow
|
765
|
+
temp_client = None
|
766
|
+
try:
|
767
|
+
temp_client = await server.mcp_manager.get_mcp_client(request, actor, oauth_provider)
|
768
|
+
|
769
|
+
# Run connect_to_server in background to avoid blocking
|
770
|
+
# This will trigger the OAuth flow and the redirect_handler will save the authorization URL to database
|
771
|
+
connect_task = asyncio.create_task(temp_client.connect_to_server())
|
772
|
+
|
773
|
+
# Give the OAuth flow time to trigger and save the URL
|
774
|
+
await asyncio.sleep(1.0)
|
775
|
+
|
776
|
+
# Fetch the authorization URL from database and yield state to client to proceed with handling authorization URL
|
777
|
+
auth_session = await server.mcp_manager.get_oauth_session_by_id(session_id, actor)
|
778
|
+
if auth_session and auth_session.authorization_url:
|
779
|
+
yield oauth_stream_event(OauthStreamEvent.AUTHORIZATION_URL, url=auth_session.authorization_url, session_id=session_id)
|
780
|
+
|
781
|
+
except Exception as e:
|
782
|
+
logger.error(f"Error triggering OAuth flow: {e}")
|
783
|
+
yield oauth_stream_event(OauthStreamEvent.ERROR, message=f"Failed to trigger OAuth: {str(e)}")
|
784
|
+
|
785
|
+
# Clean up active resources
|
786
|
+
if connect_task and not connect_task.done():
|
787
|
+
connect_task.cancel()
|
788
|
+
try:
|
789
|
+
await connect_task
|
790
|
+
except asyncio.CancelledError:
|
791
|
+
pass
|
792
|
+
if temp_client:
|
793
|
+
try:
|
794
|
+
await temp_client.cleanup()
|
795
|
+
except Exception as cleanup_error:
|
796
|
+
logger.warning(f"Error during temp MCP client cleanup: {cleanup_error}")
|
797
|
+
return
|
798
|
+
|
799
|
+
# Wait for user authorization (with timeout), client should render loading state until user completes the flow and /mcp/oauth/callback/{session_id} is hit
|
800
|
+
yield oauth_stream_event(OauthStreamEvent.WAITING_FOR_AUTH, message="Waiting for user authorization...")
|
801
|
+
|
802
|
+
# Callback handler will poll for authorization code and state and update the OAuth session
|
803
|
+
await connect_task
|
804
|
+
|
805
|
+
tools = await temp_client.list_tools(serialize=True)
|
806
|
+
|
807
|
+
yield oauth_stream_event(OauthStreamEvent.SUCCESS, tools=tools)
|
808
|
+
return
|
809
|
+
except Exception as e:
|
810
|
+
detailed_error = drill_down_exception(e)
|
811
|
+
logger.error(f"Error in OAuth stream:\n{detailed_error}")
|
812
|
+
yield oauth_stream_event(OauthStreamEvent.ERROR, message=f"Internal error: {detailed_error}")
|
813
|
+
finally:
|
814
|
+
if connect_task and not connect_task.done():
|
815
|
+
connect_task.cancel()
|
816
|
+
try:
|
817
|
+
await connect_task
|
818
|
+
except asyncio.CancelledError:
|
819
|
+
pass
|
820
|
+
if client:
|
821
|
+
try:
|
822
|
+
await client.cleanup()
|
823
|
+
except Exception as cleanup_error:
|
824
|
+
detailed_error = drill_down_exception(cleanup_error)
|
825
|
+
logger.warning(f"Error during MCP client cleanup: {detailed_error}")
|
826
|
+
if temp_client:
|
827
|
+
try:
|
828
|
+
await temp_client.cleanup()
|
829
|
+
except Exception as cleanup_error:
|
830
|
+
# TODO: @jnjpng fix async cancel scope issue
|
831
|
+
# detailed_error = drill_down_exception(cleanup_error)
|
832
|
+
logger.warning(f"Aysnc cleanup confict during temp MCP client cleanup: {cleanup_error}")
|
833
|
+
|
834
|
+
return StreamingResponseWithStatusCode(oauth_stream_generator(request), media_type="text/event-stream")
|
835
|
+
|
836
|
+
|
675
837
|
class CodeInput(BaseModel):
|
676
838
|
code: str = Field(..., description="Python source code to parse for JSON schema")
|
677
839
|
|
@@ -693,6 +855,45 @@ async def generate_json_schema(
|
|
693
855
|
raise HTTPException(status_code=400, detail=f"Failed to generate schema: {str(e)}")
|
694
856
|
|
695
857
|
|
858
|
+
# TODO: @jnjpng need to route this through cloud API for production
|
859
|
+
@router.get("/mcp/oauth/callback/{session_id}", operation_id="mcp_oauth_callback", response_class=HTMLResponse)
|
860
|
+
async def mcp_oauth_callback(
|
861
|
+
session_id: str,
|
862
|
+
code: Optional[str] = Query(None, description="OAuth authorization code"),
|
863
|
+
state: Optional[str] = Query(None, description="OAuth state parameter"),
|
864
|
+
error: Optional[str] = Query(None, description="OAuth error"),
|
865
|
+
error_description: Optional[str] = Query(None, description="OAuth error description"),
|
866
|
+
):
|
867
|
+
"""
|
868
|
+
Handle OAuth callback for MCP server authentication.
|
869
|
+
"""
|
870
|
+
try:
|
871
|
+
oauth_session = MCPOAuthSession(session_id)
|
872
|
+
|
873
|
+
if error:
|
874
|
+
error_msg = f"OAuth error: {error}"
|
875
|
+
if error_description:
|
876
|
+
error_msg += f" - {error_description}"
|
877
|
+
await oauth_session.update_session_status(OAuthSessionStatus.ERROR)
|
878
|
+
return {"status": "error", "message": error_msg}
|
879
|
+
|
880
|
+
if not code or not state:
|
881
|
+
await oauth_session.update_session_status(OAuthSessionStatus.ERROR)
|
882
|
+
return {"status": "error", "message": "Missing authorization code or state"}
|
883
|
+
|
884
|
+
# Store authorization code
|
885
|
+
success = await oauth_session.store_authorization_code(code, state)
|
886
|
+
if not success:
|
887
|
+
await oauth_session.update_session_status(OAuthSessionStatus.ERROR)
|
888
|
+
return {"status": "error", "message": "Invalid state parameter"}
|
889
|
+
|
890
|
+
return HTMLResponse(content=get_oauth_success_html(), status_code=200)
|
891
|
+
|
892
|
+
except Exception as e:
|
893
|
+
logger.error(f"OAuth callback error: {e}")
|
894
|
+
return {"status": "error", "message": f"OAuth callback failed: {str(e)}"}
|
895
|
+
|
896
|
+
|
696
897
|
class GenerateToolInput(BaseModel):
|
697
898
|
tool_name: str = Field(..., description="Name of the tool to generate code for")
|
698
899
|
prompt: str = Field(..., description="User prompt to generate code")
|
letta/services/agent_manager.py
CHANGED
@@ -15,6 +15,8 @@ from letta.constants import (
|
|
15
15
|
BASE_TOOLS,
|
16
16
|
BASE_VOICE_SLEEPTIME_CHAT_TOOLS,
|
17
17
|
BASE_VOICE_SLEEPTIME_TOOLS,
|
18
|
+
DEFAULT_CORE_MEMORY_SOURCE_CHAR_LIMIT,
|
19
|
+
DEFAULT_MAX_FILES_OPEN,
|
18
20
|
DEFAULT_TIMEZONE,
|
19
21
|
DEPRECATED_LETTA_TOOLS,
|
20
22
|
FILES_TOOLS,
|
@@ -1644,7 +1646,7 @@ class AgentManager:
|
|
1644
1646
|
|
1645
1647
|
# note: we only update the system prompt if the core memory is changed
|
1646
1648
|
# this means that the archival/recall memory statistics may be someout out of date
|
1647
|
-
curr_memory_str = agent_state.memory.
|
1649
|
+
curr_memory_str = await agent_state.memory.compile_async(
|
1648
1650
|
sources=agent_state.sources,
|
1649
1651
|
tool_usage_rules=tool_rules_solver.compile_tool_rule_prompts(),
|
1650
1652
|
max_files_open=agent_state.max_files_open,
|
@@ -1834,14 +1836,12 @@ class AgentManager:
|
|
1834
1836
|
agent_state = await self.get_agent_by_id_async(agent_id=agent_id, actor=actor, include_relationships=["memory", "sources"])
|
1835
1837
|
system_message = await self.message_manager.get_message_by_id_async(message_id=agent_state.message_ids[0], actor=actor)
|
1836
1838
|
temp_tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
|
1837
|
-
|
1838
|
-
|
1839
|
-
|
1840
|
-
|
1841
|
-
|
1842
|
-
|
1843
|
-
not in system_message.content[0].text
|
1844
|
-
):
|
1839
|
+
new_memory_str = await new_memory.compile_async(
|
1840
|
+
sources=agent_state.sources,
|
1841
|
+
tool_usage_rules=temp_tool_rules_solver.compile_tool_rule_prompts(),
|
1842
|
+
max_files_open=agent_state.max_files_open,
|
1843
|
+
)
|
1844
|
+
if new_memory_str not in system_message.content[0].text:
|
1845
1845
|
# update the blocks (LRW) in the DB
|
1846
1846
|
for label in agent_state.memory.list_block_labels():
|
1847
1847
|
updated_value = new_memory.get_block(label).value
|
@@ -3169,6 +3169,12 @@ class AgentManager:
|
|
3169
3169
|
if max_files is None:
|
3170
3170
|
max_files = default_max_files
|
3171
3171
|
|
3172
|
+
# FINAL fallback: ensure neither is None (should never happen, but just in case)
|
3173
|
+
if per_file_limit is None:
|
3174
|
+
per_file_limit = DEFAULT_CORE_MEMORY_SOURCE_CHAR_LIMIT
|
3175
|
+
if max_files is None:
|
3176
|
+
max_files = DEFAULT_MAX_FILES_OPEN
|
3177
|
+
|
3172
3178
|
return per_file_limit, max_files
|
3173
3179
|
|
3174
3180
|
@enforce_types
|
@@ -1,6 +1,7 @@
|
|
1
1
|
from datetime import datetime, timezone
|
2
2
|
from typing import Dict, List
|
3
3
|
|
4
|
+
from letta.constants import MCP_TOOL_TAG_NAME_PREFIX
|
4
5
|
from letta.errors import AgentFileExportError, AgentFileImportError
|
5
6
|
from letta.helpers.pinecone_utils import should_use_pinecone
|
6
7
|
from letta.log import get_logger
|
@@ -13,6 +14,7 @@ from letta.schemas.agent_file import (
|
|
13
14
|
FileSchema,
|
14
15
|
GroupSchema,
|
15
16
|
ImportResult,
|
17
|
+
MCPServerSchema,
|
16
18
|
MessageSchema,
|
17
19
|
SourceSchema,
|
18
20
|
ToolSchema,
|
@@ -20,6 +22,7 @@ from letta.schemas.agent_file import (
|
|
20
22
|
from letta.schemas.block import Block
|
21
23
|
from letta.schemas.enums import FileProcessingStatus
|
22
24
|
from letta.schemas.file import FileMetadata
|
25
|
+
from letta.schemas.mcp import MCPServer
|
23
26
|
from letta.schemas.message import Message
|
24
27
|
from letta.schemas.source import Source
|
25
28
|
from letta.schemas.tool import Tool
|
@@ -92,7 +95,7 @@ class AgentSerializationManager:
|
|
92
95
|
ToolSchema.__id_prefix__: 0,
|
93
96
|
MessageSchema.__id_prefix__: 0,
|
94
97
|
FileAgentSchema.__id_prefix__: 0,
|
95
|
-
|
98
|
+
MCPServerSchema.__id_prefix__: 0,
|
96
99
|
}
|
97
100
|
|
98
101
|
def _reset_state(self):
|
@@ -258,6 +261,53 @@ class AgentSerializationManager:
|
|
258
261
|
file_schema.source_id = self._map_db_to_file_id(file_metadata.source_id, SourceSchema.__id_prefix__, allow_new=False)
|
259
262
|
return file_schema
|
260
263
|
|
264
|
+
async def _extract_unique_mcp_servers(self, tools: List, actor: User) -> List:
|
265
|
+
"""Extract unique MCP servers from tools based on metadata, using server_id if available, otherwise falling back to server_name."""
|
266
|
+
mcp_server_ids = set()
|
267
|
+
mcp_server_names = set()
|
268
|
+
for tool in tools:
|
269
|
+
# Check if tool has MCP metadata
|
270
|
+
if tool.metadata_ and MCP_TOOL_TAG_NAME_PREFIX in tool.metadata_:
|
271
|
+
mcp_metadata = tool.metadata_[MCP_TOOL_TAG_NAME_PREFIX]
|
272
|
+
# TODO: @jnjpng clean this up once we fully migrate to server_id being the main identifier
|
273
|
+
if "server_id" in mcp_metadata:
|
274
|
+
mcp_server_ids.add(mcp_metadata["server_id"])
|
275
|
+
elif "server_name" in mcp_metadata:
|
276
|
+
mcp_server_names.add(mcp_metadata["server_name"])
|
277
|
+
|
278
|
+
# Fetch MCP servers by ID
|
279
|
+
mcp_servers = []
|
280
|
+
fetched_server_ids = set()
|
281
|
+
if mcp_server_ids:
|
282
|
+
try:
|
283
|
+
mcp_servers = await self.mcp_manager.get_mcp_servers_by_ids(list(mcp_server_ids), actor)
|
284
|
+
fetched_server_ids.update([mcp_server.id for mcp_server in mcp_servers])
|
285
|
+
except Exception as e:
|
286
|
+
logger.warning(f"Failed to fetch MCP servers by IDs {mcp_server_ids}: {e}")
|
287
|
+
|
288
|
+
# Fetch MCP servers by name if not already fetched by ID
|
289
|
+
if mcp_server_names:
|
290
|
+
for server_name in mcp_server_names:
|
291
|
+
try:
|
292
|
+
mcp_server = await self.mcp_manager.get_mcp_server(server_name, actor)
|
293
|
+
if mcp_server and mcp_server.id not in fetched_server_ids:
|
294
|
+
mcp_servers.append(mcp_server)
|
295
|
+
except Exception as e:
|
296
|
+
logger.warning(f"Failed to fetch MCP server by name {server_name}: {e}")
|
297
|
+
|
298
|
+
return mcp_servers
|
299
|
+
|
300
|
+
def _convert_mcp_server_to_schema(self, mcp_server: MCPServer) -> MCPServerSchema:
|
301
|
+
"""Convert MCPServer to MCPServerSchema with ID remapping and auth scrubbing"""
|
302
|
+
try:
|
303
|
+
mcp_file_id = self._map_db_to_file_id(mcp_server.id, MCPServerSchema.__id_prefix__, allow_new=False)
|
304
|
+
mcp_schema = MCPServerSchema.from_mcp_server(mcp_server)
|
305
|
+
mcp_schema.id = mcp_file_id
|
306
|
+
return mcp_schema
|
307
|
+
except Exception as e:
|
308
|
+
logger.error(f"Failed to convert MCP server {mcp_server.id}: {e}")
|
309
|
+
raise
|
310
|
+
|
261
311
|
async def export(self, agent_ids: List[str], actor: User) -> AgentFileSchema:
|
262
312
|
"""
|
263
313
|
Export agents and their related entities to AgentFileSchema format.
|
@@ -289,6 +339,13 @@ class AgentSerializationManager:
|
|
289
339
|
tool_set = self._extract_unique_tools(agent_states)
|
290
340
|
block_set = self._extract_unique_blocks(agent_states)
|
291
341
|
|
342
|
+
# Extract MCP servers from tools BEFORE conversion (must be done before ID mapping)
|
343
|
+
mcp_server_set = await self._extract_unique_mcp_servers(tool_set, actor)
|
344
|
+
|
345
|
+
# Map MCP server IDs before converting schemas
|
346
|
+
for mcp_server in mcp_server_set:
|
347
|
+
self._map_db_to_file_id(mcp_server.id, MCPServerSchema.__id_prefix__)
|
348
|
+
|
292
349
|
# Extract sources and files from agent states BEFORE conversion (with caching)
|
293
350
|
source_set, file_set = await self._extract_unique_sources_and_files_from_agents(agent_states, actor, files_agents_cache)
|
294
351
|
|
@@ -301,6 +358,7 @@ class AgentSerializationManager:
|
|
301
358
|
block_schemas = [self._convert_block_to_schema(block) for block in block_set]
|
302
359
|
source_schemas = [self._convert_source_to_schema(source) for source in source_set]
|
303
360
|
file_schemas = [self._convert_file_to_schema(file_metadata) for file_metadata in file_set]
|
361
|
+
mcp_server_schemas = [self._convert_mcp_server_to_schema(mcp_server) for mcp_server in mcp_server_set]
|
304
362
|
|
305
363
|
logger.info(f"Exporting {len(agent_ids)} agents to agent file format")
|
306
364
|
|
@@ -312,7 +370,7 @@ class AgentSerializationManager:
|
|
312
370
|
files=file_schemas,
|
313
371
|
sources=source_schemas,
|
314
372
|
tools=tool_schemas,
|
315
|
-
|
373
|
+
mcp_servers=mcp_server_schemas,
|
316
374
|
metadata={"revision_id": await get_latest_alembic_revision()},
|
317
375
|
created_at=datetime.now(timezone.utc),
|
318
376
|
)
|
@@ -359,7 +417,20 @@ class AgentSerializationManager:
|
|
359
417
|
# in-memory cache for file metadata to avoid repeated db calls
|
360
418
|
file_metadata_cache = {} # Maps database file ID to FileMetadata
|
361
419
|
|
362
|
-
# 1. Create
|
420
|
+
# 1. Create MCP servers first (tools depend on them)
|
421
|
+
if schema.mcp_servers:
|
422
|
+
for mcp_server_schema in schema.mcp_servers:
|
423
|
+
server_data = mcp_server_schema.model_dump(exclude={"id"})
|
424
|
+
filtered_server_data = self._filter_dict_for_model(server_data, MCPServer)
|
425
|
+
create_schema = MCPServer(**filtered_server_data)
|
426
|
+
|
427
|
+
# Note: We don't have auth info from export, so the user will need to re-configure auth.
|
428
|
+
# TODO: @jnjpng store metadata about obfuscated metadata to surface to the user
|
429
|
+
created_mcp_server = await self.mcp_manager.create_or_update_mcp_server(create_schema, actor)
|
430
|
+
file_to_db_ids[mcp_server_schema.id] = created_mcp_server.id
|
431
|
+
imported_count += 1
|
432
|
+
|
433
|
+
# 2. Create tools (may depend on MCP servers) - using bulk upsert for efficiency
|
363
434
|
if schema.tools:
|
364
435
|
# convert tool schemas to pydantic tools
|
365
436
|
pydantic_tools = []
|
@@ -559,6 +630,7 @@ class AgentSerializationManager:
|
|
559
630
|
(schema.files, FileSchema.__id_prefix__),
|
560
631
|
(schema.sources, SourceSchema.__id_prefix__),
|
561
632
|
(schema.tools, ToolSchema.__id_prefix__),
|
633
|
+
(schema.mcp_servers, MCPServerSchema.__id_prefix__),
|
562
634
|
]
|
563
635
|
|
564
636
|
for entities, expected_prefix in entity_checks:
|
@@ -601,6 +673,7 @@ class AgentSerializationManager:
|
|
601
673
|
("files", schema.files),
|
602
674
|
("sources", schema.sources),
|
603
675
|
("tools", schema.tools),
|
676
|
+
("mcp_servers", schema.mcp_servers),
|
604
677
|
]
|
605
678
|
|
606
679
|
for entity_type, entities in entity_collections:
|
@@ -705,3 +778,11 @@ class AgentSerializationManager:
|
|
705
778
|
raise AgentFileImportError(f"Schema validation failed: {'; '.join(errors)}")
|
706
779
|
|
707
780
|
logger.info("Schema validation passed")
|
781
|
+
|
782
|
+
def _filter_dict_for_model(self, data: dict, model_cls):
|
783
|
+
"""Filter a dictionary to only include keys that are in the model fields"""
|
784
|
+
try:
|
785
|
+
allowed = model_cls.model_fields.keys() # Pydantic v2
|
786
|
+
except AttributeError:
|
787
|
+
allowed = model_cls.__fields__.keys() # Pydantic v1
|
788
|
+
return {k: v for k, v in data.items() if k in allowed}
|
letta/services/block_manager.py
CHANGED
@@ -178,6 +178,7 @@ class BlockManager:
|
|
178
178
|
template_name: Optional[str] = None,
|
179
179
|
identity_id: Optional[str] = None,
|
180
180
|
identifier_keys: Optional[List[str]] = None,
|
181
|
+
project_id: Optional[str] = None,
|
181
182
|
before: Optional[str] = None,
|
182
183
|
after: Optional[str] = None,
|
183
184
|
limit: Optional[int] = 50,
|
@@ -210,6 +211,9 @@ class BlockManager:
|
|
210
211
|
if template_name:
|
211
212
|
query = query.where(BlockModel.template_name == template_name)
|
212
213
|
|
214
|
+
if project_id:
|
215
|
+
query = query.where(BlockModel.project_id == project_id)
|
216
|
+
|
213
217
|
needs_distinct = False
|
214
218
|
|
215
219
|
if identifier_keys:
|
letta/services/file_manager.py
CHANGED
@@ -151,7 +151,8 @@ class FileManager:
|
|
151
151
|
Enforces state transition rules (when enforce_state_transitions=True):
|
152
152
|
- PENDING -> PARSING -> EMBEDDING -> COMPLETED (normal flow)
|
153
153
|
- Any non-terminal state -> ERROR
|
154
|
-
-
|
154
|
+
- Same-state transitions are allowed (e.g., EMBEDDING -> EMBEDDING)
|
155
|
+
- ERROR and COMPLETED are terminal (no status transitions allowed, metadata updates blocked)
|
155
156
|
|
156
157
|
Args:
|
157
158
|
file_id: ID of the file to update
|
@@ -196,22 +197,31 @@ class FileManager:
|
|
196
197
|
]
|
197
198
|
|
198
199
|
# only add state transition validation if enforce_state_transitions is True
|
199
|
-
if enforce_state_transitions:
|
200
|
-
#
|
200
|
+
if enforce_state_transitions and processing_status is not None:
|
201
|
+
# enforce specific transitions based on target status
|
202
|
+
if processing_status == FileProcessingStatus.PARSING:
|
203
|
+
where_conditions.append(
|
204
|
+
FileMetadataModel.processing_status.in_([FileProcessingStatus.PENDING, FileProcessingStatus.PARSING])
|
205
|
+
)
|
206
|
+
elif processing_status == FileProcessingStatus.EMBEDDING:
|
207
|
+
where_conditions.append(
|
208
|
+
FileMetadataModel.processing_status.in_([FileProcessingStatus.PARSING, FileProcessingStatus.EMBEDDING])
|
209
|
+
)
|
210
|
+
elif processing_status == FileProcessingStatus.COMPLETED:
|
211
|
+
where_conditions.append(
|
212
|
+
FileMetadataModel.processing_status.in_([FileProcessingStatus.EMBEDDING, FileProcessingStatus.COMPLETED])
|
213
|
+
)
|
214
|
+
elif processing_status == FileProcessingStatus.ERROR:
|
215
|
+
# ERROR can be set from any non-terminal state
|
216
|
+
where_conditions.append(
|
217
|
+
FileMetadataModel.processing_status.notin_([FileProcessingStatus.ERROR, FileProcessingStatus.COMPLETED])
|
218
|
+
)
|
219
|
+
elif enforce_state_transitions and processing_status is None:
|
220
|
+
# If only updating metadata fields (not status), prevent updates to terminal states
|
201
221
|
where_conditions.append(
|
202
222
|
FileMetadataModel.processing_status.notin_([FileProcessingStatus.ERROR, FileProcessingStatus.COMPLETED])
|
203
223
|
)
|
204
224
|
|
205
|
-
if processing_status is not None:
|
206
|
-
# enforce specific transitions based on target status
|
207
|
-
if processing_status == FileProcessingStatus.PARSING:
|
208
|
-
where_conditions.append(FileMetadataModel.processing_status == FileProcessingStatus.PENDING)
|
209
|
-
elif processing_status == FileProcessingStatus.EMBEDDING:
|
210
|
-
where_conditions.append(FileMetadataModel.processing_status == FileProcessingStatus.PARSING)
|
211
|
-
elif processing_status == FileProcessingStatus.COMPLETED:
|
212
|
-
where_conditions.append(FileMetadataModel.processing_status == FileProcessingStatus.EMBEDDING)
|
213
|
-
# ERROR can be set from any non-terminal state (already handled by terminal check above)
|
214
|
-
|
215
225
|
# fast in-place update with state validation
|
216
226
|
stmt = (
|
217
227
|
update(FileMetadataModel)
|
@@ -69,6 +69,15 @@ class FileProcessor:
|
|
69
69
|
raise ValueError("No chunks created from text")
|
70
70
|
all_chunks.extend(chunks)
|
71
71
|
|
72
|
+
# Update with chunks length
|
73
|
+
file_metadata = await self.file_manager.update_file_status(
|
74
|
+
file_id=file_metadata.id,
|
75
|
+
actor=self.actor,
|
76
|
+
processing_status=FileProcessingStatus.EMBEDDING,
|
77
|
+
total_chunks=len(all_chunks),
|
78
|
+
chunks_embedded=0,
|
79
|
+
)
|
80
|
+
|
72
81
|
all_passages = await self.embedder.generate_embedded_passages(
|
73
82
|
file_id=file_metadata.id,
|
74
83
|
source_id=source_id,
|
@@ -177,9 +186,7 @@ class FileProcessor:
|
|
177
186
|
"file_processor.ocr_completed",
|
178
187
|
{"filename": filename, "pages_extracted": len(ocr_response.pages), "text_length": len(raw_markdown_text)},
|
179
188
|
)
|
180
|
-
|
181
|
-
file_id=file_metadata.id, actor=self.actor, processing_status=FileProcessingStatus.EMBEDDING
|
182
|
-
)
|
189
|
+
|
183
190
|
file_metadata = await self.file_manager.upsert_file_content(file_id=file_metadata.id, text=raw_markdown_text, actor=self.actor)
|
184
191
|
|
185
192
|
await self.agent_manager.insert_file_into_context_windows(
|
@@ -241,13 +248,7 @@ class FileProcessor:
|
|
241
248
|
file_id=file_metadata.id,
|
242
249
|
actor=self.actor,
|
243
250
|
processing_status=FileProcessingStatus.COMPLETED,
|
244
|
-
|
245
|
-
else:
|
246
|
-
await self.file_manager.update_file_status(
|
247
|
-
file_id=file_metadata.id,
|
248
|
-
actor=self.actor,
|
249
|
-
total_chunks=len(all_passages),
|
250
|
-
chunks_embedded=0,
|
251
|
+
chunks_embedded=len(all_passages),
|
251
252
|
)
|
252
253
|
|
253
254
|
return all_passages
|
@@ -286,6 +287,7 @@ class FileProcessor:
|
|
286
287
|
document_annotation=None,
|
287
288
|
)
|
288
289
|
|
290
|
+
# TODO: The file state machine here is kind of out of date, we need to match with the correct one above
|
289
291
|
@trace_method
|
290
292
|
async def process_imported_file(self, file_metadata: FileMetadata, source_id: str) -> List[Passage]:
|
291
293
|
"""Process an imported file that already has content - skip OCR, do chunking/embedding"""
|