letta-nightly 0.7.0.dev20250423003112__py3-none-any.whl → 0.7.1.dev20250423104245__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- letta/__init__.py +1 -1
- letta/agent.py +113 -81
- letta/agents/letta_agent.py +2 -2
- letta/agents/letta_agent_batch.py +38 -34
- letta/client/client.py +10 -2
- letta/constants.py +4 -3
- letta/functions/function_sets/multi_agent.py +1 -3
- letta/functions/helpers.py +3 -3
- letta/groups/dynamic_multi_agent.py +58 -59
- letta/groups/round_robin_multi_agent.py +43 -49
- letta/groups/sleeptime_multi_agent.py +28 -18
- letta/groups/supervisor_multi_agent.py +21 -20
- letta/helpers/converters.py +29 -0
- letta/helpers/message_helper.py +1 -0
- letta/helpers/tool_execution_helper.py +3 -3
- letta/orm/agent.py +8 -1
- letta/orm/custom_columns.py +15 -0
- letta/schemas/agent.py +6 -0
- letta/schemas/message.py +1 -0
- letta/schemas/response_format.py +78 -0
- letta/schemas/tool_execution_result.py +14 -0
- letta/server/rest_api/interface.py +2 -1
- letta/server/rest_api/routers/openai/chat_completions/chat_completions.py +1 -1
- letta/server/rest_api/routers/v1/agents.py +4 -4
- letta/server/rest_api/routers/v1/groups.py +2 -2
- letta/server/rest_api/routers/v1/messages.py +32 -18
- letta/server/server.py +24 -57
- letta/services/agent_manager.py +1 -0
- letta/services/llm_batch_manager.py +28 -26
- letta/services/tool_executor/tool_execution_manager.py +37 -28
- letta/services/tool_executor/tool_execution_sandbox.py +35 -16
- letta/services/tool_executor/tool_executor.py +299 -68
- letta/services/tool_sandbox/base.py +3 -2
- letta/services/tool_sandbox/e2b_sandbox.py +5 -4
- letta/services/tool_sandbox/local_sandbox.py +11 -6
- {letta_nightly-0.7.0.dev20250423003112.dist-info → letta_nightly-0.7.1.dev20250423104245.dist-info}/METADATA +1 -1
- {letta_nightly-0.7.0.dev20250423003112.dist-info → letta_nightly-0.7.1.dev20250423104245.dist-info}/RECORD +40 -38
- {letta_nightly-0.7.0.dev20250423003112.dist-info → letta_nightly-0.7.1.dev20250423104245.dist-info}/LICENSE +0 -0
- {letta_nightly-0.7.0.dev20250423003112.dist-info → letta_nightly-0.7.1.dev20250423104245.dist-info}/WHEEL +0 -0
- {letta_nightly-0.7.0.dev20250423003112.dist-info → letta_nightly-0.7.1.dev20250423104245.dist-info}/entry_points.txt +0 -0
@@ -9,7 +9,7 @@ from letta.interface import AgentInterface
|
|
9
9
|
from letta.orm import User
|
10
10
|
from letta.orm.enums import ToolType
|
11
11
|
from letta.schemas.letta_message_content import TextContent
|
12
|
-
from letta.schemas.message import
|
12
|
+
from letta.schemas.message import MessageCreate
|
13
13
|
from letta.schemas.tool import Tool
|
14
14
|
from letta.schemas.tool_rule import ChildToolRule, InitToolRule, TerminalToolRule
|
15
15
|
from letta.schemas.usage import LettaUsageStatistics
|
@@ -37,17 +37,18 @@ class SupervisorMultiAgent(Agent):
|
|
37
37
|
|
38
38
|
def step(
|
39
39
|
self,
|
40
|
-
|
40
|
+
input_messages: List[MessageCreate],
|
41
41
|
chaining: bool = True,
|
42
42
|
max_chaining_steps: Optional[int] = None,
|
43
43
|
put_inner_thoughts_first: bool = True,
|
44
44
|
assistant_message_tool_name: str = DEFAULT_MESSAGE_TOOL,
|
45
45
|
**kwargs,
|
46
46
|
) -> LettaUsageStatistics:
|
47
|
+
# Load settings
|
47
48
|
token_streaming = self.interface.streaming_mode if hasattr(self.interface, "streaming_mode") else False
|
48
49
|
metadata = self.interface.metadata if hasattr(self.interface, "metadata") else None
|
49
50
|
|
50
|
-
#
|
51
|
+
# Prepare supervisor agent
|
51
52
|
if self.tool_manager.get_tool_by_name(tool_name="send_message_to_all_agents_in_group", actor=self.user) is None:
|
52
53
|
multi_agent_tool = Tool(
|
53
54
|
name=send_message_to_all_agents_in_group.__name__,
|
@@ -64,7 +65,6 @@ class SupervisorMultiAgent(Agent):
|
|
64
65
|
)
|
65
66
|
self.agent_state = self.agent_manager.attach_tool(agent_id=self.agent_state.id, tool_id=multi_agent_tool.id, actor=self.user)
|
66
67
|
|
67
|
-
# override tool rules
|
68
68
|
old_tool_rules = self.agent_state.tool_rules
|
69
69
|
self.agent_state.tool_rules = [
|
70
70
|
InitToolRule(
|
@@ -79,24 +79,25 @@ class SupervisorMultiAgent(Agent):
|
|
79
79
|
),
|
80
80
|
]
|
81
81
|
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
content=[TextContent(text=message.content)]
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
tool_call_id=None,
|
91
|
-
group_id=self.group_id,
|
92
|
-
otid=message.otid,
|
93
|
-
)
|
94
|
-
for message in messages
|
95
|
-
]
|
82
|
+
# Prepare new messages
|
83
|
+
new_messages = []
|
84
|
+
for message in input_messages:
|
85
|
+
if isinstance(message.content, str):
|
86
|
+
message.content = [TextContent(text=message.content)]
|
87
|
+
message.group_id = self.group_id
|
88
|
+
new_messages.append(message)
|
89
|
+
|
96
90
|
try:
|
97
|
-
|
91
|
+
# Load supervisor agent
|
92
|
+
supervisor_agent = Agent(
|
93
|
+
agent_state=self.agent_state,
|
94
|
+
interface=self.interface,
|
95
|
+
user=self.user,
|
96
|
+
)
|
97
|
+
|
98
|
+
# Perform supervisor step
|
98
99
|
usage_stats = supervisor_agent.step(
|
99
|
-
|
100
|
+
input_messages=new_messages,
|
100
101
|
chaining=chaining,
|
101
102
|
max_chaining_steps=max_chaining_steps,
|
102
103
|
stream=token_streaming,
|
letta/helpers/converters.py
CHANGED
@@ -22,6 +22,13 @@ from letta.schemas.letta_message_content import (
|
|
22
22
|
)
|
23
23
|
from letta.schemas.llm_config import LLMConfig
|
24
24
|
from letta.schemas.message import ToolReturn
|
25
|
+
from letta.schemas.response_format import (
|
26
|
+
JsonObjectResponseFormat,
|
27
|
+
JsonSchemaResponseFormat,
|
28
|
+
ResponseFormatType,
|
29
|
+
ResponseFormatUnion,
|
30
|
+
TextResponseFormat,
|
31
|
+
)
|
25
32
|
from letta.schemas.tool_rule import (
|
26
33
|
ChildToolRule,
|
27
34
|
ConditionalToolRule,
|
@@ -371,3 +378,25 @@ def deserialize_agent_step_state(data: Optional[Dict]) -> Optional[AgentStepStat
|
|
371
378
|
return None
|
372
379
|
|
373
380
|
return AgentStepState(**data)
|
381
|
+
|
382
|
+
|
383
|
+
# --------------------------
|
384
|
+
# Response Format Serialization
|
385
|
+
# --------------------------
|
386
|
+
|
387
|
+
|
388
|
+
def serialize_response_format(response_format: Optional[ResponseFormatUnion]) -> Optional[Dict[str, Any]]:
|
389
|
+
if not response_format:
|
390
|
+
return None
|
391
|
+
return response_format.model_dump(mode="json")
|
392
|
+
|
393
|
+
|
394
|
+
def deserialize_response_format(data: Optional[Dict]) -> Optional[ResponseFormatUnion]:
|
395
|
+
if not data:
|
396
|
+
return None
|
397
|
+
if data["type"] == ResponseFormatType.text:
|
398
|
+
return TextResponseFormat(**data)
|
399
|
+
if data["type"] == ResponseFormatType.json_schema:
|
400
|
+
return JsonSchemaResponseFormat(**data)
|
401
|
+
if data["type"] == ResponseFormatType.json_object:
|
402
|
+
return JsonObjectResponseFormat(**data)
|
letta/helpers/message_helper.py
CHANGED
@@ -160,12 +160,12 @@ def execute_external_tool(
|
|
160
160
|
else:
|
161
161
|
agent_state_copy = None
|
162
162
|
|
163
|
-
|
164
|
-
function_response, updated_agent_state =
|
163
|
+
tool_execution_result = ToolExecutionSandbox(function_name, function_args, actor).run(agent_state=agent_state_copy)
|
164
|
+
function_response, updated_agent_state = tool_execution_result.func_return, tool_execution_result.agent_state
|
165
165
|
# TODO: Bring this back
|
166
166
|
# if allow_agent_state_modifications and updated_agent_state is not None:
|
167
167
|
# self.update_memory_if_changed(updated_agent_state.memory)
|
168
|
-
return function_response,
|
168
|
+
return function_response, tool_execution_result
|
169
169
|
except Exception as e:
|
170
170
|
# Need to catch error here, or else trunction wont happen
|
171
171
|
# TODO: modify to function execution error
|
letta/orm/agent.py
CHANGED
@@ -5,7 +5,7 @@ from sqlalchemy import JSON, Boolean, Index, String
|
|
5
5
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
6
6
|
|
7
7
|
from letta.orm.block import Block
|
8
|
-
from letta.orm.custom_columns import EmbeddingConfigColumn, LLMConfigColumn, ToolRulesColumn
|
8
|
+
from letta.orm.custom_columns import EmbeddingConfigColumn, LLMConfigColumn, ResponseFormatColumn, ToolRulesColumn
|
9
9
|
from letta.orm.identity import Identity
|
10
10
|
from letta.orm.mixins import OrganizationMixin
|
11
11
|
from letta.orm.organization import Organization
|
@@ -15,6 +15,7 @@ from letta.schemas.agent import AgentType, get_prompt_template_for_agent_type
|
|
15
15
|
from letta.schemas.embedding_config import EmbeddingConfig
|
16
16
|
from letta.schemas.llm_config import LLMConfig
|
17
17
|
from letta.schemas.memory import Memory
|
18
|
+
from letta.schemas.response_format import ResponseFormatUnion
|
18
19
|
from letta.schemas.tool_rule import ToolRule
|
19
20
|
|
20
21
|
if TYPE_CHECKING:
|
@@ -48,6 +49,11 @@ class Agent(SqlalchemyBase, OrganizationMixin):
|
|
48
49
|
# This is dangerously flexible with the JSON type
|
49
50
|
message_ids: Mapped[Optional[List[str]]] = mapped_column(JSON, nullable=True, doc="List of message IDs in in-context memory.")
|
50
51
|
|
52
|
+
# Response Format
|
53
|
+
response_format: Mapped[Optional[ResponseFormatUnion]] = mapped_column(
|
54
|
+
ResponseFormatColumn, nullable=True, doc="The response format for the agent."
|
55
|
+
)
|
56
|
+
|
51
57
|
# Metadata and configs
|
52
58
|
metadata_: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True, doc="metadata for the agent.")
|
53
59
|
llm_config: Mapped[Optional[LLMConfig]] = mapped_column(
|
@@ -168,6 +174,7 @@ class Agent(SqlalchemyBase, OrganizationMixin):
|
|
168
174
|
"multi_agent_group": None,
|
169
175
|
"tool_exec_environment_variables": [],
|
170
176
|
"enable_sleeptime": None,
|
177
|
+
"response_format": self.response_format,
|
171
178
|
}
|
172
179
|
|
173
180
|
# Optional fields: only included if requested
|
letta/orm/custom_columns.py
CHANGED
@@ -9,6 +9,7 @@ from letta.helpers.converters import (
|
|
9
9
|
deserialize_llm_config,
|
10
10
|
deserialize_message_content,
|
11
11
|
deserialize_poll_batch_response,
|
12
|
+
deserialize_response_format,
|
12
13
|
deserialize_tool_calls,
|
13
14
|
deserialize_tool_returns,
|
14
15
|
deserialize_tool_rules,
|
@@ -20,6 +21,7 @@ from letta.helpers.converters import (
|
|
20
21
|
serialize_llm_config,
|
21
22
|
serialize_message_content,
|
22
23
|
serialize_poll_batch_response,
|
24
|
+
serialize_response_format,
|
23
25
|
serialize_tool_calls,
|
24
26
|
serialize_tool_returns,
|
25
27
|
serialize_tool_rules,
|
@@ -168,3 +170,16 @@ class AgentStepStateColumn(TypeDecorator):
|
|
168
170
|
|
169
171
|
def process_result_value(self, value, dialect):
|
170
172
|
return deserialize_agent_step_state(value)
|
173
|
+
|
174
|
+
|
175
|
+
class ResponseFormatColumn(TypeDecorator):
|
176
|
+
"""Custom SQLAlchemy column type for storing a list of ToolRules as JSON."""
|
177
|
+
|
178
|
+
impl = JSON
|
179
|
+
cache_ok = True
|
180
|
+
|
181
|
+
def process_bind_param(self, value, dialect):
|
182
|
+
return serialize_response_format(value)
|
183
|
+
|
184
|
+
def process_result_value(self, value, dialect):
|
185
|
+
return deserialize_response_format(value)
|
letta/schemas/agent.py
CHANGED
@@ -14,6 +14,7 @@ from letta.schemas.llm_config import LLMConfig
|
|
14
14
|
from letta.schemas.memory import Memory
|
15
15
|
from letta.schemas.message import Message, MessageCreate
|
16
16
|
from letta.schemas.openai.chat_completion_response import UsageStatistics
|
17
|
+
from letta.schemas.response_format import ResponseFormatUnion
|
17
18
|
from letta.schemas.source import Source
|
18
19
|
from letta.schemas.tool import Tool
|
19
20
|
from letta.schemas.tool_rule import ToolRule
|
@@ -66,6 +67,9 @@ class AgentState(OrmMetadataBase, validate_assignment=True):
|
|
66
67
|
# llm information
|
67
68
|
llm_config: LLMConfig = Field(..., description="The LLM configuration used by the agent.")
|
68
69
|
embedding_config: EmbeddingConfig = Field(..., description="The embedding configuration used by the agent.")
|
70
|
+
response_format: Optional[ResponseFormatUnion] = Field(
|
71
|
+
None, description="The response format used by the agent when returning from `send_message`."
|
72
|
+
)
|
69
73
|
|
70
74
|
# This is an object representing the in-process state of a running `Agent`
|
71
75
|
# Field in this object can be theoretically edited by tools, and will be persisted by the ORM
|
@@ -180,6 +184,7 @@ class CreateAgent(BaseModel, validate_assignment=True): #
|
|
180
184
|
description="If set to True, the agent will not remember previous messages (though the agent will still retain state via core memory blocks and archival/recall memory). Not recommended unless you have an advanced use case.",
|
181
185
|
)
|
182
186
|
enable_sleeptime: Optional[bool] = Field(None, description="If set to True, memory management will move to a background agent thread.")
|
187
|
+
response_format: Optional[ResponseFormatUnion] = Field(None, description="The response format for the agent.")
|
183
188
|
|
184
189
|
@field_validator("name")
|
185
190
|
@classmethod
|
@@ -259,6 +264,7 @@ class UpdateAgent(BaseModel):
|
|
259
264
|
None, description="The embedding configuration handle used by the agent, specified in the format provider/model-name."
|
260
265
|
)
|
261
266
|
enable_sleeptime: Optional[bool] = Field(None, description="If set to True, memory management will move to a background agent thread.")
|
267
|
+
response_format: Optional[ResponseFormatUnion] = Field(None, description="The response format for the agent.")
|
262
268
|
|
263
269
|
class Config:
|
264
270
|
extra = "ignore" # Ignores extra fields
|
letta/schemas/message.py
CHANGED
@@ -82,6 +82,7 @@ class MessageCreate(BaseModel):
|
|
82
82
|
name: Optional[str] = Field(None, description="The name of the participant.")
|
83
83
|
otid: Optional[str] = Field(None, description="The offline threading id associated with this message")
|
84
84
|
sender_id: Optional[str] = Field(None, description="The id of the sender of the message, can be an identity id or agent id")
|
85
|
+
group_id: Optional[str] = Field(None, description="The multi-agent group that the message was sent in")
|
85
86
|
|
86
87
|
def model_dump(self, to_orm: bool = False, **kwargs) -> Dict[str, Any]:
|
87
88
|
data = super().model_dump(**kwargs)
|
@@ -0,0 +1,78 @@
|
|
1
|
+
from enum import Enum
|
2
|
+
from typing import Annotated, Any, Dict, Literal, Union
|
3
|
+
|
4
|
+
from pydantic import BaseModel, Field, validator
|
5
|
+
|
6
|
+
|
7
|
+
class ResponseFormatType(str, Enum):
|
8
|
+
"""Enum defining the possible response format types."""
|
9
|
+
|
10
|
+
text = "text"
|
11
|
+
json_schema = "json_schema"
|
12
|
+
json_object = "json_object"
|
13
|
+
|
14
|
+
|
15
|
+
class ResponseFormat(BaseModel):
|
16
|
+
"""Base class for all response formats."""
|
17
|
+
|
18
|
+
type: ResponseFormatType = Field(
|
19
|
+
...,
|
20
|
+
description="The type of the response format.",
|
21
|
+
# why use this?
|
22
|
+
example=ResponseFormatType.text,
|
23
|
+
)
|
24
|
+
|
25
|
+
|
26
|
+
# ---------------------
|
27
|
+
# Response Format Types
|
28
|
+
# ---------------------
|
29
|
+
|
30
|
+
# SQLAlchemy type for database mapping
|
31
|
+
ResponseFormatDict = Dict[str, Any]
|
32
|
+
|
33
|
+
|
34
|
+
class TextResponseFormat(ResponseFormat):
|
35
|
+
"""Response format for plain text responses."""
|
36
|
+
|
37
|
+
type: Literal[ResponseFormatType.text] = Field(
|
38
|
+
ResponseFormatType.text,
|
39
|
+
description="The type of the response format.",
|
40
|
+
)
|
41
|
+
|
42
|
+
|
43
|
+
class JsonSchemaResponseFormat(ResponseFormat):
|
44
|
+
"""Response format for JSON schema-based responses."""
|
45
|
+
|
46
|
+
type: Literal[ResponseFormatType.json_schema] = Field(
|
47
|
+
ResponseFormatType.json_schema,
|
48
|
+
description="The type of the response format.",
|
49
|
+
)
|
50
|
+
json_schema: Dict[str, Any] = Field(
|
51
|
+
...,
|
52
|
+
description="The JSON schema of the response.",
|
53
|
+
)
|
54
|
+
|
55
|
+
@validator("json_schema")
|
56
|
+
def validate_json_schema(cls, v: Dict[str, Any]) -> Dict[str, Any]:
|
57
|
+
"""Validate that the provided schema is a valid JSON schema."""
|
58
|
+
if not isinstance(v, dict):
|
59
|
+
raise ValueError("JSON schema must be a dictionary")
|
60
|
+
if "schema" not in v:
|
61
|
+
raise ValueError("JSON schema should include a $schema property")
|
62
|
+
return v
|
63
|
+
|
64
|
+
|
65
|
+
class JsonObjectResponseFormat(ResponseFormat):
|
66
|
+
"""Response format for JSON object responses."""
|
67
|
+
|
68
|
+
type: Literal[ResponseFormatType.json_object] = Field(
|
69
|
+
ResponseFormatType.json_object,
|
70
|
+
description="The type of the response format.",
|
71
|
+
)
|
72
|
+
|
73
|
+
|
74
|
+
# Pydantic type for validation
|
75
|
+
ResponseFormatUnion = Annotated[
|
76
|
+
Union[TextResponseFormat | JsonSchemaResponseFormat | JsonObjectResponseFormat],
|
77
|
+
Field(discriminator="type"),
|
78
|
+
]
|
@@ -0,0 +1,14 @@
|
|
1
|
+
from typing import Any, List, Literal, Optional
|
2
|
+
|
3
|
+
from pydantic import BaseModel, Field
|
4
|
+
|
5
|
+
from letta.schemas.agent import AgentState
|
6
|
+
|
7
|
+
|
8
|
+
class ToolExecutionResult(BaseModel):
|
9
|
+
status: Literal["success", "error"] = Field(..., description="The status of the tool execution and return object")
|
10
|
+
func_return: Optional[Any] = Field(None, description="The function return object")
|
11
|
+
agent_state: Optional[AgentState] = Field(None, description="The agent state")
|
12
|
+
stdout: Optional[List[str]] = Field(None, description="Captured stdout (prints, logs) from function invocation")
|
13
|
+
stderr: Optional[List[str]] = Field(None, description="Captured stderr from the function invocation")
|
14
|
+
sandbox_config_fingerprint: Optional[str] = Field(None, description="The fingerprint of the config for the sandbox")
|
@@ -1240,10 +1240,11 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
|
1240
1240
|
and function_call.function.name == self.assistant_message_tool_name
|
1241
1241
|
and self.assistant_message_tool_kwarg in func_args
|
1242
1242
|
):
|
1243
|
+
# Coerce content to `str` in cases where it's a JSON due to `response_format` being a JSON
|
1243
1244
|
processed_chunk = AssistantMessage(
|
1244
1245
|
id=msg_obj.id,
|
1245
1246
|
date=msg_obj.created_at,
|
1246
|
-
content=func_args[self.assistant_message_tool_kwarg],
|
1247
|
+
content=str(func_args[self.assistant_message_tool_kwarg]),
|
1247
1248
|
name=msg_obj.name,
|
1248
1249
|
otid=Message.generate_otid_from_id(msg_obj.id, chunk_index) if chunk_index is not None else None,
|
1249
1250
|
)
|
@@ -111,7 +111,7 @@ async def send_message_to_agent_chat_completions(
|
|
111
111
|
server.send_messages,
|
112
112
|
actor=actor,
|
113
113
|
agent_id=letta_agent.agent_state.id,
|
114
|
-
|
114
|
+
input_messages=messages,
|
115
115
|
interface=streaming_interface,
|
116
116
|
put_inner_thoughts_first=False,
|
117
117
|
)
|
@@ -412,7 +412,7 @@ def list_blocks(
|
|
412
412
|
"""
|
413
413
|
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
414
414
|
try:
|
415
|
-
agent = server.agent_manager.get_agent_by_id(agent_id, actor
|
415
|
+
agent = server.agent_manager.get_agent_by_id(agent_id, actor)
|
416
416
|
return agent.memory.blocks
|
417
417
|
except NoResultFound as e:
|
418
418
|
raise HTTPException(status_code=404, detail=str(e))
|
@@ -640,7 +640,7 @@ async def send_message(
|
|
640
640
|
result = await server.send_message_to_agent(
|
641
641
|
agent_id=agent_id,
|
642
642
|
actor=actor,
|
643
|
-
|
643
|
+
input_messages=request.messages,
|
644
644
|
stream_steps=False,
|
645
645
|
stream_tokens=False,
|
646
646
|
# Support for AssistantMessage
|
@@ -703,7 +703,7 @@ async def send_message_streaming(
|
|
703
703
|
result = await server.send_message_to_agent(
|
704
704
|
agent_id=agent_id,
|
705
705
|
actor=actor,
|
706
|
-
|
706
|
+
input_messages=request.messages,
|
707
707
|
stream_steps=True,
|
708
708
|
stream_tokens=request.stream_tokens,
|
709
709
|
# Support for AssistantMessage
|
@@ -730,7 +730,7 @@ async def process_message_background(
|
|
730
730
|
result = await server.send_message_to_agent(
|
731
731
|
agent_id=agent_id,
|
732
732
|
actor=actor,
|
733
|
-
|
733
|
+
input_messages=messages,
|
734
734
|
stream_steps=False, # NOTE(matt)
|
735
735
|
stream_tokens=False,
|
736
736
|
use_assistant_message=use_assistant_message,
|
@@ -128,7 +128,7 @@ async def send_group_message(
|
|
128
128
|
result = await server.send_group_message_to_agent(
|
129
129
|
group_id=group_id,
|
130
130
|
actor=actor,
|
131
|
-
|
131
|
+
input_messages=request.messages,
|
132
132
|
stream_steps=False,
|
133
133
|
stream_tokens=False,
|
134
134
|
# Support for AssistantMessage
|
@@ -167,7 +167,7 @@ async def send_group_message_streaming(
|
|
167
167
|
result = await server.send_group_message_to_agent(
|
168
168
|
group_id=group_id,
|
169
169
|
actor=actor,
|
170
|
-
|
170
|
+
input_messages=request.messages,
|
171
171
|
stream_steps=True,
|
172
172
|
stream_tokens=request.stream_tokens,
|
173
173
|
# Support for AssistantMessage
|
@@ -7,7 +7,7 @@ from starlette.requests import Request
|
|
7
7
|
from letta.agents.letta_agent_batch import LettaAgentBatch
|
8
8
|
from letta.log import get_logger
|
9
9
|
from letta.orm.errors import NoResultFound
|
10
|
-
from letta.schemas.job import BatchJob, JobStatus, JobType
|
10
|
+
from letta.schemas.job import BatchJob, JobStatus, JobType, JobUpdate
|
11
11
|
from letta.schemas.letta_request import CreateBatch
|
12
12
|
from letta.server.rest_api.utils import get_letta_server
|
13
13
|
from letta.server.server import SyncServer
|
@@ -43,18 +43,18 @@ async def create_messages_batch(
|
|
43
43
|
if length > max_bytes:
|
44
44
|
raise HTTPException(status_code=413, detail=f"Request too large ({length} bytes). Max is {max_bytes} bytes.")
|
45
45
|
|
46
|
+
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
47
|
+
batch_job = BatchJob(
|
48
|
+
user_id=actor.id,
|
49
|
+
status=JobStatus.running,
|
50
|
+
metadata={
|
51
|
+
"job_type": "batch_messages",
|
52
|
+
},
|
53
|
+
callback_url=str(payload.callback_url),
|
54
|
+
)
|
55
|
+
|
46
56
|
try:
|
47
|
-
|
48
|
-
|
49
|
-
# Create a new job
|
50
|
-
batch_job = BatchJob(
|
51
|
-
user_id=actor.id,
|
52
|
-
status=JobStatus.created,
|
53
|
-
metadata={
|
54
|
-
"job_type": "batch_messages",
|
55
|
-
},
|
56
|
-
callback_url=str(payload.callback_url),
|
57
|
-
)
|
57
|
+
batch_job = server.job_manager.create_job(pydantic_job=batch_job, actor=actor)
|
58
58
|
|
59
59
|
# create the batch runner
|
60
60
|
batch_runner = LettaAgentBatch(
|
@@ -67,14 +67,17 @@ async def create_messages_batch(
|
|
67
67
|
job_manager=server.job_manager,
|
68
68
|
actor=actor,
|
69
69
|
)
|
70
|
-
|
70
|
+
await batch_runner.step_until_request(batch_requests=payload.requests, letta_batch_job_id=batch_job.id)
|
71
71
|
|
72
72
|
# TODO: update run metadata
|
73
|
-
|
74
|
-
except Exception:
|
73
|
+
except Exception as e:
|
75
74
|
import traceback
|
76
75
|
|
76
|
+
print("Error creating batch job", e)
|
77
77
|
traceback.print_exc()
|
78
|
+
|
79
|
+
# mark job as failed
|
80
|
+
server.job_manager.update_job_by_id(job_id=batch_job.id, job=BatchJob(status=JobStatus.failed), actor=actor)
|
78
81
|
raise
|
79
82
|
return batch_job
|
80
83
|
|
@@ -125,8 +128,19 @@ async def cancel_batch_run(
|
|
125
128
|
|
126
129
|
try:
|
127
130
|
job = server.job_manager.get_job_by_id(job_id=batch_id, actor=actor)
|
128
|
-
job.
|
129
|
-
|
130
|
-
#
|
131
|
+
job = server.job_manager.update_job_by_id(job_id=job.id, job_update=JobUpdate(status=JobStatus.cancelled), actor=actor)
|
132
|
+
|
133
|
+
# Get related llm batch jobs
|
134
|
+
llm_batch_jobs = server.batch_manager.list_llm_batch_jobs(letta_batch_id=job.id, actor=actor)
|
135
|
+
for llm_batch_job in llm_batch_jobs:
|
136
|
+
if llm_batch_job.status in {JobStatus.running, JobStatus.created}:
|
137
|
+
# TODO: Extend to providers beyond anthropic
|
138
|
+
# TODO: For now, we only support anthropic
|
139
|
+
# Cancel the job
|
140
|
+
anthropic_batch_id = llm_batch_job.create_batch_response.id
|
141
|
+
await server.anthropic_async_client.messages.batches.cancel(anthropic_batch_id)
|
142
|
+
|
143
|
+
# Update all the batch_job statuses
|
144
|
+
server.batch_manager.update_llm_batch_status(llm_batch_id=llm_batch_job.id, status=JobStatus.cancelled, actor=actor)
|
131
145
|
except NoResultFound:
|
132
146
|
raise HTTPException(status_code=404, detail="Run not found")
|