letta-nightly 0.7.0.dev20250423003112__py3-none-any.whl → 0.7.1.dev20250423104245__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. letta/__init__.py +1 -1
  2. letta/agent.py +113 -81
  3. letta/agents/letta_agent.py +2 -2
  4. letta/agents/letta_agent_batch.py +38 -34
  5. letta/client/client.py +10 -2
  6. letta/constants.py +4 -3
  7. letta/functions/function_sets/multi_agent.py +1 -3
  8. letta/functions/helpers.py +3 -3
  9. letta/groups/dynamic_multi_agent.py +58 -59
  10. letta/groups/round_robin_multi_agent.py +43 -49
  11. letta/groups/sleeptime_multi_agent.py +28 -18
  12. letta/groups/supervisor_multi_agent.py +21 -20
  13. letta/helpers/converters.py +29 -0
  14. letta/helpers/message_helper.py +1 -0
  15. letta/helpers/tool_execution_helper.py +3 -3
  16. letta/orm/agent.py +8 -1
  17. letta/orm/custom_columns.py +15 -0
  18. letta/schemas/agent.py +6 -0
  19. letta/schemas/message.py +1 -0
  20. letta/schemas/response_format.py +78 -0
  21. letta/schemas/tool_execution_result.py +14 -0
  22. letta/server/rest_api/interface.py +2 -1
  23. letta/server/rest_api/routers/openai/chat_completions/chat_completions.py +1 -1
  24. letta/server/rest_api/routers/v1/agents.py +4 -4
  25. letta/server/rest_api/routers/v1/groups.py +2 -2
  26. letta/server/rest_api/routers/v1/messages.py +32 -18
  27. letta/server/server.py +24 -57
  28. letta/services/agent_manager.py +1 -0
  29. letta/services/llm_batch_manager.py +28 -26
  30. letta/services/tool_executor/tool_execution_manager.py +37 -28
  31. letta/services/tool_executor/tool_execution_sandbox.py +35 -16
  32. letta/services/tool_executor/tool_executor.py +299 -68
  33. letta/services/tool_sandbox/base.py +3 -2
  34. letta/services/tool_sandbox/e2b_sandbox.py +5 -4
  35. letta/services/tool_sandbox/local_sandbox.py +11 -6
  36. {letta_nightly-0.7.0.dev20250423003112.dist-info → letta_nightly-0.7.1.dev20250423104245.dist-info}/METADATA +1 -1
  37. {letta_nightly-0.7.0.dev20250423003112.dist-info → letta_nightly-0.7.1.dev20250423104245.dist-info}/RECORD +40 -38
  38. {letta_nightly-0.7.0.dev20250423003112.dist-info → letta_nightly-0.7.1.dev20250423104245.dist-info}/LICENSE +0 -0
  39. {letta_nightly-0.7.0.dev20250423003112.dist-info → letta_nightly-0.7.1.dev20250423104245.dist-info}/WHEEL +0 -0
  40. {letta_nightly-0.7.0.dev20250423003112.dist-info → letta_nightly-0.7.1.dev20250423104245.dist-info}/entry_points.txt +0 -0
@@ -9,7 +9,7 @@ from letta.interface import AgentInterface
9
9
  from letta.orm import User
10
10
  from letta.orm.enums import ToolType
11
11
  from letta.schemas.letta_message_content import TextContent
12
- from letta.schemas.message import Message, MessageCreate
12
+ from letta.schemas.message import MessageCreate
13
13
  from letta.schemas.tool import Tool
14
14
  from letta.schemas.tool_rule import ChildToolRule, InitToolRule, TerminalToolRule
15
15
  from letta.schemas.usage import LettaUsageStatistics
@@ -37,17 +37,18 @@ class SupervisorMultiAgent(Agent):
37
37
 
38
38
  def step(
39
39
  self,
40
- messages: List[MessageCreate],
40
+ input_messages: List[MessageCreate],
41
41
  chaining: bool = True,
42
42
  max_chaining_steps: Optional[int] = None,
43
43
  put_inner_thoughts_first: bool = True,
44
44
  assistant_message_tool_name: str = DEFAULT_MESSAGE_TOOL,
45
45
  **kwargs,
46
46
  ) -> LettaUsageStatistics:
47
+ # Load settings
47
48
  token_streaming = self.interface.streaming_mode if hasattr(self.interface, "streaming_mode") else False
48
49
  metadata = self.interface.metadata if hasattr(self.interface, "metadata") else None
49
50
 
50
- # add multi agent tool
51
+ # Prepare supervisor agent
51
52
  if self.tool_manager.get_tool_by_name(tool_name="send_message_to_all_agents_in_group", actor=self.user) is None:
52
53
  multi_agent_tool = Tool(
53
54
  name=send_message_to_all_agents_in_group.__name__,
@@ -64,7 +65,6 @@ class SupervisorMultiAgent(Agent):
64
65
  )
65
66
  self.agent_state = self.agent_manager.attach_tool(agent_id=self.agent_state.id, tool_id=multi_agent_tool.id, actor=self.user)
66
67
 
67
- # override tool rules
68
68
  old_tool_rules = self.agent_state.tool_rules
69
69
  self.agent_state.tool_rules = [
70
70
  InitToolRule(
@@ -79,24 +79,25 @@ class SupervisorMultiAgent(Agent):
79
79
  ),
80
80
  ]
81
81
 
82
- supervisor_messages = [
83
- Message(
84
- agent_id=self.agent_state.id,
85
- role="user",
86
- content=[TextContent(text=message.content)],
87
- name=None,
88
- model=None,
89
- tool_calls=None,
90
- tool_call_id=None,
91
- group_id=self.group_id,
92
- otid=message.otid,
93
- )
94
- for message in messages
95
- ]
82
+ # Prepare new messages
83
+ new_messages = []
84
+ for message in input_messages:
85
+ if isinstance(message.content, str):
86
+ message.content = [TextContent(text=message.content)]
87
+ message.group_id = self.group_id
88
+ new_messages.append(message)
89
+
96
90
  try:
97
- supervisor_agent = Agent(agent_state=self.agent_state, interface=self.interface, user=self.user)
91
+ # Load supervisor agent
92
+ supervisor_agent = Agent(
93
+ agent_state=self.agent_state,
94
+ interface=self.interface,
95
+ user=self.user,
96
+ )
97
+
98
+ # Perform supervisor step
98
99
  usage_stats = supervisor_agent.step(
99
- messages=supervisor_messages,
100
+ input_messages=new_messages,
100
101
  chaining=chaining,
101
102
  max_chaining_steps=max_chaining_steps,
102
103
  stream=token_streaming,
@@ -22,6 +22,13 @@ from letta.schemas.letta_message_content import (
22
22
  )
23
23
  from letta.schemas.llm_config import LLMConfig
24
24
  from letta.schemas.message import ToolReturn
25
+ from letta.schemas.response_format import (
26
+ JsonObjectResponseFormat,
27
+ JsonSchemaResponseFormat,
28
+ ResponseFormatType,
29
+ ResponseFormatUnion,
30
+ TextResponseFormat,
31
+ )
25
32
  from letta.schemas.tool_rule import (
26
33
  ChildToolRule,
27
34
  ConditionalToolRule,
@@ -371,3 +378,25 @@ def deserialize_agent_step_state(data: Optional[Dict]) -> Optional[AgentStepStat
371
378
  return None
372
379
 
373
380
  return AgentStepState(**data)
381
+
382
+
383
+ # --------------------------
384
+ # Response Format Serialization
385
+ # --------------------------
386
+
387
+
388
+ def serialize_response_format(response_format: Optional[ResponseFormatUnion]) -> Optional[Dict[str, Any]]:
389
+ if not response_format:
390
+ return None
391
+ return response_format.model_dump(mode="json")
392
+
393
+
394
+ def deserialize_response_format(data: Optional[Dict]) -> Optional[ResponseFormatUnion]:
395
+ if not data:
396
+ return None
397
+ if data["type"] == ResponseFormatType.text:
398
+ return TextResponseFormat(**data)
399
+ if data["type"] == ResponseFormatType.json_schema:
400
+ return JsonSchemaResponseFormat(**data)
401
+ if data["type"] == ResponseFormatType.json_object:
402
+ return JsonObjectResponseFormat(**data)
@@ -40,4 +40,5 @@ def prepare_input_message_create(
40
40
  tool_call_id=None,
41
41
  otid=message.otid,
42
42
  sender_id=message.sender_id,
43
+ group_id=message.group_id,
43
44
  )
@@ -160,12 +160,12 @@ def execute_external_tool(
160
160
  else:
161
161
  agent_state_copy = None
162
162
 
163
- sandbox_run_result = ToolExecutionSandbox(function_name, function_args, actor).run(agent_state=agent_state_copy)
164
- function_response, updated_agent_state = sandbox_run_result.func_return, sandbox_run_result.agent_state
163
+ tool_execution_result = ToolExecutionSandbox(function_name, function_args, actor).run(agent_state=agent_state_copy)
164
+ function_response, updated_agent_state = tool_execution_result.func_return, tool_execution_result.agent_state
165
165
  # TODO: Bring this back
166
166
  # if allow_agent_state_modifications and updated_agent_state is not None:
167
167
  # self.update_memory_if_changed(updated_agent_state.memory)
168
- return function_response, sandbox_run_result
168
+ return function_response, tool_execution_result
169
169
  except Exception as e:
170
170
  # Need to catch error here, or else trunction wont happen
171
171
  # TODO: modify to function execution error
letta/orm/agent.py CHANGED
@@ -5,7 +5,7 @@ from sqlalchemy import JSON, Boolean, Index, String
5
5
  from sqlalchemy.orm import Mapped, mapped_column, relationship
6
6
 
7
7
  from letta.orm.block import Block
8
- from letta.orm.custom_columns import EmbeddingConfigColumn, LLMConfigColumn, ToolRulesColumn
8
+ from letta.orm.custom_columns import EmbeddingConfigColumn, LLMConfigColumn, ResponseFormatColumn, ToolRulesColumn
9
9
  from letta.orm.identity import Identity
10
10
  from letta.orm.mixins import OrganizationMixin
11
11
  from letta.orm.organization import Organization
@@ -15,6 +15,7 @@ from letta.schemas.agent import AgentType, get_prompt_template_for_agent_type
15
15
  from letta.schemas.embedding_config import EmbeddingConfig
16
16
  from letta.schemas.llm_config import LLMConfig
17
17
  from letta.schemas.memory import Memory
18
+ from letta.schemas.response_format import ResponseFormatUnion
18
19
  from letta.schemas.tool_rule import ToolRule
19
20
 
20
21
  if TYPE_CHECKING:
@@ -48,6 +49,11 @@ class Agent(SqlalchemyBase, OrganizationMixin):
48
49
  # This is dangerously flexible with the JSON type
49
50
  message_ids: Mapped[Optional[List[str]]] = mapped_column(JSON, nullable=True, doc="List of message IDs in in-context memory.")
50
51
 
52
+ # Response Format
53
+ response_format: Mapped[Optional[ResponseFormatUnion]] = mapped_column(
54
+ ResponseFormatColumn, nullable=True, doc="The response format for the agent."
55
+ )
56
+
51
57
  # Metadata and configs
52
58
  metadata_: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True, doc="metadata for the agent.")
53
59
  llm_config: Mapped[Optional[LLMConfig]] = mapped_column(
@@ -168,6 +174,7 @@ class Agent(SqlalchemyBase, OrganizationMixin):
168
174
  "multi_agent_group": None,
169
175
  "tool_exec_environment_variables": [],
170
176
  "enable_sleeptime": None,
177
+ "response_format": self.response_format,
171
178
  }
172
179
 
173
180
  # Optional fields: only included if requested
@@ -9,6 +9,7 @@ from letta.helpers.converters import (
9
9
  deserialize_llm_config,
10
10
  deserialize_message_content,
11
11
  deserialize_poll_batch_response,
12
+ deserialize_response_format,
12
13
  deserialize_tool_calls,
13
14
  deserialize_tool_returns,
14
15
  deserialize_tool_rules,
@@ -20,6 +21,7 @@ from letta.helpers.converters import (
20
21
  serialize_llm_config,
21
22
  serialize_message_content,
22
23
  serialize_poll_batch_response,
24
+ serialize_response_format,
23
25
  serialize_tool_calls,
24
26
  serialize_tool_returns,
25
27
  serialize_tool_rules,
@@ -168,3 +170,16 @@ class AgentStepStateColumn(TypeDecorator):
168
170
 
169
171
  def process_result_value(self, value, dialect):
170
172
  return deserialize_agent_step_state(value)
173
+
174
+
175
+ class ResponseFormatColumn(TypeDecorator):
176
+ """Custom SQLAlchemy column type for storing a list of ToolRules as JSON."""
177
+
178
+ impl = JSON
179
+ cache_ok = True
180
+
181
+ def process_bind_param(self, value, dialect):
182
+ return serialize_response_format(value)
183
+
184
+ def process_result_value(self, value, dialect):
185
+ return deserialize_response_format(value)
letta/schemas/agent.py CHANGED
@@ -14,6 +14,7 @@ from letta.schemas.llm_config import LLMConfig
14
14
  from letta.schemas.memory import Memory
15
15
  from letta.schemas.message import Message, MessageCreate
16
16
  from letta.schemas.openai.chat_completion_response import UsageStatistics
17
+ from letta.schemas.response_format import ResponseFormatUnion
17
18
  from letta.schemas.source import Source
18
19
  from letta.schemas.tool import Tool
19
20
  from letta.schemas.tool_rule import ToolRule
@@ -66,6 +67,9 @@ class AgentState(OrmMetadataBase, validate_assignment=True):
66
67
  # llm information
67
68
  llm_config: LLMConfig = Field(..., description="The LLM configuration used by the agent.")
68
69
  embedding_config: EmbeddingConfig = Field(..., description="The embedding configuration used by the agent.")
70
+ response_format: Optional[ResponseFormatUnion] = Field(
71
+ None, description="The response format used by the agent when returning from `send_message`."
72
+ )
69
73
 
70
74
  # This is an object representing the in-process state of a running `Agent`
71
75
  # Field in this object can be theoretically edited by tools, and will be persisted by the ORM
@@ -180,6 +184,7 @@ class CreateAgent(BaseModel, validate_assignment=True): #
180
184
  description="If set to True, the agent will not remember previous messages (though the agent will still retain state via core memory blocks and archival/recall memory). Not recommended unless you have an advanced use case.",
181
185
  )
182
186
  enable_sleeptime: Optional[bool] = Field(None, description="If set to True, memory management will move to a background agent thread.")
187
+ response_format: Optional[ResponseFormatUnion] = Field(None, description="The response format for the agent.")
183
188
 
184
189
  @field_validator("name")
185
190
  @classmethod
@@ -259,6 +264,7 @@ class UpdateAgent(BaseModel):
259
264
  None, description="The embedding configuration handle used by the agent, specified in the format provider/model-name."
260
265
  )
261
266
  enable_sleeptime: Optional[bool] = Field(None, description="If set to True, memory management will move to a background agent thread.")
267
+ response_format: Optional[ResponseFormatUnion] = Field(None, description="The response format for the agent.")
262
268
 
263
269
  class Config:
264
270
  extra = "ignore" # Ignores extra fields
letta/schemas/message.py CHANGED
@@ -82,6 +82,7 @@ class MessageCreate(BaseModel):
82
82
  name: Optional[str] = Field(None, description="The name of the participant.")
83
83
  otid: Optional[str] = Field(None, description="The offline threading id associated with this message")
84
84
  sender_id: Optional[str] = Field(None, description="The id of the sender of the message, can be an identity id or agent id")
85
+ group_id: Optional[str] = Field(None, description="The multi-agent group that the message was sent in")
85
86
 
86
87
  def model_dump(self, to_orm: bool = False, **kwargs) -> Dict[str, Any]:
87
88
  data = super().model_dump(**kwargs)
@@ -0,0 +1,78 @@
1
+ from enum import Enum
2
+ from typing import Annotated, Any, Dict, Literal, Union
3
+
4
+ from pydantic import BaseModel, Field, validator
5
+
6
+
7
+ class ResponseFormatType(str, Enum):
8
+ """Enum defining the possible response format types."""
9
+
10
+ text = "text"
11
+ json_schema = "json_schema"
12
+ json_object = "json_object"
13
+
14
+
15
+ class ResponseFormat(BaseModel):
16
+ """Base class for all response formats."""
17
+
18
+ type: ResponseFormatType = Field(
19
+ ...,
20
+ description="The type of the response format.",
21
+ # why use this?
22
+ example=ResponseFormatType.text,
23
+ )
24
+
25
+
26
+ # ---------------------
27
+ # Response Format Types
28
+ # ---------------------
29
+
30
+ # SQLAlchemy type for database mapping
31
+ ResponseFormatDict = Dict[str, Any]
32
+
33
+
34
+ class TextResponseFormat(ResponseFormat):
35
+ """Response format for plain text responses."""
36
+
37
+ type: Literal[ResponseFormatType.text] = Field(
38
+ ResponseFormatType.text,
39
+ description="The type of the response format.",
40
+ )
41
+
42
+
43
+ class JsonSchemaResponseFormat(ResponseFormat):
44
+ """Response format for JSON schema-based responses."""
45
+
46
+ type: Literal[ResponseFormatType.json_schema] = Field(
47
+ ResponseFormatType.json_schema,
48
+ description="The type of the response format.",
49
+ )
50
+ json_schema: Dict[str, Any] = Field(
51
+ ...,
52
+ description="The JSON schema of the response.",
53
+ )
54
+
55
+ @validator("json_schema")
56
+ def validate_json_schema(cls, v: Dict[str, Any]) -> Dict[str, Any]:
57
+ """Validate that the provided schema is a valid JSON schema."""
58
+ if not isinstance(v, dict):
59
+ raise ValueError("JSON schema must be a dictionary")
60
+ if "schema" not in v:
61
+ raise ValueError("JSON schema should include a $schema property")
62
+ return v
63
+
64
+
65
+ class JsonObjectResponseFormat(ResponseFormat):
66
+ """Response format for JSON object responses."""
67
+
68
+ type: Literal[ResponseFormatType.json_object] = Field(
69
+ ResponseFormatType.json_object,
70
+ description="The type of the response format.",
71
+ )
72
+
73
+
74
+ # Pydantic type for validation
75
+ ResponseFormatUnion = Annotated[
76
+ Union[TextResponseFormat | JsonSchemaResponseFormat | JsonObjectResponseFormat],
77
+ Field(discriminator="type"),
78
+ ]
@@ -0,0 +1,14 @@
1
+ from typing import Any, List, Literal, Optional
2
+
3
+ from pydantic import BaseModel, Field
4
+
5
+ from letta.schemas.agent import AgentState
6
+
7
+
8
+ class ToolExecutionResult(BaseModel):
9
+ status: Literal["success", "error"] = Field(..., description="The status of the tool execution and return object")
10
+ func_return: Optional[Any] = Field(None, description="The function return object")
11
+ agent_state: Optional[AgentState] = Field(None, description="The agent state")
12
+ stdout: Optional[List[str]] = Field(None, description="Captured stdout (prints, logs) from function invocation")
13
+ stderr: Optional[List[str]] = Field(None, description="Captured stderr from the function invocation")
14
+ sandbox_config_fingerprint: Optional[str] = Field(None, description="The fingerprint of the config for the sandbox")
@@ -1240,10 +1240,11 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
1240
1240
  and function_call.function.name == self.assistant_message_tool_name
1241
1241
  and self.assistant_message_tool_kwarg in func_args
1242
1242
  ):
1243
+ # Coerce content to `str` in cases where it's a JSON due to `response_format` being a JSON
1243
1244
  processed_chunk = AssistantMessage(
1244
1245
  id=msg_obj.id,
1245
1246
  date=msg_obj.created_at,
1246
- content=func_args[self.assistant_message_tool_kwarg],
1247
+ content=str(func_args[self.assistant_message_tool_kwarg]),
1247
1248
  name=msg_obj.name,
1248
1249
  otid=Message.generate_otid_from_id(msg_obj.id, chunk_index) if chunk_index is not None else None,
1249
1250
  )
@@ -111,7 +111,7 @@ async def send_message_to_agent_chat_completions(
111
111
  server.send_messages,
112
112
  actor=actor,
113
113
  agent_id=letta_agent.agent_state.id,
114
- messages=messages,
114
+ input_messages=messages,
115
115
  interface=streaming_interface,
116
116
  put_inner_thoughts_first=False,
117
117
  )
@@ -412,7 +412,7 @@ def list_blocks(
412
412
  """
413
413
  actor = server.user_manager.get_user_or_default(user_id=actor_id)
414
414
  try:
415
- agent = server.agent_manager.get_agent_by_id(agent_id, actor=actor)
415
+ agent = server.agent_manager.get_agent_by_id(agent_id, actor)
416
416
  return agent.memory.blocks
417
417
  except NoResultFound as e:
418
418
  raise HTTPException(status_code=404, detail=str(e))
@@ -640,7 +640,7 @@ async def send_message(
640
640
  result = await server.send_message_to_agent(
641
641
  agent_id=agent_id,
642
642
  actor=actor,
643
- messages=request.messages,
643
+ input_messages=request.messages,
644
644
  stream_steps=False,
645
645
  stream_tokens=False,
646
646
  # Support for AssistantMessage
@@ -703,7 +703,7 @@ async def send_message_streaming(
703
703
  result = await server.send_message_to_agent(
704
704
  agent_id=agent_id,
705
705
  actor=actor,
706
- messages=request.messages,
706
+ input_messages=request.messages,
707
707
  stream_steps=True,
708
708
  stream_tokens=request.stream_tokens,
709
709
  # Support for AssistantMessage
@@ -730,7 +730,7 @@ async def process_message_background(
730
730
  result = await server.send_message_to_agent(
731
731
  agent_id=agent_id,
732
732
  actor=actor,
733
- messages=messages,
733
+ input_messages=messages,
734
734
  stream_steps=False, # NOTE(matt)
735
735
  stream_tokens=False,
736
736
  use_assistant_message=use_assistant_message,
@@ -128,7 +128,7 @@ async def send_group_message(
128
128
  result = await server.send_group_message_to_agent(
129
129
  group_id=group_id,
130
130
  actor=actor,
131
- messages=request.messages,
131
+ input_messages=request.messages,
132
132
  stream_steps=False,
133
133
  stream_tokens=False,
134
134
  # Support for AssistantMessage
@@ -167,7 +167,7 @@ async def send_group_message_streaming(
167
167
  result = await server.send_group_message_to_agent(
168
168
  group_id=group_id,
169
169
  actor=actor,
170
- messages=request.messages,
170
+ input_messages=request.messages,
171
171
  stream_steps=True,
172
172
  stream_tokens=request.stream_tokens,
173
173
  # Support for AssistantMessage
@@ -7,7 +7,7 @@ from starlette.requests import Request
7
7
  from letta.agents.letta_agent_batch import LettaAgentBatch
8
8
  from letta.log import get_logger
9
9
  from letta.orm.errors import NoResultFound
10
- from letta.schemas.job import BatchJob, JobStatus, JobType
10
+ from letta.schemas.job import BatchJob, JobStatus, JobType, JobUpdate
11
11
  from letta.schemas.letta_request import CreateBatch
12
12
  from letta.server.rest_api.utils import get_letta_server
13
13
  from letta.server.server import SyncServer
@@ -43,18 +43,18 @@ async def create_messages_batch(
43
43
  if length > max_bytes:
44
44
  raise HTTPException(status_code=413, detail=f"Request too large ({length} bytes). Max is {max_bytes} bytes.")
45
45
 
46
+ actor = server.user_manager.get_user_or_default(user_id=actor_id)
47
+ batch_job = BatchJob(
48
+ user_id=actor.id,
49
+ status=JobStatus.running,
50
+ metadata={
51
+ "job_type": "batch_messages",
52
+ },
53
+ callback_url=str(payload.callback_url),
54
+ )
55
+
46
56
  try:
47
- actor = server.user_manager.get_user_or_default(user_id=actor_id)
48
-
49
- # Create a new job
50
- batch_job = BatchJob(
51
- user_id=actor.id,
52
- status=JobStatus.created,
53
- metadata={
54
- "job_type": "batch_messages",
55
- },
56
- callback_url=str(payload.callback_url),
57
- )
57
+ batch_job = server.job_manager.create_job(pydantic_job=batch_job, actor=actor)
58
58
 
59
59
  # create the batch runner
60
60
  batch_runner = LettaAgentBatch(
@@ -67,14 +67,17 @@ async def create_messages_batch(
67
67
  job_manager=server.job_manager,
68
68
  actor=actor,
69
69
  )
70
- llm_batch_job = await batch_runner.step_until_request(batch_requests=payload.requests, letta_batch_job_id=batch_job.id)
70
+ await batch_runner.step_until_request(batch_requests=payload.requests, letta_batch_job_id=batch_job.id)
71
71
 
72
72
  # TODO: update run metadata
73
- batch_job = server.job_manager.create_job(pydantic_job=batch_job, actor=actor)
74
- except Exception:
73
+ except Exception as e:
75
74
  import traceback
76
75
 
76
+ print("Error creating batch job", e)
77
77
  traceback.print_exc()
78
+
79
+ # mark job as failed
80
+ server.job_manager.update_job_by_id(job_id=batch_job.id, job=BatchJob(status=JobStatus.failed), actor=actor)
78
81
  raise
79
82
  return batch_job
80
83
 
@@ -125,8 +128,19 @@ async def cancel_batch_run(
125
128
 
126
129
  try:
127
130
  job = server.job_manager.get_job_by_id(job_id=batch_id, actor=actor)
128
- job.status = JobStatus.cancelled
129
- server.job_manager.update_job_by_id(job_id=job, job=job)
130
- # TODO: actually cancel it
131
+ job = server.job_manager.update_job_by_id(job_id=job.id, job_update=JobUpdate(status=JobStatus.cancelled), actor=actor)
132
+
133
+ # Get related llm batch jobs
134
+ llm_batch_jobs = server.batch_manager.list_llm_batch_jobs(letta_batch_id=job.id, actor=actor)
135
+ for llm_batch_job in llm_batch_jobs:
136
+ if llm_batch_job.status in {JobStatus.running, JobStatus.created}:
137
+ # TODO: Extend to providers beyond anthropic
138
+ # TODO: For now, we only support anthropic
139
+ # Cancel the job
140
+ anthropic_batch_id = llm_batch_job.create_batch_response.id
141
+ await server.anthropic_async_client.messages.batches.cancel(anthropic_batch_id)
142
+
143
+ # Update all the batch_job statuses
144
+ server.batch_manager.update_llm_batch_status(llm_batch_id=llm_batch_job.id, status=JobStatus.cancelled, actor=actor)
131
145
  except NoResultFound:
132
146
  raise HTTPException(status_code=404, detail="Run not found")