letta-nightly 0.11.3.dev20250820104219__py3-none-any.whl → 0.11.4.dev20250820213507__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. letta/__init__.py +1 -1
  2. letta/agents/helpers.py +4 -0
  3. letta/agents/letta_agent.py +142 -5
  4. letta/constants.py +10 -7
  5. letta/data_sources/connectors.py +70 -53
  6. letta/embeddings.py +3 -240
  7. letta/errors.py +28 -0
  8. letta/functions/function_sets/base.py +4 -4
  9. letta/functions/functions.py +287 -32
  10. letta/functions/mcp_client/types.py +11 -0
  11. letta/functions/schema_validator.py +187 -0
  12. letta/functions/typescript_parser.py +196 -0
  13. letta/helpers/datetime_helpers.py +8 -4
  14. letta/helpers/tool_execution_helper.py +25 -2
  15. letta/llm_api/anthropic_client.py +23 -18
  16. letta/llm_api/azure_client.py +73 -0
  17. letta/llm_api/bedrock_client.py +8 -4
  18. letta/llm_api/google_vertex_client.py +14 -5
  19. letta/llm_api/llm_api_tools.py +2 -217
  20. letta/llm_api/llm_client.py +15 -1
  21. letta/llm_api/llm_client_base.py +32 -1
  22. letta/llm_api/openai.py +1 -0
  23. letta/llm_api/openai_client.py +18 -28
  24. letta/llm_api/together_client.py +55 -0
  25. letta/orm/provider.py +1 -0
  26. letta/orm/step_metrics.py +40 -1
  27. letta/otel/db_pool_monitoring.py +1 -1
  28. letta/schemas/agent.py +3 -4
  29. letta/schemas/agent_file.py +2 -0
  30. letta/schemas/block.py +11 -5
  31. letta/schemas/embedding_config.py +4 -5
  32. letta/schemas/enums.py +1 -1
  33. letta/schemas/job.py +2 -3
  34. letta/schemas/llm_config.py +79 -7
  35. letta/schemas/mcp.py +0 -24
  36. letta/schemas/message.py +0 -108
  37. letta/schemas/openai/chat_completion_request.py +1 -0
  38. letta/schemas/providers/__init__.py +0 -2
  39. letta/schemas/providers/anthropic.py +106 -8
  40. letta/schemas/providers/azure.py +102 -8
  41. letta/schemas/providers/base.py +10 -3
  42. letta/schemas/providers/bedrock.py +28 -16
  43. letta/schemas/providers/letta.py +3 -3
  44. letta/schemas/providers/ollama.py +2 -12
  45. letta/schemas/providers/openai.py +4 -4
  46. letta/schemas/providers/together.py +14 -2
  47. letta/schemas/sandbox_config.py +2 -1
  48. letta/schemas/tool.py +46 -22
  49. letta/server/rest_api/routers/v1/agents.py +179 -38
  50. letta/server/rest_api/routers/v1/folders.py +13 -8
  51. letta/server/rest_api/routers/v1/providers.py +10 -3
  52. letta/server/rest_api/routers/v1/sources.py +14 -8
  53. letta/server/rest_api/routers/v1/steps.py +17 -1
  54. letta/server/rest_api/routers/v1/tools.py +96 -5
  55. letta/server/rest_api/streaming_response.py +91 -45
  56. letta/server/server.py +27 -38
  57. letta/services/agent_manager.py +92 -20
  58. letta/services/agent_serialization_manager.py +11 -7
  59. letta/services/context_window_calculator/context_window_calculator.py +40 -2
  60. letta/services/helpers/agent_manager_helper.py +73 -12
  61. letta/services/mcp_manager.py +109 -15
  62. letta/services/passage_manager.py +28 -109
  63. letta/services/provider_manager.py +24 -0
  64. letta/services/step_manager.py +68 -0
  65. letta/services/summarizer/summarizer.py +1 -4
  66. letta/services/tool_executor/core_tool_executor.py +1 -1
  67. letta/services/tool_executor/sandbox_tool_executor.py +26 -9
  68. letta/services/tool_manager.py +82 -5
  69. letta/services/tool_sandbox/base.py +3 -11
  70. letta/services/tool_sandbox/modal_constants.py +17 -0
  71. letta/services/tool_sandbox/modal_deployment_manager.py +242 -0
  72. letta/services/tool_sandbox/modal_sandbox.py +218 -3
  73. letta/services/tool_sandbox/modal_sandbox_v2.py +429 -0
  74. letta/services/tool_sandbox/modal_version_manager.py +273 -0
  75. letta/services/tool_sandbox/safe_pickle.py +193 -0
  76. letta/settings.py +5 -3
  77. letta/templates/sandbox_code_file.py.j2 +2 -4
  78. letta/templates/sandbox_code_file_async.py.j2 +2 -4
  79. letta/utils.py +1 -1
  80. {letta_nightly-0.11.3.dev20250820104219.dist-info → letta_nightly-0.11.4.dev20250820213507.dist-info}/METADATA +2 -2
  81. {letta_nightly-0.11.3.dev20250820104219.dist-info → letta_nightly-0.11.4.dev20250820213507.dist-info}/RECORD +84 -81
  82. letta/llm_api/anthropic.py +0 -1206
  83. letta/llm_api/aws_bedrock.py +0 -104
  84. letta/llm_api/azure_openai.py +0 -118
  85. letta/llm_api/azure_openai_constants.py +0 -11
  86. letta/llm_api/cohere.py +0 -391
  87. letta/schemas/providers/cohere.py +0 -18
  88. {letta_nightly-0.11.3.dev20250820104219.dist-info → letta_nightly-0.11.4.dev20250820213507.dist-info}/LICENSE +0 -0
  89. {letta_nightly-0.11.3.dev20250820104219.dist-info → letta_nightly-0.11.4.dev20250820213507.dist-info}/WHEEL +0 -0
  90. {letta_nightly-0.11.3.dev20250820104219.dist-info → letta_nightly-0.11.4.dev20250820213507.dist-info}/entry_points.txt +0 -0
@@ -21,7 +21,26 @@ class ContextWindowCalculator:
21
21
 
22
22
  @staticmethod
23
23
  def extract_system_components(system_message: str) -> Tuple[str, str, str]:
24
- """Extract system prompt, core memory, and external memory summary from system message"""
24
+ """
25
+ Extract structured components from a formatted system message.
26
+
27
+ Parses the system message to extract three distinct sections marked by XML-style tags:
28
+ - base_instructions: The core system prompt and agent instructions
29
+ - memory_blocks: The agent's core memory (persistent context)
30
+ - memory_metadata: Metadata about external memory systems
31
+
32
+ Args:
33
+ system_message: A formatted system message containing XML-style section markers
34
+
35
+ Returns:
36
+ A tuple of (system_prompt, core_memory, external_memory_summary)
37
+ Each component will be an empty string if its section is not found
38
+
39
+ Note:
40
+ This method assumes a specific format with sections delimited by:
41
+ <base_instructions>, <memory_blocks>, and <memory_metadata> tags.
42
+ The extraction is position-based and expects sections in this order.
43
+ """
25
44
  base_start = system_message.find("<base_instructions>")
26
45
  memory_blocks_start = system_message.find("<memory_blocks>")
27
46
  metadata_start = system_message.find("<memory_metadata>")
@@ -43,7 +62,26 @@ class ContextWindowCalculator:
43
62
 
44
63
  @staticmethod
45
64
  def extract_summary_memory(messages: List[Any]) -> Tuple[Optional[str], int]:
46
- """Extract summary memory if present and return starting index for real messages"""
65
+ """
66
+ Extract summary memory from the message list if present.
67
+
68
+ Summary memory is a special message injected at position 1 (after system message)
69
+ that contains a condensed summary of previous conversation history. This is used
70
+ when the full conversation history doesn't fit in the context window.
71
+
72
+ Args:
73
+ messages: List of message objects to search for summary memory
74
+
75
+ Returns:
76
+ A tuple of (summary_text, start_index) where:
77
+ - summary_text: The extracted summary content, or None if not found
78
+ - start_index: Index where actual conversation messages begin (1 or 2)
79
+
80
+ Detection Logic:
81
+ Looks for a user message at index 1 containing the phrase
82
+ "The following is a summary of the previous" which indicates
83
+ it's a summarized conversation history rather than a real user message.
84
+ """
47
85
  if (
48
86
  len(messages) > 1
49
87
  and messages[1].role == MessageRole.user
@@ -20,9 +20,9 @@ from letta.constants import (
20
20
  MULTI_AGENT_TOOLS,
21
21
  STRUCTURED_OUTPUT_MODELS,
22
22
  )
23
- from letta.embeddings import embedding_model
24
23
  from letta.helpers import ToolRulesSolver
25
24
  from letta.helpers.datetime_helpers import format_datetime, get_local_time, get_local_time_fast
25
+ from letta.llm_api.llm_client import LLMClient
26
26
  from letta.orm.agent import Agent as AgentModel
27
27
  from letta.orm.agents_tags import AgentsTags
28
28
  from letta.orm.archives_agents import ArchivesAgents
@@ -156,7 +156,25 @@ def _process_tags(agent: "AgentModel", tags: List[str], replace=True):
156
156
  agent.tags.extend([tag for tag in new_tags if tag.tag not in existing_tags])
157
157
 
158
158
 
159
- def derive_system_message(agent_type: AgentType, enable_sleeptime: Optional[bool] = None, system: Optional[str] = None):
159
+ def derive_system_message(agent_type: AgentType, enable_sleeptime: Optional[bool] = None, system: Optional[str] = None) -> str:
160
+ """
161
+ Derive the appropriate system message based on agent type and configuration.
162
+
163
+ This function determines which system prompt template to use based on the
164
+ agent's type and whether sleeptime functionality is enabled. If a custom
165
+ system message is provided, it returns that instead.
166
+
167
+ Args:
168
+ agent_type: The type of agent (e.g., memgpt_agent, sleeptime_agent, react_agent)
169
+ enable_sleeptime: Whether sleeptime tools should be available (affects prompt choice)
170
+ system: Optional custom system message to use instead of defaults
171
+
172
+ Returns:
173
+ The system message string appropriate for the agent configuration
174
+
175
+ Raises:
176
+ ValueError: If an invalid or unsupported agent type is provided
177
+ """
160
178
  if system is None:
161
179
  # TODO: don't hardcode
162
180
 
@@ -204,8 +222,33 @@ def compile_memory_metadata_block(
204
222
  memory_edit_timestamp: datetime,
205
223
  timezone: str,
206
224
  previous_message_count: int = 0,
207
- archival_memory_size: int = 0,
225
+ archival_memory_size: Optional[int] = 0,
208
226
  ) -> str:
227
+ """
228
+ Generate a memory metadata block for the agent's system prompt.
229
+
230
+ This creates a structured metadata section that informs the agent about
231
+ the current state of its memory systems, including timing information
232
+ and memory counts. This helps the agent understand what information
233
+ is available through its tools.
234
+
235
+ Args:
236
+ memory_edit_timestamp: When memory blocks were last modified
237
+ timezone: The timezone to use for formatting timestamps (e.g., 'America/Los_Angeles')
238
+ previous_message_count: Number of messages in recall memory (conversation history)
239
+ archival_memory_size: Number of items in archival memory (long-term storage)
240
+
241
+ Returns:
242
+ A formatted string containing the memory metadata block with XML-style tags
243
+
244
+ Example Output:
245
+ <memory_metadata>
246
+ - The current time is: 2024-01-15 10:30 AM PST
247
+ - Memory blocks were last modified: 2024-01-15 09:00 AM PST
248
+ - 42 previous messages between you and the user are stored in recall memory (use tools to access them)
249
+ - 156 total memories you created are stored in archival memory (use tools to access them)
250
+ </memory_metadata>
251
+ """
209
252
  # Put the timestamp in the local timezone (mimicking get_local_time())
210
253
  timestamp_str = format_datetime(memory_edit_timestamp, timezone)
211
254
 
@@ -939,7 +982,7 @@ def _apply_relationship_filters(query, include_relationships: Optional[List[str]
939
982
  return query
940
983
 
941
984
 
942
- def build_passage_query(
985
+ async def build_passage_query(
943
986
  actor: User,
944
987
  agent_id: Optional[str] = None,
945
988
  file_id: Optional[str] = None,
@@ -963,8 +1006,14 @@ def build_passage_query(
963
1006
  if embed_query:
964
1007
  assert embedding_config is not None, "embedding_config must be specified for vector search"
965
1008
  assert query_text is not None, "query_text must be specified for vector search"
966
- embedded_text = embedding_model(embedding_config).get_text_embedding(query_text)
967
- embedded_text = np.array(embedded_text)
1009
+
1010
+ # Use the new LLMClient for embeddings
1011
+ embedding_client = LLMClient.create(
1012
+ provider_type=embedding_config.embedding_endpoint_type,
1013
+ actor=actor,
1014
+ )
1015
+ embeddings = await embedding_client.request_embeddings([query_text], embedding_config)
1016
+ embedded_text = np.array(embeddings[0])
968
1017
  embedded_text = np.pad(embedded_text, (0, MAX_EMBEDDING_DIM - embedded_text.shape[0]), mode="constant").tolist()
969
1018
 
970
1019
  # Start with base query for source passages
@@ -1150,7 +1199,7 @@ def build_passage_query(
1150
1199
  return main_query
1151
1200
 
1152
1201
 
1153
- def build_source_passage_query(
1202
+ async def build_source_passage_query(
1154
1203
  actor: User,
1155
1204
  agent_id: Optional[str] = None,
1156
1205
  file_id: Optional[str] = None,
@@ -1171,8 +1220,14 @@ def build_source_passage_query(
1171
1220
  if embed_query:
1172
1221
  assert embedding_config is not None, "embedding_config must be specified for vector search"
1173
1222
  assert query_text is not None, "query_text must be specified for vector search"
1174
- embedded_text = embedding_model(embedding_config).get_text_embedding(query_text)
1175
- embedded_text = np.array(embedded_text)
1223
+
1224
+ # Use the new LLMClient for embeddings
1225
+ embedding_client = LLMClient.create(
1226
+ provider_type=embedding_config.embedding_endpoint_type,
1227
+ actor=actor,
1228
+ )
1229
+ embeddings = await embedding_client.request_embeddings([query_text], embedding_config)
1230
+ embedded_text = np.array(embeddings[0])
1176
1231
  embedded_text = np.pad(embedded_text, (0, MAX_EMBEDDING_DIM - embedded_text.shape[0]), mode="constant").tolist()
1177
1232
 
1178
1233
  # Base query for source passages
@@ -1248,7 +1303,7 @@ def build_source_passage_query(
1248
1303
  return query
1249
1304
 
1250
1305
 
1251
- def build_agent_passage_query(
1306
+ async def build_agent_passage_query(
1252
1307
  actor: User,
1253
1308
  agent_id: str, # Required for agent passages
1254
1309
  query_text: Optional[str] = None,
@@ -1267,8 +1322,14 @@ def build_agent_passage_query(
1267
1322
  if embed_query:
1268
1323
  assert embedding_config is not None, "embedding_config must be specified for vector search"
1269
1324
  assert query_text is not None, "query_text must be specified for vector search"
1270
- embedded_text = embedding_model(embedding_config).get_text_embedding(query_text)
1271
- embedded_text = np.array(embedded_text)
1325
+
1326
+ # Use the new LLMClient for embeddings
1327
+ embedding_client = LLMClient.create(
1328
+ provider_type=embedding_config.embedding_endpoint_type,
1329
+ actor=actor,
1330
+ )
1331
+ embeddings = await embedding_client.request_embeddings([query_text], embedding_config)
1332
+ embedded_text = np.array(embeddings[0])
1272
1333
  embedded_text = np.pad(embedded_text, (0, MAX_EMBEDDING_DIM - embedded_text.shape[0]), mode="constant").tolist()
1273
1334
 
1274
1335
  # Base query for agent passages - join through archives_agents
@@ -6,11 +6,19 @@ from datetime import datetime, timedelta
6
6
  from typing import Any, Dict, List, Optional, Tuple, Union
7
7
 
8
8
  from fastapi import HTTPException
9
- from sqlalchemy import null
9
+ from sqlalchemy import delete, null
10
10
  from starlette.requests import Request
11
11
 
12
12
  import letta.constants as constants
13
- from letta.functions.mcp_client.types import MCPServerType, MCPTool, SSEServerConfig, StdioServerConfig, StreamableHTTPServerConfig
13
+ from letta.functions.mcp_client.types import (
14
+ MCPServerType,
15
+ MCPTool,
16
+ MCPToolHealth,
17
+ SSEServerConfig,
18
+ StdioServerConfig,
19
+ StreamableHTTPServerConfig,
20
+ )
21
+ from letta.functions.schema_validator import validate_complete_json_schema
14
22
  from letta.log import get_logger
15
23
  from letta.orm.errors import NoResultFound
16
24
  from letta.orm.mcp_oauth import MCPOAuth, OAuthSessionStatus
@@ -49,6 +57,7 @@ class MCPManager:
49
57
  @enforce_types
50
58
  async def list_mcp_server_tools(self, mcp_server_name: str, actor: PydanticUser) -> List[MCPTool]:
51
59
  """Get a list of all tools for a specific MCP server."""
60
+ mcp_client = None
52
61
  try:
53
62
  mcp_server_id = await self.get_mcp_server_id_by_name(mcp_server_name, actor=actor)
54
63
  mcp_config = await self.get_mcp_server_by_id_async(mcp_server_id, actor=actor)
@@ -58,6 +67,13 @@ class MCPManager:
58
67
 
59
68
  # list tools
60
69
  tools = await mcp_client.list_tools()
70
+
71
+ # Add health information to each tool
72
+ for tool in tools:
73
+ if tool.inputSchema:
74
+ health_status, reasons = validate_complete_json_schema(tool.inputSchema)
75
+ tool.health = MCPToolHealth(status=health_status.value, reasons=reasons)
76
+
61
77
  return tools
62
78
  except Exception as e:
63
79
  # MCP tool listing errors are often due to connection/configuration issues, not system errors
@@ -65,7 +81,8 @@ class MCPManager:
65
81
  logger.info(f"Error listing tools for MCP server {mcp_server_name}: {e}")
66
82
  return []
67
83
  finally:
68
- await mcp_client.cleanup()
84
+ if mcp_client:
85
+ await mcp_client.cleanup()
69
86
 
70
87
  @enforce_types
71
88
  async def execute_mcp_server_tool(
@@ -114,7 +131,16 @@ class MCPManager:
114
131
  mcp_tools = await self.list_mcp_server_tools(mcp_server_name, actor=actor)
115
132
 
116
133
  for mcp_tool in mcp_tools:
134
+ # TODO: @jnjpng move health check to tool class
117
135
  if mcp_tool.name == mcp_tool_name:
136
+ # Check tool health - reject only INVALID tools
137
+ if mcp_tool.health:
138
+ if mcp_tool.health.status == "INVALID":
139
+ raise ValueError(
140
+ f"Tool {mcp_tool_name} cannot be attached, JSON schema is invalid."
141
+ f"Reasons: {', '.join(mcp_tool.health.reasons)}"
142
+ )
143
+
118
144
  tool_create = ToolCreate.from_mcp(mcp_server_name=mcp_server_name, mcp_tool=mcp_tool)
119
145
  return await self.tool_manager.create_mcp_tool_async(
120
146
  tool_create=tool_create, mcp_server_name=mcp_server_name, mcp_server_id=mcp_server_id, actor=actor
@@ -169,17 +195,50 @@ class MCPManager:
169
195
  async def create_mcp_server(self, pydantic_mcp_server: MCPServer, actor: PydanticUser) -> MCPServer:
170
196
  """Create a new MCP server."""
171
197
  async with db_registry.async_session() as session:
172
- # Set the organization id at the ORM layer
173
- pydantic_mcp_server.organization_id = actor.organization_id
174
- mcp_server_data = pydantic_mcp_server.model_dump(to_orm=True)
198
+ try:
199
+ # Set the organization id at the ORM layer
200
+ pydantic_mcp_server.organization_id = actor.organization_id
201
+ mcp_server_data = pydantic_mcp_server.model_dump(to_orm=True)
202
+
203
+ # Ensure custom_headers None is stored as SQL NULL, not JSON null
204
+ if mcp_server_data.get("custom_headers") is None:
205
+ mcp_server_data.pop("custom_headers", None)
206
+
207
+ mcp_server = MCPServerModel(**mcp_server_data)
208
+ mcp_server = await mcp_server.create_async(session, actor=actor, no_commit=True)
209
+
210
+ # Link existing OAuth sessions for the same user and server URL
211
+ # This ensures OAuth sessions created during testing get linked to the server
212
+ server_url = getattr(mcp_server, "server_url", None)
213
+ if server_url:
214
+ from sqlalchemy import select
215
+
216
+ result = await session.execute(
217
+ select(MCPOAuth).where(
218
+ MCPOAuth.server_url == server_url,
219
+ MCPOAuth.organization_id == actor.organization_id,
220
+ MCPOAuth.user_id == actor.id, # Only link sessions for the same user
221
+ MCPOAuth.server_id.is_(None), # Only update sessions not already linked
222
+ )
223
+ )
224
+ oauth_sessions = result.scalars().all()
175
225
 
176
- # Ensure custom_headers None is stored as SQL NULL, not JSON null
177
- if mcp_server_data.get("custom_headers") is None:
178
- mcp_server_data.pop("custom_headers", None)
226
+ # TODO: @jnjpng we should upate sessions in bulk
227
+ for oauth_session in oauth_sessions:
228
+ oauth_session.server_id = mcp_server.id
229
+ await oauth_session.update_async(db_session=session, actor=actor, no_commit=True)
179
230
 
180
- mcp_server = MCPServerModel(**mcp_server_data)
181
- mcp_server = await mcp_server.create_async(session, actor=actor)
182
- return mcp_server.to_pydantic()
231
+ if oauth_sessions:
232
+ logger.info(
233
+ f"Linked {len(oauth_sessions)} OAuth sessions to MCP server {mcp_server.id} (URL: {server_url}) for user {actor.id}"
234
+ )
235
+
236
+ await session.commit()
237
+ return mcp_server.to_pydantic()
238
+ except Exception as e:
239
+ await session.rollback()
240
+ logger.error(f"Failed to create MCP server: {e}")
241
+ raise
183
242
 
184
243
  @enforce_types
185
244
  async def update_mcp_server_by_id(self, mcp_server_id: str, mcp_server_update: UpdateMCPServer, actor: PydanticUser) -> MCPServer:
@@ -252,7 +311,7 @@ class MCPManager:
252
311
 
253
312
  @enforce_types
254
313
  async def get_mcp_server(self, mcp_server_name: str, actor: PydanticUser) -> PydanticTool:
255
- """Get a tool by name."""
314
+ """Get a MCP server by name."""
256
315
  async with db_registry.async_session() as session:
257
316
  mcp_server_id = await self.get_mcp_server_id_by_name(mcp_server_name, actor)
258
317
  mcp_server = await MCPServerModel.read_async(db_session=session, identifier=mcp_server_id, actor=actor)
@@ -286,13 +345,48 @@ class MCPManager:
286
345
 
287
346
  @enforce_types
288
347
  async def delete_mcp_server_by_id(self, mcp_server_id: str, actor: PydanticUser) -> None:
289
- """Delete a tool by its ID."""
348
+ """Delete a MCP server by its ID."""
290
349
  async with db_registry.async_session() as session:
291
350
  try:
292
351
  mcp_server = await MCPServerModel.read_async(db_session=session, identifier=mcp_server_id, actor=actor)
293
- await mcp_server.hard_delete_async(db_session=session, actor=actor)
352
+ if not mcp_server:
353
+ raise NoResultFound(f"MCP server with id {mcp_server_id} not found.")
354
+
355
+ server_url = getattr(mcp_server, "server_url", None)
356
+
357
+ # Delete OAuth sessions for the same user and server URL in the same transaction
358
+ # This handles orphaned sessions that were created during testing/connection
359
+ oauth_count = 0
360
+ if server_url:
361
+ result = await session.execute(
362
+ delete(MCPOAuth).where(
363
+ MCPOAuth.server_url == server_url,
364
+ MCPOAuth.organization_id == actor.organization_id,
365
+ MCPOAuth.user_id == actor.id, # Only delete sessions for the same user
366
+ )
367
+ )
368
+ oauth_count = result.rowcount
369
+ if oauth_count > 0:
370
+ logger.info(
371
+ f"Deleting {oauth_count} OAuth sessions for MCP server {mcp_server_id} (URL: {server_url}) for user {actor.id}"
372
+ )
373
+
374
+ # Delete the MCP server, will cascade delete to linked OAuth sessions
375
+ await session.execute(
376
+ delete(MCPServerModel).where(
377
+ MCPServerModel.id == mcp_server_id,
378
+ MCPServerModel.organization_id == actor.organization_id,
379
+ )
380
+ )
381
+
382
+ await session.commit()
294
383
  except NoResultFound:
384
+ await session.rollback()
295
385
  raise ValueError(f"MCP server with id {mcp_server_id} not found.")
386
+ except Exception as e:
387
+ await session.rollback()
388
+ logger.error(f"Failed to delete MCP server {mcp_server_id}: {e}")
389
+ raise
296
390
 
297
391
  def read_mcp_config(self) -> dict[str, Union[SSEServerConfig, StdioServerConfig, StreamableHTTPServerConfig]]:
298
392
  mcp_server_list = {}
@@ -1,4 +1,3 @@
1
- import asyncio
2
1
  from datetime import datetime, timezone
3
2
  from functools import lru_cache
4
3
  from typing import List, Optional
@@ -7,8 +6,9 @@ from openai import AsyncOpenAI, OpenAI
7
6
  from sqlalchemy import select
8
7
 
9
8
  from letta.constants import MAX_EMBEDDING_DIM
10
- from letta.embeddings import embedding_model, parse_and_chunk_text
9
+ from letta.embeddings import parse_and_chunk_text
11
10
  from letta.helpers.decorators import async_redis_cache
11
+ from letta.llm_api.llm_client import LLMClient
12
12
  from letta.orm import ArchivesAgents
13
13
  from letta.orm.errors import NoResultFound
14
14
  from letta.orm.passage import ArchivalPassage, SourcePassage
@@ -460,7 +460,7 @@ class PassageManager:
460
460
 
461
461
  @enforce_types
462
462
  @trace_method
463
- def insert_passage(
463
+ async def insert_passage(
464
464
  self,
465
465
  agent_state: AgentState,
466
466
  text: str,
@@ -469,45 +469,32 @@ class PassageManager:
469
469
  """Insert passage(s) into archival memory"""
470
470
 
471
471
  embedding_chunk_size = agent_state.embedding_config.embedding_chunk_size
472
+ embedding_client = LLMClient.create(
473
+ provider_type=agent_state.embedding_config.embedding_endpoint_type,
474
+ actor=actor,
475
+ )
472
476
 
473
- # TODO eventually migrate off of llama-index for embeddings?
474
- # Already causing pain for OpenAI proxy endpoints like LM Studio...
475
- if agent_state.embedding_config.embedding_endpoint_type != "openai":
476
- embed_model = embedding_model(agent_state.embedding_config)
477
+ # Get or create the default archive for the agent
478
+ archive = await self.archive_manager.get_or_create_default_archive_for_agent_async(
479
+ agent_id=agent_state.id, agent_name=agent_state.name, actor=actor
480
+ )
477
481
 
478
- passages = []
482
+ text_chunks = list(parse_and_chunk_text(text, embedding_chunk_size))
479
483
 
480
- try:
481
- # breakup string into passages
482
- for text in parse_and_chunk_text(text, embedding_chunk_size):
483
- if agent_state.embedding_config.embedding_endpoint_type != "openai":
484
- embedding = embed_model.get_text_embedding(text)
485
- else:
486
- # TODO should have the settings passed in via the server call
487
- embedding = get_openai_embedding(
488
- text,
489
- agent_state.embedding_config.embedding_model,
490
- agent_state.embedding_config.embedding_endpoint,
491
- )
484
+ if not text_chunks:
485
+ return []
492
486
 
493
- if isinstance(embedding, dict):
494
- try:
495
- embedding = embedding["data"][0]["embedding"]
496
- except (KeyError, IndexError):
497
- # TODO as a fallback, see if we can find any lists in the payload
498
- raise TypeError(
499
- f"Got back an unexpected payload from text embedding function, type={type(embedding)}, value={embedding}"
500
- )
501
- # Get or create the default archive for the agent
502
- archive = self.archive_manager.get_or_create_default_archive_for_agent(
503
- agent_id=agent_state.id, agent_name=agent_state.name, actor=actor
504
- )
487
+ try:
488
+ # Generate embeddings for all chunks using the new async API
489
+ embeddings = await embedding_client.request_embeddings(text_chunks, agent_state.embedding_config)
505
490
 
506
- passage = self.create_agent_passage(
491
+ passages = []
492
+ for chunk_text, embedding in zip(text_chunks, embeddings):
493
+ passage = await self.create_agent_passage_async(
507
494
  PydanticPassage(
508
495
  organization_id=actor.organization_id,
509
496
  archive_id=archive.id,
510
- text=text,
497
+ text=chunk_text,
511
498
  embedding=embedding,
512
499
  embedding_config=agent_state.embedding_config,
513
500
  ),
@@ -520,84 +507,16 @@ class PassageManager:
520
507
  except Exception as e:
521
508
  raise e
522
509
 
523
- @enforce_types
524
- @trace_method
525
- async def insert_passage_async(
526
- self,
527
- agent_state: AgentState,
528
- text: str,
529
- actor: PydanticUser,
530
- image_ids: Optional[List[str]] = None,
531
- ) -> List[PydanticPassage]:
532
- """Insert passage(s) into archival memory"""
533
- # Get or create default archive for the agent
534
- archive = await self.archive_manager.get_or_create_default_archive_for_agent_async(
535
- agent_id=agent_state.id,
536
- agent_name=agent_state.name,
510
+ async def _generate_embeddings_concurrent(self, text_chunks: List[str], embedding_config, actor: PydanticUser) -> List[List[float]]:
511
+ """Generate embeddings for all text chunks concurrently using LLMClient"""
512
+
513
+ embedding_client = LLMClient.create(
514
+ provider_type=embedding_config.embedding_endpoint_type,
537
515
  actor=actor,
538
516
  )
539
- archive_id = archive.id
540
-
541
- embedding_chunk_size = agent_state.embedding_config.embedding_chunk_size
542
- text_chunks = list(parse_and_chunk_text(text, embedding_chunk_size))
543
-
544
- if not text_chunks:
545
- return []
546
-
547
- try:
548
- embeddings = await self._generate_embeddings_concurrent(text_chunks, agent_state.embedding_config)
549
-
550
- passages = [
551
- PydanticPassage(
552
- organization_id=actor.organization_id,
553
- archive_id=archive_id,
554
- text=chunk_text,
555
- embedding=embedding,
556
- embedding_config=agent_state.embedding_config,
557
- )
558
- for chunk_text, embedding in zip(text_chunks, embeddings)
559
- ]
560
-
561
- passages = await self.create_many_archival_passages_async(passages=passages, actor=actor)
562
-
563
- return passages
564
-
565
- except Exception as e:
566
- raise e
567
-
568
- async def _generate_embeddings_concurrent(self, text_chunks: List[str], embedding_config) -> List[List[float]]:
569
- """Generate embeddings for all text chunks concurrently"""
570
-
571
- if embedding_config.embedding_endpoint_type != "openai":
572
- embed_model = embedding_model(embedding_config)
573
- loop = asyncio.get_event_loop()
574
-
575
- tasks = [loop.run_in_executor(None, embed_model.get_text_embedding, text) for text in text_chunks]
576
- embeddings = await asyncio.gather(*tasks)
577
- else:
578
- tasks = [
579
- get_openai_embedding_async(
580
- text,
581
- embedding_config.embedding_model,
582
- embedding_config.embedding_endpoint,
583
- )
584
- for text in text_chunks
585
- ]
586
- embeddings = await asyncio.gather(*tasks)
587
-
588
- processed_embeddings = []
589
- for embedding in embeddings:
590
- if isinstance(embedding, dict):
591
- try:
592
- processed_embeddings.append(embedding["data"][0]["embedding"])
593
- except (KeyError, IndexError):
594
- raise TypeError(
595
- f"Got back an unexpected payload from text embedding function, type={type(embedding)}, value={embedding}"
596
- )
597
- else:
598
- processed_embeddings.append(embedding)
599
517
 
600
- return processed_embeddings
518
+ embeddings = await embedding_client.request_embeddings(text_chunks, embedding_config)
519
+ return embeddings
601
520
 
602
521
  @enforce_types
603
522
  @trace_method
@@ -205,6 +205,28 @@ class ProviderManager:
205
205
  region = providers[0].region if providers else None
206
206
  return access_key, secret_key, region
207
207
 
208
+ @enforce_types
209
+ @trace_method
210
+ def get_azure_credentials(
211
+ self, provider_name: Union[str, None], actor: PydanticUser
212
+ ) -> Tuple[Optional[str], Optional[str], Optional[str]]:
213
+ providers = self.list_providers(name=provider_name, actor=actor)
214
+ api_key = providers[0].api_key if providers else None
215
+ base_url = providers[0].base_url if providers else None
216
+ api_version = providers[0].api_version if providers else None
217
+ return api_key, base_url, api_version
218
+
219
+ @enforce_types
220
+ @trace_method
221
+ async def get_azure_credentials_async(
222
+ self, provider_name: Union[str, None], actor: PydanticUser
223
+ ) -> Tuple[Optional[str], Optional[str], Optional[str]]:
224
+ providers = await self.list_providers_async(name=provider_name, actor=actor)
225
+ api_key = providers[0].api_key if providers else None
226
+ base_url = providers[0].base_url if providers else None
227
+ api_version = providers[0].api_version if providers else None
228
+ return api_key, base_url, api_version
229
+
208
230
  @enforce_types
209
231
  @trace_method
210
232
  async def check_provider_api_key(self, provider_check: ProviderCheck) -> None:
@@ -215,6 +237,8 @@ class ProviderManager:
215
237
  provider_category=ProviderCategory.byok,
216
238
  access_key=provider_check.access_key, # This contains the access key ID for Bedrock
217
239
  region=provider_check.region,
240
+ base_url=provider_check.base_url,
241
+ api_version=provider_check.api_version,
218
242
  ).cast_to_subtype()
219
243
 
220
244
  # TODO: add more string sanity checks here before we hit actual endpoints