letta-nightly 0.6.5.dev20241219104153__py3-none-any.whl → 0.6.6.dev20241220190343__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of letta-nightly might be problematic. Click here for more details.
- letta/__init__.py +1 -1
- letta/agent.py +20 -1
- letta/client/client.py +1 -1
- letta/client/streaming.py +9 -9
- letta/errors.py +6 -6
- letta/helpers/tool_rule_solver.py +82 -51
- letta/orm/custom_columns.py +5 -2
- letta/providers.py +2 -1
- letta/schemas/enums.py +1 -0
- letta/schemas/letta_message.py +76 -40
- letta/schemas/letta_response.py +9 -1
- letta/schemas/message.py +13 -13
- letta/schemas/tool_rule.py +12 -2
- letta/server/rest_api/interface.py +48 -48
- letta/server/rest_api/routers/openai/chat_completions/chat_completions.py +2 -2
- letta/server/rest_api/routers/v1/agents.py +0 -8
- letta/server/rest_api/routers/v1/tools.py +2 -2
- letta/server/server.py +9 -42
- letta/services/message_manager.py +1 -0
- {letta_nightly-0.6.5.dev20241219104153.dist-info → letta_nightly-0.6.6.dev20241220190343.dist-info}/METADATA +1 -1
- {letta_nightly-0.6.5.dev20241219104153.dist-info → letta_nightly-0.6.6.dev20241220190343.dist-info}/RECORD +24 -24
- {letta_nightly-0.6.5.dev20241219104153.dist-info → letta_nightly-0.6.6.dev20241220190343.dist-info}/LICENSE +0 -0
- {letta_nightly-0.6.5.dev20241219104153.dist-info → letta_nightly-0.6.6.dev20241220190343.dist-info}/WHEEL +0 -0
- {letta_nightly-0.6.5.dev20241219104153.dist-info → letta_nightly-0.6.6.dev20241220190343.dist-info}/entry_points.txt +0 -0
letta/__init__.py
CHANGED
letta/agent.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import datetime
|
|
2
2
|
import inspect
|
|
3
|
+
import json
|
|
3
4
|
import time
|
|
4
5
|
import traceback
|
|
5
6
|
import warnings
|
|
@@ -371,6 +372,9 @@ class Agent(BaseAgent):
|
|
|
371
372
|
self._append_to_messages(added_messages=init_messages_objs)
|
|
372
373
|
self._validate_message_buffer_is_utc()
|
|
373
374
|
|
|
375
|
+
# Load last function response from message history
|
|
376
|
+
self.last_function_response = self.load_last_function_response()
|
|
377
|
+
|
|
374
378
|
# Keep track of the total number of messages throughout all time
|
|
375
379
|
self.messages_total = messages_total if messages_total is not None else (len(self._messages) - 1) # (-system)
|
|
376
380
|
self.messages_total_init = len(self._messages) - 1
|
|
@@ -389,6 +393,19 @@ class Agent(BaseAgent):
|
|
|
389
393
|
else:
|
|
390
394
|
self.supports_structured_output = True
|
|
391
395
|
|
|
396
|
+
def load_last_function_response(self):
|
|
397
|
+
"""Load the last function response from message history"""
|
|
398
|
+
for i in range(len(self._messages) - 1, -1, -1):
|
|
399
|
+
msg = self._messages[i]
|
|
400
|
+
if msg.role == MessageRole.tool and msg.text:
|
|
401
|
+
try:
|
|
402
|
+
response_json = json.loads(msg.text)
|
|
403
|
+
if response_json.get("message"):
|
|
404
|
+
return response_json["message"]
|
|
405
|
+
except (json.JSONDecodeError, KeyError):
|
|
406
|
+
raise ValueError(f"Invalid JSON format in message: {msg.text}")
|
|
407
|
+
return None
|
|
408
|
+
|
|
392
409
|
def update_memory_if_change(self, new_memory: Memory) -> bool:
|
|
393
410
|
"""
|
|
394
411
|
Update internal memory object and system prompt if there have been modifications.
|
|
@@ -586,7 +603,7 @@ class Agent(BaseAgent):
|
|
|
586
603
|
) -> ChatCompletionResponse:
|
|
587
604
|
"""Get response from LLM API with robust retry mechanism."""
|
|
588
605
|
|
|
589
|
-
allowed_tool_names = self.tool_rules_solver.get_allowed_tool_names()
|
|
606
|
+
allowed_tool_names = self.tool_rules_solver.get_allowed_tool_names(last_function_response=self.last_function_response)
|
|
590
607
|
agent_state_tool_jsons = [t.json_schema for t in self.agent_state.tools]
|
|
591
608
|
|
|
592
609
|
allowed_functions = (
|
|
@@ -826,6 +843,7 @@ class Agent(BaseAgent):
|
|
|
826
843
|
error_msg_user = f"{error_msg}\n{traceback.format_exc()}"
|
|
827
844
|
printd(error_msg_user)
|
|
828
845
|
function_response = package_function_response(False, error_msg)
|
|
846
|
+
self.last_function_response = function_response
|
|
829
847
|
# TODO: truncate error message somehow
|
|
830
848
|
messages.append(
|
|
831
849
|
Message.dict_to_message(
|
|
@@ -861,6 +879,7 @@ class Agent(BaseAgent):
|
|
|
861
879
|
) # extend conversation with function response
|
|
862
880
|
self.interface.function_message(f"Ran {function_name}({function_args})", msg_obj=messages[-1])
|
|
863
881
|
self.interface.function_message(f"Success: {function_response_string}", msg_obj=messages[-1])
|
|
882
|
+
self.last_function_response = function_response
|
|
864
883
|
|
|
865
884
|
else:
|
|
866
885
|
# Standard non-function reply
|
letta/client/client.py
CHANGED
|
@@ -2909,7 +2909,7 @@ class LocalClient(AbstractClient):
|
|
|
2909
2909
|
job = self.server.job_manager.create_job(pydantic_job=job, actor=self.user)
|
|
2910
2910
|
|
|
2911
2911
|
# TODO: implement blocking vs. non-blocking
|
|
2912
|
-
self.server.load_file_to_source(source_id=source_id, file_path=filename, job_id=job.id)
|
|
2912
|
+
self.server.load_file_to_source(source_id=source_id, file_path=filename, job_id=job.id, actor=self.user)
|
|
2913
2913
|
return job
|
|
2914
2914
|
|
|
2915
2915
|
def delete_file_from_source(self, source_id: str, file_id: str):
|
letta/client/streaming.py
CHANGED
|
@@ -8,9 +8,9 @@ from letta.constants import OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING
|
|
|
8
8
|
from letta.errors import LLMError
|
|
9
9
|
from letta.schemas.enums import MessageStreamStatus
|
|
10
10
|
from letta.schemas.letta_message import (
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
11
|
+
ToolCallMessage,
|
|
12
|
+
ToolReturnMessage,
|
|
13
|
+
ReasoningMessage,
|
|
14
14
|
)
|
|
15
15
|
from letta.schemas.letta_response import LettaStreamingResponse
|
|
16
16
|
from letta.schemas.usage import LettaUsageStatistics
|
|
@@ -53,12 +53,12 @@ def _sse_post(url: str, data: dict, headers: dict) -> Generator[LettaStreamingRe
|
|
|
53
53
|
yield MessageStreamStatus(sse.data)
|
|
54
54
|
else:
|
|
55
55
|
chunk_data = json.loads(sse.data)
|
|
56
|
-
if "
|
|
57
|
-
yield
|
|
58
|
-
elif "
|
|
59
|
-
yield
|
|
60
|
-
elif "
|
|
61
|
-
yield
|
|
56
|
+
if "reasoning" in chunk_data:
|
|
57
|
+
yield ReasoningMessage(**chunk_data)
|
|
58
|
+
elif "tool_call" in chunk_data:
|
|
59
|
+
yield ToolCallMessage(**chunk_data)
|
|
60
|
+
elif "tool_return" in chunk_data:
|
|
61
|
+
yield ToolReturnMessage(**chunk_data)
|
|
62
62
|
elif "usage" in chunk_data:
|
|
63
63
|
yield LettaUsageStatistics(**chunk_data["usage"])
|
|
64
64
|
else:
|
letta/errors.py
CHANGED
|
@@ -131,16 +131,16 @@ class LettaMessageError(LettaError):
|
|
|
131
131
|
return f"{error_msg}\n\n{message_json}"
|
|
132
132
|
|
|
133
133
|
|
|
134
|
-
class
|
|
135
|
-
"""Error raised when a message is missing a
|
|
134
|
+
class MissingToolCallError(LettaMessageError):
|
|
135
|
+
"""Error raised when a message is missing a tool call."""
|
|
136
136
|
|
|
137
|
-
default_error_message = "The message is missing a
|
|
137
|
+
default_error_message = "The message is missing a tool call."
|
|
138
138
|
|
|
139
139
|
|
|
140
|
-
class
|
|
141
|
-
"""Error raised when a message uses an invalid
|
|
140
|
+
class InvalidToolCallError(LettaMessageError):
|
|
141
|
+
"""Error raised when a message uses an invalid tool call."""
|
|
142
142
|
|
|
143
|
-
default_error_message = "The message uses an invalid
|
|
143
|
+
default_error_message = "The message uses an invalid tool call or has improper usage of a tool call."
|
|
144
144
|
|
|
145
145
|
|
|
146
146
|
class MissingInnerMonologueError(LettaMessageError):
|
|
@@ -1,4 +1,5 @@
|
|
|
1
|
-
|
|
1
|
+
import json
|
|
2
|
+
from typing import List, Optional, Union
|
|
2
3
|
|
|
3
4
|
from pydantic import BaseModel, Field
|
|
4
5
|
|
|
@@ -6,6 +7,7 @@ from letta.schemas.enums import ToolRuleType
|
|
|
6
7
|
from letta.schemas.tool_rule import (
|
|
7
8
|
BaseToolRule,
|
|
8
9
|
ChildToolRule,
|
|
10
|
+
ConditionalToolRule,
|
|
9
11
|
InitToolRule,
|
|
10
12
|
TerminalToolRule,
|
|
11
13
|
)
|
|
@@ -22,7 +24,7 @@ class ToolRulesSolver(BaseModel):
|
|
|
22
24
|
init_tool_rules: List[InitToolRule] = Field(
|
|
23
25
|
default_factory=list, description="Initial tool rules to be used at the start of tool execution."
|
|
24
26
|
)
|
|
25
|
-
tool_rules: List[ChildToolRule] = Field(
|
|
27
|
+
tool_rules: List[Union[ChildToolRule, ConditionalToolRule]] = Field(
|
|
26
28
|
default_factory=list, description="Standard tool rules for controlling execution sequence and allowed transitions."
|
|
27
29
|
)
|
|
28
30
|
terminal_tool_rules: List[TerminalToolRule] = Field(
|
|
@@ -35,21 +37,25 @@ class ToolRulesSolver(BaseModel):
|
|
|
35
37
|
# Separate the provided tool rules into init, standard, and terminal categories
|
|
36
38
|
for rule in tool_rules:
|
|
37
39
|
if rule.type == ToolRuleType.run_first:
|
|
40
|
+
assert isinstance(rule, InitToolRule)
|
|
38
41
|
self.init_tool_rules.append(rule)
|
|
39
42
|
elif rule.type == ToolRuleType.constrain_child_tools:
|
|
43
|
+
assert isinstance(rule, ChildToolRule)
|
|
44
|
+
self.tool_rules.append(rule)
|
|
45
|
+
elif rule.type == ToolRuleType.conditional:
|
|
46
|
+
assert isinstance(rule, ConditionalToolRule)
|
|
47
|
+
self.validate_conditional_tool(rule)
|
|
40
48
|
self.tool_rules.append(rule)
|
|
41
49
|
elif rule.type == ToolRuleType.exit_loop:
|
|
50
|
+
assert isinstance(rule, TerminalToolRule)
|
|
42
51
|
self.terminal_tool_rules.append(rule)
|
|
43
52
|
|
|
44
|
-
# Validate the tool rules to ensure they form a DAG
|
|
45
|
-
if not self.validate_tool_rules():
|
|
46
|
-
raise ToolRuleValidationError("Tool rules contain cycles, which are not allowed in a valid configuration.")
|
|
47
53
|
|
|
48
54
|
def update_tool_usage(self, tool_name: str):
|
|
49
55
|
"""Update the internal state to track the last tool called."""
|
|
50
56
|
self.last_tool_name = tool_name
|
|
51
57
|
|
|
52
|
-
def get_allowed_tool_names(self, error_on_empty: bool = False) -> List[str]:
|
|
58
|
+
def get_allowed_tool_names(self, error_on_empty: bool = False, last_function_response: Optional[str] = None) -> List[str]:
|
|
53
59
|
"""Get a list of tool names allowed based on the last tool called."""
|
|
54
60
|
if self.last_tool_name is None:
|
|
55
61
|
# Use initial tool rules if no tool has been called yet
|
|
@@ -58,18 +64,21 @@ class ToolRulesSolver(BaseModel):
|
|
|
58
64
|
# Find a matching ToolRule for the last tool used
|
|
59
65
|
current_rule = next((rule for rule in self.tool_rules if rule.tool_name == self.last_tool_name), None)
|
|
60
66
|
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
# Default to empty if no rule matches
|
|
66
|
-
message = "User provided tool rules and execution state resolved to no more possible tool calls."
|
|
67
|
-
if error_on_empty:
|
|
68
|
-
raise RuntimeError(message)
|
|
69
|
-
else:
|
|
70
|
-
# warnings.warn(message)
|
|
67
|
+
if current_rule is None:
|
|
68
|
+
if error_on_empty:
|
|
69
|
+
raise ValueError(f"No tool rule found for {self.last_tool_name}")
|
|
71
70
|
return []
|
|
72
71
|
|
|
72
|
+
# If the current rule is a conditional tool rule, use the LLM response to
|
|
73
|
+
# determine which child tool to use
|
|
74
|
+
if isinstance(current_rule, ConditionalToolRule):
|
|
75
|
+
if not last_function_response:
|
|
76
|
+
raise ValueError("Conditional tool rule requires an LLM response to determine which child tool to use")
|
|
77
|
+
next_tool = self.evaluate_conditional_tool(current_rule, last_function_response)
|
|
78
|
+
return [next_tool] if next_tool else []
|
|
79
|
+
|
|
80
|
+
return current_rule.children if current_rule.children else []
|
|
81
|
+
|
|
73
82
|
def is_terminal_tool(self, tool_name: str) -> bool:
|
|
74
83
|
"""Check if the tool is defined as a terminal tool in the terminal tool rules."""
|
|
75
84
|
return any(rule.tool_name == tool_name for rule in self.terminal_tool_rules)
|
|
@@ -78,38 +87,60 @@ class ToolRulesSolver(BaseModel):
|
|
|
78
87
|
"""Check if the tool has children tools"""
|
|
79
88
|
return any(rule.tool_name == tool_name for rule in self.tool_rules)
|
|
80
89
|
|
|
81
|
-
def
|
|
82
|
-
|
|
83
|
-
Validate
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
#
|
|
110
|
-
for
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
90
|
+
def validate_conditional_tool(self, rule: ConditionalToolRule):
|
|
91
|
+
'''
|
|
92
|
+
Validate a conditional tool rule
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
rule (ConditionalToolRule): The conditional tool rule to validate
|
|
96
|
+
|
|
97
|
+
Raises:
|
|
98
|
+
ToolRuleValidationError: If the rule is invalid
|
|
99
|
+
'''
|
|
100
|
+
if len(rule.child_output_mapping) == 0:
|
|
101
|
+
raise ToolRuleValidationError("Conditional tool rule must have at least one child tool.")
|
|
102
|
+
return True
|
|
103
|
+
|
|
104
|
+
def evaluate_conditional_tool(self, tool: ConditionalToolRule, last_function_response: str) -> str:
|
|
105
|
+
'''
|
|
106
|
+
Parse function response to determine which child tool to use based on the mapping
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
tool (ConditionalToolRule): The conditional tool rule
|
|
110
|
+
last_function_response (str): The function response in JSON format
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
str: The name of the child tool to use next
|
|
114
|
+
'''
|
|
115
|
+
json_response = json.loads(last_function_response)
|
|
116
|
+
function_output = json_response["message"]
|
|
117
|
+
|
|
118
|
+
# Try to match the function output with a mapping key
|
|
119
|
+
for key in tool.child_output_mapping:
|
|
120
|
+
|
|
121
|
+
# Convert function output to match key type for comparison
|
|
122
|
+
if isinstance(key, bool):
|
|
123
|
+
typed_output = function_output.lower() == "true"
|
|
124
|
+
elif isinstance(key, int):
|
|
125
|
+
try:
|
|
126
|
+
typed_output = int(function_output)
|
|
127
|
+
except (ValueError, TypeError):
|
|
128
|
+
continue
|
|
129
|
+
elif isinstance(key, float):
|
|
130
|
+
try:
|
|
131
|
+
typed_output = float(function_output)
|
|
132
|
+
except (ValueError, TypeError):
|
|
133
|
+
continue
|
|
134
|
+
else: # string
|
|
135
|
+
if function_output == "True" or function_output == "False":
|
|
136
|
+
typed_output = function_output.lower()
|
|
137
|
+
elif function_output == "None":
|
|
138
|
+
typed_output = None
|
|
139
|
+
else:
|
|
140
|
+
typed_output = function_output
|
|
141
|
+
|
|
142
|
+
if typed_output == key:
|
|
143
|
+
return tool.child_output_mapping[key]
|
|
144
|
+
|
|
145
|
+
# If no match found, use default
|
|
146
|
+
return tool.default_child
|
letta/orm/custom_columns.py
CHANGED
|
@@ -9,7 +9,7 @@ from letta.schemas.embedding_config import EmbeddingConfig
|
|
|
9
9
|
from letta.schemas.enums import ToolRuleType
|
|
10
10
|
from letta.schemas.llm_config import LLMConfig
|
|
11
11
|
from letta.schemas.openai.chat_completions import ToolCall, ToolCallFunction
|
|
12
|
-
from letta.schemas.tool_rule import ChildToolRule, InitToolRule, TerminalToolRule
|
|
12
|
+
from letta.schemas.tool_rule import ChildToolRule, ConditionalToolRule, InitToolRule, TerminalToolRule
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
class EmbeddingConfigColumn(TypeDecorator):
|
|
@@ -80,7 +80,7 @@ class ToolRulesColumn(TypeDecorator):
|
|
|
80
80
|
return value
|
|
81
81
|
|
|
82
82
|
@staticmethod
|
|
83
|
-
def deserialize_tool_rule(data: dict) -> Union[ChildToolRule, InitToolRule, TerminalToolRule]:
|
|
83
|
+
def deserialize_tool_rule(data: dict) -> Union[ChildToolRule, InitToolRule, TerminalToolRule, ConditionalToolRule]:
|
|
84
84
|
"""Deserialize a dictionary to the appropriate ToolRule subclass based on the 'type'."""
|
|
85
85
|
rule_type = ToolRuleType(data.get("type")) # Remove 'type' field if it exists since it is a class var
|
|
86
86
|
if rule_type == ToolRuleType.run_first:
|
|
@@ -90,6 +90,9 @@ class ToolRulesColumn(TypeDecorator):
|
|
|
90
90
|
elif rule_type == ToolRuleType.constrain_child_tools:
|
|
91
91
|
rule = ChildToolRule(**data)
|
|
92
92
|
return rule
|
|
93
|
+
elif rule_type == ToolRuleType.conditional:
|
|
94
|
+
rule = ConditionalToolRule(**data)
|
|
95
|
+
return rule
|
|
93
96
|
else:
|
|
94
97
|
raise ValueError(f"Unknown tool rule type: {rule_type}")
|
|
95
98
|
|
letta/providers.py
CHANGED
|
@@ -482,7 +482,8 @@ class GoogleAIProvider(Provider):
|
|
|
482
482
|
model_options = [mo[len("models/") :] if mo.startswith("models/") else mo for mo in model_options]
|
|
483
483
|
|
|
484
484
|
# TODO remove manual filtering for gemini-pro
|
|
485
|
-
|
|
485
|
+
# Add support for all gemini models
|
|
486
|
+
model_options = [mo for mo in model_options if str(mo).startswith("gemini-")]
|
|
486
487
|
|
|
487
488
|
configs = []
|
|
488
489
|
for model in model_options:
|
letta/schemas/enums.py
CHANGED
|
@@ -45,5 +45,6 @@ class ToolRuleType(str, Enum):
|
|
|
45
45
|
run_first = "InitToolRule"
|
|
46
46
|
exit_loop = "TerminalToolRule" # reasoning loop should exit
|
|
47
47
|
continue_loop = "continue_loop" # reasoning loop should continue
|
|
48
|
+
conditional = "conditional"
|
|
48
49
|
constrain_child_tools = "ToolRule"
|
|
49
50
|
require_parent_tools = "require_parent_tools"
|
letta/schemas/letta_message.py
CHANGED
|
@@ -9,7 +9,7 @@ from pydantic import BaseModel, Field, field_serializer, field_validator
|
|
|
9
9
|
|
|
10
10
|
class LettaMessage(BaseModel):
|
|
11
11
|
"""
|
|
12
|
-
Base class for simplified Letta message response type. This is intended to be used for developers who want the internal monologue,
|
|
12
|
+
Base class for simplified Letta message response type. This is intended to be used for developers who want the internal monologue, tool calls, and tool returns in a simplified format that does not include additional information other than the content and timestamp.
|
|
13
13
|
|
|
14
14
|
Attributes:
|
|
15
15
|
id (str): The ID of the message
|
|
@@ -60,32 +60,32 @@ class UserMessage(LettaMessage):
|
|
|
60
60
|
message: str
|
|
61
61
|
|
|
62
62
|
|
|
63
|
-
class
|
|
63
|
+
class ReasoningMessage(LettaMessage):
|
|
64
64
|
"""
|
|
65
|
-
Representation of an agent's internal
|
|
65
|
+
Representation of an agent's internal reasoning.
|
|
66
66
|
|
|
67
67
|
Attributes:
|
|
68
|
-
|
|
68
|
+
reasoning (str): The internal reasoning of the agent
|
|
69
69
|
id (str): The ID of the message
|
|
70
70
|
date (datetime): The date the message was created in ISO format
|
|
71
71
|
"""
|
|
72
72
|
|
|
73
|
-
message_type: Literal["
|
|
74
|
-
|
|
73
|
+
message_type: Literal["reasoning_message"] = "reasoning_message"
|
|
74
|
+
reasoning: str
|
|
75
75
|
|
|
76
76
|
|
|
77
|
-
class
|
|
77
|
+
class ToolCall(BaseModel):
|
|
78
78
|
|
|
79
79
|
name: str
|
|
80
80
|
arguments: str
|
|
81
|
-
|
|
81
|
+
tool_call_id: str
|
|
82
82
|
|
|
83
83
|
|
|
84
|
-
class
|
|
84
|
+
class ToolCallDelta(BaseModel):
|
|
85
85
|
|
|
86
86
|
name: Optional[str]
|
|
87
87
|
arguments: Optional[str]
|
|
88
|
-
|
|
88
|
+
tool_call_id: Optional[str]
|
|
89
89
|
|
|
90
90
|
# NOTE: this is a workaround to exclude None values from the JSON dump,
|
|
91
91
|
# since the OpenAI style of returning chunks doesn't include keys with null values
|
|
@@ -97,50 +97,84 @@ class FunctionCallDelta(BaseModel):
|
|
|
97
97
|
return json.dumps(self.model_dump(exclude_none=True), *args, **kwargs)
|
|
98
98
|
|
|
99
99
|
|
|
100
|
-
class
|
|
100
|
+
class ToolCallMessage(LettaMessage):
|
|
101
101
|
"""
|
|
102
|
-
A message representing a request to call a
|
|
102
|
+
A message representing a request to call a tool (generated by the LLM to trigger tool execution).
|
|
103
103
|
|
|
104
104
|
Attributes:
|
|
105
|
-
|
|
105
|
+
tool_call (Union[ToolCall, ToolCallDelta]): The tool call
|
|
106
106
|
id (str): The ID of the message
|
|
107
107
|
date (datetime): The date the message was created in ISO format
|
|
108
108
|
"""
|
|
109
109
|
|
|
110
|
-
message_type: Literal["
|
|
111
|
-
|
|
110
|
+
message_type: Literal["tool_call_message"] = "tool_call_message"
|
|
111
|
+
tool_call: Union[ToolCall, ToolCallDelta]
|
|
112
112
|
|
|
113
|
-
# NOTE: this is required for the
|
|
113
|
+
# NOTE: this is required for the ToolCallDelta exclude_none to work correctly
|
|
114
114
|
def model_dump(self, *args, **kwargs):
|
|
115
115
|
kwargs["exclude_none"] = True
|
|
116
116
|
data = super().model_dump(*args, **kwargs)
|
|
117
|
-
if isinstance(data["
|
|
118
|
-
data["
|
|
117
|
+
if isinstance(data["tool_call"], dict):
|
|
118
|
+
data["tool_call"] = {k: v for k, v in data["tool_call"].items() if v is not None}
|
|
119
119
|
return data
|
|
120
120
|
|
|
121
121
|
class Config:
|
|
122
122
|
json_encoders = {
|
|
123
|
-
|
|
124
|
-
|
|
123
|
+
ToolCallDelta: lambda v: v.model_dump(exclude_none=True),
|
|
124
|
+
ToolCall: lambda v: v.model_dump(exclude_none=True),
|
|
125
125
|
}
|
|
126
126
|
|
|
127
|
-
# NOTE: this is required to cast dicts into
|
|
127
|
+
# NOTE: this is required to cast dicts into ToolCallMessage objects
|
|
128
128
|
# Without this extra validator, Pydantic will throw an error if 'name' or 'arguments' are None
|
|
129
|
-
# (instead of properly casting to
|
|
130
|
-
@field_validator("
|
|
129
|
+
# (instead of properly casting to ToolCallDelta instead of ToolCall)
|
|
130
|
+
@field_validator("tool_call", mode="before")
|
|
131
131
|
@classmethod
|
|
132
|
-
def
|
|
132
|
+
def validate_tool_call(cls, v):
|
|
133
133
|
if isinstance(v, dict):
|
|
134
|
-
if "name" in v and "arguments" in v and "
|
|
135
|
-
return
|
|
136
|
-
elif "name" in v or "arguments" in v or "
|
|
137
|
-
return
|
|
134
|
+
if "name" in v and "arguments" in v and "tool_call_id" in v:
|
|
135
|
+
return ToolCall(name=v["name"], arguments=v["arguments"], tool_call_id=v["tool_call_id"])
|
|
136
|
+
elif "name" in v or "arguments" in v or "tool_call_id" in v:
|
|
137
|
+
return ToolCallDelta(name=v.get("name"), arguments=v.get("arguments"), tool_call_id=v.get("tool_call_id"))
|
|
138
138
|
else:
|
|
139
|
-
raise ValueError("
|
|
139
|
+
raise ValueError("tool_call must contain either 'name' or 'arguments'")
|
|
140
140
|
return v
|
|
141
141
|
|
|
142
142
|
|
|
143
|
-
class
|
|
143
|
+
class ToolReturnMessage(LettaMessage):
|
|
144
|
+
"""
|
|
145
|
+
A message representing the return value of a tool call (generated by Letta executing the requested tool).
|
|
146
|
+
|
|
147
|
+
Attributes:
|
|
148
|
+
tool_return (str): The return value of the tool
|
|
149
|
+
status (Literal["success", "error"]): The status of the tool call
|
|
150
|
+
id (str): The ID of the message
|
|
151
|
+
date (datetime): The date the message was created in ISO format
|
|
152
|
+
tool_call_id (str): A unique identifier for the tool call that generated this message
|
|
153
|
+
stdout (Optional[List(str)]): Captured stdout (e.g. prints, logs) from the tool invocation
|
|
154
|
+
stderr (Optional[List(str)]): Captured stderr from the tool invocation
|
|
155
|
+
"""
|
|
156
|
+
|
|
157
|
+
message_type: Literal["tool_return_message"] = "tool_return_message"
|
|
158
|
+
tool_return: str
|
|
159
|
+
status: Literal["success", "error"]
|
|
160
|
+
tool_call_id: str
|
|
161
|
+
stdout: Optional[List[str]] = None
|
|
162
|
+
stderr: Optional[List[str]] = None
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
# Legacy Letta API had an additional type "assistant_message" and the "function_call" was a formatted string
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
class AssistantMessage(LettaMessage):
|
|
169
|
+
message_type: Literal["assistant_message"] = "assistant_message"
|
|
170
|
+
assistant_message: str
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
class LegacyFunctionCallMessage(LettaMessage):
|
|
174
|
+
function_call: str
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
class LegacyFunctionReturn(LettaMessage):
|
|
144
178
|
"""
|
|
145
179
|
A message representing the return value of a function call (generated by Letta executing the requested function).
|
|
146
180
|
|
|
@@ -162,22 +196,24 @@ class FunctionReturn(LettaMessage):
|
|
|
162
196
|
stderr: Optional[List[str]] = None
|
|
163
197
|
|
|
164
198
|
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
class AssistantMessage(LettaMessage):
|
|
169
|
-
message_type: Literal["assistant_message"] = "assistant_message"
|
|
170
|
-
assistant_message: str
|
|
199
|
+
class LegacyInternalMonologue(LettaMessage):
|
|
200
|
+
"""
|
|
201
|
+
Representation of an agent's internal monologue.
|
|
171
202
|
|
|
203
|
+
Attributes:
|
|
204
|
+
internal_monologue (str): The internal monologue of the agent
|
|
205
|
+
id (str): The ID of the message
|
|
206
|
+
date (datetime): The date the message was created in ISO format
|
|
207
|
+
"""
|
|
172
208
|
|
|
173
|
-
|
|
174
|
-
|
|
209
|
+
message_type: Literal["internal_monologue"] = "internal_monologue"
|
|
210
|
+
internal_monologue: str
|
|
175
211
|
|
|
176
212
|
|
|
177
|
-
LegacyLettaMessage = Union[
|
|
213
|
+
LegacyLettaMessage = Union[LegacyInternalMonologue, AssistantMessage, LegacyFunctionCallMessage, LegacyFunctionReturn]
|
|
178
214
|
|
|
179
215
|
|
|
180
216
|
LettaMessageUnion = Annotated[
|
|
181
|
-
Union[SystemMessage, UserMessage,
|
|
217
|
+
Union[SystemMessage, UserMessage, ReasoningMessage, ToolCallMessage, ToolReturnMessage, AssistantMessage],
|
|
182
218
|
Field(discriminator="message_type"),
|
|
183
219
|
]
|
letta/schemas/letta_response.py
CHANGED
|
@@ -40,14 +40,22 @@ class LettaResponse(BaseModel):
|
|
|
40
40
|
def get_formatted_content(msg):
|
|
41
41
|
if msg.message_type == "internal_monologue":
|
|
42
42
|
return f'<div class="content"><span class="internal-monologue">{html.escape(msg.internal_monologue)}</span></div>'
|
|
43
|
+
if msg.message_type == "reasoning_message":
|
|
44
|
+
return f'<div class="content"><span class="internal-monologue">{html.escape(msg.reasoning)}</span></div>'
|
|
43
45
|
elif msg.message_type == "function_call":
|
|
44
46
|
args = format_json(msg.function_call.arguments)
|
|
45
47
|
return f'<div class="content"><span class="function-name">{html.escape(msg.function_call.name)}</span>({args})</div>'
|
|
48
|
+
elif msg.message_type == "tool_call_message":
|
|
49
|
+
args = format_json(msg.tool_call.arguments)
|
|
50
|
+
return f'<div class="content"><span class="function-name">{html.escape(msg.function_call.name)}</span>({args})</div>'
|
|
46
51
|
elif msg.message_type == "function_return":
|
|
47
|
-
|
|
48
52
|
return_value = format_json(msg.function_return)
|
|
49
53
|
# return f'<div class="status-line">Status: {html.escape(msg.status)}</div><div class="content">{return_value}</div>'
|
|
50
54
|
return f'<div class="content">{return_value}</div>'
|
|
55
|
+
elif msg.message_type == "tool_return_message":
|
|
56
|
+
return_value = format_json(msg.tool_return)
|
|
57
|
+
# return f'<div class="status-line">Status: {html.escape(msg.status)}</div><div class="content">{return_value}</div>'
|
|
58
|
+
return f'<div class="content">{return_value}</div>'
|
|
51
59
|
elif msg.message_type == "user_message":
|
|
52
60
|
if is_json(msg.message):
|
|
53
61
|
return f'<div class="content">{format_json(msg.message)}</div>'
|
letta/schemas/message.py
CHANGED
|
@@ -16,10 +16,10 @@ from letta.schemas.enums import MessageRole
|
|
|
16
16
|
from letta.schemas.letta_base import OrmMetadataBase
|
|
17
17
|
from letta.schemas.letta_message import (
|
|
18
18
|
AssistantMessage,
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
19
|
+
ToolCall as LettaToolCall,
|
|
20
|
+
ToolCallMessage,
|
|
21
|
+
ToolReturnMessage,
|
|
22
|
+
ReasoningMessage,
|
|
23
23
|
LettaMessage,
|
|
24
24
|
SystemMessage,
|
|
25
25
|
UserMessage,
|
|
@@ -145,10 +145,10 @@ class Message(BaseMessage):
|
|
|
145
145
|
if self.text is not None:
|
|
146
146
|
# This is type InnerThoughts
|
|
147
147
|
messages.append(
|
|
148
|
-
|
|
148
|
+
ReasoningMessage(
|
|
149
149
|
id=self.id,
|
|
150
150
|
date=self.created_at,
|
|
151
|
-
|
|
151
|
+
reasoning=self.text,
|
|
152
152
|
)
|
|
153
153
|
)
|
|
154
154
|
if self.tool_calls is not None:
|
|
@@ -172,18 +172,18 @@ class Message(BaseMessage):
|
|
|
172
172
|
)
|
|
173
173
|
else:
|
|
174
174
|
messages.append(
|
|
175
|
-
|
|
175
|
+
ToolCallMessage(
|
|
176
176
|
id=self.id,
|
|
177
177
|
date=self.created_at,
|
|
178
|
-
|
|
178
|
+
tool_call=LettaToolCall(
|
|
179
179
|
name=tool_call.function.name,
|
|
180
180
|
arguments=tool_call.function.arguments,
|
|
181
|
-
|
|
181
|
+
tool_call_id=tool_call.id,
|
|
182
182
|
),
|
|
183
183
|
)
|
|
184
184
|
)
|
|
185
185
|
elif self.role == MessageRole.tool:
|
|
186
|
-
# This is type
|
|
186
|
+
# This is type ToolReturnMessage
|
|
187
187
|
# Try to interpret the function return, recall that this is how we packaged:
|
|
188
188
|
# def package_function_response(was_success, response_string, timestamp=None):
|
|
189
189
|
# formatted_time = get_local_time() if timestamp is None else timestamp
|
|
@@ -208,12 +208,12 @@ class Message(BaseMessage):
|
|
|
208
208
|
messages.append(
|
|
209
209
|
# TODO make sure this is what the API returns
|
|
210
210
|
# function_return may not match exactly...
|
|
211
|
-
|
|
211
|
+
ToolReturnMessage(
|
|
212
212
|
id=self.id,
|
|
213
213
|
date=self.created_at,
|
|
214
|
-
|
|
214
|
+
tool_return=self.text,
|
|
215
215
|
status=status_enum,
|
|
216
|
-
|
|
216
|
+
tool_call_id=self.tool_call_id,
|
|
217
217
|
)
|
|
218
218
|
)
|
|
219
219
|
elif self.role == MessageRole.user:
|