letta-nightly 0.9.0.dev20250726104256__py3-none-any.whl → 0.9.1.dev20250727104258__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. letta/__init__.py +1 -1
  2. letta/agents/base_agent.py +1 -1
  3. letta/agents/letta_agent.py +6 -0
  4. letta/helpers/datetime_helpers.py +1 -1
  5. letta/helpers/json_helpers.py +1 -1
  6. letta/orm/agent.py +2 -3
  7. letta/orm/agents_tags.py +1 -0
  8. letta/orm/block.py +2 -2
  9. letta/orm/group.py +2 -2
  10. letta/orm/identity.py +3 -4
  11. letta/orm/mcp_oauth.py +62 -0
  12. letta/orm/step.py +2 -4
  13. letta/schemas/agent_file.py +31 -5
  14. letta/schemas/block.py +3 -0
  15. letta/schemas/enums.py +4 -0
  16. letta/schemas/group.py +3 -0
  17. letta/schemas/mcp.py +70 -0
  18. letta/schemas/memory.py +35 -0
  19. letta/schemas/message.py +98 -91
  20. letta/schemas/providers/openai.py +1 -1
  21. letta/server/rest_api/app.py +19 -21
  22. letta/server/rest_api/middleware/__init__.py +4 -0
  23. letta/server/rest_api/middleware/check_password.py +24 -0
  24. letta/server/rest_api/middleware/profiler_context.py +25 -0
  25. letta/server/rest_api/routers/v1/blocks.py +2 -0
  26. letta/server/rest_api/routers/v1/groups.py +1 -1
  27. letta/server/rest_api/routers/v1/sources.py +26 -0
  28. letta/server/rest_api/routers/v1/tools.py +224 -23
  29. letta/services/agent_manager.py +15 -9
  30. letta/services/agent_serialization_manager.py +84 -3
  31. letta/services/block_manager.py +4 -0
  32. letta/services/file_manager.py +23 -13
  33. letta/services/file_processor/file_processor.py +12 -10
  34. letta/services/mcp/base_client.py +20 -28
  35. letta/services/mcp/oauth_utils.py +433 -0
  36. letta/services/mcp/sse_client.py +12 -1
  37. letta/services/mcp/streamable_http_client.py +17 -5
  38. letta/services/mcp/types.py +9 -0
  39. letta/services/mcp_manager.py +304 -42
  40. letta/services/provider_manager.py +2 -2
  41. letta/services/tool_executor/tool_executor.py +6 -2
  42. letta/services/tool_manager.py +8 -4
  43. letta/services/tool_sandbox/base.py +3 -3
  44. letta/services/tool_sandbox/e2b_sandbox.py +1 -1
  45. letta/services/tool_sandbox/local_sandbox.py +16 -9
  46. letta/settings.py +11 -1
  47. letta/system.py +1 -1
  48. letta/templates/template_helper.py +25 -1
  49. letta/utils.py +19 -35
  50. {letta_nightly-0.9.0.dev20250726104256.dist-info → letta_nightly-0.9.1.dev20250727104258.dist-info}/METADATA +3 -2
  51. {letta_nightly-0.9.0.dev20250726104256.dist-info → letta_nightly-0.9.1.dev20250727104258.dist-info}/RECORD +54 -49
  52. {letta_nightly-0.9.0.dev20250726104256.dist-info → letta_nightly-0.9.1.dev20250727104258.dist-info}/LICENSE +0 -0
  53. {letta_nightly-0.9.0.dev20250726104256.dist-info → letta_nightly-0.9.1.dev20250727104258.dist-info}/WHEEL +0 -0
  54. {letta_nightly-0.9.0.dev20250726104256.dist-info → letta_nightly-0.9.1.dev20250727104258.dist-info}/entry_points.txt +0 -0
@@ -1,4 +1,6 @@
1
+ import asyncio
1
2
  import json
3
+ from collections.abc import AsyncGenerator
2
4
  from typing import Any, Dict, List, Optional, Union
3
5
 
4
6
  from composio.client import ComposioClientError, HTTPError, NoItemsFound
@@ -11,27 +13,38 @@ from composio.exceptions import (
11
13
  EnumStringNotFound,
12
14
  )
13
15
  from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query
16
+ from fastapi.responses import HTMLResponse
14
17
  from pydantic import BaseModel, Field
18
+ from starlette.responses import StreamingResponse
15
19
 
16
20
  from letta.errors import LettaToolCreateError
17
21
  from letta.functions.functions import derive_openai_json_schema
18
22
  from letta.functions.mcp_client.exceptions import MCPTimeoutError
19
- from letta.functions.mcp_client.types import MCPServerType, MCPTool, SSEServerConfig, StdioServerConfig, StreamableHTTPServerConfig
23
+ from letta.functions.mcp_client.types import MCPTool, SSEServerConfig, StdioServerConfig, StreamableHTTPServerConfig
20
24
  from letta.helpers.composio_helpers import get_composio_api_key
25
+ from letta.helpers.decorators import deprecated
21
26
  from letta.llm_api.llm_client import LLMClient
22
27
  from letta.log import get_logger
23
28
  from letta.orm.errors import UniqueConstraintViolationError
29
+ from letta.orm.mcp_oauth import OAuthSessionStatus
24
30
  from letta.schemas.enums import MessageRole
25
31
  from letta.schemas.letta_message import ToolReturnMessage
26
32
  from letta.schemas.letta_message_content import TextContent
27
- from letta.schemas.mcp import UpdateSSEMCPServer, UpdateStdioMCPServer, UpdateStreamableHTTPMCPServer
33
+ from letta.schemas.mcp import MCPOAuthSessionCreate, UpdateSSEMCPServer, UpdateStdioMCPServer, UpdateStreamableHTTPMCPServer
28
34
  from letta.schemas.message import Message
29
35
  from letta.schemas.tool import Tool, ToolCreate, ToolRunFromSource, ToolUpdate
36
+ from letta.server.rest_api.streaming_response import StreamingResponseWithStatusCode
30
37
  from letta.server.rest_api.utils import get_letta_server
31
38
  from letta.server.server import SyncServer
32
- from letta.services.mcp.sse_client import AsyncSSEMCPClient
39
+ from letta.services.mcp.oauth_utils import (
40
+ MCPOAuthSession,
41
+ create_oauth_provider,
42
+ drill_down_exception,
43
+ get_oauth_success_html,
44
+ oauth_stream_event,
45
+ )
33
46
  from letta.services.mcp.stdio_client import AsyncStdioMCPClient
34
- from letta.services.mcp.streamable_http_client import AsyncStreamableHTTPMCPClient
47
+ from letta.services.mcp.types import OauthStreamEvent
35
48
  from letta.settings import tool_settings
36
49
 
37
50
  router = APIRouter(prefix="/tools", tags=["tools"])
@@ -475,7 +488,11 @@ async def add_mcp_tool(
475
488
  )
476
489
 
477
490
  tool_create = ToolCreate.from_mcp(mcp_server_name=mcp_server_name, mcp_tool=mcp_tool)
478
- return await server.tool_manager.create_mcp_tool_async(tool_create=tool_create, mcp_server_name=mcp_server_name, actor=actor)
491
+ # For config-based servers, use the server name as ID since they don't have database IDs
492
+ mcp_server_id = mcp_server_name
493
+ return await server.tool_manager.create_mcp_tool_async(
494
+ tool_create=tool_create, mcp_server_name=mcp_server_name, mcp_server_id=mcp_server_id, actor=actor
495
+ )
479
496
 
480
497
  else:
481
498
  return await server.mcp_manager.add_tool_from_mcp_server(mcp_server_name=mcp_server_name, mcp_tool_name=mcp_tool_name, actor=actor)
@@ -608,35 +625,26 @@ async def delete_mcp_server_from_config(
608
625
  return [server.to_config() for server in all_servers]
609
626
 
610
627
 
611
- @router.post("/mcp/servers/test", response_model=List[MCPTool], operation_id="test_mcp_server")
628
+ @deprecated("Deprecated in favor of /mcp/servers/connect which handles OAuth flow via SSE stream")
629
+ @router.post("/mcp/servers/test", operation_id="test_mcp_server")
612
630
  async def test_mcp_server(
613
631
  request: Union[StdioServerConfig, SSEServerConfig, StreamableHTTPServerConfig] = Body(...),
632
+ server: SyncServer = Depends(get_letta_server),
633
+ actor_id: Optional[str] = Header(None, alias="user_id"),
614
634
  ):
615
635
  """
616
636
  Test connection to an MCP server without adding it.
617
- Returns the list of available tools if successful.
637
+ Returns the list of available tools if successful, or OAuth information if OAuth is required.
618
638
  """
619
639
  client = None
620
640
  try:
621
- # create a temporary MCP client based on the server type
622
- if request.type == MCPServerType.SSE:
623
- if not isinstance(request, SSEServerConfig):
624
- request = SSEServerConfig(**request.model_dump())
625
- client = AsyncSSEMCPClient(request)
626
- elif request.type == MCPServerType.STREAMABLE_HTTP:
627
- if not isinstance(request, StreamableHTTPServerConfig):
628
- request = StreamableHTTPServerConfig(**request.model_dump())
629
- client = AsyncStreamableHTTPMCPClient(request)
630
- elif request.type == MCPServerType.STDIO:
631
- if not isinstance(request, StdioServerConfig):
632
- request = StdioServerConfig(**request.model_dump())
633
- client = AsyncStdioMCPClient(request)
634
- else:
635
- raise ValueError(f"Invalid MCP server type: {request.type}")
641
+ actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
642
+ client = await server.mcp_manager.get_mcp_client(request, actor)
636
643
 
637
644
  await client.connect_to_server()
638
645
  tools = await client.list_tools()
639
- return tools
646
+
647
+ return {"status": "success", "tools": tools}
640
648
  except ConnectionError as e:
641
649
  raise HTTPException(
642
650
  status_code=400,
@@ -672,6 +680,160 @@ async def test_mcp_server(
672
680
  logger.warning(f"Error during MCP client cleanup: {cleanup_error}")
673
681
 
674
682
 
683
+ @router.post(
684
+ "/mcp/servers/connect",
685
+ response_model=None,
686
+ responses={
687
+ 200: {
688
+ "description": "Successful response",
689
+ "content": {
690
+ "text/event-stream": {"description": "Server-Sent Events stream"},
691
+ },
692
+ }
693
+ },
694
+ operation_id="connect_mcp_server",
695
+ )
696
+ async def connect_mcp_server(
697
+ request: Union[StdioServerConfig, SSEServerConfig, StreamableHTTPServerConfig] = Body(...),
698
+ server: SyncServer = Depends(get_letta_server),
699
+ actor_id: Optional[str] = Header(None, alias="user_id"),
700
+ ) -> StreamingResponse:
701
+ """
702
+ Connect to an MCP server with support for OAuth via SSE.
703
+ Returns a stream of events handling authorization state and exchange if OAuth is required.
704
+ """
705
+
706
+ async def oauth_stream_generator(
707
+ request: Union[StdioServerConfig, SSEServerConfig, StreamableHTTPServerConfig],
708
+ ) -> AsyncGenerator[str, None]:
709
+ client = None
710
+ oauth_provider = None
711
+ temp_client = None
712
+ connect_task = None
713
+
714
+ try:
715
+ # Acknolwedge connection attempt
716
+ yield oauth_stream_event(OauthStreamEvent.CONNECTION_ATTEMPT, server_name=request.server_name)
717
+
718
+ actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
719
+
720
+ # Create MCP client with respective transport type
721
+ try:
722
+ client = await server.mcp_manager.get_mcp_client(request, actor)
723
+ except ValueError as e:
724
+ yield oauth_stream_event(OauthStreamEvent.ERROR, message=str(e))
725
+ return
726
+
727
+ # Try normal connection first for flows that don't require OAuth
728
+ try:
729
+ await client.connect_to_server()
730
+ tools = await client.list_tools(serialize=True)
731
+ yield oauth_stream_event(OauthStreamEvent.SUCCESS, tools=tools)
732
+ return
733
+ except ConnectionError:
734
+ # TODO: jnjpng make this connection error check more specific to the 401 unauthorized error
735
+ if isinstance(client, AsyncStdioMCPClient):
736
+ logger.warning(f"OAuth not supported for stdio")
737
+ yield oauth_stream_event(OauthStreamEvent.ERROR, message=f"OAuth not supported for stdio")
738
+ return
739
+ # Continue to OAuth flow
740
+ logger.info(f"Attempting OAuth flow for {request}...")
741
+ except Exception as e:
742
+ yield oauth_stream_event(OauthStreamEvent.ERROR, message=f"Connection failed: {str(e)}")
743
+ return
744
+
745
+ # OAuth required, yield state to client to prepare to handle authorization URL
746
+ yield oauth_stream_event(OauthStreamEvent.OAUTH_REQUIRED, message="OAuth authentication required")
747
+
748
+ # Create OAuth session to persist the state of the OAuth flow
749
+ session_create = MCPOAuthSessionCreate(
750
+ server_url=request.server_url,
751
+ server_name=request.server_name,
752
+ user_id=actor.id,
753
+ organization_id=actor.organization_id,
754
+ )
755
+ oauth_session = await server.mcp_manager.create_oauth_session(session_create, actor)
756
+ session_id = oauth_session.id
757
+
758
+ # Create OAuth provider for the instance of the stream connection
759
+ # Note: Using the correct API path for the callback
760
+ # do not edit this this is the correct url
761
+ redirect_uri = f"http://localhost:8283/v1/tools/mcp/oauth/callback/{session_id}"
762
+ oauth_provider = await create_oauth_provider(session_id, request.server_url, redirect_uri, server.mcp_manager, actor)
763
+
764
+ # Get authorization URL by triggering OAuth flow
765
+ temp_client = None
766
+ try:
767
+ temp_client = await server.mcp_manager.get_mcp_client(request, actor, oauth_provider)
768
+
769
+ # Run connect_to_server in background to avoid blocking
770
+ # This will trigger the OAuth flow and the redirect_handler will save the authorization URL to database
771
+ connect_task = asyncio.create_task(temp_client.connect_to_server())
772
+
773
+ # Give the OAuth flow time to trigger and save the URL
774
+ await asyncio.sleep(1.0)
775
+
776
+ # Fetch the authorization URL from database and yield state to client to proceed with handling authorization URL
777
+ auth_session = await server.mcp_manager.get_oauth_session_by_id(session_id, actor)
778
+ if auth_session and auth_session.authorization_url:
779
+ yield oauth_stream_event(OauthStreamEvent.AUTHORIZATION_URL, url=auth_session.authorization_url, session_id=session_id)
780
+
781
+ except Exception as e:
782
+ logger.error(f"Error triggering OAuth flow: {e}")
783
+ yield oauth_stream_event(OauthStreamEvent.ERROR, message=f"Failed to trigger OAuth: {str(e)}")
784
+
785
+ # Clean up active resources
786
+ if connect_task and not connect_task.done():
787
+ connect_task.cancel()
788
+ try:
789
+ await connect_task
790
+ except asyncio.CancelledError:
791
+ pass
792
+ if temp_client:
793
+ try:
794
+ await temp_client.cleanup()
795
+ except Exception as cleanup_error:
796
+ logger.warning(f"Error during temp MCP client cleanup: {cleanup_error}")
797
+ return
798
+
799
+ # Wait for user authorization (with timeout), client should render loading state until user completes the flow and /mcp/oauth/callback/{session_id} is hit
800
+ yield oauth_stream_event(OauthStreamEvent.WAITING_FOR_AUTH, message="Waiting for user authorization...")
801
+
802
+ # Callback handler will poll for authorization code and state and update the OAuth session
803
+ await connect_task
804
+
805
+ tools = await temp_client.list_tools(serialize=True)
806
+
807
+ yield oauth_stream_event(OauthStreamEvent.SUCCESS, tools=tools)
808
+ return
809
+ except Exception as e:
810
+ detailed_error = drill_down_exception(e)
811
+ logger.error(f"Error in OAuth stream:\n{detailed_error}")
812
+ yield oauth_stream_event(OauthStreamEvent.ERROR, message=f"Internal error: {detailed_error}")
813
+ finally:
814
+ if connect_task and not connect_task.done():
815
+ connect_task.cancel()
816
+ try:
817
+ await connect_task
818
+ except asyncio.CancelledError:
819
+ pass
820
+ if client:
821
+ try:
822
+ await client.cleanup()
823
+ except Exception as cleanup_error:
824
+ detailed_error = drill_down_exception(cleanup_error)
825
+ logger.warning(f"Error during MCP client cleanup: {detailed_error}")
826
+ if temp_client:
827
+ try:
828
+ await temp_client.cleanup()
829
+ except Exception as cleanup_error:
830
+ # TODO: @jnjpng fix async cancel scope issue
831
+ # detailed_error = drill_down_exception(cleanup_error)
832
+ logger.warning(f"Aysnc cleanup confict during temp MCP client cleanup: {cleanup_error}")
833
+
834
+ return StreamingResponseWithStatusCode(oauth_stream_generator(request), media_type="text/event-stream")
835
+
836
+
675
837
  class CodeInput(BaseModel):
676
838
  code: str = Field(..., description="Python source code to parse for JSON schema")
677
839
 
@@ -693,6 +855,45 @@ async def generate_json_schema(
693
855
  raise HTTPException(status_code=400, detail=f"Failed to generate schema: {str(e)}")
694
856
 
695
857
 
858
+ # TODO: @jnjpng need to route this through cloud API for production
859
+ @router.get("/mcp/oauth/callback/{session_id}", operation_id="mcp_oauth_callback", response_class=HTMLResponse)
860
+ async def mcp_oauth_callback(
861
+ session_id: str,
862
+ code: Optional[str] = Query(None, description="OAuth authorization code"),
863
+ state: Optional[str] = Query(None, description="OAuth state parameter"),
864
+ error: Optional[str] = Query(None, description="OAuth error"),
865
+ error_description: Optional[str] = Query(None, description="OAuth error description"),
866
+ ):
867
+ """
868
+ Handle OAuth callback for MCP server authentication.
869
+ """
870
+ try:
871
+ oauth_session = MCPOAuthSession(session_id)
872
+
873
+ if error:
874
+ error_msg = f"OAuth error: {error}"
875
+ if error_description:
876
+ error_msg += f" - {error_description}"
877
+ await oauth_session.update_session_status(OAuthSessionStatus.ERROR)
878
+ return {"status": "error", "message": error_msg}
879
+
880
+ if not code or not state:
881
+ await oauth_session.update_session_status(OAuthSessionStatus.ERROR)
882
+ return {"status": "error", "message": "Missing authorization code or state"}
883
+
884
+ # Store authorization code
885
+ success = await oauth_session.store_authorization_code(code, state)
886
+ if not success:
887
+ await oauth_session.update_session_status(OAuthSessionStatus.ERROR)
888
+ return {"status": "error", "message": "Invalid state parameter"}
889
+
890
+ return HTMLResponse(content=get_oauth_success_html(), status_code=200)
891
+
892
+ except Exception as e:
893
+ logger.error(f"OAuth callback error: {e}")
894
+ return {"status": "error", "message": f"OAuth callback failed: {str(e)}"}
895
+
896
+
696
897
  class GenerateToolInput(BaseModel):
697
898
  tool_name: str = Field(..., description="Name of the tool to generate code for")
698
899
  prompt: str = Field(..., description="User prompt to generate code")
@@ -15,6 +15,8 @@ from letta.constants import (
15
15
  BASE_TOOLS,
16
16
  BASE_VOICE_SLEEPTIME_CHAT_TOOLS,
17
17
  BASE_VOICE_SLEEPTIME_TOOLS,
18
+ DEFAULT_CORE_MEMORY_SOURCE_CHAR_LIMIT,
19
+ DEFAULT_MAX_FILES_OPEN,
18
20
  DEFAULT_TIMEZONE,
19
21
  DEPRECATED_LETTA_TOOLS,
20
22
  FILES_TOOLS,
@@ -1644,7 +1646,7 @@ class AgentManager:
1644
1646
 
1645
1647
  # note: we only update the system prompt if the core memory is changed
1646
1648
  # this means that the archival/recall memory statistics may be someout out of date
1647
- curr_memory_str = agent_state.memory.compile(
1649
+ curr_memory_str = await agent_state.memory.compile_async(
1648
1650
  sources=agent_state.sources,
1649
1651
  tool_usage_rules=tool_rules_solver.compile_tool_rule_prompts(),
1650
1652
  max_files_open=agent_state.max_files_open,
@@ -1834,14 +1836,12 @@ class AgentManager:
1834
1836
  agent_state = await self.get_agent_by_id_async(agent_id=agent_id, actor=actor, include_relationships=["memory", "sources"])
1835
1837
  system_message = await self.message_manager.get_message_by_id_async(message_id=agent_state.message_ids[0], actor=actor)
1836
1838
  temp_tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
1837
- if (
1838
- new_memory.compile(
1839
- sources=agent_state.sources,
1840
- tool_usage_rules=temp_tool_rules_solver.compile_tool_rule_prompts(),
1841
- max_files_open=agent_state.max_files_open,
1842
- )
1843
- not in system_message.content[0].text
1844
- ):
1839
+ new_memory_str = await new_memory.compile_async(
1840
+ sources=agent_state.sources,
1841
+ tool_usage_rules=temp_tool_rules_solver.compile_tool_rule_prompts(),
1842
+ max_files_open=agent_state.max_files_open,
1843
+ )
1844
+ if new_memory_str not in system_message.content[0].text:
1845
1845
  # update the blocks (LRW) in the DB
1846
1846
  for label in agent_state.memory.list_block_labels():
1847
1847
  updated_value = new_memory.get_block(label).value
@@ -3169,6 +3169,12 @@ class AgentManager:
3169
3169
  if max_files is None:
3170
3170
  max_files = default_max_files
3171
3171
 
3172
+ # FINAL fallback: ensure neither is None (should never happen, but just in case)
3173
+ if per_file_limit is None:
3174
+ per_file_limit = DEFAULT_CORE_MEMORY_SOURCE_CHAR_LIMIT
3175
+ if max_files is None:
3176
+ max_files = DEFAULT_MAX_FILES_OPEN
3177
+
3172
3178
  return per_file_limit, max_files
3173
3179
 
3174
3180
  @enforce_types
@@ -1,6 +1,7 @@
1
1
  from datetime import datetime, timezone
2
2
  from typing import Dict, List
3
3
 
4
+ from letta.constants import MCP_TOOL_TAG_NAME_PREFIX
4
5
  from letta.errors import AgentFileExportError, AgentFileImportError
5
6
  from letta.helpers.pinecone_utils import should_use_pinecone
6
7
  from letta.log import get_logger
@@ -13,6 +14,7 @@ from letta.schemas.agent_file import (
13
14
  FileSchema,
14
15
  GroupSchema,
15
16
  ImportResult,
17
+ MCPServerSchema,
16
18
  MessageSchema,
17
19
  SourceSchema,
18
20
  ToolSchema,
@@ -20,6 +22,7 @@ from letta.schemas.agent_file import (
20
22
  from letta.schemas.block import Block
21
23
  from letta.schemas.enums import FileProcessingStatus
22
24
  from letta.schemas.file import FileMetadata
25
+ from letta.schemas.mcp import MCPServer
23
26
  from letta.schemas.message import Message
24
27
  from letta.schemas.source import Source
25
28
  from letta.schemas.tool import Tool
@@ -92,7 +95,7 @@ class AgentSerializationManager:
92
95
  ToolSchema.__id_prefix__: 0,
93
96
  MessageSchema.__id_prefix__: 0,
94
97
  FileAgentSchema.__id_prefix__: 0,
95
- # MCPServerSchema.__id_prefix__: 0,
98
+ MCPServerSchema.__id_prefix__: 0,
96
99
  }
97
100
 
98
101
  def _reset_state(self):
@@ -258,6 +261,53 @@ class AgentSerializationManager:
258
261
  file_schema.source_id = self._map_db_to_file_id(file_metadata.source_id, SourceSchema.__id_prefix__, allow_new=False)
259
262
  return file_schema
260
263
 
264
+ async def _extract_unique_mcp_servers(self, tools: List, actor: User) -> List:
265
+ """Extract unique MCP servers from tools based on metadata, using server_id if available, otherwise falling back to server_name."""
266
+ mcp_server_ids = set()
267
+ mcp_server_names = set()
268
+ for tool in tools:
269
+ # Check if tool has MCP metadata
270
+ if tool.metadata_ and MCP_TOOL_TAG_NAME_PREFIX in tool.metadata_:
271
+ mcp_metadata = tool.metadata_[MCP_TOOL_TAG_NAME_PREFIX]
272
+ # TODO: @jnjpng clean this up once we fully migrate to server_id being the main identifier
273
+ if "server_id" in mcp_metadata:
274
+ mcp_server_ids.add(mcp_metadata["server_id"])
275
+ elif "server_name" in mcp_metadata:
276
+ mcp_server_names.add(mcp_metadata["server_name"])
277
+
278
+ # Fetch MCP servers by ID
279
+ mcp_servers = []
280
+ fetched_server_ids = set()
281
+ if mcp_server_ids:
282
+ try:
283
+ mcp_servers = await self.mcp_manager.get_mcp_servers_by_ids(list(mcp_server_ids), actor)
284
+ fetched_server_ids.update([mcp_server.id for mcp_server in mcp_servers])
285
+ except Exception as e:
286
+ logger.warning(f"Failed to fetch MCP servers by IDs {mcp_server_ids}: {e}")
287
+
288
+ # Fetch MCP servers by name if not already fetched by ID
289
+ if mcp_server_names:
290
+ for server_name in mcp_server_names:
291
+ try:
292
+ mcp_server = await self.mcp_manager.get_mcp_server(server_name, actor)
293
+ if mcp_server and mcp_server.id not in fetched_server_ids:
294
+ mcp_servers.append(mcp_server)
295
+ except Exception as e:
296
+ logger.warning(f"Failed to fetch MCP server by name {server_name}: {e}")
297
+
298
+ return mcp_servers
299
+
300
+ def _convert_mcp_server_to_schema(self, mcp_server: MCPServer) -> MCPServerSchema:
301
+ """Convert MCPServer to MCPServerSchema with ID remapping and auth scrubbing"""
302
+ try:
303
+ mcp_file_id = self._map_db_to_file_id(mcp_server.id, MCPServerSchema.__id_prefix__, allow_new=False)
304
+ mcp_schema = MCPServerSchema.from_mcp_server(mcp_server)
305
+ mcp_schema.id = mcp_file_id
306
+ return mcp_schema
307
+ except Exception as e:
308
+ logger.error(f"Failed to convert MCP server {mcp_server.id}: {e}")
309
+ raise
310
+
261
311
  async def export(self, agent_ids: List[str], actor: User) -> AgentFileSchema:
262
312
  """
263
313
  Export agents and their related entities to AgentFileSchema format.
@@ -289,6 +339,13 @@ class AgentSerializationManager:
289
339
  tool_set = self._extract_unique_tools(agent_states)
290
340
  block_set = self._extract_unique_blocks(agent_states)
291
341
 
342
+ # Extract MCP servers from tools BEFORE conversion (must be done before ID mapping)
343
+ mcp_server_set = await self._extract_unique_mcp_servers(tool_set, actor)
344
+
345
+ # Map MCP server IDs before converting schemas
346
+ for mcp_server in mcp_server_set:
347
+ self._map_db_to_file_id(mcp_server.id, MCPServerSchema.__id_prefix__)
348
+
292
349
  # Extract sources and files from agent states BEFORE conversion (with caching)
293
350
  source_set, file_set = await self._extract_unique_sources_and_files_from_agents(agent_states, actor, files_agents_cache)
294
351
 
@@ -301,6 +358,7 @@ class AgentSerializationManager:
301
358
  block_schemas = [self._convert_block_to_schema(block) for block in block_set]
302
359
  source_schemas = [self._convert_source_to_schema(source) for source in source_set]
303
360
  file_schemas = [self._convert_file_to_schema(file_metadata) for file_metadata in file_set]
361
+ mcp_server_schemas = [self._convert_mcp_server_to_schema(mcp_server) for mcp_server in mcp_server_set]
304
362
 
305
363
  logger.info(f"Exporting {len(agent_ids)} agents to agent file format")
306
364
 
@@ -312,7 +370,7 @@ class AgentSerializationManager:
312
370
  files=file_schemas,
313
371
  sources=source_schemas,
314
372
  tools=tool_schemas,
315
- # mcp_servers=[], # TODO: Extract and convert MCP servers
373
+ mcp_servers=mcp_server_schemas,
316
374
  metadata={"revision_id": await get_latest_alembic_revision()},
317
375
  created_at=datetime.now(timezone.utc),
318
376
  )
@@ -359,7 +417,20 @@ class AgentSerializationManager:
359
417
  # in-memory cache for file metadata to avoid repeated db calls
360
418
  file_metadata_cache = {} # Maps database file ID to FileMetadata
361
419
 
362
- # 1. Create tools first (no dependencies) - using bulk upsert for efficiency
420
+ # 1. Create MCP servers first (tools depend on them)
421
+ if schema.mcp_servers:
422
+ for mcp_server_schema in schema.mcp_servers:
423
+ server_data = mcp_server_schema.model_dump(exclude={"id"})
424
+ filtered_server_data = self._filter_dict_for_model(server_data, MCPServer)
425
+ create_schema = MCPServer(**filtered_server_data)
426
+
427
+ # Note: We don't have auth info from export, so the user will need to re-configure auth.
428
+ # TODO: @jnjpng store metadata about obfuscated metadata to surface to the user
429
+ created_mcp_server = await self.mcp_manager.create_or_update_mcp_server(create_schema, actor)
430
+ file_to_db_ids[mcp_server_schema.id] = created_mcp_server.id
431
+ imported_count += 1
432
+
433
+ # 2. Create tools (may depend on MCP servers) - using bulk upsert for efficiency
363
434
  if schema.tools:
364
435
  # convert tool schemas to pydantic tools
365
436
  pydantic_tools = []
@@ -559,6 +630,7 @@ class AgentSerializationManager:
559
630
  (schema.files, FileSchema.__id_prefix__),
560
631
  (schema.sources, SourceSchema.__id_prefix__),
561
632
  (schema.tools, ToolSchema.__id_prefix__),
633
+ (schema.mcp_servers, MCPServerSchema.__id_prefix__),
562
634
  ]
563
635
 
564
636
  for entities, expected_prefix in entity_checks:
@@ -601,6 +673,7 @@ class AgentSerializationManager:
601
673
  ("files", schema.files),
602
674
  ("sources", schema.sources),
603
675
  ("tools", schema.tools),
676
+ ("mcp_servers", schema.mcp_servers),
604
677
  ]
605
678
 
606
679
  for entity_type, entities in entity_collections:
@@ -705,3 +778,11 @@ class AgentSerializationManager:
705
778
  raise AgentFileImportError(f"Schema validation failed: {'; '.join(errors)}")
706
779
 
707
780
  logger.info("Schema validation passed")
781
+
782
+ def _filter_dict_for_model(self, data: dict, model_cls):
783
+ """Filter a dictionary to only include keys that are in the model fields"""
784
+ try:
785
+ allowed = model_cls.model_fields.keys() # Pydantic v2
786
+ except AttributeError:
787
+ allowed = model_cls.__fields__.keys() # Pydantic v1
788
+ return {k: v for k, v in data.items() if k in allowed}
@@ -178,6 +178,7 @@ class BlockManager:
178
178
  template_name: Optional[str] = None,
179
179
  identity_id: Optional[str] = None,
180
180
  identifier_keys: Optional[List[str]] = None,
181
+ project_id: Optional[str] = None,
181
182
  before: Optional[str] = None,
182
183
  after: Optional[str] = None,
183
184
  limit: Optional[int] = 50,
@@ -210,6 +211,9 @@ class BlockManager:
210
211
  if template_name:
211
212
  query = query.where(BlockModel.template_name == template_name)
212
213
 
214
+ if project_id:
215
+ query = query.where(BlockModel.project_id == project_id)
216
+
213
217
  needs_distinct = False
214
218
 
215
219
  if identifier_keys:
@@ -151,7 +151,8 @@ class FileManager:
151
151
  Enforces state transition rules (when enforce_state_transitions=True):
152
152
  - PENDING -> PARSING -> EMBEDDING -> COMPLETED (normal flow)
153
153
  - Any non-terminal state -> ERROR
154
- - ERROR and COMPLETED are terminal (no transitions allowed)
154
+ - Same-state transitions are allowed (e.g., EMBEDDING -> EMBEDDING)
155
+ - ERROR and COMPLETED are terminal (no status transitions allowed, metadata updates blocked)
155
156
 
156
157
  Args:
157
158
  file_id: ID of the file to update
@@ -196,22 +197,31 @@ class FileManager:
196
197
  ]
197
198
 
198
199
  # only add state transition validation if enforce_state_transitions is True
199
- if enforce_state_transitions:
200
- # prevent updates to terminal states (ERROR, COMPLETED)
200
+ if enforce_state_transitions and processing_status is not None:
201
+ # enforce specific transitions based on target status
202
+ if processing_status == FileProcessingStatus.PARSING:
203
+ where_conditions.append(
204
+ FileMetadataModel.processing_status.in_([FileProcessingStatus.PENDING, FileProcessingStatus.PARSING])
205
+ )
206
+ elif processing_status == FileProcessingStatus.EMBEDDING:
207
+ where_conditions.append(
208
+ FileMetadataModel.processing_status.in_([FileProcessingStatus.PARSING, FileProcessingStatus.EMBEDDING])
209
+ )
210
+ elif processing_status == FileProcessingStatus.COMPLETED:
211
+ where_conditions.append(
212
+ FileMetadataModel.processing_status.in_([FileProcessingStatus.EMBEDDING, FileProcessingStatus.COMPLETED])
213
+ )
214
+ elif processing_status == FileProcessingStatus.ERROR:
215
+ # ERROR can be set from any non-terminal state
216
+ where_conditions.append(
217
+ FileMetadataModel.processing_status.notin_([FileProcessingStatus.ERROR, FileProcessingStatus.COMPLETED])
218
+ )
219
+ elif enforce_state_transitions and processing_status is None:
220
+ # If only updating metadata fields (not status), prevent updates to terminal states
201
221
  where_conditions.append(
202
222
  FileMetadataModel.processing_status.notin_([FileProcessingStatus.ERROR, FileProcessingStatus.COMPLETED])
203
223
  )
204
224
 
205
- if processing_status is not None:
206
- # enforce specific transitions based on target status
207
- if processing_status == FileProcessingStatus.PARSING:
208
- where_conditions.append(FileMetadataModel.processing_status == FileProcessingStatus.PENDING)
209
- elif processing_status == FileProcessingStatus.EMBEDDING:
210
- where_conditions.append(FileMetadataModel.processing_status == FileProcessingStatus.PARSING)
211
- elif processing_status == FileProcessingStatus.COMPLETED:
212
- where_conditions.append(FileMetadataModel.processing_status == FileProcessingStatus.EMBEDDING)
213
- # ERROR can be set from any non-terminal state (already handled by terminal check above)
214
-
215
225
  # fast in-place update with state validation
216
226
  stmt = (
217
227
  update(FileMetadataModel)
@@ -69,6 +69,15 @@ class FileProcessor:
69
69
  raise ValueError("No chunks created from text")
70
70
  all_chunks.extend(chunks)
71
71
 
72
+ # Update with chunks length
73
+ file_metadata = await self.file_manager.update_file_status(
74
+ file_id=file_metadata.id,
75
+ actor=self.actor,
76
+ processing_status=FileProcessingStatus.EMBEDDING,
77
+ total_chunks=len(all_chunks),
78
+ chunks_embedded=0,
79
+ )
80
+
72
81
  all_passages = await self.embedder.generate_embedded_passages(
73
82
  file_id=file_metadata.id,
74
83
  source_id=source_id,
@@ -177,9 +186,7 @@ class FileProcessor:
177
186
  "file_processor.ocr_completed",
178
187
  {"filename": filename, "pages_extracted": len(ocr_response.pages), "text_length": len(raw_markdown_text)},
179
188
  )
180
- file_metadata = await self.file_manager.update_file_status(
181
- file_id=file_metadata.id, actor=self.actor, processing_status=FileProcessingStatus.EMBEDDING
182
- )
189
+
183
190
  file_metadata = await self.file_manager.upsert_file_content(file_id=file_metadata.id, text=raw_markdown_text, actor=self.actor)
184
191
 
185
192
  await self.agent_manager.insert_file_into_context_windows(
@@ -241,13 +248,7 @@ class FileProcessor:
241
248
  file_id=file_metadata.id,
242
249
  actor=self.actor,
243
250
  processing_status=FileProcessingStatus.COMPLETED,
244
- )
245
- else:
246
- await self.file_manager.update_file_status(
247
- file_id=file_metadata.id,
248
- actor=self.actor,
249
- total_chunks=len(all_passages),
250
- chunks_embedded=0,
251
+ chunks_embedded=len(all_passages),
251
252
  )
252
253
 
253
254
  return all_passages
@@ -286,6 +287,7 @@ class FileProcessor:
286
287
  document_annotation=None,
287
288
  )
288
289
 
290
+ # TODO: The file state machine here is kind of out of date, we need to match with the correct one above
289
291
  @trace_method
290
292
  async def process_imported_file(self, file_metadata: FileMetadata, source_id: str) -> List[Passage]:
291
293
  """Process an imported file that already has content - skip OCR, do chunking/embedding"""