letta-nightly 0.6.43.dev20250320104204__py3-none-any.whl → 0.6.43.dev20250322104133__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of letta-nightly might be problematic. Click here for more details.
- letta/agent.py +2 -2
- letta/agents/ephemeral_memory_agent.py +114 -0
- letta/agents/{low_latency_agent.py → voice_agent.py} +133 -79
- letta/client/client.py +1 -1
- letta/embeddings.py +3 -14
- letta/functions/function_sets/multi_agent.py +46 -1
- letta/functions/helpers.py +10 -57
- letta/functions/mcp_client/base_client.py +7 -9
- letta/functions/mcp_client/exceptions.py +6 -0
- letta/helpers/tool_execution_helper.py +9 -7
- letta/llm_api/anthropic.py +1 -19
- letta/llm_api/aws_bedrock.py +2 -2
- letta/llm_api/azure_openai.py +22 -46
- letta/llm_api/llm_api_tools.py +15 -4
- letta/orm/sqlalchemy_base.py +106 -7
- letta/schemas/openai/chat_completion_request.py +20 -1
- letta/schemas/providers.py +251 -0
- letta/schemas/tool.py +4 -1
- letta/server/rest_api/app.py +1 -11
- letta/server/rest_api/optimistic_json_parser.py +5 -5
- letta/server/rest_api/routers/v1/tools.py +34 -2
- letta/server/rest_api/routers/v1/voice.py +5 -5
- letta/server/server.py +6 -0
- letta/services/agent_manager.py +1 -1
- letta/services/block_manager.py +8 -6
- letta/services/message_manager.py +65 -2
- letta/settings.py +3 -3
- {letta_nightly-0.6.43.dev20250320104204.dist-info → letta_nightly-0.6.43.dev20250322104133.dist-info}/METADATA +4 -4
- {letta_nightly-0.6.43.dev20250320104204.dist-info → letta_nightly-0.6.43.dev20250322104133.dist-info}/RECORD +32 -30
- {letta_nightly-0.6.43.dev20250320104204.dist-info → letta_nightly-0.6.43.dev20250322104133.dist-info}/LICENSE +0 -0
- {letta_nightly-0.6.43.dev20250320104204.dist-info → letta_nightly-0.6.43.dev20250322104133.dist-info}/WHEEL +0 -0
- {letta_nightly-0.6.43.dev20250320104204.dist-info → letta_nightly-0.6.43.dev20250322104133.dist-info}/entry_points.txt +0 -0
letta/functions/helpers.py
CHANGED
|
@@ -93,7 +93,7 @@ def execute_composio_action(
|
|
|
93
93
|
|
|
94
94
|
entity_id = entity_id or os.getenv(COMPOSIO_ENTITY_ENV_VAR_KEY, DEFAULT_ENTITY_ID)
|
|
95
95
|
try:
|
|
96
|
-
composio_toolset = ComposioToolSet(api_key=api_key, entity_id=entity_id)
|
|
96
|
+
composio_toolset = ComposioToolSet(api_key=api_key, entity_id=entity_id, lock=False)
|
|
97
97
|
response = composio_toolset.execute_action(action=action_name, params=args)
|
|
98
98
|
except ApiKeyNotProvidedError:
|
|
99
99
|
raise RuntimeError(
|
|
@@ -533,57 +533,17 @@ def fire_and_forget_send_to_agent(
|
|
|
533
533
|
|
|
534
534
|
|
|
535
535
|
async def _send_message_to_agents_matching_tags_async(
|
|
536
|
-
sender_agent: "Agent",
|
|
536
|
+
sender_agent: "Agent", server: "SyncServer", messages: List[MessageCreate], matching_agents: List["AgentState"]
|
|
537
537
|
) -> List[str]:
|
|
538
|
-
log_telemetry(
|
|
539
|
-
sender_agent.logger,
|
|
540
|
-
"_send_message_to_agents_matching_tags_async start",
|
|
541
|
-
message=message,
|
|
542
|
-
match_all=match_all,
|
|
543
|
-
match_some=match_some,
|
|
544
|
-
)
|
|
545
|
-
server = get_letta_server()
|
|
546
|
-
|
|
547
|
-
augmented_message = (
|
|
548
|
-
f"[Incoming message from agent with ID '{sender_agent.agent_state.id}' - to reply to this message, "
|
|
549
|
-
f"make sure to use the 'send_message' at the end, and the system will notify the sender of your response] "
|
|
550
|
-
f"{message}"
|
|
551
|
-
)
|
|
552
|
-
|
|
553
|
-
# Retrieve up to 100 matching agents
|
|
554
|
-
log_telemetry(
|
|
555
|
-
sender_agent.logger,
|
|
556
|
-
"_send_message_to_agents_matching_tags_async listing agents start",
|
|
557
|
-
message=message,
|
|
558
|
-
match_all=match_all,
|
|
559
|
-
match_some=match_some,
|
|
560
|
-
)
|
|
561
|
-
matching_agents = server.agent_manager.list_agents_matching_tags(actor=sender_agent.user, match_all=match_all, match_some=match_some)
|
|
562
|
-
|
|
563
|
-
log_telemetry(
|
|
564
|
-
sender_agent.logger,
|
|
565
|
-
"_send_message_to_agents_matching_tags_async listing agents finish",
|
|
566
|
-
message=message,
|
|
567
|
-
match_all=match_all,
|
|
568
|
-
match_some=match_some,
|
|
569
|
-
)
|
|
570
|
-
|
|
571
|
-
# Create a system message
|
|
572
|
-
messages = [MessageCreate(role=MessageRole.system, content=augmented_message, name=sender_agent.agent_state.name)]
|
|
573
|
-
|
|
574
|
-
# Possibly limit concurrency to avoid meltdown:
|
|
575
|
-
sem = asyncio.Semaphore(settings.multi_agent_concurrent_sends)
|
|
576
|
-
|
|
577
538
|
async def _send_single(agent_state):
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
)
|
|
539
|
+
return await async_send_message_with_retries(
|
|
540
|
+
server=server,
|
|
541
|
+
sender_agent=sender_agent,
|
|
542
|
+
target_agent_id=agent_state.id,
|
|
543
|
+
messages=messages,
|
|
544
|
+
max_retries=3,
|
|
545
|
+
timeout=settings.multi_agent_send_message_timeout,
|
|
546
|
+
)
|
|
587
547
|
|
|
588
548
|
tasks = [asyncio.create_task(_send_single(agent_state)) for agent_state in matching_agents]
|
|
589
549
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
@@ -594,13 +554,6 @@ async def _send_message_to_agents_matching_tags_async(
|
|
|
594
554
|
else:
|
|
595
555
|
final.append(r)
|
|
596
556
|
|
|
597
|
-
log_telemetry(
|
|
598
|
-
sender_agent.logger,
|
|
599
|
-
"_send_message_to_agents_matching_tags_async finish",
|
|
600
|
-
message=message,
|
|
601
|
-
match_all=match_all,
|
|
602
|
-
match_some=match_some,
|
|
603
|
-
)
|
|
604
557
|
return final
|
|
605
558
|
|
|
606
559
|
|
|
@@ -1,9 +1,10 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
from typing import List, Optional, Tuple
|
|
3
3
|
|
|
4
|
-
from mcp import ClientSession
|
|
4
|
+
from mcp import ClientSession
|
|
5
5
|
|
|
6
|
-
from letta.functions.mcp_client.
|
|
6
|
+
from letta.functions.mcp_client.exceptions import MCPTimeoutError
|
|
7
|
+
from letta.functions.mcp_client.types import BaseServerConfig, MCPTool
|
|
7
8
|
from letta.log import get_logger
|
|
8
9
|
from letta.settings import tool_settings
|
|
9
10
|
|
|
@@ -31,9 +32,7 @@ class BaseMCPClient:
|
|
|
31
32
|
)
|
|
32
33
|
self.initialized = True
|
|
33
34
|
except asyncio.TimeoutError:
|
|
34
|
-
raise
|
|
35
|
-
f"Timed out while initializing session for MCP server {self.server_config.server_name} (timeout={tool_settings.mcp_connect_to_server_timeout}s)."
|
|
36
|
-
)
|
|
35
|
+
raise MCPTimeoutError("initializing session", self.server_config.server_name, tool_settings.mcp_connect_to_server_timeout)
|
|
37
36
|
else:
|
|
38
37
|
raise RuntimeError(
|
|
39
38
|
f"Connecting to MCP server failed. Please review your server config: {self.server_config.model_dump_json(indent=4)}"
|
|
@@ -42,7 +41,7 @@ class BaseMCPClient:
|
|
|
42
41
|
def _initialize_connection(self, server_config: BaseServerConfig, timeout: float) -> bool:
|
|
43
42
|
raise NotImplementedError("Subclasses must implement _initialize_connection")
|
|
44
43
|
|
|
45
|
-
def list_tools(self) -> List[
|
|
44
|
+
def list_tools(self) -> List[MCPTool]:
|
|
46
45
|
self._check_initialized()
|
|
47
46
|
try:
|
|
48
47
|
response = self.loop.run_until_complete(
|
|
@@ -50,11 +49,10 @@ class BaseMCPClient:
|
|
|
50
49
|
)
|
|
51
50
|
return response.tools
|
|
52
51
|
except asyncio.TimeoutError:
|
|
53
|
-
# Could log, throw a custom exception, etc.
|
|
54
52
|
logger.error(
|
|
55
53
|
f"Timed out while listing tools for MCP server {self.server_config.server_name} (timeout={tool_settings.mcp_list_tools_timeout}s)."
|
|
56
54
|
)
|
|
57
|
-
|
|
55
|
+
raise MCPTimeoutError("listing tools", self.server_config.server_name, tool_settings.mcp_list_tools_timeout)
|
|
58
56
|
|
|
59
57
|
def execute_tool(self, tool_name: str, tool_args: dict) -> Tuple[str, bool]:
|
|
60
58
|
self._check_initialized()
|
|
@@ -67,7 +65,7 @@ class BaseMCPClient:
|
|
|
67
65
|
logger.error(
|
|
68
66
|
f"Timed out while executing tool '{tool_name}' for MCP server {self.server_config.server_name} (timeout={tool_settings.mcp_execute_tool_timeout}s)."
|
|
69
67
|
)
|
|
70
|
-
|
|
68
|
+
raise MCPTimeoutError(f"executing tool '{tool_name}'", self.server_config.server_name, tool_settings.mcp_execute_tool_timeout)
|
|
71
69
|
|
|
72
70
|
def _check_initialized(self):
|
|
73
71
|
if not self.initialized:
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
class MCPTimeoutError(RuntimeError):
|
|
2
|
+
"""Custom exception raised when an MCP operation times out."""
|
|
3
|
+
|
|
4
|
+
def __init__(self, operation: str, server_name: str, timeout: float):
|
|
5
|
+
message = f"Timed out while {operation} for MCP server {server_name} (timeout={timeout}s)."
|
|
6
|
+
super().__init__(message)
|
|
@@ -36,11 +36,10 @@ def enable_strict_mode(tool_schema: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
36
36
|
# Set additionalProperties to False
|
|
37
37
|
parameters["additionalProperties"] = False
|
|
38
38
|
schema["parameters"] = parameters
|
|
39
|
-
|
|
40
39
|
return schema
|
|
41
40
|
|
|
42
41
|
|
|
43
|
-
def add_pre_execution_message(tool_schema: Dict[str, Any]) -> Dict[str, Any]:
|
|
42
|
+
def add_pre_execution_message(tool_schema: Dict[str, Any], description: Optional[str] = None) -> Dict[str, Any]:
|
|
44
43
|
"""Adds a `pre_execution_message` parameter to a tool schema to prompt a natural, human-like message before executing the tool.
|
|
45
44
|
|
|
46
45
|
Args:
|
|
@@ -58,14 +57,17 @@ def add_pre_execution_message(tool_schema: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
58
57
|
properties = parameters.get("properties", {})
|
|
59
58
|
required = parameters.get("required", [])
|
|
60
59
|
|
|
61
|
-
# Define the new `pre_execution_message` field
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
60
|
+
# Define the new `pre_execution_message` field
|
|
61
|
+
if not description:
|
|
62
|
+
# Default description
|
|
63
|
+
description = (
|
|
65
64
|
"A concise message to be uttered before executing this tool. "
|
|
66
65
|
"This should sound natural, as if a person is casually announcing their next action."
|
|
67
66
|
"You MUST also include punctuation at the end of this message."
|
|
68
|
-
)
|
|
67
|
+
)
|
|
68
|
+
pre_execution_message_field = {
|
|
69
|
+
"type": "string",
|
|
70
|
+
"description": description,
|
|
69
71
|
}
|
|
70
72
|
|
|
71
73
|
# Ensure the pre-execution message is the first field in properties
|
letta/llm_api/anthropic.py
CHANGED
|
@@ -606,25 +606,6 @@ def _prepare_anthropic_request(
|
|
|
606
606
|
# TODO eventually enable parallel tool use
|
|
607
607
|
data["tools"] = anthropic_tools
|
|
608
608
|
|
|
609
|
-
# tool_choice_type other than "auto" only plays nice if thinking goes inside the tool calls
|
|
610
|
-
if put_inner_thoughts_in_kwargs:
|
|
611
|
-
if len(anthropic_tools) == 1:
|
|
612
|
-
data["tool_choice"] = {
|
|
613
|
-
"type": "tool",
|
|
614
|
-
"name": anthropic_tools[0]["name"],
|
|
615
|
-
"disable_parallel_tool_use": True,
|
|
616
|
-
}
|
|
617
|
-
else:
|
|
618
|
-
data["tool_choice"] = {
|
|
619
|
-
"type": "any",
|
|
620
|
-
"disable_parallel_tool_use": True,
|
|
621
|
-
}
|
|
622
|
-
else:
|
|
623
|
-
data["tool_choice"] = {
|
|
624
|
-
"type": "auto",
|
|
625
|
-
"disable_parallel_tool_use": True,
|
|
626
|
-
}
|
|
627
|
-
|
|
628
609
|
# Move 'system' to the top level
|
|
629
610
|
assert data["messages"][0]["role"] == "system", f"Expected 'system' role in messages[0]:\n{data['messages'][0]}"
|
|
630
611
|
data["system"] = data["messages"][0]["content"]
|
|
@@ -720,6 +701,7 @@ def anthropic_bedrock_chat_completions_request(
|
|
|
720
701
|
# Make the request
|
|
721
702
|
try:
|
|
722
703
|
# bedrock does not support certain args
|
|
704
|
+
print("Warning: Tool rules not supported with Anthropic Bedrock")
|
|
723
705
|
data["tool_choice"] = {"type": "any"}
|
|
724
706
|
log_event(name="llm_request_sent", attributes=data)
|
|
725
707
|
response = client.messages.create(**data)
|
letta/llm_api/aws_bedrock.py
CHANGED
|
@@ -13,7 +13,7 @@ def has_valid_aws_credentials() -> bool:
|
|
|
13
13
|
"""
|
|
14
14
|
Check if AWS credentials are properly configured.
|
|
15
15
|
"""
|
|
16
|
-
valid_aws_credentials = os.getenv("AWS_ACCESS_KEY") and os.getenv("AWS_SECRET_ACCESS_KEY") and os.getenv("AWS_REGION")
|
|
16
|
+
valid_aws_credentials = os.getenv("AWS_ACCESS_KEY") is not None and os.getenv("AWS_SECRET_ACCESS_KEY") is not None and os.getenv("AWS_REGION") is not None
|
|
17
17
|
return valid_aws_credentials
|
|
18
18
|
|
|
19
19
|
|
|
@@ -78,7 +78,7 @@ def bedrock_get_model_details(region_name: str, model_id: str) -> Dict[str, Any]
|
|
|
78
78
|
response = bedrock.get_foundation_model(modelIdentifier=model_id)
|
|
79
79
|
return response["modelDetails"]
|
|
80
80
|
except ClientError as e:
|
|
81
|
-
logger.exception(f"Error getting model details: {str(e)}"
|
|
81
|
+
logger.exception(f"Error getting model details: {str(e)}")
|
|
82
82
|
raise e
|
|
83
83
|
|
|
84
84
|
|
letta/llm_api/azure_openai.py
CHANGED
|
@@ -1,14 +1,14 @@
|
|
|
1
1
|
from collections import defaultdict
|
|
2
2
|
|
|
3
3
|
import requests
|
|
4
|
+
from openai import AzureOpenAI
|
|
4
5
|
|
|
5
|
-
|
|
6
|
+
|
|
7
|
+
from letta.llm_api.openai import prepare_openai_payload
|
|
6
8
|
from letta.schemas.llm_config import LLMConfig
|
|
7
9
|
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
|
|
8
10
|
from letta.schemas.openai.chat_completions import ChatCompletionRequest
|
|
9
|
-
from letta.schemas.openai.embedding_response import EmbeddingResponse
|
|
10
11
|
from letta.settings import ModelSettings
|
|
11
|
-
from letta.tracing import log_event
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
def get_azure_chat_completions_endpoint(base_url: str, model: str, api_version: str):
|
|
@@ -33,19 +33,19 @@ def get_azure_deployment_list_endpoint(base_url: str):
|
|
|
33
33
|
def azure_openai_get_deployed_model_list(base_url: str, api_key: str, api_version: str) -> list:
|
|
34
34
|
"""https://learn.microsoft.com/en-us/rest/api/azureopenai/models/list?view=rest-azureopenai-2023-05-15&tabs=HTTP"""
|
|
35
35
|
|
|
36
|
-
|
|
37
|
-
headers = {"Content-Type": "application/json"}
|
|
38
|
-
if api_key is not None:
|
|
39
|
-
headers["api-key"] = f"{api_key}"
|
|
36
|
+
client = AzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=base_url)
|
|
40
37
|
|
|
41
|
-
# 1. Get all available models
|
|
42
|
-
url = get_azure_model_list_endpoint(base_url, api_version)
|
|
43
38
|
try:
|
|
44
|
-
|
|
45
|
-
response.raise_for_status()
|
|
39
|
+
models_list = client.models.list()
|
|
46
40
|
except requests.RequestException as e:
|
|
47
41
|
raise RuntimeError(f"Failed to retrieve model list: {e}")
|
|
48
|
-
|
|
42
|
+
|
|
43
|
+
all_available_models = [model.to_dict() for model in models_list.data]
|
|
44
|
+
|
|
45
|
+
# https://xxx.openai.azure.com/openai/models?api-version=xxx
|
|
46
|
+
headers = {"Content-Type": "application/json"}
|
|
47
|
+
if api_key is not None:
|
|
48
|
+
headers["api-key"] = f"{api_key}"
|
|
49
49
|
|
|
50
50
|
# 2. Get all the deployed models
|
|
51
51
|
url = get_azure_deployment_list_endpoint(base_url)
|
|
@@ -102,42 +102,18 @@ def azure_openai_get_embeddings_model_list(base_url: str, api_key: str, api_vers
|
|
|
102
102
|
|
|
103
103
|
|
|
104
104
|
def azure_openai_chat_completions_request(
|
|
105
|
-
model_settings: ModelSettings, llm_config: LLMConfig,
|
|
105
|
+
model_settings: ModelSettings, llm_config: LLMConfig, chat_completion_request: ChatCompletionRequest
|
|
106
106
|
) -> ChatCompletionResponse:
|
|
107
107
|
"""https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions"""
|
|
108
108
|
|
|
109
|
-
assert
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
data = chat_completion_request.model_dump(exclude_none=True)
|
|
113
|
-
|
|
114
|
-
# If functions == None, strip from the payload
|
|
115
|
-
if "functions" in data and data["functions"] is None:
|
|
116
|
-
data.pop("functions")
|
|
117
|
-
data.pop("function_call", None) # extra safe, should exist always (default="auto")
|
|
118
|
-
|
|
119
|
-
if "tools" in data and data["tools"] is None:
|
|
120
|
-
data.pop("tools")
|
|
121
|
-
data.pop("tool_choice", None) # extra safe, should exist always (default="auto")
|
|
122
|
-
|
|
123
|
-
url = get_azure_chat_completions_endpoint(model_settings.azure_base_url, llm_config.model, model_settings.azure_api_version)
|
|
124
|
-
log_event(name="llm_request_sent", attributes=data)
|
|
125
|
-
response_json = make_post_request(url, headers, data)
|
|
126
|
-
# NOTE: azure openai does not include "content" in the response when it is None, so we need to add it
|
|
127
|
-
if "content" not in response_json["choices"][0].get("message"):
|
|
128
|
-
response_json["choices"][0]["message"]["content"] = None
|
|
129
|
-
log_event(name="llm_response_received", attributes=response_json)
|
|
130
|
-
response = ChatCompletionResponse(**response_json) # convert to 'dot-dict' style which is the openai python client default
|
|
131
|
-
return response
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
def azure_openai_embeddings_request(
|
|
135
|
-
resource_name: str, deployment_id: str, api_version: str, api_key: str, data: dict
|
|
136
|
-
) -> EmbeddingResponse:
|
|
137
|
-
"""https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#embeddings"""
|
|
109
|
+
assert model_settings.azure_api_key is not None, "Missing required api key field when calling Azure OpenAI"
|
|
110
|
+
assert model_settings.azure_api_version is not None, "Missing required api version field when calling Azure OpenAI"
|
|
111
|
+
assert model_settings.azure_base_url is not None, "Missing required base url field when calling Azure OpenAI"
|
|
138
112
|
|
|
139
|
-
|
|
140
|
-
|
|
113
|
+
data = prepare_openai_payload(chat_completion_request)
|
|
114
|
+
client = AzureOpenAI(
|
|
115
|
+
api_key=model_settings.azure_api_key, api_version=model_settings.azure_api_version, azure_endpoint=model_settings.azure_base_url
|
|
116
|
+
)
|
|
117
|
+
chat_completion = client.chat.completions.create(**data)
|
|
141
118
|
|
|
142
|
-
|
|
143
|
-
return EmbeddingResponse(**response_json)
|
|
119
|
+
return ChatCompletionResponse(**chat_completion.model_dump())
|
letta/llm_api/llm_api_tools.py
CHANGED
|
@@ -306,7 +306,6 @@ def create(
|
|
|
306
306
|
response = azure_openai_chat_completions_request(
|
|
307
307
|
model_settings=model_settings,
|
|
308
308
|
llm_config=llm_config,
|
|
309
|
-
api_key=model_settings.azure_api_key,
|
|
310
309
|
chat_completion_request=chat_completion_request,
|
|
311
310
|
)
|
|
312
311
|
|
|
@@ -374,14 +373,26 @@ def create(
|
|
|
374
373
|
# Force tool calling
|
|
375
374
|
tool_call = None
|
|
376
375
|
if force_tool_call is not None:
|
|
377
|
-
tool_call = {"type": "function", "function": {"name": force_tool_call}}
|
|
376
|
+
# tool_call = {"type": "function", "function": {"name": force_tool_call}}
|
|
377
|
+
tool_choice = {"type": "tool", "name": force_tool_call}
|
|
378
|
+
tools = [{"type": "function", "function": f} for f in functions if f["name"] == force_tool_call]
|
|
378
379
|
assert functions is not None
|
|
379
380
|
|
|
381
|
+
# need to have this setting to be able to put inner thoughts in kwargs
|
|
382
|
+
llm_config.put_inner_thoughts_in_kwargs = True
|
|
383
|
+
else:
|
|
384
|
+
if llm_config.put_inner_thoughts_in_kwargs:
|
|
385
|
+
# tool_choice_type other than "auto" only plays nice if thinking goes inside the tool calls
|
|
386
|
+
tool_choice = {"type": "any", "disable_parallel_tool_use": True}
|
|
387
|
+
else:
|
|
388
|
+
tool_choice = {"type": "auto", "disable_parallel_tool_use": True}
|
|
389
|
+
tools = [{"type": "function", "function": f} for f in functions]
|
|
390
|
+
|
|
380
391
|
chat_completion_request = ChatCompletionRequest(
|
|
381
392
|
model=llm_config.model,
|
|
382
393
|
messages=[cast_message_to_subtype(m.to_openai_dict()) for m in messages],
|
|
383
|
-
tools=
|
|
384
|
-
tool_choice=
|
|
394
|
+
tools=tools,
|
|
395
|
+
tool_choice=tool_choice,
|
|
385
396
|
max_tokens=llm_config.max_tokens, # Note: max_tokens is required for Anthropic API
|
|
386
397
|
temperature=llm_config.temperature,
|
|
387
398
|
stream=stream,
|
letta/orm/sqlalchemy_base.py
CHANGED
|
@@ -286,7 +286,45 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|
|
286
286
|
Raises:
|
|
287
287
|
NoResultFound: if the object is not found
|
|
288
288
|
"""
|
|
289
|
-
|
|
289
|
+
# this is ok because read_multiple will check if the
|
|
290
|
+
identifiers = [] if identifier is None else [identifier]
|
|
291
|
+
found = cls.read_multiple(db_session, identifiers, actor, access, access_type, **kwargs)
|
|
292
|
+
if len(found) == 0:
|
|
293
|
+
# for backwards compatibility.
|
|
294
|
+
conditions = []
|
|
295
|
+
if identifier:
|
|
296
|
+
conditions.append(f"id={identifier}")
|
|
297
|
+
if actor:
|
|
298
|
+
conditions.append(f"access level in {access} for {actor}")
|
|
299
|
+
if hasattr(cls, "is_deleted"):
|
|
300
|
+
conditions.append("is_deleted=False")
|
|
301
|
+
raise NoResultFound(f"{cls.__name__} not found with {', '.join(conditions if conditions else ['no conditions'])}")
|
|
302
|
+
return found[0]
|
|
303
|
+
|
|
304
|
+
@classmethod
|
|
305
|
+
@handle_db_timeout
|
|
306
|
+
def read_multiple(
|
|
307
|
+
cls,
|
|
308
|
+
db_session: "Session",
|
|
309
|
+
identifiers: List[str] = [],
|
|
310
|
+
actor: Optional["User"] = None,
|
|
311
|
+
access: Optional[List[Literal["read", "write", "admin"]]] = ["read"],
|
|
312
|
+
access_type: AccessType = AccessType.ORGANIZATION,
|
|
313
|
+
**kwargs,
|
|
314
|
+
) -> List["SqlalchemyBase"]:
|
|
315
|
+
"""The primary accessor for ORM record(s)
|
|
316
|
+
Args:
|
|
317
|
+
db_session: the database session to use when retrieving the record
|
|
318
|
+
identifiers: a list of identifiers of the records to read, can be the id string or the UUID object for backwards compatibility
|
|
319
|
+
actor: if specified, results will be scoped only to records the user is able to access
|
|
320
|
+
access: if actor is specified, records will be filtered to the minimum permission level for the actor
|
|
321
|
+
kwargs: additional arguments to pass to the read, used for more complex objects
|
|
322
|
+
Returns:
|
|
323
|
+
The matching object
|
|
324
|
+
Raises:
|
|
325
|
+
NoResultFound: if the object is not found
|
|
326
|
+
"""
|
|
327
|
+
logger.debug(f"Reading {cls.__name__} with ID(s): {identifiers} with actor={actor}")
|
|
290
328
|
|
|
291
329
|
# Start the query
|
|
292
330
|
query = select(cls)
|
|
@@ -294,9 +332,9 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|
|
294
332
|
query_conditions = []
|
|
295
333
|
|
|
296
334
|
# If an identifier is provided, add it to the query conditions
|
|
297
|
-
if
|
|
298
|
-
query = query.where(cls.id
|
|
299
|
-
query_conditions.append(f"id='{
|
|
335
|
+
if len(identifiers) > 0:
|
|
336
|
+
query = query.where(cls.id.in_(identifiers))
|
|
337
|
+
query_conditions.append(f"id='{identifiers}'")
|
|
300
338
|
|
|
301
339
|
if kwargs:
|
|
302
340
|
query = query.filter_by(**kwargs)
|
|
@@ -309,12 +347,29 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|
|
309
347
|
if hasattr(cls, "is_deleted"):
|
|
310
348
|
query = query.where(cls.is_deleted == False)
|
|
311
349
|
query_conditions.append("is_deleted=False")
|
|
312
|
-
|
|
313
|
-
|
|
350
|
+
|
|
351
|
+
results = db_session.execute(query).scalars().all()
|
|
352
|
+
if results: # if empty list a.k.a. no results
|
|
353
|
+
if len(identifiers) > 0:
|
|
354
|
+
# find which identifiers were not found
|
|
355
|
+
# only when identifier length is greater than 0 (so it was used in the actual query)
|
|
356
|
+
identifier_set = set(identifiers)
|
|
357
|
+
results_set = set(map(lambda obj: obj.id, results))
|
|
358
|
+
|
|
359
|
+
# we log a warning message if any of the queried IDs were not found.
|
|
360
|
+
# TODO: should we error out instead?
|
|
361
|
+
if identifier_set != results_set:
|
|
362
|
+
# Construct a detailed error message based on query conditions
|
|
363
|
+
conditions_str = ", ".join(query_conditions) if query_conditions else "no specific conditions"
|
|
364
|
+
logger.warning(
|
|
365
|
+
f"{cls.__name__} not found with {conditions_str}. Queried ids: {identifier_set}, Found ids: {results_set}"
|
|
366
|
+
)
|
|
367
|
+
return results
|
|
314
368
|
|
|
315
369
|
# Construct a detailed error message based on query conditions
|
|
316
370
|
conditions_str = ", ".join(query_conditions) if query_conditions else "no specific conditions"
|
|
317
|
-
|
|
371
|
+
logger.warning(f"{cls.__name__} not found with {conditions_str}")
|
|
372
|
+
return []
|
|
318
373
|
|
|
319
374
|
@handle_db_timeout
|
|
320
375
|
def create(self, db_session: "Session", actor: Optional["User"] = None) -> "SqlalchemyBase":
|
|
@@ -331,6 +386,50 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|
|
331
386
|
except (DBAPIError, IntegrityError) as e:
|
|
332
387
|
self._handle_dbapi_error(e)
|
|
333
388
|
|
|
389
|
+
@classmethod
|
|
390
|
+
@handle_db_timeout
|
|
391
|
+
def batch_create(cls, items: List["SqlalchemyBase"], db_session: "Session", actor: Optional["User"] = None) -> List["SqlalchemyBase"]:
|
|
392
|
+
"""
|
|
393
|
+
Create multiple records in a single transaction for better performance.
|
|
394
|
+
|
|
395
|
+
Args:
|
|
396
|
+
items: List of model instances to create
|
|
397
|
+
db_session: SQLAlchemy session
|
|
398
|
+
actor: Optional user performing the action
|
|
399
|
+
|
|
400
|
+
Returns:
|
|
401
|
+
List of created model instances
|
|
402
|
+
"""
|
|
403
|
+
logger.debug(f"Batch creating {len(items)} {cls.__name__} items with actor={actor}")
|
|
404
|
+
|
|
405
|
+
if not items:
|
|
406
|
+
return []
|
|
407
|
+
|
|
408
|
+
# Set created/updated by fields if actor is provided
|
|
409
|
+
if actor:
|
|
410
|
+
for item in items:
|
|
411
|
+
item._set_created_and_updated_by_fields(actor.id)
|
|
412
|
+
|
|
413
|
+
try:
|
|
414
|
+
with db_session as session:
|
|
415
|
+
session.add_all(items)
|
|
416
|
+
session.flush() # Flush to generate IDs but don't commit yet
|
|
417
|
+
|
|
418
|
+
# Collect IDs to fetch the complete objects after commit
|
|
419
|
+
item_ids = [item.id for item in items]
|
|
420
|
+
|
|
421
|
+
session.commit()
|
|
422
|
+
|
|
423
|
+
# Re-query the objects to get them with relationships loaded
|
|
424
|
+
query = select(cls).where(cls.id.in_(item_ids))
|
|
425
|
+
if hasattr(cls, "created_at"):
|
|
426
|
+
query = query.order_by(cls.created_at)
|
|
427
|
+
|
|
428
|
+
return list(session.execute(query).scalars())
|
|
429
|
+
|
|
430
|
+
except (DBAPIError, IntegrityError) as e:
|
|
431
|
+
cls._handle_dbapi_error(e)
|
|
432
|
+
|
|
334
433
|
@handle_db_timeout
|
|
335
434
|
def delete(self, db_session: "Session", actor: Optional["User"] = None) -> "SqlalchemyBase":
|
|
336
435
|
logger.debug(f"Soft deleting {self.__class__.__name__} with ID: {self.id} with actor={actor}")
|
|
@@ -74,7 +74,25 @@ class ToolFunctionChoice(BaseModel):
|
|
|
74
74
|
function: FunctionCall
|
|
75
75
|
|
|
76
76
|
|
|
77
|
-
|
|
77
|
+
class AnthropicToolChoiceTool(BaseModel):
|
|
78
|
+
type: str = "tool"
|
|
79
|
+
name: str
|
|
80
|
+
disable_parallel_tool_use: Optional[bool] = False
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class AnthropicToolChoiceAny(BaseModel):
|
|
84
|
+
type: str = "any"
|
|
85
|
+
disable_parallel_tool_use: Optional[bool] = False
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class AnthropicToolChoiceAuto(BaseModel):
|
|
89
|
+
type: str = "auto"
|
|
90
|
+
disable_parallel_tool_use: Optional[bool] = False
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
ToolChoice = Union[
|
|
94
|
+
Literal["none", "auto", "required", "any"], ToolFunctionChoice, AnthropicToolChoiceTool, AnthropicToolChoiceAny, AnthropicToolChoiceAuto
|
|
95
|
+
]
|
|
78
96
|
|
|
79
97
|
|
|
80
98
|
## tools ##
|
|
@@ -82,6 +100,7 @@ class FunctionSchema(BaseModel):
|
|
|
82
100
|
name: str
|
|
83
101
|
description: Optional[str] = None
|
|
84
102
|
parameters: Optional[Dict[str, Any]] = None # JSON Schema for the parameters
|
|
103
|
+
strict: bool = False
|
|
85
104
|
|
|
86
105
|
|
|
87
106
|
class Tool(BaseModel):
|