jl-ecms-client 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jl_ecms_client-0.2.0.dist-info/METADATA +295 -0
- jl_ecms_client-0.2.0.dist-info/RECORD +51 -0
- jl_ecms_client-0.2.0.dist-info/WHEEL +5 -0
- jl_ecms_client-0.2.0.dist-info/licenses/LICENSE +190 -0
- jl_ecms_client-0.2.0.dist-info/top_level.txt +1 -0
- mirix/client/__init__.py +72 -0
- mirix/client/client.py +2594 -0
- mirix/client/remote_client.py +1136 -0
- mirix/helpers/__init__.py +1 -0
- mirix/helpers/converters.py +429 -0
- mirix/helpers/datetime_helpers.py +90 -0
- mirix/helpers/json_helpers.py +47 -0
- mirix/helpers/message_helpers.py +74 -0
- mirix/helpers/tool_rule_solver.py +166 -0
- mirix/schemas/__init__.py +1 -0
- mirix/schemas/agent.py +400 -0
- mirix/schemas/block.py +188 -0
- mirix/schemas/cloud_file_mapping.py +29 -0
- mirix/schemas/embedding_config.py +114 -0
- mirix/schemas/enums.py +69 -0
- mirix/schemas/environment_variables.py +82 -0
- mirix/schemas/episodic_memory.py +170 -0
- mirix/schemas/file.py +57 -0
- mirix/schemas/health.py +10 -0
- mirix/schemas/knowledge_vault.py +181 -0
- mirix/schemas/llm_config.py +187 -0
- mirix/schemas/memory.py +318 -0
- mirix/schemas/message.py +1315 -0
- mirix/schemas/mirix_base.py +107 -0
- mirix/schemas/mirix_message.py +411 -0
- mirix/schemas/mirix_message_content.py +230 -0
- mirix/schemas/mirix_request.py +39 -0
- mirix/schemas/mirix_response.py +183 -0
- mirix/schemas/openai/__init__.py +1 -0
- mirix/schemas/openai/chat_completion_request.py +122 -0
- mirix/schemas/openai/chat_completion_response.py +144 -0
- mirix/schemas/openai/chat_completions.py +127 -0
- mirix/schemas/openai/embedding_response.py +11 -0
- mirix/schemas/openai/openai.py +229 -0
- mirix/schemas/organization.py +38 -0
- mirix/schemas/procedural_memory.py +151 -0
- mirix/schemas/providers.py +816 -0
- mirix/schemas/resource_memory.py +134 -0
- mirix/schemas/sandbox_config.py +132 -0
- mirix/schemas/semantic_memory.py +162 -0
- mirix/schemas/source.py +96 -0
- mirix/schemas/step.py +53 -0
- mirix/schemas/tool.py +241 -0
- mirix/schemas/tool_rule.py +209 -0
- mirix/schemas/usage.py +31 -0
- mirix/schemas/user.py +67 -0
mirix/schemas/message.py
ADDED
|
@@ -0,0 +1,1315 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import copy
|
|
4
|
+
import json
|
|
5
|
+
import uuid
|
|
6
|
+
import warnings
|
|
7
|
+
from collections import OrderedDict
|
|
8
|
+
from datetime import datetime, timezone
|
|
9
|
+
from typing import Any, Dict, List, Literal, Optional, Union
|
|
10
|
+
|
|
11
|
+
from pydantic import BaseModel, Field, field_validator
|
|
12
|
+
|
|
13
|
+
from mirix.constants import (
|
|
14
|
+
DEFAULT_MESSAGE_TOOL,
|
|
15
|
+
DEFAULT_MESSAGE_TOOL_KWARG,
|
|
16
|
+
TOOL_CALL_ID_MAX_LEN,
|
|
17
|
+
)
|
|
18
|
+
from mirix.helpers.datetime_helpers import get_utc_time, is_utc_datetime
|
|
19
|
+
from mirix.helpers.json_helpers import json_dumps
|
|
20
|
+
from mirix.schemas.enums import MessageRole
|
|
21
|
+
from mirix.schemas.mirix_base import OrmMetadataBase
|
|
22
|
+
from mirix.schemas.mirix_message import (
|
|
23
|
+
AssistantMessage,
|
|
24
|
+
HiddenReasoningMessage,
|
|
25
|
+
MirixMessage,
|
|
26
|
+
ReasoningMessage,
|
|
27
|
+
SystemMessage,
|
|
28
|
+
ToolCall,
|
|
29
|
+
ToolCallMessage,
|
|
30
|
+
ToolReturnMessage,
|
|
31
|
+
UserMessage,
|
|
32
|
+
)
|
|
33
|
+
from mirix.schemas.mirix_message_content import (
|
|
34
|
+
CloudFileContent,
|
|
35
|
+
FileContent,
|
|
36
|
+
ImageContent,
|
|
37
|
+
MirixMessageContentUnion,
|
|
38
|
+
ReasoningContent,
|
|
39
|
+
RedactedReasoningContent,
|
|
40
|
+
TextContent,
|
|
41
|
+
get_mirix_message_content_union_str_json_schema,
|
|
42
|
+
)
|
|
43
|
+
from mirix.schemas.openai.openai import Function as OpenAIFunction
|
|
44
|
+
from mirix.schemas.openai.openai import ToolCall as OpenAIToolCall
|
|
45
|
+
from mirix.system import unpack_message
|
|
46
|
+
from mirix.helpers.json_helpers import parse_json
|
|
47
|
+
|
|
48
|
+
class BaseMessage(OrmMetadataBase):
|
|
49
|
+
__id_prefix__ = "message"
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class MessageCreate(BaseModel):
|
|
53
|
+
"""Request to create a message"""
|
|
54
|
+
|
|
55
|
+
# In the simplified format, only allow simple roles
|
|
56
|
+
role: Literal[
|
|
57
|
+
MessageRole.user,
|
|
58
|
+
MessageRole.system,
|
|
59
|
+
] = Field(..., description="The role of the participant.")
|
|
60
|
+
content: Union[str, List[MirixMessageContentUnion]] = Field(
|
|
61
|
+
...,
|
|
62
|
+
description="The content of the message.",
|
|
63
|
+
json_schema_extra=get_mirix_message_content_union_str_json_schema(),
|
|
64
|
+
)
|
|
65
|
+
name: Optional[str] = Field(None, description="The name of the participant.")
|
|
66
|
+
otid: Optional[str] = Field(
|
|
67
|
+
None, description="The offline threading id associated with this message"
|
|
68
|
+
)
|
|
69
|
+
sender_id: Optional[str] = Field(
|
|
70
|
+
None,
|
|
71
|
+
description="The id of the sender of the message, can be an identity id or agent id",
|
|
72
|
+
)
|
|
73
|
+
group_id: Optional[str] = Field(
|
|
74
|
+
None, description="The multi-agent group that the message was sent in"
|
|
75
|
+
)
|
|
76
|
+
filter_tags: Optional[Dict[str, Any]] = Field(
|
|
77
|
+
None,
|
|
78
|
+
description="Optional tags for filtering and categorizing this message and related memories"
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
def model_dump(self, to_orm: bool = False, **kwargs) -> Dict[str, Any]:
|
|
82
|
+
data = super().model_dump(**kwargs)
|
|
83
|
+
if to_orm and "content" in data:
|
|
84
|
+
if isinstance(data["content"], str):
|
|
85
|
+
data["content"] = [TextContent(text=data["content"])]
|
|
86
|
+
return data
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class MessageUpdate(BaseModel):
|
|
90
|
+
"""Request to update a message"""
|
|
91
|
+
|
|
92
|
+
role: Optional[MessageRole] = Field(
|
|
93
|
+
None, description="The role of the participant."
|
|
94
|
+
)
|
|
95
|
+
content: Optional[Union[str, List[MirixMessageContentUnion]]] = Field(
|
|
96
|
+
None,
|
|
97
|
+
description="The content of the message.",
|
|
98
|
+
json_schema_extra=get_mirix_message_content_union_str_json_schema(),
|
|
99
|
+
)
|
|
100
|
+
# NOTE: probably doesn't make sense to allow remapping user_id or agent_id (vs creating a new message)
|
|
101
|
+
# user_id: Optional[str] = Field(None, description="The unique identifier of the user.")
|
|
102
|
+
# agent_id: Optional[str] = Field(None, description="The unique identifier of the agent.")
|
|
103
|
+
# NOTE: we probably shouldn't allow updating the model field, otherwise this loses meaning
|
|
104
|
+
# model: Optional[str] = Field(None, description="The model used to make the function call.")
|
|
105
|
+
name: Optional[str] = Field(None, description="The name of the participant.")
|
|
106
|
+
# NOTE: we probably shouldn't allow updating the created_at field, right?
|
|
107
|
+
# created_at: Optional[datetime] = Field(None, description="The time the message was created.")
|
|
108
|
+
tool_calls: Optional[List[OpenAIToolCall,]] = Field(
|
|
109
|
+
None, description="The list of tool calls requested."
|
|
110
|
+
)
|
|
111
|
+
tool_call_id: Optional[str] = Field(None, description="The id of the tool call.")
|
|
112
|
+
|
|
113
|
+
def model_dump(self, to_orm: bool = False, **kwargs) -> Dict[str, Any]:
|
|
114
|
+
data = super().model_dump(**kwargs)
|
|
115
|
+
if to_orm and "content" in data:
|
|
116
|
+
if isinstance(data["content"], str):
|
|
117
|
+
data["content"] = [TextContent(text=data["content"])]
|
|
118
|
+
return data
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class Message(BaseMessage):
|
|
122
|
+
"""
|
|
123
|
+
Mirix's internal representation of a message. Includes methods to convert to/from LLM provider formats.
|
|
124
|
+
|
|
125
|
+
Attributes:
|
|
126
|
+
id (str): The unique identifier of the message.
|
|
127
|
+
role (MessageRole): The role of the participant.
|
|
128
|
+
text (str): The text of the message.
|
|
129
|
+
user_id (str): The unique identifier of the user.
|
|
130
|
+
agent_id (str): The unique identifier of the agent.
|
|
131
|
+
model (str): The model used to make the function call.
|
|
132
|
+
name (str): The name of the participant.
|
|
133
|
+
created_at (datetime): The time the message was created.
|
|
134
|
+
tool_calls (List[OpenAIToolCall,]): The list of tool calls requested.
|
|
135
|
+
tool_call_id (str): The id of the tool call.
|
|
136
|
+
|
|
137
|
+
"""
|
|
138
|
+
|
|
139
|
+
id: str = BaseMessage.generate_id_field()
|
|
140
|
+
organization_id: Optional[str] = Field(
|
|
141
|
+
None, description="The unique identifier of the organization."
|
|
142
|
+
)
|
|
143
|
+
user_id: Optional[str] = Field(
|
|
144
|
+
None, description="The unique identifier of the user."
|
|
145
|
+
)
|
|
146
|
+
agent_id: Optional[str] = Field(
|
|
147
|
+
None, description="The unique identifier of the agent."
|
|
148
|
+
)
|
|
149
|
+
model: Optional[str] = Field(
|
|
150
|
+
None, description="The model used to make the function call."
|
|
151
|
+
)
|
|
152
|
+
# Basic OpenAI-style fields
|
|
153
|
+
role: MessageRole = Field(..., description="The role of the participant.")
|
|
154
|
+
content: Optional[List[MirixMessageContentUnion]] = Field(
|
|
155
|
+
None, description="The content of the message."
|
|
156
|
+
)
|
|
157
|
+
# NOTE: in OpenAI, this field is only used for roles 'user', 'assistant', and 'function' (now deprecated). 'tool' does not use it.
|
|
158
|
+
name: Optional[str] = Field(
|
|
159
|
+
None,
|
|
160
|
+
description="For role user/assistant: the (optional) name of the participant. For role tool/function: the name of the function called.",
|
|
161
|
+
)
|
|
162
|
+
tool_calls: Optional[List[OpenAIToolCall]] = Field(
|
|
163
|
+
None,
|
|
164
|
+
description="The list of tool calls requested. Only applicable for role assistant.",
|
|
165
|
+
)
|
|
166
|
+
tool_call_id: Optional[str] = Field(
|
|
167
|
+
None, description="The ID of the tool call. Only applicable for role tool."
|
|
168
|
+
)
|
|
169
|
+
# Extras
|
|
170
|
+
step_id: Optional[str] = Field(
|
|
171
|
+
None, description="The id of the step that this message was created in."
|
|
172
|
+
)
|
|
173
|
+
otid: Optional[str] = Field(
|
|
174
|
+
None, description="The offline threading id associated with this message"
|
|
175
|
+
)
|
|
176
|
+
tool_returns: Optional[List[ToolReturn]] = Field(
|
|
177
|
+
None, description="Tool execution return information for prior tool calls"
|
|
178
|
+
)
|
|
179
|
+
group_id: Optional[str] = Field(
|
|
180
|
+
None, description="The multi-agent group that the message was sent in"
|
|
181
|
+
)
|
|
182
|
+
sender_id: Optional[str] = Field(
|
|
183
|
+
None,
|
|
184
|
+
description="The id of the sender of the message, can be an identity id or agent id",
|
|
185
|
+
)
|
|
186
|
+
# This overrides the optional base orm schema, created_at MUST exist on all messages objects
|
|
187
|
+
created_at: datetime = Field(
|
|
188
|
+
default_factory=get_utc_time,
|
|
189
|
+
description="The timestamp when the object was created.",
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
# NEW: Filter tags for flexible filtering and categorization
|
|
193
|
+
filter_tags: Optional[Dict[str, Any]] = Field(
|
|
194
|
+
default=None,
|
|
195
|
+
description="Custom filter tags for filtering and categorization",
|
|
196
|
+
examples=[
|
|
197
|
+
{
|
|
198
|
+
"project_id": "proj-abc",
|
|
199
|
+
"session_id": "sess-xyz",
|
|
200
|
+
"tags": ["important", "work"],
|
|
201
|
+
"priority": "high"
|
|
202
|
+
}
|
|
203
|
+
]
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
@field_validator("role")
|
|
207
|
+
@classmethod
|
|
208
|
+
def validate_role(cls, v: str) -> str:
|
|
209
|
+
roles = ["system", "assistant", "user", "tool"]
|
|
210
|
+
assert v in roles, f"Role must be one of {roles}"
|
|
211
|
+
return v
|
|
212
|
+
|
|
213
|
+
def to_json(self):
|
|
214
|
+
json_message = vars(self)
|
|
215
|
+
if json_message["tool_calls"] is not None:
|
|
216
|
+
json_message["tool_calls"] = [vars(tc) for tc in json_message["tool_calls"]]
|
|
217
|
+
# turn datetime to ISO format
|
|
218
|
+
# also if the created_at is missing a timezone, add UTC
|
|
219
|
+
if not is_utc_datetime(self.created_at):
|
|
220
|
+
self.created_at = self.created_at.replace(tzinfo=timezone.utc)
|
|
221
|
+
json_message["created_at"] = self.created_at.isoformat()
|
|
222
|
+
return json_message
|
|
223
|
+
|
|
224
|
+
@staticmethod
|
|
225
|
+
def generate_otid():
|
|
226
|
+
return str(uuid.uuid4())
|
|
227
|
+
|
|
228
|
+
@staticmethod
|
|
229
|
+
def to_mirix_messages_from_list(
|
|
230
|
+
messages: List[Message],
|
|
231
|
+
use_assistant_message: bool = True,
|
|
232
|
+
assistant_message_tool_name: str = DEFAULT_MESSAGE_TOOL,
|
|
233
|
+
assistant_message_tool_kwarg: str = DEFAULT_MESSAGE_TOOL_KWARG,
|
|
234
|
+
reverse: bool = True,
|
|
235
|
+
) -> List[MirixMessage]:
|
|
236
|
+
if use_assistant_message:
|
|
237
|
+
message_ids_to_remove = []
|
|
238
|
+
assistant_messages_by_tool_call = {
|
|
239
|
+
tool_call.id: msg
|
|
240
|
+
for msg in messages
|
|
241
|
+
if msg.role == MessageRole.assistant and msg.tool_calls
|
|
242
|
+
for tool_call in msg.tool_calls
|
|
243
|
+
}
|
|
244
|
+
for message in messages:
|
|
245
|
+
if (
|
|
246
|
+
message.role == MessageRole.tool
|
|
247
|
+
and message.tool_call_id in assistant_messages_by_tool_call
|
|
248
|
+
and assistant_messages_by_tool_call[message.tool_call_id].tool_calls
|
|
249
|
+
and assistant_message_tool_name
|
|
250
|
+
in [
|
|
251
|
+
tool_call.function.name
|
|
252
|
+
for tool_call in assistant_messages_by_tool_call[
|
|
253
|
+
message.tool_call_id
|
|
254
|
+
].tool_calls
|
|
255
|
+
]
|
|
256
|
+
):
|
|
257
|
+
message_ids_to_remove.append(message.id)
|
|
258
|
+
|
|
259
|
+
messages = [msg for msg in messages if msg.id not in message_ids_to_remove]
|
|
260
|
+
|
|
261
|
+
# Convert messages to MirixMessages
|
|
262
|
+
return [
|
|
263
|
+
msg
|
|
264
|
+
for m in messages
|
|
265
|
+
for msg in m.to_mirix_message(
|
|
266
|
+
use_assistant_message=use_assistant_message,
|
|
267
|
+
assistant_message_tool_name=assistant_message_tool_name,
|
|
268
|
+
assistant_message_tool_kwarg=assistant_message_tool_kwarg,
|
|
269
|
+
reverse=reverse,
|
|
270
|
+
)
|
|
271
|
+
]
|
|
272
|
+
|
|
273
|
+
def to_mirix_message(
|
|
274
|
+
self,
|
|
275
|
+
use_assistant_message: bool = False,
|
|
276
|
+
assistant_message_tool_name: str = DEFAULT_MESSAGE_TOOL,
|
|
277
|
+
assistant_message_tool_kwarg: str = DEFAULT_MESSAGE_TOOL_KWARG,
|
|
278
|
+
reverse: bool = True,
|
|
279
|
+
) -> List[MirixMessage]:
|
|
280
|
+
"""Convert message object (in DB format) to the style used by the original Mirix API"""
|
|
281
|
+
messages = []
|
|
282
|
+
|
|
283
|
+
if self.role == MessageRole.assistant:
|
|
284
|
+
# Handle reasoning
|
|
285
|
+
if self.content:
|
|
286
|
+
# Check for ReACT-style COT inside of TextContent
|
|
287
|
+
if len(self.content) == 1 and isinstance(self.content[0], TextContent):
|
|
288
|
+
otid = Message.generate_otid_from_id(self.id, len(messages))
|
|
289
|
+
messages.append(
|
|
290
|
+
ReasoningMessage(
|
|
291
|
+
id=self.id,
|
|
292
|
+
date=self.created_at,
|
|
293
|
+
reasoning=self.content[0].text,
|
|
294
|
+
name=self.name,
|
|
295
|
+
otid=otid,
|
|
296
|
+
sender_id=self.sender_id,
|
|
297
|
+
)
|
|
298
|
+
)
|
|
299
|
+
# Otherwise, we may have a list of multiple types
|
|
300
|
+
else:
|
|
301
|
+
# TODO we can probably collapse these two cases into a single loop
|
|
302
|
+
for content_part in self.content:
|
|
303
|
+
otid = Message.generate_otid_from_id(self.id, len(messages))
|
|
304
|
+
if isinstance(content_part, TextContent):
|
|
305
|
+
# COT
|
|
306
|
+
messages.append(
|
|
307
|
+
ReasoningMessage(
|
|
308
|
+
id=self.id,
|
|
309
|
+
date=self.created_at,
|
|
310
|
+
reasoning=content_part.text,
|
|
311
|
+
name=self.name,
|
|
312
|
+
otid=otid,
|
|
313
|
+
sender_id=self.sender_id,
|
|
314
|
+
)
|
|
315
|
+
)
|
|
316
|
+
elif isinstance(content_part, ReasoningContent):
|
|
317
|
+
# "native" COT
|
|
318
|
+
messages.append(
|
|
319
|
+
ReasoningMessage(
|
|
320
|
+
id=self.id,
|
|
321
|
+
date=self.created_at,
|
|
322
|
+
reasoning=content_part.reasoning,
|
|
323
|
+
source="reasoner_model", # TODO do we want to tag like this?
|
|
324
|
+
signature=content_part.signature,
|
|
325
|
+
name=self.name,
|
|
326
|
+
otid=otid,
|
|
327
|
+
)
|
|
328
|
+
)
|
|
329
|
+
elif isinstance(content_part, RedactedReasoningContent):
|
|
330
|
+
# "native" redacted/hidden COT
|
|
331
|
+
messages.append(
|
|
332
|
+
HiddenReasoningMessage(
|
|
333
|
+
id=self.id,
|
|
334
|
+
date=self.created_at,
|
|
335
|
+
state="redacted",
|
|
336
|
+
hidden_reasoning=content_part.data,
|
|
337
|
+
name=self.name,
|
|
338
|
+
otid=otid,
|
|
339
|
+
sender_id=self.sender_id,
|
|
340
|
+
)
|
|
341
|
+
)
|
|
342
|
+
else:
|
|
343
|
+
warnings.warn(
|
|
344
|
+
f"Unrecognized content part in assistant message: {content_part}"
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
if self.tool_calls is not None:
|
|
348
|
+
# This is type FunctionCall
|
|
349
|
+
for tool_call in self.tool_calls:
|
|
350
|
+
otid = Message.generate_otid_from_id(self.id, len(messages))
|
|
351
|
+
# If we're supporting using assistant message,
|
|
352
|
+
# then we want to treat certain function calls as a special case
|
|
353
|
+
if (
|
|
354
|
+
use_assistant_message
|
|
355
|
+
and tool_call.function.name == assistant_message_tool_name
|
|
356
|
+
):
|
|
357
|
+
# We need to unpack the actual message contents from the function call
|
|
358
|
+
try:
|
|
359
|
+
func_args = parse_json(tool_call.function.arguments)
|
|
360
|
+
message_string = func_args[assistant_message_tool_kwarg]
|
|
361
|
+
except KeyError:
|
|
362
|
+
raise ValueError(
|
|
363
|
+
f"Function call {tool_call.function.name} missing {assistant_message_tool_kwarg} argument"
|
|
364
|
+
)
|
|
365
|
+
messages.append(
|
|
366
|
+
AssistantMessage(
|
|
367
|
+
id=self.id,
|
|
368
|
+
date=self.created_at,
|
|
369
|
+
content=message_string,
|
|
370
|
+
name=self.name,
|
|
371
|
+
otid=otid,
|
|
372
|
+
sender_id=self.sender_id,
|
|
373
|
+
)
|
|
374
|
+
)
|
|
375
|
+
else:
|
|
376
|
+
messages.append(
|
|
377
|
+
ToolCallMessage(
|
|
378
|
+
id=self.id,
|
|
379
|
+
date=self.created_at,
|
|
380
|
+
tool_call=ToolCall(
|
|
381
|
+
name=tool_call.function.name,
|
|
382
|
+
arguments=tool_call.function.arguments,
|
|
383
|
+
tool_call_id=tool_call.id,
|
|
384
|
+
),
|
|
385
|
+
name=self.name,
|
|
386
|
+
otid=otid,
|
|
387
|
+
sender_id=self.sender_id,
|
|
388
|
+
)
|
|
389
|
+
)
|
|
390
|
+
elif self.role == MessageRole.tool:
|
|
391
|
+
# This is type ToolReturnMessage
|
|
392
|
+
# Try to interpret the function return, recall that this is how we packaged:
|
|
393
|
+
# def package_function_response(was_success, response_string, timestamp=None):
|
|
394
|
+
# formatted_time = get_local_time() if timestamp is None else timestamp
|
|
395
|
+
# packaged_message = {
|
|
396
|
+
# "status": "OK" if was_success else "Failed",
|
|
397
|
+
# "message": response_string,
|
|
398
|
+
# "time": formatted_time,
|
|
399
|
+
# }
|
|
400
|
+
if (
|
|
401
|
+
self.content
|
|
402
|
+
and len(self.content) == 1
|
|
403
|
+
and isinstance(self.content[0], TextContent)
|
|
404
|
+
):
|
|
405
|
+
text_content = self.content[0].text
|
|
406
|
+
else:
|
|
407
|
+
raise ValueError(
|
|
408
|
+
f"Invalid tool return (no text object on message): {self.content}"
|
|
409
|
+
)
|
|
410
|
+
|
|
411
|
+
try:
|
|
412
|
+
function_return = parse_json(text_content)
|
|
413
|
+
status = function_return["status"]
|
|
414
|
+
if status == "OK":
|
|
415
|
+
status_enum = "success"
|
|
416
|
+
elif status == "Failed":
|
|
417
|
+
status_enum = "error"
|
|
418
|
+
else:
|
|
419
|
+
raise ValueError(f"Invalid status: {status}")
|
|
420
|
+
except json.JSONDecodeError:
|
|
421
|
+
raise ValueError(f"Failed to decode function return: {text_content}")
|
|
422
|
+
assert self.tool_call_id is not None
|
|
423
|
+
messages.append(
|
|
424
|
+
# TODO make sure this is what the API returns
|
|
425
|
+
# function_return may not match exactly...
|
|
426
|
+
ToolReturnMessage(
|
|
427
|
+
id=self.id,
|
|
428
|
+
date=self.created_at,
|
|
429
|
+
tool_return=text_content,
|
|
430
|
+
status=self.tool_returns[0].status
|
|
431
|
+
if self.tool_returns
|
|
432
|
+
else status_enum,
|
|
433
|
+
tool_call_id=self.tool_call_id,
|
|
434
|
+
stdout=self.tool_returns[0].stdout if self.tool_returns else None,
|
|
435
|
+
stderr=self.tool_returns[0].stderr if self.tool_returns else None,
|
|
436
|
+
name=self.name,
|
|
437
|
+
otid=self.id.replace("message-", ""),
|
|
438
|
+
sender_id=self.sender_id,
|
|
439
|
+
)
|
|
440
|
+
)
|
|
441
|
+
elif self.role == MessageRole.user:
|
|
442
|
+
# This is type UserMessage
|
|
443
|
+
if (
|
|
444
|
+
self.content
|
|
445
|
+
and len(self.content) == 1
|
|
446
|
+
and isinstance(self.content[0], TextContent)
|
|
447
|
+
):
|
|
448
|
+
text_content = self.content[0].text
|
|
449
|
+
elif self.content and len(self.content) > 1:
|
|
450
|
+
text_content = ""
|
|
451
|
+
for content in self.content:
|
|
452
|
+
if isinstance(content, TextContent):
|
|
453
|
+
text_content += content.text
|
|
454
|
+
elif isinstance(content, ImageContent):
|
|
455
|
+
text_content += "<image>" + content.image_id + "</image>"
|
|
456
|
+
elif isinstance(content, FileContent):
|
|
457
|
+
text_content += "<file>" + content.file_id + "</file>"
|
|
458
|
+
elif isinstance(content, CloudFileContent):
|
|
459
|
+
text_content += (
|
|
460
|
+
"<cloud_file>" + content.cloud_file_uri + "</cloud_file>"
|
|
461
|
+
)
|
|
462
|
+
else:
|
|
463
|
+
raise ValueError(
|
|
464
|
+
f"Invalid user message (no text object on message): {self.content}"
|
|
465
|
+
)
|
|
466
|
+
|
|
467
|
+
message_str = unpack_message(text_content)
|
|
468
|
+
messages.append(
|
|
469
|
+
UserMessage(
|
|
470
|
+
id=self.id,
|
|
471
|
+
date=self.created_at,
|
|
472
|
+
content=message_str or text_content,
|
|
473
|
+
name=self.name,
|
|
474
|
+
otid=self.otid,
|
|
475
|
+
sender_id=self.sender_id,
|
|
476
|
+
)
|
|
477
|
+
)
|
|
478
|
+
elif self.role == MessageRole.system:
|
|
479
|
+
# This is type SystemMessage
|
|
480
|
+
if (
|
|
481
|
+
self.content
|
|
482
|
+
and len(self.content) == 1
|
|
483
|
+
and isinstance(self.content[0], TextContent)
|
|
484
|
+
):
|
|
485
|
+
text_content = self.content[0].text
|
|
486
|
+
else:
|
|
487
|
+
raise ValueError(
|
|
488
|
+
f"Invalid system message (no text object on system): {self.content}"
|
|
489
|
+
)
|
|
490
|
+
|
|
491
|
+
messages.append(
|
|
492
|
+
SystemMessage(
|
|
493
|
+
id=self.id,
|
|
494
|
+
date=self.created_at,
|
|
495
|
+
content=text_content,
|
|
496
|
+
name=self.name,
|
|
497
|
+
otid=self.otid,
|
|
498
|
+
sender_id=self.sender_id,
|
|
499
|
+
)
|
|
500
|
+
)
|
|
501
|
+
else:
|
|
502
|
+
raise ValueError(self.role)
|
|
503
|
+
|
|
504
|
+
if reverse:
|
|
505
|
+
messages.reverse()
|
|
506
|
+
|
|
507
|
+
return messages
|
|
508
|
+
|
|
509
|
+
@staticmethod
|
|
510
|
+
def dict_to_message(
|
|
511
|
+
agent_id: str,
|
|
512
|
+
openai_message_dict: dict,
|
|
513
|
+
model: Optional[str] = None, # model used to make function call
|
|
514
|
+
allow_functions_style: bool = False, # allow deprecated functions style?
|
|
515
|
+
created_at: Optional[datetime] = None,
|
|
516
|
+
id: Optional[str] = None,
|
|
517
|
+
name: Optional[str] = None,
|
|
518
|
+
group_id: Optional[str] = None,
|
|
519
|
+
tool_returns: Optional[List[ToolReturn]] = None,
|
|
520
|
+
):
|
|
521
|
+
"""Convert a ChatCompletion message object into a Message object (synced to DB)"""
|
|
522
|
+
if not created_at:
|
|
523
|
+
# timestamp for creation
|
|
524
|
+
created_at = get_utc_time()
|
|
525
|
+
|
|
526
|
+
assert "role" in openai_message_dict, openai_message_dict
|
|
527
|
+
assert "content" in openai_message_dict, openai_message_dict
|
|
528
|
+
|
|
529
|
+
# TODO(caren) implicit support for only non-parts/list content types
|
|
530
|
+
if (
|
|
531
|
+
openai_message_dict["content"] is not None
|
|
532
|
+
and type(openai_message_dict["content"]) is not str
|
|
533
|
+
):
|
|
534
|
+
raise ValueError(
|
|
535
|
+
f"Invalid content type: {type(openai_message_dict['content'])}"
|
|
536
|
+
)
|
|
537
|
+
content = (
|
|
538
|
+
[TextContent(text=openai_message_dict["content"])]
|
|
539
|
+
if openai_message_dict["content"]
|
|
540
|
+
else []
|
|
541
|
+
)
|
|
542
|
+
|
|
543
|
+
# TODO(caren) bad assumption here that "reasoning_content" always comes before "redacted_reasoning_content"
|
|
544
|
+
if (
|
|
545
|
+
"reasoning_content" in openai_message_dict
|
|
546
|
+
and openai_message_dict["reasoning_content"]
|
|
547
|
+
):
|
|
548
|
+
content.append(
|
|
549
|
+
ReasoningContent(
|
|
550
|
+
reasoning=openai_message_dict["reasoning_content"],
|
|
551
|
+
is_native=True,
|
|
552
|
+
signature=(
|
|
553
|
+
openai_message_dict["reasoning_content_signature"]
|
|
554
|
+
if openai_message_dict["reasoning_content_signature"]
|
|
555
|
+
else None
|
|
556
|
+
),
|
|
557
|
+
),
|
|
558
|
+
)
|
|
559
|
+
if (
|
|
560
|
+
"redacted_reasoning_content" in openai_message_dict
|
|
561
|
+
and openai_message_dict["redacted_reasoning_content"]
|
|
562
|
+
):
|
|
563
|
+
content.append(
|
|
564
|
+
RedactedReasoningContent(
|
|
565
|
+
data=openai_message_dict["redacted_reasoning_content"]
|
|
566
|
+
if "redacted_reasoning_content" in openai_message_dict
|
|
567
|
+
else None,
|
|
568
|
+
),
|
|
569
|
+
)
|
|
570
|
+
|
|
571
|
+
# If we're going from deprecated function form
|
|
572
|
+
if openai_message_dict["role"] == "function":
|
|
573
|
+
if not allow_functions_style:
|
|
574
|
+
raise DeprecationWarning(openai_message_dict)
|
|
575
|
+
assert "tool_call_id" in openai_message_dict, openai_message_dict
|
|
576
|
+
|
|
577
|
+
# Convert from 'function' response to a 'tool' response
|
|
578
|
+
if id is not None:
|
|
579
|
+
return Message(
|
|
580
|
+
agent_id=agent_id,
|
|
581
|
+
model=model,
|
|
582
|
+
# standard fields expected in an OpenAI ChatCompletion message object
|
|
583
|
+
role=MessageRole.tool, # NOTE
|
|
584
|
+
content=content,
|
|
585
|
+
name=name,
|
|
586
|
+
tool_calls=openai_message_dict["tool_calls"]
|
|
587
|
+
if "tool_calls" in openai_message_dict
|
|
588
|
+
else None,
|
|
589
|
+
tool_call_id=openai_message_dict["tool_call_id"]
|
|
590
|
+
if "tool_call_id" in openai_message_dict
|
|
591
|
+
else None,
|
|
592
|
+
created_at=created_at,
|
|
593
|
+
id=str(id),
|
|
594
|
+
tool_returns=tool_returns,
|
|
595
|
+
group_id=group_id,
|
|
596
|
+
)
|
|
597
|
+
else:
|
|
598
|
+
return Message(
|
|
599
|
+
agent_id=agent_id,
|
|
600
|
+
model=model,
|
|
601
|
+
# standard fields expected in an OpenAI ChatCompletion message object
|
|
602
|
+
role=MessageRole.tool, # NOTE
|
|
603
|
+
content=content,
|
|
604
|
+
name=name,
|
|
605
|
+
tool_calls=openai_message_dict["tool_calls"]
|
|
606
|
+
if "tool_calls" in openai_message_dict
|
|
607
|
+
else None,
|
|
608
|
+
tool_call_id=openai_message_dict["tool_call_id"]
|
|
609
|
+
if "tool_call_id" in openai_message_dict
|
|
610
|
+
else None,
|
|
611
|
+
created_at=created_at,
|
|
612
|
+
tool_returns=tool_returns,
|
|
613
|
+
group_id=group_id,
|
|
614
|
+
)
|
|
615
|
+
|
|
616
|
+
elif (
|
|
617
|
+
"function_call" in openai_message_dict
|
|
618
|
+
and openai_message_dict["function_call"] is not None
|
|
619
|
+
):
|
|
620
|
+
if not allow_functions_style:
|
|
621
|
+
raise DeprecationWarning(openai_message_dict)
|
|
622
|
+
assert openai_message_dict["role"] == "assistant", openai_message_dict
|
|
623
|
+
assert "tool_call_id" in openai_message_dict, openai_message_dict
|
|
624
|
+
|
|
625
|
+
# Convert a function_call (from an assistant message) into a tool_call
|
|
626
|
+
# NOTE: this does not conventionally include a tool_call_id (ToolCall.id), it's on the caster to provide it
|
|
627
|
+
tool_calls = [
|
|
628
|
+
OpenAIToolCall(
|
|
629
|
+
id=openai_message_dict[
|
|
630
|
+
"tool_call_id"
|
|
631
|
+
], # NOTE: unconventional source, not to spec
|
|
632
|
+
type="function",
|
|
633
|
+
function=OpenAIFunction(
|
|
634
|
+
name=openai_message_dict["function_call"]["name"],
|
|
635
|
+
arguments=openai_message_dict["function_call"]["arguments"],
|
|
636
|
+
),
|
|
637
|
+
)
|
|
638
|
+
]
|
|
639
|
+
|
|
640
|
+
if id is not None:
|
|
641
|
+
return Message(
|
|
642
|
+
agent_id=agent_id,
|
|
643
|
+
model=model,
|
|
644
|
+
# standard fields expected in an OpenAI ChatCompletion message object
|
|
645
|
+
role=MessageRole(openai_message_dict["role"]),
|
|
646
|
+
content=content,
|
|
647
|
+
name=name,
|
|
648
|
+
tool_calls=tool_calls,
|
|
649
|
+
tool_call_id=None, # NOTE: None, since this field is only non-null for role=='tool'
|
|
650
|
+
created_at=created_at,
|
|
651
|
+
id=str(id),
|
|
652
|
+
tool_returns=tool_returns,
|
|
653
|
+
group_id=group_id,
|
|
654
|
+
)
|
|
655
|
+
else:
|
|
656
|
+
return Message(
|
|
657
|
+
agent_id=agent_id,
|
|
658
|
+
model=model,
|
|
659
|
+
# standard fields expected in an OpenAI ChatCompletion message object
|
|
660
|
+
role=MessageRole(openai_message_dict["role"]),
|
|
661
|
+
content=content,
|
|
662
|
+
name=openai_message_dict["name"]
|
|
663
|
+
if "name" in openai_message_dict
|
|
664
|
+
else None,
|
|
665
|
+
tool_calls=tool_calls,
|
|
666
|
+
tool_call_id=None, # NOTE: None, since this field is only non-null for role=='tool'
|
|
667
|
+
created_at=created_at,
|
|
668
|
+
tool_returns=tool_returns,
|
|
669
|
+
group_id=group_id,
|
|
670
|
+
)
|
|
671
|
+
|
|
672
|
+
else:
|
|
673
|
+
# Basic sanity check
|
|
674
|
+
if openai_message_dict["role"] == "tool":
|
|
675
|
+
assert (
|
|
676
|
+
"tool_call_id" in openai_message_dict
|
|
677
|
+
and openai_message_dict["tool_call_id"] is not None
|
|
678
|
+
), openai_message_dict
|
|
679
|
+
else:
|
|
680
|
+
if "tool_call_id" in openai_message_dict:
|
|
681
|
+
assert openai_message_dict["tool_call_id"] is None, (
|
|
682
|
+
openai_message_dict
|
|
683
|
+
)
|
|
684
|
+
|
|
685
|
+
if (
|
|
686
|
+
"tool_calls" in openai_message_dict
|
|
687
|
+
and openai_message_dict["tool_calls"] is not None
|
|
688
|
+
):
|
|
689
|
+
assert openai_message_dict["role"] == "assistant", openai_message_dict
|
|
690
|
+
|
|
691
|
+
tool_calls = [
|
|
692
|
+
OpenAIToolCall(
|
|
693
|
+
id=tool_call["id"],
|
|
694
|
+
type=tool_call["type"],
|
|
695
|
+
function=tool_call["function"],
|
|
696
|
+
)
|
|
697
|
+
for tool_call in openai_message_dict["tool_calls"]
|
|
698
|
+
]
|
|
699
|
+
else:
|
|
700
|
+
tool_calls = None
|
|
701
|
+
|
|
702
|
+
# If we're going from tool-call style
|
|
703
|
+
if id is not None:
|
|
704
|
+
return Message(
|
|
705
|
+
agent_id=agent_id,
|
|
706
|
+
model=model,
|
|
707
|
+
# standard fields expected in an OpenAI ChatCompletion message object
|
|
708
|
+
role=MessageRole(openai_message_dict["role"]),
|
|
709
|
+
content=content,
|
|
710
|
+
name=openai_message_dict["name"]
|
|
711
|
+
if "name" in openai_message_dict
|
|
712
|
+
else name,
|
|
713
|
+
tool_calls=tool_calls,
|
|
714
|
+
tool_call_id=openai_message_dict["tool_call_id"]
|
|
715
|
+
if "tool_call_id" in openai_message_dict
|
|
716
|
+
else None,
|
|
717
|
+
created_at=created_at,
|
|
718
|
+
id=str(id),
|
|
719
|
+
tool_returns=tool_returns,
|
|
720
|
+
group_id=group_id,
|
|
721
|
+
)
|
|
722
|
+
else:
|
|
723
|
+
return Message(
|
|
724
|
+
agent_id=agent_id,
|
|
725
|
+
model=model,
|
|
726
|
+
# standard fields expected in an OpenAI ChatCompletion message object
|
|
727
|
+
role=MessageRole(openai_message_dict["role"]),
|
|
728
|
+
content=content,
|
|
729
|
+
name=openai_message_dict["name"]
|
|
730
|
+
if "name" in openai_message_dict
|
|
731
|
+
else name,
|
|
732
|
+
tool_calls=tool_calls,
|
|
733
|
+
tool_call_id=openai_message_dict["tool_call_id"]
|
|
734
|
+
if "tool_call_id" in openai_message_dict
|
|
735
|
+
else None,
|
|
736
|
+
created_at=created_at,
|
|
737
|
+
tool_returns=tool_returns,
|
|
738
|
+
group_id=group_id,
|
|
739
|
+
)
|
|
740
|
+
|
|
741
|
+
def to_openai_dict_search_results(
|
|
742
|
+
self, max_tool_id_length: int = TOOL_CALL_ID_MAX_LEN
|
|
743
|
+
) -> dict:
|
|
744
|
+
result_json = self.to_openai_dict()
|
|
745
|
+
search_result_json = {
|
|
746
|
+
"timestamp": self.created_at,
|
|
747
|
+
"message": {"content": result_json["content"], "role": result_json["role"]},
|
|
748
|
+
}
|
|
749
|
+
return search_result_json
|
|
750
|
+
|
|
751
|
+
def to_openai_dict(
|
|
752
|
+
self,
|
|
753
|
+
max_tool_id_length: int = TOOL_CALL_ID_MAX_LEN,
|
|
754
|
+
put_inner_thoughts_in_kwargs: bool = False,
|
|
755
|
+
use_developer_message: bool = False,
|
|
756
|
+
) -> dict:
|
|
757
|
+
"""Go from Message class to ChatCompletion message object"""
|
|
758
|
+
|
|
759
|
+
# TODO change to pydantic casting, eg `return SystemMessageModel(self)`
|
|
760
|
+
# If we only have one content part and it's text, treat it as COT
|
|
761
|
+
parse_content_parts = False
|
|
762
|
+
if (
|
|
763
|
+
self.content
|
|
764
|
+
and len(self.content) == 1
|
|
765
|
+
and isinstance(self.content[0], TextContent)
|
|
766
|
+
):
|
|
767
|
+
content = self.content[0].text
|
|
768
|
+
# Otherwise, check if we have TextContent and multiple other parts
|
|
769
|
+
elif self.content and len(self.content) > 1:
|
|
770
|
+
content = []
|
|
771
|
+
text_content_count = 0
|
|
772
|
+
|
|
773
|
+
for content_part in self.content:
|
|
774
|
+
if isinstance(content_part, TextContent):
|
|
775
|
+
content.append(
|
|
776
|
+
{
|
|
777
|
+
"type": "text",
|
|
778
|
+
"text": content_part.text,
|
|
779
|
+
}
|
|
780
|
+
)
|
|
781
|
+
text_content_count += 1
|
|
782
|
+
elif isinstance(content_part, ImageContent):
|
|
783
|
+
content.append(
|
|
784
|
+
{
|
|
785
|
+
"type": content_part.type,
|
|
786
|
+
"image_id": content_part.image_id,
|
|
787
|
+
"detail": content_part.detail,
|
|
788
|
+
}
|
|
789
|
+
)
|
|
790
|
+
elif isinstance(content_part, FileContent):
|
|
791
|
+
content.append(
|
|
792
|
+
{
|
|
793
|
+
"type": content_part.type,
|
|
794
|
+
"file_id": content_part.file_id,
|
|
795
|
+
}
|
|
796
|
+
)
|
|
797
|
+
elif isinstance(content_part, CloudFileContent):
|
|
798
|
+
content.append(
|
|
799
|
+
{
|
|
800
|
+
"type": content_part.type,
|
|
801
|
+
"cloud_file_uri": content_part.cloud_file_uri,
|
|
802
|
+
}
|
|
803
|
+
)
|
|
804
|
+
else:
|
|
805
|
+
raise ValueError(f"Invalid content type: {content_part.type}")
|
|
806
|
+
|
|
807
|
+
if text_content_count > 1:
|
|
808
|
+
# TODO: (yu) @caren check this
|
|
809
|
+
parse_content_parts = True
|
|
810
|
+
else:
|
|
811
|
+
content = None
|
|
812
|
+
|
|
813
|
+
# TODO(caren) we should eventually support multiple content parts here?
|
|
814
|
+
# ie, actually make dict['content'] type list
|
|
815
|
+
# But for now, it's OK until we support multi-modal,
|
|
816
|
+
# since the only "parts" we have are for supporting various COT
|
|
817
|
+
|
|
818
|
+
if self.role == "system":
|
|
819
|
+
assert all([v is not None for v in [self.role]]), vars(self)
|
|
820
|
+
openai_message = {
|
|
821
|
+
"content": content,
|
|
822
|
+
"role": "developer" if use_developer_message else self.role,
|
|
823
|
+
}
|
|
824
|
+
|
|
825
|
+
elif self.role == "user":
|
|
826
|
+
assert all([v is not None for v in [content, self.role]]), vars(self)
|
|
827
|
+
openai_message = {
|
|
828
|
+
"content": content,
|
|
829
|
+
"role": self.role,
|
|
830
|
+
}
|
|
831
|
+
|
|
832
|
+
elif self.role == "assistant":
|
|
833
|
+
assert self.tool_calls is not None or content is not None
|
|
834
|
+
openai_message = {
|
|
835
|
+
"content": None if put_inner_thoughts_in_kwargs else content,
|
|
836
|
+
"role": self.role,
|
|
837
|
+
}
|
|
838
|
+
|
|
839
|
+
if self.tool_calls is not None:
|
|
840
|
+
openai_message["tool_calls"] = [
|
|
841
|
+
tool_call.model_dump() for tool_call in self.tool_calls
|
|
842
|
+
]
|
|
843
|
+
if max_tool_id_length:
|
|
844
|
+
for tool_call_dict in openai_message["tool_calls"]:
|
|
845
|
+
tool_call_dict["id"] = tool_call_dict["id"][:max_tool_id_length]
|
|
846
|
+
|
|
847
|
+
elif self.role == "tool":
|
|
848
|
+
assert all([v is not None for v in [self.role, self.tool_call_id]]), vars(
|
|
849
|
+
self
|
|
850
|
+
)
|
|
851
|
+
openai_message = {
|
|
852
|
+
"content": content,
|
|
853
|
+
"role": self.role,
|
|
854
|
+
"tool_call_id": self.tool_call_id[:max_tool_id_length]
|
|
855
|
+
if max_tool_id_length
|
|
856
|
+
else self.tool_call_id,
|
|
857
|
+
}
|
|
858
|
+
|
|
859
|
+
else:
|
|
860
|
+
raise ValueError(self.role)
|
|
861
|
+
|
|
862
|
+
if parse_content_parts:
|
|
863
|
+
for content in self.content:
|
|
864
|
+
if isinstance(content, ReasoningContent):
|
|
865
|
+
openai_message["reasoning_content"] = content.reasoning
|
|
866
|
+
if content.signature:
|
|
867
|
+
openai_message["reasoning_content_signature"] = (
|
|
868
|
+
content.signature
|
|
869
|
+
)
|
|
870
|
+
if isinstance(content, RedactedReasoningContent):
|
|
871
|
+
openai_message["redacted_reasoning_content"] = content.data
|
|
872
|
+
|
|
873
|
+
return openai_message
|
|
874
|
+
|
|
875
|
+
def to_anthropic_dict(
|
|
876
|
+
self,
|
|
877
|
+
inner_thoughts_xml_tag="thinking",
|
|
878
|
+
put_inner_thoughts_in_kwargs: bool = False,
|
|
879
|
+
) -> dict:
|
|
880
|
+
"""
|
|
881
|
+
Convert to an Anthropic message dictionary
|
|
882
|
+
|
|
883
|
+
Args:
|
|
884
|
+
inner_thoughts_xml_tag (str): The XML tag to wrap around inner thoughts
|
|
885
|
+
"""
|
|
886
|
+
|
|
887
|
+
# Check for COT
|
|
888
|
+
if (
|
|
889
|
+
self.content
|
|
890
|
+
and len(self.content) == 1
|
|
891
|
+
and isinstance(self.content[0], TextContent)
|
|
892
|
+
):
|
|
893
|
+
content = self.content[0].text
|
|
894
|
+
elif self.content and len(self.content) > 1:
|
|
895
|
+
assert self.role == "user"
|
|
896
|
+
content = []
|
|
897
|
+
|
|
898
|
+
for content_part in self.content:
|
|
899
|
+
if isinstance(content_part, TextContent):
|
|
900
|
+
content.append(
|
|
901
|
+
{
|
|
902
|
+
"type": "text",
|
|
903
|
+
"text": content_part.text,
|
|
904
|
+
}
|
|
905
|
+
)
|
|
906
|
+
elif isinstance(content_part, ImageContent):
|
|
907
|
+
content.append(
|
|
908
|
+
{"type": "image_url", "image_id": content_part.image_id}
|
|
909
|
+
)
|
|
910
|
+
elif isinstance(content_part, FileContent):
|
|
911
|
+
content.append(
|
|
912
|
+
{
|
|
913
|
+
"type": "file_uri",
|
|
914
|
+
"file_id": content_part.file_id,
|
|
915
|
+
}
|
|
916
|
+
)
|
|
917
|
+
elif isinstance(content_part, CloudFileContent):
|
|
918
|
+
content.append(
|
|
919
|
+
{
|
|
920
|
+
"type": "cloud_file_uri",
|
|
921
|
+
"cloud_file_uri": content_part.cloud_file_uri,
|
|
922
|
+
}
|
|
923
|
+
)
|
|
924
|
+
|
|
925
|
+
else:
|
|
926
|
+
content = None
|
|
927
|
+
|
|
928
|
+
def add_xml_tag(string: str, xml_tag: Optional[str]):
|
|
929
|
+
# NOTE: Anthropic docs recommends using <thinking> tag when using CoT + tool use
|
|
930
|
+
if f"<{xml_tag}>" in string and f"</{xml_tag}>" in string:
|
|
931
|
+
# don't nest if tags already exist
|
|
932
|
+
return string
|
|
933
|
+
return f"<{xml_tag}>{string}</{xml_tag}" if xml_tag else string
|
|
934
|
+
|
|
935
|
+
if self.role == "system":
|
|
936
|
+
# NOTE: this is not for system instructions, but instead system "events"
|
|
937
|
+
|
|
938
|
+
assert all([v is not None for v in [content, self.role]]), vars(self)
|
|
939
|
+
# Two options here, we would use system.package_system_message,
|
|
940
|
+
# or use a more Anthropic-specific packaging ie xml tags
|
|
941
|
+
user_system_event = add_xml_tag(
|
|
942
|
+
string=f"SYSTEM ALERT: {content}", xml_tag="event"
|
|
943
|
+
)
|
|
944
|
+
anthropic_message = {
|
|
945
|
+
"content": user_system_event,
|
|
946
|
+
"role": "user",
|
|
947
|
+
}
|
|
948
|
+
|
|
949
|
+
elif self.role == "user":
|
|
950
|
+
assert all([v is not None for v in [content, self.role]]), vars(self)
|
|
951
|
+
anthropic_message = {
|
|
952
|
+
"content": content,
|
|
953
|
+
"role": self.role,
|
|
954
|
+
}
|
|
955
|
+
|
|
956
|
+
elif self.role == "assistant":
|
|
957
|
+
assert self.tool_calls is not None or content is not None
|
|
958
|
+
anthropic_message = {
|
|
959
|
+
"role": self.role,
|
|
960
|
+
}
|
|
961
|
+
content = []
|
|
962
|
+
# COT / reasoning / thinking
|
|
963
|
+
if len(self.content) > 1:
|
|
964
|
+
for content_part in self.content:
|
|
965
|
+
if isinstance(content_part, ReasoningContent):
|
|
966
|
+
content.append(
|
|
967
|
+
{
|
|
968
|
+
"type": "thinking",
|
|
969
|
+
"thinking": content_part.reasoning,
|
|
970
|
+
"signature": content_part.signature,
|
|
971
|
+
}
|
|
972
|
+
)
|
|
973
|
+
if isinstance(content_part, RedactedReasoningContent):
|
|
974
|
+
content.append(
|
|
975
|
+
{
|
|
976
|
+
"type": "redacted_thinking",
|
|
977
|
+
"data": content_part.data,
|
|
978
|
+
}
|
|
979
|
+
)
|
|
980
|
+
elif content is not None:
|
|
981
|
+
content.append(
|
|
982
|
+
{
|
|
983
|
+
"type": "text",
|
|
984
|
+
"text": add_xml_tag(
|
|
985
|
+
string=content, xml_tag=inner_thoughts_xml_tag
|
|
986
|
+
),
|
|
987
|
+
}
|
|
988
|
+
)
|
|
989
|
+
# Tool calling
|
|
990
|
+
if self.tool_calls is not None:
|
|
991
|
+
for tool_call in self.tool_calls:
|
|
992
|
+
tool_call_input = parse_json(tool_call.function.arguments)
|
|
993
|
+
|
|
994
|
+
content.append(
|
|
995
|
+
{
|
|
996
|
+
"type": "tool_use",
|
|
997
|
+
"id": tool_call.id,
|
|
998
|
+
"name": tool_call.function.name,
|
|
999
|
+
"input": tool_call_input,
|
|
1000
|
+
}
|
|
1001
|
+
)
|
|
1002
|
+
|
|
1003
|
+
# If the only content was text, unpack it back into a singleton
|
|
1004
|
+
# TODO support multi-modal
|
|
1005
|
+
anthropic_message["content"] = content
|
|
1006
|
+
|
|
1007
|
+
elif self.role == "tool":
|
|
1008
|
+
# NOTE: Anthropic uses role "user" for "tool" responses
|
|
1009
|
+
assert all([v is not None for v in [self.role, self.tool_call_id]]), vars(
|
|
1010
|
+
self
|
|
1011
|
+
)
|
|
1012
|
+
anthropic_message = {
|
|
1013
|
+
"role": "user", # NOTE: diff
|
|
1014
|
+
"content": [
|
|
1015
|
+
# TODO support error types etc
|
|
1016
|
+
{
|
|
1017
|
+
"type": "tool_result",
|
|
1018
|
+
"tool_use_id": self.tool_call_id,
|
|
1019
|
+
"content": content,
|
|
1020
|
+
}
|
|
1021
|
+
],
|
|
1022
|
+
}
|
|
1023
|
+
|
|
1024
|
+
else:
|
|
1025
|
+
raise ValueError(self.role)
|
|
1026
|
+
|
|
1027
|
+
return anthropic_message
|
|
1028
|
+
|
|
1029
|
+
def to_google_ai_dict(self, put_inner_thoughts_in_kwargs: bool = True) -> dict:
|
|
1030
|
+
"""
|
|
1031
|
+
Go from Message class to Google AI REST message object
|
|
1032
|
+
"""
|
|
1033
|
+
# type Content: https://ai.google.dev/api/rest/v1/Content / https://ai.google.dev/api/rest/v1beta/Content
|
|
1034
|
+
# parts[]: Part
|
|
1035
|
+
# role: str ('user' or 'model')
|
|
1036
|
+
if (
|
|
1037
|
+
self.content
|
|
1038
|
+
and len(self.content) == 1
|
|
1039
|
+
and isinstance(self.content[0], TextContent)
|
|
1040
|
+
):
|
|
1041
|
+
text_content = self.content[0].text
|
|
1042
|
+
contents = [{"text": text_content}]
|
|
1043
|
+
elif self.content:
|
|
1044
|
+
assert self.role == "user"
|
|
1045
|
+
contents = []
|
|
1046
|
+
for content in self.content:
|
|
1047
|
+
if isinstance(content, ImageContent):
|
|
1048
|
+
contents.append({"image_id": content.image_id})
|
|
1049
|
+
elif isinstance(content, TextContent):
|
|
1050
|
+
contents.append({"text": content.text})
|
|
1051
|
+
elif isinstance(content, FileContent):
|
|
1052
|
+
contents.append({"file_id": content.file_id})
|
|
1053
|
+
elif isinstance(content, CloudFileContent):
|
|
1054
|
+
contents.append({"cloud_file_uri": content.cloud_file_uri})
|
|
1055
|
+
else:
|
|
1056
|
+
raise ValueError(f"Invalid content type: {content.type}")
|
|
1057
|
+
else:
|
|
1058
|
+
text_content = None
|
|
1059
|
+
contents = None
|
|
1060
|
+
|
|
1061
|
+
if self.role != "tool" and self.name is not None:
|
|
1062
|
+
warnings.warn(
|
|
1063
|
+
f"Using Google AI with non-null 'name' field (name={self.name} role={self.role}), not yet supported."
|
|
1064
|
+
)
|
|
1065
|
+
|
|
1066
|
+
if self.role == "system":
|
|
1067
|
+
# NOTE: Gemini API doesn't have a 'system' role, use 'user' instead
|
|
1068
|
+
# https://www.reddit.com/r/Bard/comments/1b90i8o/does_gemini_have_a_system_prompt_option_while/
|
|
1069
|
+
google_ai_message = {
|
|
1070
|
+
"role": "user", # NOTE: no 'system'
|
|
1071
|
+
"parts": [{"text": text_content}],
|
|
1072
|
+
}
|
|
1073
|
+
|
|
1074
|
+
elif self.role == "user":
|
|
1075
|
+
if not all([v is not None for v in [contents, self.role]]):
|
|
1076
|
+
import ipdb
|
|
1077
|
+
|
|
1078
|
+
ipdb.set_trace()
|
|
1079
|
+
|
|
1080
|
+
assert all([v is not None for v in [contents, self.role]]), vars(self)
|
|
1081
|
+
google_ai_message = {
|
|
1082
|
+
"role": "user",
|
|
1083
|
+
"parts": contents,
|
|
1084
|
+
}
|
|
1085
|
+
|
|
1086
|
+
elif self.role == "assistant":
|
|
1087
|
+
assert self.tool_calls is not None or text_content is not None
|
|
1088
|
+
google_ai_message = {
|
|
1089
|
+
"role": "model", # NOTE: different
|
|
1090
|
+
}
|
|
1091
|
+
|
|
1092
|
+
# NOTE: Google AI API doesn't allow non-null content + function call
|
|
1093
|
+
# To get around this, just two a two part message, inner thoughts first then
|
|
1094
|
+
parts = []
|
|
1095
|
+
if not put_inner_thoughts_in_kwargs and text_content is not None:
|
|
1096
|
+
# NOTE: ideally we do multi-part for CoT / inner thoughts + function call, but Google AI API doesn't allow it
|
|
1097
|
+
raise NotImplementedError
|
|
1098
|
+
parts.append({"text": text_content})
|
|
1099
|
+
|
|
1100
|
+
if self.tool_calls is not None:
|
|
1101
|
+
# NOTE: implied support for multiple calls
|
|
1102
|
+
for tool_call in self.tool_calls:
|
|
1103
|
+
function_name = tool_call.function.name
|
|
1104
|
+
function_args = tool_call.function.arguments
|
|
1105
|
+
try:
|
|
1106
|
+
# NOTE: Google AI wants actual JSON objects, not strings
|
|
1107
|
+
function_args = parse_json(function_args)
|
|
1108
|
+
except Exception:
|
|
1109
|
+
raise UserWarning(
|
|
1110
|
+
f"Failed to parse JSON function args: {function_args}"
|
|
1111
|
+
)
|
|
1112
|
+
|
|
1113
|
+
parts.append(
|
|
1114
|
+
{
|
|
1115
|
+
"functionCall": {
|
|
1116
|
+
"name": function_name,
|
|
1117
|
+
"args": function_args,
|
|
1118
|
+
}
|
|
1119
|
+
}
|
|
1120
|
+
)
|
|
1121
|
+
else:
|
|
1122
|
+
assert text_content is not None
|
|
1123
|
+
parts.append({"text": text_content})
|
|
1124
|
+
google_ai_message["parts"] = parts
|
|
1125
|
+
|
|
1126
|
+
elif self.role == "tool":
|
|
1127
|
+
# NOTE: Significantly different tool calling format, more similar to function calling format
|
|
1128
|
+
assert all([v is not None for v in [self.role, self.tool_call_id]]), vars(
|
|
1129
|
+
self
|
|
1130
|
+
)
|
|
1131
|
+
|
|
1132
|
+
if self.name is None:
|
|
1133
|
+
warnings.warn(
|
|
1134
|
+
"Couldn't find function name on tool call, defaulting to tool ID instead."
|
|
1135
|
+
)
|
|
1136
|
+
function_name = self.tool_call_id
|
|
1137
|
+
else:
|
|
1138
|
+
function_name = self.name
|
|
1139
|
+
|
|
1140
|
+
# NOTE: Google AI API wants the function response as JSON only, no string
|
|
1141
|
+
try:
|
|
1142
|
+
function_response = parse_json(text_content)
|
|
1143
|
+
except Exception:
|
|
1144
|
+
function_response = {"function_response": text_content}
|
|
1145
|
+
|
|
1146
|
+
google_ai_message = {
|
|
1147
|
+
"role": "function",
|
|
1148
|
+
"parts": [
|
|
1149
|
+
{
|
|
1150
|
+
"functionResponse": {
|
|
1151
|
+
"name": function_name,
|
|
1152
|
+
"response": {
|
|
1153
|
+
"name": function_name, # NOTE: name twice... why?
|
|
1154
|
+
"content": function_response,
|
|
1155
|
+
},
|
|
1156
|
+
}
|
|
1157
|
+
}
|
|
1158
|
+
],
|
|
1159
|
+
}
|
|
1160
|
+
|
|
1161
|
+
else:
|
|
1162
|
+
raise ValueError(self.role)
|
|
1163
|
+
|
|
1164
|
+
# Validate that parts is never empty before returning
|
|
1165
|
+
if "parts" not in google_ai_message or not google_ai_message["parts"]:
|
|
1166
|
+
# If parts is empty, add a default text part
|
|
1167
|
+
google_ai_message["parts"] = [{"text": "empty message"}]
|
|
1168
|
+
warnings.warn(
|
|
1169
|
+
f"Empty 'parts' detected in message with role '{self.role}'. Added default empty text part. Full message:\n{vars(self)}"
|
|
1170
|
+
)
|
|
1171
|
+
|
|
1172
|
+
return google_ai_message
|
|
1173
|
+
|
|
1174
|
+
def to_cohere_dict(
|
|
1175
|
+
self,
|
|
1176
|
+
function_call_role: Optional[str] = "SYSTEM",
|
|
1177
|
+
function_call_prefix: Optional[str] = "[CHATBOT called function]",
|
|
1178
|
+
function_response_role: Optional[str] = "SYSTEM",
|
|
1179
|
+
function_response_prefix: Optional[str] = "[CHATBOT function returned]",
|
|
1180
|
+
inner_thoughts_as_kwarg: Optional[bool] = False,
|
|
1181
|
+
) -> List[dict]:
|
|
1182
|
+
"""
|
|
1183
|
+
Cohere chat_history dicts only have 'role' and 'message' fields
|
|
1184
|
+
"""
|
|
1185
|
+
|
|
1186
|
+
# NOTE: returns a list of dicts so that we can convert:
|
|
1187
|
+
# assistant [cot]: "I'll send a message"
|
|
1188
|
+
# assistant [func]: send_message("hi")
|
|
1189
|
+
# tool: {'status': 'OK'}
|
|
1190
|
+
# to:
|
|
1191
|
+
# CHATBOT.text: "I'll send a message"
|
|
1192
|
+
# SYSTEM.text: [CHATBOT called function] send_message("hi")
|
|
1193
|
+
# SYSTEM.text: [CHATBOT function returned] {'status': 'OK'}
|
|
1194
|
+
|
|
1195
|
+
# TODO: update this prompt style once guidance from Cohere on
|
|
1196
|
+
# embedded function calls in multi-turn conversation become more clear
|
|
1197
|
+
if (
|
|
1198
|
+
self.content
|
|
1199
|
+
and len(self.content) == 1
|
|
1200
|
+
and isinstance(self.content[0], TextContent)
|
|
1201
|
+
):
|
|
1202
|
+
text_content = self.content[0].text
|
|
1203
|
+
else:
|
|
1204
|
+
text_content = None
|
|
1205
|
+
if self.role == "system":
|
|
1206
|
+
"""
|
|
1207
|
+
The chat_history parameter should not be used for SYSTEM messages in most cases.
|
|
1208
|
+
Instead, to add a SYSTEM role message at the beginning of a conversation, the preamble parameter should be used.
|
|
1209
|
+
"""
|
|
1210
|
+
raise UserWarning(
|
|
1211
|
+
"role 'system' messages should go in 'preamble' field for Cohere API"
|
|
1212
|
+
)
|
|
1213
|
+
|
|
1214
|
+
elif self.role == "user":
|
|
1215
|
+
assert all([v is not None for v in [text_content, self.role]]), vars(self)
|
|
1216
|
+
cohere_message = [
|
|
1217
|
+
{
|
|
1218
|
+
"role": "USER",
|
|
1219
|
+
"message": text_content,
|
|
1220
|
+
}
|
|
1221
|
+
]
|
|
1222
|
+
|
|
1223
|
+
elif self.role == "assistant":
|
|
1224
|
+
# NOTE: we may break this into two message - an inner thought and a function call
|
|
1225
|
+
# Optionally, we could just make this a function call with the inner thought inside
|
|
1226
|
+
assert self.tool_calls is not None or text_content is not None
|
|
1227
|
+
|
|
1228
|
+
if text_content and self.tool_calls:
|
|
1229
|
+
if inner_thoughts_as_kwarg:
|
|
1230
|
+
raise NotImplementedError
|
|
1231
|
+
cohere_message = [
|
|
1232
|
+
{
|
|
1233
|
+
"role": "CHATBOT",
|
|
1234
|
+
"message": text_content,
|
|
1235
|
+
},
|
|
1236
|
+
]
|
|
1237
|
+
for tc in self.tool_calls:
|
|
1238
|
+
function_name = tc.function["name"]
|
|
1239
|
+
function_args = parse_json(tc.function["arguments"])
|
|
1240
|
+
function_args_str = ",".join(
|
|
1241
|
+
[f"{k}={v}" for k, v in function_args.items()]
|
|
1242
|
+
)
|
|
1243
|
+
function_call_text = f"{function_name}({function_args_str})"
|
|
1244
|
+
cohere_message.append(
|
|
1245
|
+
{
|
|
1246
|
+
"role": function_call_role,
|
|
1247
|
+
"message": f"{function_call_prefix} {function_call_text}",
|
|
1248
|
+
}
|
|
1249
|
+
)
|
|
1250
|
+
elif not text_content and self.tool_calls:
|
|
1251
|
+
cohere_message = []
|
|
1252
|
+
for tc in self.tool_calls:
|
|
1253
|
+
# TODO better way to pack?
|
|
1254
|
+
function_call_text = json_dumps(tc.to_dict())
|
|
1255
|
+
cohere_message.append(
|
|
1256
|
+
{
|
|
1257
|
+
"role": function_call_role,
|
|
1258
|
+
"message": f"{function_call_prefix} {function_call_text}",
|
|
1259
|
+
}
|
|
1260
|
+
)
|
|
1261
|
+
elif text_content and not self.tool_calls:
|
|
1262
|
+
cohere_message = [
|
|
1263
|
+
{
|
|
1264
|
+
"role": "CHATBOT",
|
|
1265
|
+
"message": text_content,
|
|
1266
|
+
}
|
|
1267
|
+
]
|
|
1268
|
+
else:
|
|
1269
|
+
raise ValueError("Message does not have content nor tool_calls")
|
|
1270
|
+
|
|
1271
|
+
elif self.role == "tool":
|
|
1272
|
+
assert all([v is not None for v in [self.role, self.tool_call_id]]), vars(
|
|
1273
|
+
self
|
|
1274
|
+
)
|
|
1275
|
+
function_response_text = text_content
|
|
1276
|
+
cohere_message = [
|
|
1277
|
+
{
|
|
1278
|
+
"role": function_response_role,
|
|
1279
|
+
"message": f"{function_response_prefix} {function_response_text}",
|
|
1280
|
+
}
|
|
1281
|
+
]
|
|
1282
|
+
|
|
1283
|
+
else:
|
|
1284
|
+
raise ValueError(self.role)
|
|
1285
|
+
|
|
1286
|
+
return cohere_message
|
|
1287
|
+
|
|
1288
|
+
@staticmethod
|
|
1289
|
+
def generate_otid_from_id(message_id: str, index: int) -> str:
|
|
1290
|
+
"""
|
|
1291
|
+
Convert message id to bits and change the list bit to the index
|
|
1292
|
+
"""
|
|
1293
|
+
if not 0 <= index < 128:
|
|
1294
|
+
raise ValueError("Index must be between 0 and 127")
|
|
1295
|
+
|
|
1296
|
+
message_uuid = message_id.replace("message-", "")
|
|
1297
|
+
uuid_int = int(message_uuid.replace("-", ""), 16)
|
|
1298
|
+
|
|
1299
|
+
# Clear last 7 bits and set them to index; supports up to 128 unique indices
|
|
1300
|
+
uuid_int = (uuid_int & ~0x7F) | (index & 0x7F)
|
|
1301
|
+
|
|
1302
|
+
hex_str = f"{uuid_int:032x}"
|
|
1303
|
+
return f"{hex_str[:8]}-{hex_str[8:12]}-{hex_str[12:16]}-{hex_str[16:20]}-{hex_str[20:]}"
|
|
1304
|
+
|
|
1305
|
+
|
|
1306
|
+
class ToolReturn(BaseModel):
|
|
1307
|
+
status: Literal["success", "error"] = Field(
|
|
1308
|
+
..., description="The status of the tool call"
|
|
1309
|
+
)
|
|
1310
|
+
stdout: Optional[List[str]] = Field(
|
|
1311
|
+
None, description="Captured stdout (e.g. prints, logs) from the tool invocation"
|
|
1312
|
+
)
|
|
1313
|
+
stderr: Optional[List[str]] = Field(
|
|
1314
|
+
None, description="Captured stderr from the tool invocation"
|
|
1315
|
+
)
|