letta-nightly 0.7.0.dev20250423003112__py3-none-any.whl → 0.7.1.dev20250423104245__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- letta/__init__.py +1 -1
- letta/agent.py +113 -81
- letta/agents/letta_agent.py +2 -2
- letta/agents/letta_agent_batch.py +38 -34
- letta/client/client.py +10 -2
- letta/constants.py +4 -3
- letta/functions/function_sets/multi_agent.py +1 -3
- letta/functions/helpers.py +3 -3
- letta/groups/dynamic_multi_agent.py +58 -59
- letta/groups/round_robin_multi_agent.py +43 -49
- letta/groups/sleeptime_multi_agent.py +28 -18
- letta/groups/supervisor_multi_agent.py +21 -20
- letta/helpers/converters.py +29 -0
- letta/helpers/message_helper.py +1 -0
- letta/helpers/tool_execution_helper.py +3 -3
- letta/orm/agent.py +8 -1
- letta/orm/custom_columns.py +15 -0
- letta/schemas/agent.py +6 -0
- letta/schemas/message.py +1 -0
- letta/schemas/response_format.py +78 -0
- letta/schemas/tool_execution_result.py +14 -0
- letta/server/rest_api/interface.py +2 -1
- letta/server/rest_api/routers/openai/chat_completions/chat_completions.py +1 -1
- letta/server/rest_api/routers/v1/agents.py +4 -4
- letta/server/rest_api/routers/v1/groups.py +2 -2
- letta/server/rest_api/routers/v1/messages.py +32 -18
- letta/server/server.py +24 -57
- letta/services/agent_manager.py +1 -0
- letta/services/llm_batch_manager.py +28 -26
- letta/services/tool_executor/tool_execution_manager.py +37 -28
- letta/services/tool_executor/tool_execution_sandbox.py +35 -16
- letta/services/tool_executor/tool_executor.py +299 -68
- letta/services/tool_sandbox/base.py +3 -2
- letta/services/tool_sandbox/e2b_sandbox.py +5 -4
- letta/services/tool_sandbox/local_sandbox.py +11 -6
- {letta_nightly-0.7.0.dev20250423003112.dist-info → letta_nightly-0.7.1.dev20250423104245.dist-info}/METADATA +1 -1
- {letta_nightly-0.7.0.dev20250423003112.dist-info → letta_nightly-0.7.1.dev20250423104245.dist-info}/RECORD +40 -38
- {letta_nightly-0.7.0.dev20250423003112.dist-info → letta_nightly-0.7.1.dev20250423104245.dist-info}/LICENSE +0 -0
- {letta_nightly-0.7.0.dev20250423003112.dist-info → letta_nightly-0.7.1.dev20250423104245.dist-info}/WHEEL +0 -0
- {letta_nightly-0.7.0.dev20250423003112.dist-info → letta_nightly-0.7.1.dev20250423104245.dist-info}/entry_points.txt +0 -0
letta/__init__.py
CHANGED
letta/agent.py
CHANGED
@@ -3,7 +3,7 @@ import time
|
|
3
3
|
import traceback
|
4
4
|
import warnings
|
5
5
|
from abc import ABC, abstractmethod
|
6
|
-
from typing import
|
6
|
+
from typing import Dict, List, Optional, Tuple, Union
|
7
7
|
|
8
8
|
from openai.types.beta.function_tool import FunctionTool as OpenAITool
|
9
9
|
|
@@ -17,6 +17,7 @@ from letta.constants import (
|
|
17
17
|
LETTA_MULTI_AGENT_TOOL_MODULE_NAME,
|
18
18
|
LLM_MAX_TOKENS,
|
19
19
|
REQ_HEARTBEAT_MESSAGE,
|
20
|
+
SEND_MESSAGE_TOOL_NAME,
|
20
21
|
)
|
21
22
|
from letta.errors import ContextWindowExceededError
|
22
23
|
from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_function_annotations_from_source
|
@@ -27,6 +28,7 @@ from letta.helpers import ToolRulesSolver
|
|
27
28
|
from letta.helpers.composio_helpers import get_composio_api_key
|
28
29
|
from letta.helpers.datetime_helpers import get_utc_time
|
29
30
|
from letta.helpers.json_helpers import json_dumps, json_loads
|
31
|
+
from letta.helpers.message_helper import prepare_input_message_create
|
30
32
|
from letta.interface import AgentInterface
|
31
33
|
from letta.llm_api.helpers import calculate_summarizer_cutoff, get_token_counts_for_messages, is_context_overflow_error
|
32
34
|
from letta.llm_api.llm_api_tools import create
|
@@ -42,12 +44,13 @@ from letta.schemas.embedding_config import EmbeddingConfig
|
|
42
44
|
from letta.schemas.enums import MessageRole
|
43
45
|
from letta.schemas.letta_message_content import TextContent
|
44
46
|
from letta.schemas.memory import ContextWindowOverview, Memory
|
45
|
-
from letta.schemas.message import Message, ToolReturn
|
47
|
+
from letta.schemas.message import Message, MessageCreate, ToolReturn
|
46
48
|
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
|
47
49
|
from letta.schemas.openai.chat_completion_response import Message as ChatCompletionMessage
|
48
50
|
from letta.schemas.openai.chat_completion_response import UsageStatistics
|
49
|
-
from letta.schemas.
|
51
|
+
from letta.schemas.response_format import ResponseFormatType
|
50
52
|
from letta.schemas.tool import Tool
|
53
|
+
from letta.schemas.tool_execution_result import ToolExecutionResult
|
51
54
|
from letta.schemas.tool_rule import TerminalToolRule
|
52
55
|
from letta.schemas.usage import LettaUsageStatistics
|
53
56
|
from letta.services.agent_manager import AgentManager
|
@@ -78,7 +81,7 @@ class BaseAgent(ABC):
|
|
78
81
|
@abstractmethod
|
79
82
|
def step(
|
80
83
|
self,
|
81
|
-
|
84
|
+
input_messages: List[MessageCreate],
|
82
85
|
) -> LettaUsageStatistics:
|
83
86
|
"""
|
84
87
|
Top-level event message handler for the agent.
|
@@ -255,6 +258,28 @@ class Agent(BaseAgent):
|
|
255
258
|
# Return updated messages
|
256
259
|
return messages
|
257
260
|
|
261
|
+
def _runtime_override_tool_json_schema(
|
262
|
+
self,
|
263
|
+
functions_list: List[Dict | None],
|
264
|
+
) -> List[Dict | None]:
|
265
|
+
"""Override the tool JSON schema at runtime for a particular tool if conditions are met."""
|
266
|
+
|
267
|
+
# Currently just injects `send_message` with a `response_format` if provided to the agent.
|
268
|
+
if self.agent_state.response_format and self.agent_state.response_format.type != ResponseFormatType.text:
|
269
|
+
for func in functions_list:
|
270
|
+
if func["name"] == SEND_MESSAGE_TOOL_NAME:
|
271
|
+
if self.agent_state.response_format.type == ResponseFormatType.json_schema:
|
272
|
+
func["parameters"]["properties"]["message"] = self.agent_state.response_format.json_schema["schema"]
|
273
|
+
if self.agent_state.response_format.type == ResponseFormatType.json_object:
|
274
|
+
func["parameters"]["properties"]["message"] = {
|
275
|
+
"type": "object",
|
276
|
+
"description": "Message contents. All unicode (including emojis) are supported.",
|
277
|
+
"additionalProperties": True,
|
278
|
+
"properties": {},
|
279
|
+
}
|
280
|
+
break
|
281
|
+
return functions_list
|
282
|
+
|
258
283
|
@trace_method
|
259
284
|
def _get_ai_reply(
|
260
285
|
self,
|
@@ -268,27 +293,26 @@ class Agent(BaseAgent):
|
|
268
293
|
step_count: Optional[int] = None,
|
269
294
|
last_function_failed: bool = False,
|
270
295
|
put_inner_thoughts_first: bool = True,
|
271
|
-
) -> ChatCompletionResponse:
|
296
|
+
) -> ChatCompletionResponse | None:
|
272
297
|
"""Get response from LLM API with robust retry mechanism."""
|
273
298
|
log_telemetry(self.logger, "_get_ai_reply start")
|
274
299
|
available_tools = set([t.name for t in self.agent_state.tools])
|
275
|
-
allowed_tool_names = self.tool_rules_solver.get_allowed_tool_names(
|
276
|
-
available_tools=available_tools, last_function_response=self.last_function_response
|
277
|
-
)
|
278
300
|
agent_state_tool_jsons = [t.json_schema for t in self.agent_state.tools]
|
279
301
|
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
)
|
302
|
+
# Get allowed tools or allow all if none are allowed
|
303
|
+
allowed_tool_names = self.tool_rules_solver.get_allowed_tool_names(
|
304
|
+
available_tools=available_tools, last_function_response=self.last_function_response
|
305
|
+
) or list(available_tools)
|
285
306
|
|
286
307
|
# Don't allow a tool to be called if it failed last time
|
287
308
|
if last_function_failed and self.tool_rules_solver.tool_call_history:
|
288
|
-
|
289
|
-
if not
|
309
|
+
allowed_tool_names = [f for f in allowed_tool_names if f != self.tool_rules_solver.tool_call_history[-1]]
|
310
|
+
if not allowed_tool_names:
|
290
311
|
return None
|
291
312
|
|
313
|
+
allowed_functions = [func for func in agent_state_tool_jsons if func["name"] in allowed_tool_names]
|
314
|
+
allowed_functions = self._runtime_override_tool_json_schema(allowed_functions)
|
315
|
+
|
292
316
|
# For the first message, force the initial tool if one is specified
|
293
317
|
force_tool_call = None
|
294
318
|
if (
|
@@ -418,7 +442,7 @@ class Agent(BaseAgent):
|
|
418
442
|
tool_call_id = response_message.tool_calls[0].id
|
419
443
|
assert tool_call_id is not None # should be defined
|
420
444
|
|
421
|
-
# only necessary to add the
|
445
|
+
# only necessary to add the tool_call_id to a function call (antipattern)
|
422
446
|
# response_message_dict = response_message.model_dump()
|
423
447
|
# response_message_dict["tool_call_id"] = tool_call_id
|
424
448
|
|
@@ -513,6 +537,10 @@ class Agent(BaseAgent):
|
|
513
537
|
# Failure case 3: function failed during execution
|
514
538
|
# NOTE: the msg_obj associated with the "Running " message is the prior assistant message, not the function/tool role message
|
515
539
|
# this is because the function/tool role message is only created once the function/tool has executed/returned
|
540
|
+
|
541
|
+
# handle cases where we return a json message
|
542
|
+
if "message" in function_args:
|
543
|
+
function_args["message"] = str(function_args.get("message", ""))
|
516
544
|
self.interface.function_message(f"Running {function_name}({function_args})", msg_obj=messages[-1], chunk_index=self.chunk_index)
|
517
545
|
self.chunk_index += 1
|
518
546
|
try:
|
@@ -529,22 +557,23 @@ class Agent(BaseAgent):
|
|
529
557
|
},
|
530
558
|
)
|
531
559
|
|
532
|
-
|
560
|
+
tool_execution_result = self.execute_tool_and_persist_state(function_name, function_args, target_letta_tool)
|
561
|
+
function_response = tool_execution_result.func_return
|
533
562
|
|
534
563
|
log_event(
|
535
564
|
"tool_call_ended",
|
536
565
|
attributes={
|
537
566
|
"function_response": function_response,
|
538
|
-
"
|
567
|
+
"tool_execution_result": tool_execution_result.model_dump(),
|
539
568
|
},
|
540
569
|
)
|
541
570
|
log_telemetry(
|
542
571
|
self.logger, "_handle_ai_response execute tool finish", function_name=function_name, function_args=function_args
|
543
572
|
)
|
544
573
|
|
545
|
-
if
|
574
|
+
if tool_execution_result and tool_execution_result.status == "error":
|
546
575
|
tool_return = ToolReturn(
|
547
|
-
status=
|
576
|
+
status=tool_execution_result.status, stdout=tool_execution_result.stdout, stderr=tool_execution_result.stderr
|
548
577
|
)
|
549
578
|
messages = self._handle_function_error_response(
|
550
579
|
function_response,
|
@@ -598,14 +627,10 @@ class Agent(BaseAgent):
|
|
598
627
|
# Step 4: check if function response is an error
|
599
628
|
if function_response_string.startswith(ERROR_MESSAGE_PREFIX):
|
600
629
|
error_msg = function_response_string
|
601
|
-
tool_return = (
|
602
|
-
|
603
|
-
|
604
|
-
|
605
|
-
stderr=sandbox_run_result.stderr,
|
606
|
-
)
|
607
|
-
if sandbox_run_result
|
608
|
-
else None
|
630
|
+
tool_return = ToolReturn(
|
631
|
+
status=tool_execution_result.status,
|
632
|
+
stdout=tool_execution_result.stdout,
|
633
|
+
stderr=tool_execution_result.stderr,
|
609
634
|
)
|
610
635
|
messages = self._handle_function_error_response(
|
611
636
|
error_msg,
|
@@ -622,14 +647,10 @@ class Agent(BaseAgent):
|
|
622
647
|
|
623
648
|
# If no failures happened along the way: ...
|
624
649
|
# Step 5: send the info on the function call and function response to GPT
|
625
|
-
tool_return = (
|
626
|
-
|
627
|
-
|
628
|
-
|
629
|
-
stderr=sandbox_run_result.stderr,
|
630
|
-
)
|
631
|
-
if sandbox_run_result
|
632
|
-
else None
|
650
|
+
tool_return = ToolReturn(
|
651
|
+
status=tool_execution_result.status,
|
652
|
+
stdout=tool_execution_result.stdout,
|
653
|
+
stderr=tool_execution_result.stderr,
|
633
654
|
)
|
634
655
|
messages.append(
|
635
656
|
Message(
|
@@ -641,7 +662,7 @@ class Agent(BaseAgent):
|
|
641
662
|
content=[TextContent(text=function_response)],
|
642
663
|
tool_call_id=tool_call_id,
|
643
664
|
# Letta extras
|
644
|
-
tool_returns=[tool_return]
|
665
|
+
tool_returns=[tool_return],
|
645
666
|
group_id=group_id,
|
646
667
|
)
|
647
668
|
) # extend conversation with function response
|
@@ -691,7 +712,7 @@ class Agent(BaseAgent):
|
|
691
712
|
@trace_method
|
692
713
|
def step(
|
693
714
|
self,
|
694
|
-
|
715
|
+
input_messages: List[MessageCreate],
|
695
716
|
# additional args
|
696
717
|
chaining: bool = True,
|
697
718
|
max_chaining_steps: Optional[int] = None,
|
@@ -704,7 +725,9 @@ class Agent(BaseAgent):
|
|
704
725
|
# But just to be safe
|
705
726
|
self.tool_rules_solver.clear_tool_history()
|
706
727
|
|
707
|
-
|
728
|
+
# Convert MessageCreate objects to Message objects
|
729
|
+
message_objects = [prepare_input_message_create(m, self.agent_state.id, True, True) for m in input_messages]
|
730
|
+
next_input_messages = message_objects
|
708
731
|
counter = 0
|
709
732
|
total_usage = UsageStatistics()
|
710
733
|
step_count = 0
|
@@ -715,7 +738,7 @@ class Agent(BaseAgent):
|
|
715
738
|
kwargs["step_count"] = step_count
|
716
739
|
kwargs["last_function_failed"] = function_failed
|
717
740
|
step_response = self.inner_step(
|
718
|
-
messages=
|
741
|
+
messages=next_input_messages,
|
719
742
|
put_inner_thoughts_first=put_inner_thoughts_first,
|
720
743
|
**kwargs,
|
721
744
|
)
|
@@ -745,36 +768,42 @@ class Agent(BaseAgent):
|
|
745
768
|
# Chain handlers
|
746
769
|
elif token_warning and summarizer_settings.send_memory_warning_message:
|
747
770
|
assert self.agent_state.created_by_id is not None
|
748
|
-
|
749
|
-
|
750
|
-
|
751
|
-
|
752
|
-
|
753
|
-
|
754
|
-
|
755
|
-
|
771
|
+
next_input_messages = [
|
772
|
+
Message.dict_to_message(
|
773
|
+
agent_id=self.agent_state.id,
|
774
|
+
model=self.model,
|
775
|
+
openai_message_dict={
|
776
|
+
"role": "user", # TODO: change to system?
|
777
|
+
"content": get_token_limit_warning(),
|
778
|
+
},
|
779
|
+
),
|
780
|
+
]
|
756
781
|
continue # always chain
|
757
782
|
elif function_failed:
|
758
783
|
assert self.agent_state.created_by_id is not None
|
759
|
-
|
760
|
-
|
761
|
-
|
762
|
-
|
763
|
-
|
764
|
-
|
765
|
-
|
766
|
-
|
784
|
+
next_input_messages = [
|
785
|
+
Message.dict_to_message(
|
786
|
+
agent_id=self.agent_state.id,
|
787
|
+
model=self.model,
|
788
|
+
openai_message_dict={
|
789
|
+
"role": "user", # TODO: change to system?
|
790
|
+
"content": get_heartbeat(FUNC_FAILED_HEARTBEAT_MESSAGE),
|
791
|
+
},
|
792
|
+
)
|
793
|
+
]
|
767
794
|
continue # always chain
|
768
795
|
elif heartbeat_request:
|
769
796
|
assert self.agent_state.created_by_id is not None
|
770
|
-
|
771
|
-
|
772
|
-
|
773
|
-
|
774
|
-
|
775
|
-
|
776
|
-
|
777
|
-
|
797
|
+
next_input_messages = [
|
798
|
+
Message.dict_to_message(
|
799
|
+
agent_id=self.agent_state.id,
|
800
|
+
model=self.model,
|
801
|
+
openai_message_dict={
|
802
|
+
"role": "user", # TODO: change to system?
|
803
|
+
"content": get_heartbeat(REQ_HEARTBEAT_MESSAGE),
|
804
|
+
},
|
805
|
+
)
|
806
|
+
]
|
778
807
|
continue # always chain
|
779
808
|
# Letta no-op / yield
|
780
809
|
else:
|
@@ -788,7 +817,7 @@ class Agent(BaseAgent):
|
|
788
817
|
|
789
818
|
def inner_step(
|
790
819
|
self,
|
791
|
-
messages:
|
820
|
+
messages: List[Message],
|
792
821
|
first_message: bool = False,
|
793
822
|
first_message_retry_limit: int = FIRST_MESSAGE_ATTEMPTS,
|
794
823
|
skip_verify: bool = False,
|
@@ -814,11 +843,8 @@ class Agent(BaseAgent):
|
|
814
843
|
self.update_memory_if_changed(current_persisted_memory)
|
815
844
|
|
816
845
|
# Step 1: add user message
|
817
|
-
if isinstance(messages, Message):
|
818
|
-
messages = [messages]
|
819
|
-
|
820
846
|
if not all(isinstance(m, Message) for m in messages):
|
821
|
-
raise ValueError(f"messages should be a
|
847
|
+
raise ValueError(f"messages should be a list of Message, got {[type(m) for m in messages]}")
|
822
848
|
|
823
849
|
in_context_messages = self.agent_manager.get_in_context_messages(agent_id=self.agent_state.id, actor=self.user)
|
824
850
|
input_message_sequence = in_context_messages + messages
|
@@ -1229,9 +1255,7 @@ class Agent(BaseAgent):
|
|
1229
1255
|
return context_window_breakdown.context_window_size_current
|
1230
1256
|
|
1231
1257
|
# TODO: Refactor into separate class v.s. large if/elses here
|
1232
|
-
def execute_tool_and_persist_state(
|
1233
|
-
self, function_name: str, function_args: dict, target_letta_tool: Tool
|
1234
|
-
) -> tuple[Any, Optional[SandboxRunResult]]:
|
1258
|
+
def execute_tool_and_persist_state(self, function_name: str, function_args: dict, target_letta_tool: Tool) -> ToolExecutionResult:
|
1235
1259
|
"""
|
1236
1260
|
Execute tool modifications and persist the state of the agent.
|
1237
1261
|
Note: only some agent state modifications will be persisted, such as data in the AgentState ORM and block data
|
@@ -1293,8 +1317,10 @@ class Agent(BaseAgent):
|
|
1293
1317
|
)
|
1294
1318
|
|
1295
1319
|
function_response, is_error = mcp_client.execute_tool(tool_name=function_name, tool_args=function_args)
|
1296
|
-
|
1297
|
-
|
1320
|
+
return ToolExecutionResult(
|
1321
|
+
status="error" if is_error else "success",
|
1322
|
+
func_return=function_response,
|
1323
|
+
)
|
1298
1324
|
else:
|
1299
1325
|
try:
|
1300
1326
|
# Parse the source code to extract function annotations
|
@@ -1311,23 +1337,29 @@ class Agent(BaseAgent):
|
|
1311
1337
|
agent_state_copy.tools = []
|
1312
1338
|
agent_state_copy.tool_rules = []
|
1313
1339
|
|
1314
|
-
|
1340
|
+
tool_execution_result = ToolExecutionSandbox(function_name, function_args, self.user, tool_object=target_letta_tool).run(
|
1315
1341
|
agent_state=agent_state_copy
|
1316
1342
|
)
|
1317
|
-
function_response, updated_agent_state = sandbox_run_result.func_return, sandbox_run_result.agent_state
|
1318
1343
|
assert orig_memory_str == self.agent_state.memory.compile(), "Memory should not be modified in a sandbox tool"
|
1319
|
-
if
|
1320
|
-
self.update_memory_if_changed(
|
1321
|
-
return
|
1344
|
+
if tool_execution_result.agent_state is not None:
|
1345
|
+
self.update_memory_if_changed(tool_execution_result.agent_state.memory)
|
1346
|
+
return tool_execution_result
|
1322
1347
|
except Exception as e:
|
1323
1348
|
# Need to catch error here, or else trunction wont happen
|
1324
1349
|
# TODO: modify to function execution error
|
1325
1350
|
function_response = get_friendly_error_msg(
|
1326
1351
|
function_name=function_name, exception_name=type(e).__name__, exception_message=str(e)
|
1327
1352
|
)
|
1328
|
-
return
|
1353
|
+
return ToolExecutionResult(
|
1354
|
+
status="error",
|
1355
|
+
func_return=function_response,
|
1356
|
+
stderr=[traceback.format_exc()],
|
1357
|
+
)
|
1329
1358
|
|
1330
|
-
return
|
1359
|
+
return ToolExecutionResult(
|
1360
|
+
status="success",
|
1361
|
+
func_return=function_response,
|
1362
|
+
)
|
1331
1363
|
|
1332
1364
|
|
1333
1365
|
def save_agent(agent: Agent):
|
letta/agents/letta_agent.py
CHANGED
@@ -324,11 +324,11 @@ class LettaAgent(BaseAgent):
|
|
324
324
|
tool_execution_manager = ToolExecutionManager(agent_state=agent_state, actor=self.actor)
|
325
325
|
# TODO: Integrate sandbox result
|
326
326
|
log_event(name=f"start_{tool_name}_execution", attributes=tool_args)
|
327
|
-
|
327
|
+
tool_execution_result = await tool_execution_manager.execute_tool_async(
|
328
328
|
function_name=tool_name, function_args=tool_args, tool=target_tool
|
329
329
|
)
|
330
330
|
log_event(name=f"finish_{tool_name}_execution", attributes=tool_args)
|
331
|
-
return
|
331
|
+
return tool_execution_result.func_return, True
|
332
332
|
except Exception as e:
|
333
333
|
return f"Failed to call tool. Error: {e}", False
|
334
334
|
|
@@ -37,6 +37,7 @@ from letta.services.passage_manager import PassageManager
|
|
37
37
|
from letta.services.sandbox_config_manager import SandboxConfigManager
|
38
38
|
from letta.services.tool_executor.tool_execution_manager import ToolExecutionManager
|
39
39
|
from letta.settings import tool_settings
|
40
|
+
from letta.tracing import log_event, trace_method
|
40
41
|
from letta.utils import united_diff
|
41
42
|
|
42
43
|
logger = get_logger(__name__)
|
@@ -82,12 +83,12 @@ async def execute_tool_wrapper(params: ToolExecutionParams):
|
|
82
83
|
sandbox_config=params.sbx_config,
|
83
84
|
sandbox_env_vars=params.sbx_env_vars,
|
84
85
|
)
|
85
|
-
|
86
|
+
tool_execution_result = await mgr.execute_tool_async(
|
86
87
|
function_name=params.tool_call_name,
|
87
88
|
function_args=params.tool_args,
|
88
89
|
tool=target_tool,
|
89
90
|
)
|
90
|
-
return params.agent_id, (
|
91
|
+
return params.agent_id, (tool_execution_result.func_return, True)
|
91
92
|
except Exception as e:
|
92
93
|
return params.agent_id, (f"Failed to call tool. Error: {e}", False)
|
93
94
|
|
@@ -120,55 +121,54 @@ class LettaAgentBatch:
|
|
120
121
|
self.actor = actor
|
121
122
|
self.max_steps = max_steps
|
122
123
|
|
124
|
+
@trace_method
|
123
125
|
async def step_until_request(
|
124
126
|
self,
|
125
127
|
batch_requests: List[LettaBatchRequest],
|
126
128
|
letta_batch_job_id: str,
|
127
129
|
agent_step_state_mapping: Optional[Dict[str, AgentStepState]] = None,
|
128
130
|
) -> LettaBatchResponse:
|
129
|
-
|
131
|
+
log_event(name="validate_inputs")
|
130
132
|
if not batch_requests:
|
131
133
|
raise ValueError("Empty list of batch_requests passed in!")
|
132
134
|
if agent_step_state_mapping is None:
|
133
135
|
agent_step_state_mapping = {}
|
134
136
|
|
137
|
+
log_event(name="load_and_prepare_agents")
|
135
138
|
agent_messages_mapping: Dict[str, List[Message]] = {}
|
136
139
|
agent_tools_mapping: Dict[str, List[dict]] = {}
|
137
140
|
agent_states = []
|
138
|
-
|
139
141
|
for batch_request in batch_requests:
|
140
142
|
agent_id = batch_request.agent_id
|
141
143
|
agent_state = self.agent_manager.get_agent_by_id(agent_id, actor=self.actor)
|
142
144
|
agent_states.append(agent_state)
|
145
|
+
|
143
146
|
agent_messages_mapping[agent_id] = self._get_in_context_messages_per_agent(
|
144
147
|
agent_state=agent_state, input_messages=batch_request.messages
|
145
148
|
)
|
146
149
|
|
147
|
-
# TODO: Think about a cleaner way to do this?
|
148
150
|
if agent_id not in agent_step_state_mapping:
|
149
151
|
agent_step_state_mapping[agent_id] = AgentStepState(
|
150
152
|
step_number=0, tool_rules_solver=ToolRulesSolver(tool_rules=agent_state.tool_rules)
|
151
153
|
)
|
152
154
|
|
153
|
-
agent_tools_mapping[agent_id] = self._prepare_tools_per_agent(
|
154
|
-
agent_state, agent_step_state_mapping.get(agent_id).tool_rules_solver
|
155
|
-
)
|
155
|
+
agent_tools_mapping[agent_id] = self._prepare_tools_per_agent(agent_state, agent_step_state_mapping[agent_id].tool_rules_solver)
|
156
156
|
|
157
|
-
|
158
|
-
# TODO: But that doesn't really work in batch land
|
159
|
-
# TODO: @caren will factor this out
|
157
|
+
log_event(name="init_llm_client")
|
160
158
|
llm_client = LLMClient.create(
|
161
159
|
llm_config=agent_states[0].llm_config,
|
162
160
|
put_inner_thoughts_first=True,
|
163
161
|
)
|
164
|
-
agent_llm_config_mapping = {
|
162
|
+
agent_llm_config_mapping = {s.id: s.llm_config for s in agent_states}
|
163
|
+
|
164
|
+
log_event(name="send_llm_batch_request")
|
165
165
|
batch_response = await llm_client.send_llm_batch_request_async(
|
166
166
|
agent_messages_mapping=agent_messages_mapping,
|
167
167
|
agent_tools_mapping=agent_tools_mapping,
|
168
168
|
agent_llm_config_mapping=agent_llm_config_mapping,
|
169
169
|
)
|
170
170
|
|
171
|
-
|
171
|
+
log_event(name="persist_llm_batch_job")
|
172
172
|
llm_batch_job = self.batch_manager.create_llm_batch_job(
|
173
173
|
llm_provider=ProviderType.anthropic, # TODO: Expand to more providers
|
174
174
|
create_batch_response=batch_response,
|
@@ -177,24 +177,26 @@ class LettaAgentBatch:
|
|
177
177
|
letta_batch_job_id=letta_batch_job_id,
|
178
178
|
)
|
179
179
|
|
180
|
-
|
180
|
+
log_event(name="prepare_batch_items")
|
181
181
|
batch_items = []
|
182
|
-
for
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
182
|
+
for state in agent_states:
|
183
|
+
step_state = agent_step_state_mapping[state.id]
|
184
|
+
batch_items.append(
|
185
|
+
LLMBatchItem(
|
186
|
+
llm_batch_id=llm_batch_job.id,
|
187
|
+
agent_id=state.id,
|
188
|
+
llm_config=state.llm_config,
|
189
|
+
request_status=JobStatus.created,
|
190
|
+
step_status=AgentStepStatus.paused,
|
191
|
+
step_state=step_state,
|
192
|
+
)
|
191
193
|
)
|
192
|
-
batch_items.append(batch_item)
|
193
194
|
|
194
|
-
# Create all batch items at once using the bulk operation
|
195
195
|
if batch_items:
|
196
|
+
log_event(name="bulk_create_batch_items")
|
196
197
|
self.batch_manager.create_llm_batch_items_bulk(batch_items, actor=self.actor)
|
197
198
|
|
199
|
+
log_event(name="return_batch_response")
|
198
200
|
return LettaBatchResponse(
|
199
201
|
letta_batch_id=llm_batch_job.letta_batch_job_id,
|
200
202
|
last_llm_batch_id=llm_batch_job.id,
|
@@ -204,27 +206,27 @@ class LettaAgentBatch:
|
|
204
206
|
created_at=llm_batch_job.created_at,
|
205
207
|
)
|
206
208
|
|
209
|
+
@trace_method
|
207
210
|
async def resume_step_after_request(self, letta_batch_id: str, llm_batch_id: str) -> LettaBatchResponse:
|
208
|
-
|
211
|
+
log_event(name="load_context")
|
209
212
|
llm_batch_job = self.batch_manager.get_llm_batch_job_by_id(llm_batch_id=llm_batch_id, actor=self.actor)
|
210
213
|
ctx = await self._collect_resume_context(llm_batch_id)
|
211
214
|
|
212
|
-
|
215
|
+
log_event(name="update_statuses")
|
213
216
|
self._update_request_statuses(ctx.request_status_updates)
|
214
217
|
|
215
|
-
|
218
|
+
log_event(name="exec_tools")
|
216
219
|
exec_results = await self._execute_tools(ctx)
|
217
220
|
|
218
|
-
|
221
|
+
log_event(name="persist_messages")
|
219
222
|
msg_map = self._persist_tool_messages(exec_results, ctx)
|
220
223
|
|
221
|
-
|
224
|
+
log_event(name="mark_steps_done")
|
222
225
|
self._mark_steps_complete(llm_batch_id, ctx.agent_ids)
|
223
226
|
|
224
|
-
|
227
|
+
log_event(name="prepare_next")
|
225
228
|
next_reqs, next_step_state = self._prepare_next_iteration(exec_results, ctx, msg_map)
|
226
229
|
if len(next_reqs) == 0:
|
227
|
-
# mark batch job as completed
|
228
230
|
self.job_manager.update_job_by_id(job_id=letta_batch_id, job_update=JobUpdate(status=JobStatus.completed), actor=self.actor)
|
229
231
|
return LettaBatchResponse(
|
230
232
|
letta_batch_id=llm_batch_job.letta_batch_job_id,
|
@@ -235,15 +237,16 @@ class LettaAgentBatch:
|
|
235
237
|
created_at=llm_batch_job.created_at,
|
236
238
|
)
|
237
239
|
|
238
|
-
# 7. recurse into the normal stepping pipeline
|
239
240
|
return await self.step_until_request(
|
240
241
|
batch_requests=next_reqs,
|
241
242
|
letta_batch_job_id=letta_batch_id,
|
242
243
|
agent_step_state_mapping=next_step_state,
|
243
244
|
)
|
244
245
|
|
246
|
+
@trace_method
|
245
247
|
async def _collect_resume_context(self, llm_batch_id: str) -> _ResumeContext:
|
246
|
-
|
248
|
+
# NOTE: We only continue for items with successful results
|
249
|
+
batch_items = self.batch_manager.list_llm_batch_items(llm_batch_id=llm_batch_id, request_status=JobStatus.completed)
|
247
250
|
|
248
251
|
agent_ids, agent_state_map = [], {}
|
249
252
|
provider_results, name_map, args_map, cont_map = {}, {}, {}, {}
|
@@ -300,6 +303,7 @@ class LettaAgentBatch:
|
|
300
303
|
env = self.sandbox_config_manager.get_sandbox_env_vars_as_dict(cfg.id, actor=self.actor, limit=100)
|
301
304
|
return cfg, env
|
302
305
|
|
306
|
+
@trace_method
|
303
307
|
async def _execute_tools(self, ctx: _ResumeContext) -> Sequence[Tuple[str, Tuple[str, bool]]]:
|
304
308
|
sbx_cfg, sbx_env = self._build_sandbox()
|
305
309
|
params = [
|