letta-nightly 0.6.48.dev20250406104033__py3-none-any.whl → 0.6.49.dev20250408030511__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of letta-nightly might be problematic. Click here for more details.

Files changed (87) hide show
  1. letta/__init__.py +1 -1
  2. letta/agent.py +47 -12
  3. letta/agents/base_agent.py +7 -4
  4. letta/agents/helpers.py +52 -0
  5. letta/agents/letta_agent.py +105 -42
  6. letta/agents/voice_agent.py +2 -2
  7. letta/constants.py +13 -1
  8. letta/errors.py +10 -3
  9. letta/functions/function_sets/base.py +65 -0
  10. letta/functions/interface.py +2 -2
  11. letta/functions/mcp_client/base_client.py +18 -1
  12. letta/{dynamic_multi_agent.py → groups/dynamic_multi_agent.py} +3 -0
  13. letta/groups/helpers.py +113 -0
  14. letta/{round_robin_multi_agent.py → groups/round_robin_multi_agent.py} +2 -0
  15. letta/groups/sleeptime_multi_agent.py +259 -0
  16. letta/{supervisor_multi_agent.py → groups/supervisor_multi_agent.py} +1 -0
  17. letta/helpers/converters.py +109 -7
  18. letta/helpers/message_helper.py +1 -0
  19. letta/helpers/tool_rule_solver.py +40 -23
  20. letta/interface.py +12 -5
  21. letta/interfaces/anthropic_streaming_interface.py +329 -0
  22. letta/llm_api/anthropic.py +12 -1
  23. letta/llm_api/anthropic_client.py +65 -14
  24. letta/llm_api/azure_openai.py +2 -2
  25. letta/llm_api/google_ai_client.py +13 -2
  26. letta/llm_api/google_constants.py +3 -0
  27. letta/llm_api/google_vertex_client.py +2 -2
  28. letta/llm_api/llm_api_tools.py +1 -1
  29. letta/llm_api/llm_client.py +7 -0
  30. letta/llm_api/llm_client_base.py +2 -7
  31. letta/llm_api/openai.py +7 -1
  32. letta/llm_api/openai_client.py +250 -0
  33. letta/orm/__init__.py +4 -0
  34. letta/orm/agent.py +6 -0
  35. letta/orm/block.py +32 -2
  36. letta/orm/block_history.py +46 -0
  37. letta/orm/custom_columns.py +60 -0
  38. letta/orm/enums.py +7 -0
  39. letta/orm/group.py +6 -0
  40. letta/orm/groups_blocks.py +13 -0
  41. letta/orm/llm_batch_items.py +55 -0
  42. letta/orm/llm_batch_job.py +48 -0
  43. letta/orm/message.py +7 -1
  44. letta/orm/organization.py +2 -0
  45. letta/orm/sqlalchemy_base.py +18 -15
  46. letta/prompts/system/memgpt_sleeptime_chat.txt +52 -0
  47. letta/prompts/system/sleeptime.txt +26 -0
  48. letta/schemas/agent.py +13 -1
  49. letta/schemas/enums.py +17 -2
  50. letta/schemas/group.py +14 -1
  51. letta/schemas/letta_message.py +5 -3
  52. letta/schemas/llm_batch_job.py +53 -0
  53. letta/schemas/llm_config.py +14 -4
  54. letta/schemas/message.py +44 -0
  55. letta/schemas/tool.py +3 -0
  56. letta/schemas/usage.py +1 -0
  57. letta/server/db.py +2 -0
  58. letta/server/rest_api/app.py +1 -1
  59. letta/server/rest_api/chat_completions_interface.py +8 -3
  60. letta/server/rest_api/interface.py +36 -7
  61. letta/server/rest_api/routers/v1/agents.py +53 -39
  62. letta/server/rest_api/routers/v1/runs.py +14 -2
  63. letta/server/rest_api/utils.py +15 -4
  64. letta/server/server.py +120 -71
  65. letta/services/agent_manager.py +70 -6
  66. letta/services/block_manager.py +190 -2
  67. letta/services/group_manager.py +68 -0
  68. letta/services/helpers/agent_manager_helper.py +6 -4
  69. letta/services/llm_batch_manager.py +139 -0
  70. letta/services/message_manager.py +17 -31
  71. letta/services/tool_executor/tool_execution_sandbox.py +1 -3
  72. letta/services/tool_executor/tool_executor.py +9 -20
  73. letta/services/tool_manager.py +14 -3
  74. letta/services/tool_sandbox/__init__.py +0 -0
  75. letta/services/tool_sandbox/base.py +188 -0
  76. letta/services/tool_sandbox/e2b_sandbox.py +116 -0
  77. letta/services/tool_sandbox/local_sandbox.py +221 -0
  78. letta/sleeptime_agent.py +61 -0
  79. letta/streaming_interface.py +20 -10
  80. letta/utils.py +4 -0
  81. {letta_nightly-0.6.48.dev20250406104033.dist-info → letta_nightly-0.6.49.dev20250408030511.dist-info}/METADATA +2 -2
  82. {letta_nightly-0.6.48.dev20250406104033.dist-info → letta_nightly-0.6.49.dev20250408030511.dist-info}/RECORD +85 -69
  83. letta/offline_memory_agent.py +0 -173
  84. letta/services/tool_executor/async_tool_execution_sandbox.py +0 -397
  85. {letta_nightly-0.6.48.dev20250406104033.dist-info → letta_nightly-0.6.49.dev20250408030511.dist-info}/LICENSE +0 -0
  86. {letta_nightly-0.6.48.dev20250406104033.dist-info → letta_nightly-0.6.49.dev20250408030511.dist-info}/WHEEL +0 -0
  87. {letta_nightly-0.6.48.dev20250406104033.dist-info → letta_nightly-0.6.49.dev20250408030511.dist-info}/entry_points.txt +0 -0
@@ -2,6 +2,7 @@ import asyncio
2
2
  from typing import List, Optional, Tuple
3
3
 
4
4
  from mcp import ClientSession
5
+ from mcp.types import TextContent
5
6
 
6
7
  from letta.functions.mcp_client.exceptions import MCPTimeoutError
7
8
  from letta.functions.mcp_client.types import BaseServerConfig, MCPTool
@@ -60,7 +61,23 @@ class BaseMCPClient:
60
61
  result = self.loop.run_until_complete(
61
62
  asyncio.wait_for(self.session.call_tool(tool_name, tool_args), timeout=tool_settings.mcp_execute_tool_timeout)
62
63
  )
63
- return str(result.content), result.isError
64
+
65
+ parsed_content = []
66
+ for content_piece in result.content:
67
+ if isinstance(content_piece, TextContent):
68
+ parsed_content.append(content_piece.text)
69
+ print("parsed_content (text)", parsed_content)
70
+ else:
71
+ parsed_content.append(str(content_piece))
72
+ print("parsed_content (other)", parsed_content)
73
+
74
+ if len(parsed_content) > 0:
75
+ final_content = " ".join(parsed_content)
76
+ else:
77
+ # TODO move hardcoding to constants
78
+ final_content = "Empty response from tool"
79
+
80
+ return final_content, result.isError
64
81
  except asyncio.TimeoutError:
65
82
  logger.error(
66
83
  f"Timed out while executing tool '{tool_name}' for MCP server {self.server_config.server_name} (timeout={tool_settings.mcp_execute_tool_timeout}s)."
@@ -99,6 +99,7 @@ class DynamicMultiAgent(Agent):
99
99
  tool_calls=None,
100
100
  tool_call_id=None,
101
101
  group_id=self.group_id,
102
+ otid=message.otid,
102
103
  )
103
104
  )
104
105
 
@@ -125,6 +126,7 @@ class DynamicMultiAgent(Agent):
125
126
  role="system",
126
127
  content=message.content,
127
128
  name=participant_agent.agent_state.name,
129
+ otid=message.otid,
128
130
  )
129
131
  for message in assistant_messages
130
132
  ]
@@ -271,4 +273,5 @@ class DynamicMultiAgent(Agent):
271
273
  tool_calls=None,
272
274
  tool_call_id=None,
273
275
  group_id=self.group_id,
276
+ otid=Message.generate_otid(),
274
277
  )
@@ -0,0 +1,113 @@
1
+ import json
2
+ from typing import Dict, Optional, Union
3
+
4
+ from letta.agent import Agent
5
+ from letta.functions.mcp_client.base_client import BaseMCPClient
6
+ from letta.interface import AgentInterface
7
+ from letta.orm.group import Group
8
+ from letta.orm.user import User
9
+ from letta.schemas.agent import AgentState
10
+ from letta.schemas.group import ManagerType
11
+ from letta.schemas.message import Message
12
+
13
+
14
+ def load_multi_agent(
15
+ group: Group,
16
+ agent_state: Optional[AgentState],
17
+ actor: User,
18
+ interface: Union[AgentInterface, None] = None,
19
+ mcp_clients: Optional[Dict[str, BaseMCPClient]] = None,
20
+ ) -> Agent:
21
+ if len(group.agent_ids) == 0:
22
+ raise ValueError("Empty group: group must have at least one agent")
23
+
24
+ if not agent_state:
25
+ raise ValueError("Empty manager agent state: manager agent state must be provided")
26
+
27
+ match group.manager_type:
28
+ case ManagerType.round_robin:
29
+ from letta.groups.round_robin_multi_agent import RoundRobinMultiAgent
30
+
31
+ return RoundRobinMultiAgent(
32
+ agent_state=agent_state,
33
+ interface=interface,
34
+ user=actor,
35
+ group_id=group.id,
36
+ agent_ids=group.agent_ids,
37
+ description=group.description,
38
+ max_turns=group.max_turns,
39
+ )
40
+ case ManagerType.dynamic:
41
+ from letta.groups.dynamic_multi_agent import DynamicMultiAgent
42
+
43
+ return DynamicMultiAgent(
44
+ agent_state=agent_state,
45
+ interface=interface,
46
+ user=actor,
47
+ group_id=group.id,
48
+ agent_ids=group.agent_ids,
49
+ description=group.description,
50
+ max_turns=group.max_turns,
51
+ termination_token=group.termination_token,
52
+ )
53
+ case ManagerType.supervisor:
54
+ from letta.groups.supervisor_multi_agent import SupervisorMultiAgent
55
+
56
+ return SupervisorMultiAgent(
57
+ agent_state=agent_state,
58
+ interface=interface,
59
+ user=actor,
60
+ group_id=group.id,
61
+ agent_ids=group.agent_ids,
62
+ description=group.description,
63
+ )
64
+ case ManagerType.sleeptime:
65
+ if not agent_state.enable_sleeptime:
66
+ return Agent(
67
+ agent_state=agent_state,
68
+ interface=interface,
69
+ user=actor,
70
+ mcp_clients=mcp_clients,
71
+ )
72
+
73
+ from letta.groups.sleeptime_multi_agent import SleeptimeMultiAgent
74
+
75
+ return SleeptimeMultiAgent(
76
+ agent_state=agent_state,
77
+ interface=interface,
78
+ user=actor,
79
+ group_id=group.id,
80
+ agent_ids=group.agent_ids,
81
+ description=group.description,
82
+ sleeptime_agent_frequency=group.sleeptime_agent_frequency,
83
+ )
84
+ case _:
85
+ raise ValueError(f"Type {group.manager_type} is not supported.")
86
+
87
+
88
+ def stringify_message(message: Message, use_assistant_name: bool = False) -> str | None:
89
+ assistant_name = message.name or "assistant" if use_assistant_name else "assistant"
90
+ if message.role == "user":
91
+ content = json.loads(message.content[0].text)
92
+ if content["type"] == "user_message":
93
+ return f"{message.name or 'user'}: {content['message']}"
94
+ else:
95
+ return None
96
+ elif message.role == "assistant":
97
+ messages = []
98
+ if message.tool_calls:
99
+ if message.tool_calls[0].function.name == "send_message":
100
+ messages.append(f"{assistant_name}: {json.loads(message.tool_calls[0].function.arguments)['message']}")
101
+ else:
102
+ messages.append(f"{assistant_name}: Calling tool {message.tool_calls[0].function.name}")
103
+ return "\n".join(messages)
104
+ elif message.role == "tool":
105
+ if message.content:
106
+ content = json.loads(message.content[0].text)
107
+ if content["message"] != "None" and content["message"] != None:
108
+ return f"{assistant_name}: Tool call returned {content['message']}"
109
+ return None
110
+ elif message.role == "system":
111
+ return None
112
+ else:
113
+ return f"{message.name or 'user'}: {message.content[0].text}"
@@ -69,6 +69,7 @@ class RoundRobinMultiAgent(Agent):
69
69
  tool_calls=None,
70
70
  tool_call_id=None,
71
71
  group_id=self.group_id,
72
+ otid=message.otid,
72
73
  )
73
74
  )
74
75
 
@@ -92,6 +93,7 @@ class RoundRobinMultiAgent(Agent):
92
93
  role="system",
93
94
  content=message.content,
94
95
  name=message.name,
96
+ otid=message.otid,
95
97
  )
96
98
  for message in assistant_messages
97
99
  ]
@@ -0,0 +1,259 @@
1
+ import asyncio
2
+ import threading
3
+ from datetime import datetime
4
+ from typing import List, Optional
5
+
6
+ from letta.agent import Agent, AgentState
7
+ from letta.groups.helpers import stringify_message
8
+ from letta.interface import AgentInterface
9
+ from letta.orm import User
10
+ from letta.schemas.enums import JobStatus
11
+ from letta.schemas.job import JobUpdate
12
+ from letta.schemas.letta_message_content import TextContent
13
+ from letta.schemas.message import Message, MessageCreate
14
+ from letta.schemas.run import Run
15
+ from letta.schemas.usage import LettaUsageStatistics
16
+ from letta.server.rest_api.interface import StreamingServerInterface
17
+ from letta.services.group_manager import GroupManager
18
+ from letta.services.job_manager import JobManager
19
+ from letta.services.message_manager import MessageManager
20
+
21
+
22
+ class SleeptimeMultiAgent(Agent):
23
+
24
+ def __init__(
25
+ self,
26
+ interface: AgentInterface,
27
+ agent_state: AgentState,
28
+ user: User,
29
+ # custom
30
+ group_id: str = "",
31
+ agent_ids: List[str] = [],
32
+ description: str = "",
33
+ sleeptime_agent_frequency: Optional[int] = None,
34
+ ):
35
+ super().__init__(interface, agent_state, user)
36
+ self.group_id = group_id
37
+ self.agent_ids = agent_ids
38
+ self.description = description
39
+ self.sleeptime_agent_frequency = sleeptime_agent_frequency
40
+ self.group_manager = GroupManager()
41
+ self.message_manager = MessageManager()
42
+ self.job_manager = JobManager()
43
+
44
+ def _run_async_in_new_thread(self, coro):
45
+ """Run an async coroutine in a new thread with its own event loop"""
46
+ result = None
47
+
48
+ def run_async():
49
+ nonlocal result
50
+ loop = asyncio.new_event_loop()
51
+ asyncio.set_event_loop(loop)
52
+ try:
53
+ result = loop.run_until_complete(coro)
54
+ finally:
55
+ loop.close()
56
+ asyncio.set_event_loop(None)
57
+
58
+ thread = threading.Thread(target=run_async)
59
+ thread.start()
60
+ thread.join()
61
+ return result
62
+
63
+ async def _issue_background_task(
64
+ self,
65
+ participant_agent_id: str,
66
+ messages: List[Message],
67
+ chaining: bool,
68
+ max_chaining_steps: Optional[int],
69
+ token_streaming: bool,
70
+ metadata: Optional[dict],
71
+ put_inner_thoughts_first: bool,
72
+ last_processed_message_id: str,
73
+ ) -> str:
74
+ run = Run(
75
+ user_id=self.user.id,
76
+ status=JobStatus.created,
77
+ metadata={
78
+ "job_type": "background_agent_send_message_async",
79
+ "agent_id": participant_agent_id,
80
+ },
81
+ )
82
+ run = self.job_manager.create_job(pydantic_job=run, actor=self.user)
83
+
84
+ asyncio.create_task(
85
+ self._perform_background_agent_step(
86
+ participant_agent_id=participant_agent_id,
87
+ messages=messages,
88
+ chaining=chaining,
89
+ max_chaining_steps=max_chaining_steps,
90
+ token_streaming=token_streaming,
91
+ metadata=metadata,
92
+ put_inner_thoughts_first=put_inner_thoughts_first,
93
+ last_processed_message_id=last_processed_message_id,
94
+ run_id=run.id,
95
+ )
96
+ )
97
+
98
+ return run.id
99
+
100
+ async def _perform_background_agent_step(
101
+ self,
102
+ participant_agent_id: str,
103
+ messages: List[Message],
104
+ chaining: bool,
105
+ max_chaining_steps: Optional[int],
106
+ token_streaming: bool,
107
+ metadata: Optional[dict],
108
+ put_inner_thoughts_first: bool,
109
+ last_processed_message_id: str,
110
+ run_id: str,
111
+ ) -> LettaUsageStatistics:
112
+ try:
113
+ participant_agent_state = self.agent_manager.get_agent_by_id(participant_agent_id, actor=self.user)
114
+ participant_agent = Agent(
115
+ agent_state=participant_agent_state,
116
+ interface=StreamingServerInterface(),
117
+ user=self.user,
118
+ )
119
+
120
+ prior_messages = []
121
+ if self.sleeptime_agent_frequency:
122
+ try:
123
+ prior_messages = self.message_manager.list_messages_for_agent(
124
+ agent_id=self.agent_state.id,
125
+ actor=self.user,
126
+ after=last_processed_message_id,
127
+ before=messages[0].id,
128
+ )
129
+ except Exception as e:
130
+ print(f"Error fetching prior messages: {str(e)}")
131
+ # continue with just latest messages
132
+
133
+ transcript_summary = [stringify_message(message) for message in prior_messages + messages]
134
+ transcript_summary = [summary for summary in transcript_summary if summary is not None]
135
+ message_text = "\n".join(transcript_summary)
136
+
137
+ participant_agent_messages = [
138
+ Message(
139
+ id=Message.generate_id(),
140
+ agent_id=participant_agent.agent_state.id,
141
+ role="user",
142
+ content=[TextContent(text=message_text)],
143
+ group_id=self.group_id,
144
+ )
145
+ ]
146
+ result = participant_agent.step(
147
+ messages=participant_agent_messages,
148
+ chaining=chaining,
149
+ max_chaining_steps=max_chaining_steps,
150
+ stream=token_streaming,
151
+ skip_verify=True,
152
+ metadata=metadata,
153
+ put_inner_thoughts_first=put_inner_thoughts_first,
154
+ )
155
+ job_update = JobUpdate(
156
+ status=JobStatus.completed,
157
+ completed_at=datetime.utcnow(),
158
+ metadata={
159
+ "result": result.model_dump(mode="json"),
160
+ "agent_id": participant_agent.agent_state.id,
161
+ },
162
+ )
163
+ self.job_manager.update_job_by_id(job_id=run_id, job_update=job_update, actor=self.user)
164
+ return result
165
+ except Exception as e:
166
+ job_update = JobUpdate(
167
+ status=JobStatus.failed,
168
+ completed_at=datetime.utcnow(),
169
+ metadata={"error": str(e)},
170
+ )
171
+ self.job_manager.update_job_by_id(job_id=run_id, job_update=job_update, actor=self.user)
172
+ raise
173
+
174
+ def step(
175
+ self,
176
+ messages: List[MessageCreate],
177
+ chaining: bool = True,
178
+ max_chaining_steps: Optional[int] = None,
179
+ put_inner_thoughts_first: bool = True,
180
+ **kwargs,
181
+ ) -> LettaUsageStatistics:
182
+ run_ids = []
183
+
184
+ token_streaming = self.interface.streaming_mode if hasattr(self.interface, "streaming_mode") else False
185
+ metadata = self.interface.metadata if hasattr(self.interface, "metadata") else None
186
+
187
+ messages = [
188
+ Message(
189
+ id=Message.generate_id(),
190
+ agent_id=self.agent_state.id,
191
+ role=message.role,
192
+ content=[TextContent(text=message.content)] if isinstance(message.content, str) else message.content,
193
+ name=message.name,
194
+ model=None,
195
+ tool_calls=None,
196
+ tool_call_id=None,
197
+ group_id=self.group_id,
198
+ otid=message.otid,
199
+ )
200
+ for message in messages
201
+ ]
202
+
203
+ try:
204
+ main_agent = Agent(
205
+ agent_state=self.agent_state,
206
+ interface=self.interface,
207
+ user=self.user,
208
+ )
209
+ usage_stats = main_agent.step(
210
+ messages=messages,
211
+ chaining=chaining,
212
+ max_chaining_steps=max_chaining_steps,
213
+ stream=token_streaming,
214
+ skip_verify=True,
215
+ metadata=metadata,
216
+ put_inner_thoughts_first=put_inner_thoughts_first,
217
+ )
218
+
219
+ turns_counter = None
220
+ if self.sleeptime_agent_frequency is not None and self.sleeptime_agent_frequency > 0:
221
+ turns_counter = self.group_manager.bump_turns_counter(group_id=self.group_id, actor=self.user)
222
+
223
+ if self.sleeptime_agent_frequency is None or (
224
+ turns_counter is not None and turns_counter % self.sleeptime_agent_frequency == 0
225
+ ):
226
+ last_response_messages = [message for sublist in usage_stats.steps_messages for message in sublist]
227
+ last_processed_message_id = self.group_manager.get_last_processed_message_id_and_update(
228
+ group_id=self.group_id, last_processed_message_id=last_response_messages[-1].id, actor=self.user
229
+ )
230
+ for participant_agent_id in self.agent_ids:
231
+ try:
232
+ run_id = self._run_async_in_new_thread(
233
+ self._issue_background_task(
234
+ participant_agent_id,
235
+ last_response_messages,
236
+ chaining,
237
+ max_chaining_steps,
238
+ token_streaming,
239
+ metadata,
240
+ put_inner_thoughts_first,
241
+ last_processed_message_id,
242
+ )
243
+ )
244
+ run_ids.append(run_id)
245
+
246
+ except Exception as e:
247
+ # Handle individual task failures
248
+ print(f"Agent processing failed: {str(e)}")
249
+ raise e
250
+
251
+ except Exception as e:
252
+ raise e
253
+ finally:
254
+ self.interface.step_yield()
255
+
256
+ self.interface.step_complete()
257
+
258
+ usage_stats.run_ids = run_ids
259
+ return LettaUsageStatistics(**usage_stats.model_dump())
@@ -89,6 +89,7 @@ class SupervisorMultiAgent(Agent):
89
89
  tool_calls=None,
90
90
  tool_call_id=None,
91
91
  group_id=self.group_id,
92
+ otid=message.otid,
92
93
  )
93
94
  for message in messages
94
95
  ]
@@ -2,12 +2,14 @@ import base64
2
2
  from typing import Any, Dict, List, Optional, Union
3
3
 
4
4
  import numpy as np
5
+ from anthropic.types.beta.messages import BetaMessageBatch, BetaMessageBatchIndividualResponse
5
6
  from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall
6
7
  from openai.types.chat.chat_completion_message_tool_call import Function as OpenAIFunction
7
8
  from sqlalchemy import Dialect
8
9
 
10
+ from letta.schemas.agent import AgentStepState
9
11
  from letta.schemas.embedding_config import EmbeddingConfig
10
- from letta.schemas.enums import ToolRuleType
12
+ from letta.schemas.enums import ProviderType, ToolRuleType
11
13
  from letta.schemas.letta_message_content import (
12
14
  MessageContent,
13
15
  MessageContentType,
@@ -38,7 +40,7 @@ from letta.schemas.tool_rule import (
38
40
  def serialize_llm_config(config: Union[Optional[LLMConfig], Dict]) -> Optional[Dict]:
39
41
  """Convert an LLMConfig object into a JSON-serializable dictionary."""
40
42
  if config and isinstance(config, LLMConfig):
41
- return config.model_dump()
43
+ return config.model_dump(mode="json")
42
44
  return config
43
45
 
44
46
 
@@ -55,7 +57,7 @@ def deserialize_llm_config(data: Optional[Dict]) -> Optional[LLMConfig]:
55
57
  def serialize_embedding_config(config: Union[Optional[EmbeddingConfig], Dict]) -> Optional[Dict]:
56
58
  """Convert an EmbeddingConfig object into a JSON-serializable dictionary."""
57
59
  if config and isinstance(config, EmbeddingConfig):
58
- return config.model_dump()
60
+ return config.model_dump(mode="json")
59
61
  return config
60
62
 
61
63
 
@@ -75,7 +77,9 @@ def serialize_tool_rules(tool_rules: Optional[List[ToolRule]]) -> List[Dict[str,
75
77
  if not tool_rules:
76
78
  return []
77
79
 
78
- data = [{**rule.model_dump(), "type": rule.type.value} for rule in tool_rules] # Convert Enum to string for JSON compatibility
80
+ data = [
81
+ {**rule.model_dump(mode="json"), "type": rule.type.value} for rule in tool_rules
82
+ ] # Convert Enum to string for JSON compatibility
79
83
 
80
84
  # Validate ToolRule structure
81
85
  for rule_data in data:
@@ -130,7 +134,7 @@ def serialize_tool_calls(tool_calls: Optional[List[Union[OpenAIToolCall, dict]]]
130
134
  serialized_calls = []
131
135
  for call in tool_calls:
132
136
  if isinstance(call, OpenAIToolCall):
133
- serialized_calls.append(call.model_dump())
137
+ serialized_calls.append(call.model_dump(mode="json"))
134
138
  elif isinstance(call, dict):
135
139
  serialized_calls.append(call) # Already a dictionary, leave it as-is
136
140
  else:
@@ -166,7 +170,7 @@ def serialize_tool_returns(tool_returns: Optional[List[Union[ToolReturn, dict]]]
166
170
  serialized_tool_returns = []
167
171
  for tool_return in tool_returns:
168
172
  if isinstance(tool_return, ToolReturn):
169
- serialized_tool_returns.append(tool_return.model_dump())
173
+ serialized_tool_returns.append(tool_return.model_dump(mode="json"))
170
174
  elif isinstance(tool_return, dict):
171
175
  serialized_tool_returns.append(tool_return) # Already a dictionary, leave it as-is
172
176
  else:
@@ -201,7 +205,7 @@ def serialize_message_content(message_content: Optional[List[Union[MessageConten
201
205
  serialized_message_content = []
202
206
  for content in message_content:
203
207
  if isinstance(content, MessageContent):
204
- serialized_message_content.append(content.model_dump())
208
+ serialized_message_content.append(content.model_dump(mode="json"))
205
209
  elif isinstance(content, dict):
206
210
  serialized_message_content.append(content) # Already a dictionary, leave it as-is
207
211
  else:
@@ -266,3 +270,101 @@ def deserialize_vector(data: Optional[bytes], dialect: Dialect) -> Optional[np.n
266
270
  data = base64.b64decode(data)
267
271
 
268
272
  return np.frombuffer(data, dtype=np.float32)
273
+
274
+
275
+ # --------------------------
276
+ # Batch Request Serialization
277
+ # --------------------------
278
+
279
+
280
+ def serialize_create_batch_response(create_batch_response: Union[BetaMessageBatch]) -> Dict[str, Any]:
281
+ """Convert a list of ToolRules into a JSON-serializable format."""
282
+ llm_provider_type = None
283
+ if isinstance(create_batch_response, BetaMessageBatch):
284
+ llm_provider_type = ProviderType.anthropic.value
285
+
286
+ if not llm_provider_type:
287
+ raise ValueError(f"Could not determine llm provider from create batch response object type: {create_batch_response}")
288
+
289
+ return {"data": create_batch_response.model_dump(mode="json"), "type": llm_provider_type}
290
+
291
+
292
+ def deserialize_create_batch_response(data: Dict) -> Union[BetaMessageBatch]:
293
+ provider_type = ProviderType(data.get("type"))
294
+
295
+ if provider_type == ProviderType.anthropic:
296
+ return BetaMessageBatch(**data.get("data"))
297
+
298
+ raise ValueError(f"Unknown ProviderType type: {provider_type}")
299
+
300
+
301
+ # TODO: Note that this is the same as above for Anthropic, but this is not the case for all providers
302
+ # TODO: Some have different types based on the create v.s. poll requests
303
+ def serialize_poll_batch_response(poll_batch_response: Optional[Union[BetaMessageBatch]]) -> Optional[Dict[str, Any]]:
304
+ """Convert a list of ToolRules into a JSON-serializable format."""
305
+ if not poll_batch_response:
306
+ return None
307
+
308
+ llm_provider_type = None
309
+ if isinstance(poll_batch_response, BetaMessageBatch):
310
+ llm_provider_type = ProviderType.anthropic.value
311
+
312
+ if not llm_provider_type:
313
+ raise ValueError(f"Could not determine llm provider from poll batch response object type: {poll_batch_response}")
314
+
315
+ return {"data": poll_batch_response.model_dump(mode="json"), "type": llm_provider_type}
316
+
317
+
318
+ def deserialize_poll_batch_response(data: Optional[Dict]) -> Optional[Union[BetaMessageBatch]]:
319
+ if not data:
320
+ return None
321
+
322
+ provider_type = ProviderType(data.get("type"))
323
+
324
+ if provider_type == ProviderType.anthropic:
325
+ return BetaMessageBatch(**data.get("data"))
326
+
327
+ raise ValueError(f"Unknown ProviderType type: {provider_type}")
328
+
329
+
330
+ def serialize_batch_request_result(
331
+ batch_individual_response: Optional[Union[BetaMessageBatchIndividualResponse]],
332
+ ) -> Optional[Dict[str, Any]]:
333
+ """Convert a list of ToolRules into a JSON-serializable format."""
334
+ if not batch_individual_response:
335
+ return None
336
+
337
+ llm_provider_type = None
338
+ if isinstance(batch_individual_response, BetaMessageBatchIndividualResponse):
339
+ llm_provider_type = ProviderType.anthropic.value
340
+
341
+ if not llm_provider_type:
342
+ raise ValueError(f"Could not determine llm provider from batch result object type: {batch_individual_response}")
343
+
344
+ return {"data": batch_individual_response.model_dump(mode="json"), "type": llm_provider_type}
345
+
346
+
347
+ def deserialize_batch_request_result(data: Optional[Dict]) -> Optional[Union[BetaMessageBatchIndividualResponse]]:
348
+ if not data:
349
+ return None
350
+ provider_type = ProviderType(data.get("type"))
351
+
352
+ if provider_type == ProviderType.anthropic:
353
+ return BetaMessageBatchIndividualResponse(**data.get("data"))
354
+
355
+ raise ValueError(f"Unknown ProviderType type: {provider_type}")
356
+
357
+
358
+ def serialize_agent_step_state(agent_step_state: Optional[AgentStepState]) -> Optional[Dict[str, Any]]:
359
+ """Convert a list of ToolRules into a JSON-serializable format."""
360
+ if not agent_step_state:
361
+ return None
362
+
363
+ return agent_step_state.model_dump(mode="json")
364
+
365
+
366
+ def deserialize_agent_step_state(data: Optional[Dict]) -> Optional[AgentStepState]:
367
+ if not data:
368
+ return None
369
+
370
+ return AgentStepState(**data)
@@ -38,4 +38,5 @@ def prepare_input_message_create(
38
38
  model=None, # assigned later?
39
39
  tool_calls=None, # irrelevant
40
40
  tool_call_id=None,
41
+ otid=message.otid,
41
42
  )