mistralai 1.7.1__py3-none-any.whl → 1.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mistralai/_version.py +2 -2
- mistralai/beta.py +20 -0
- mistralai/conversations.py +2657 -0
- mistralai/extra/__init__.py +10 -2
- mistralai/extra/exceptions.py +14 -0
- mistralai/extra/mcp/__init__.py +0 -0
- mistralai/extra/mcp/auth.py +166 -0
- mistralai/extra/mcp/base.py +155 -0
- mistralai/extra/mcp/sse.py +165 -0
- mistralai/extra/mcp/stdio.py +22 -0
- mistralai/extra/run/__init__.py +0 -0
- mistralai/extra/run/context.py +295 -0
- mistralai/extra/run/result.py +212 -0
- mistralai/extra/run/tools.py +225 -0
- mistralai/extra/run/utils.py +36 -0
- mistralai/extra/tests/test_struct_chat.py +1 -1
- mistralai/mistral_agents.py +1158 -0
- mistralai/models/__init__.py +470 -1
- mistralai/models/agent.py +129 -0
- mistralai/models/agentconversation.py +71 -0
- mistralai/models/agentcreationrequest.py +109 -0
- mistralai/models/agenthandoffdoneevent.py +33 -0
- mistralai/models/agenthandoffentry.py +75 -0
- mistralai/models/agenthandoffstartedevent.py +33 -0
- mistralai/models/agents_api_v1_agents_getop.py +16 -0
- mistralai/models/agents_api_v1_agents_listop.py +24 -0
- mistralai/models/agents_api_v1_agents_update_versionop.py +21 -0
- mistralai/models/agents_api_v1_agents_updateop.py +23 -0
- mistralai/models/agents_api_v1_conversations_append_streamop.py +28 -0
- mistralai/models/agents_api_v1_conversations_appendop.py +28 -0
- mistralai/models/agents_api_v1_conversations_getop.py +33 -0
- mistralai/models/agents_api_v1_conversations_historyop.py +16 -0
- mistralai/models/agents_api_v1_conversations_listop.py +37 -0
- mistralai/models/agents_api_v1_conversations_messagesop.py +16 -0
- mistralai/models/agents_api_v1_conversations_restart_streamop.py +26 -0
- mistralai/models/agents_api_v1_conversations_restartop.py +26 -0
- mistralai/models/agentupdaterequest.py +111 -0
- mistralai/models/builtinconnectors.py +13 -0
- mistralai/models/codeinterpretertool.py +17 -0
- mistralai/models/completionargs.py +100 -0
- mistralai/models/completionargsstop.py +13 -0
- mistralai/models/completionjobout.py +3 -3
- mistralai/models/conversationappendrequest.py +35 -0
- mistralai/models/conversationappendstreamrequest.py +37 -0
- mistralai/models/conversationevents.py +72 -0
- mistralai/models/conversationhistory.py +58 -0
- mistralai/models/conversationinputs.py +14 -0
- mistralai/models/conversationmessages.py +28 -0
- mistralai/models/conversationrequest.py +133 -0
- mistralai/models/conversationresponse.py +51 -0
- mistralai/models/conversationrestartrequest.py +42 -0
- mistralai/models/conversationrestartstreamrequest.py +44 -0
- mistralai/models/conversationstreamrequest.py +135 -0
- mistralai/models/conversationusageinfo.py +63 -0
- mistralai/models/documentlibrarytool.py +22 -0
- mistralai/models/functioncallentry.py +76 -0
- mistralai/models/functioncallentryarguments.py +15 -0
- mistralai/models/functioncallevent.py +36 -0
- mistralai/models/functionresultentry.py +69 -0
- mistralai/models/functiontool.py +21 -0
- mistralai/models/imagegenerationtool.py +17 -0
- mistralai/models/inputentries.py +18 -0
- mistralai/models/messageentries.py +18 -0
- mistralai/models/messageinputcontentchunks.py +26 -0
- mistralai/models/messageinputentry.py +89 -0
- mistralai/models/messageoutputcontentchunks.py +30 -0
- mistralai/models/messageoutputentry.py +100 -0
- mistralai/models/messageoutputevent.py +93 -0
- mistralai/models/modelconversation.py +127 -0
- mistralai/models/outputcontentchunks.py +30 -0
- mistralai/models/responsedoneevent.py +25 -0
- mistralai/models/responseerrorevent.py +27 -0
- mistralai/models/responsestartedevent.py +24 -0
- mistralai/models/ssetypes.py +18 -0
- mistralai/models/toolexecutiondoneevent.py +34 -0
- mistralai/models/toolexecutionentry.py +70 -0
- mistralai/models/toolexecutionstartedevent.py +31 -0
- mistralai/models/toolfilechunk.py +61 -0
- mistralai/models/toolreferencechunk.py +61 -0
- mistralai/models/websearchpremiumtool.py +17 -0
- mistralai/models/websearchtool.py +17 -0
- mistralai/sdk.py +3 -0
- {mistralai-1.7.1.dist-info → mistralai-1.8.0.dist-info}/METADATA +42 -7
- {mistralai-1.7.1.dist-info → mistralai-1.8.0.dist-info}/RECORD +86 -10
- {mistralai-1.7.1.dist-info → mistralai-1.8.0.dist-info}/LICENSE +0 -0
- {mistralai-1.7.1.dist-info → mistralai-1.8.0.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,295 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import inspect
|
|
3
|
+
import typing
|
|
4
|
+
from contextlib import AsyncExitStack
|
|
5
|
+
from functools import wraps
|
|
6
|
+
from collections.abc import Callable
|
|
7
|
+
|
|
8
|
+
from dataclasses import dataclass, field
|
|
9
|
+
from typing import Union, Optional
|
|
10
|
+
|
|
11
|
+
import pydantic
|
|
12
|
+
|
|
13
|
+
from mistralai.extra import (
|
|
14
|
+
response_format_from_pydantic_model,
|
|
15
|
+
)
|
|
16
|
+
from mistralai.extra.exceptions import RunException
|
|
17
|
+
from mistralai.extra.mcp.base import MCPClientProtocol
|
|
18
|
+
from mistralai.extra.run.result import RunResult
|
|
19
|
+
from mistralai.types.basemodel import OptionalNullable, BaseModel, UNSET
|
|
20
|
+
from mistralai.models import (
|
|
21
|
+
ResponseFormat,
|
|
22
|
+
FunctionCallEntry,
|
|
23
|
+
Tools,
|
|
24
|
+
ToolsTypedDict,
|
|
25
|
+
CompletionArgs,
|
|
26
|
+
CompletionArgsTypedDict,
|
|
27
|
+
FunctionResultEntry,
|
|
28
|
+
ConversationInputs,
|
|
29
|
+
ConversationInputsTypedDict,
|
|
30
|
+
FunctionTool,
|
|
31
|
+
MessageInputEntry,
|
|
32
|
+
InputEntries,
|
|
33
|
+
ResponseFormatTypedDict,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
from logging import getLogger
|
|
37
|
+
|
|
38
|
+
from mistralai.extra.run.tools import (
|
|
39
|
+
create_function_result,
|
|
40
|
+
RunFunction,
|
|
41
|
+
create_tool_call,
|
|
42
|
+
RunTool,
|
|
43
|
+
RunMCPTool,
|
|
44
|
+
RunCoroutine,
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
if typing.TYPE_CHECKING:
|
|
48
|
+
from mistralai import Beta, OptionalNullable
|
|
49
|
+
|
|
50
|
+
logger = getLogger(__name__)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class AgentRequestKwargs(typing.TypedDict):
|
|
54
|
+
agent_id: str
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class ModelRequestKwargs(typing.TypedDict):
|
|
58
|
+
model: str
|
|
59
|
+
instructions: OptionalNullable[str]
|
|
60
|
+
tools: OptionalNullable[Union[list[Tools], list[ToolsTypedDict]]]
|
|
61
|
+
completion_args: OptionalNullable[Union[CompletionArgs, CompletionArgsTypedDict]]
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@dataclass
|
|
65
|
+
class RunContext:
|
|
66
|
+
"""A context for running a conversation with an agent or a model.
|
|
67
|
+
|
|
68
|
+
The context can be used to execute function calls, connect to MCP server, and keep track of information about
|
|
69
|
+
the run.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
conversation_id (Options[str]): The unique identifier for the conversation. This is
|
|
73
|
+
passed if the user wants to continue an existing conversation.
|
|
74
|
+
model (Options[str]): The model name to be used for the conversation. Can't be used along with 'agent_id'.
|
|
75
|
+
agent_id (Options[str]): The agent id to be used for the conversation. Can't be used along with 'model'.
|
|
76
|
+
output_format (Optional[type[BaseModel]]): The output format expected from the conversation. It represents
|
|
77
|
+
the `response_format` which is part of the `CompletionArgs`.
|
|
78
|
+
request_count (int): The number of requests made in the current `RunContext`.
|
|
79
|
+
continue_on_fn_error (bool): Flag to determine if the conversation should continue when function execution
|
|
80
|
+
resulted in an error.
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
_exit_stack: AsyncExitStack = field(init=False)
|
|
84
|
+
_callable_tools: dict[str, RunTool] = field(init=False, default_factory=dict)
|
|
85
|
+
_mcp_clients: list[MCPClientProtocol] = field(init=False, default_factory=list)
|
|
86
|
+
|
|
87
|
+
conversation_id: Optional[str] = field(default=None)
|
|
88
|
+
model: Optional[str] = field(default=None)
|
|
89
|
+
agent_id: Optional[str] = field(default=None)
|
|
90
|
+
output_format: Optional[type[BaseModel]] = field(default=None)
|
|
91
|
+
request_count: int = field(default=0)
|
|
92
|
+
continue_on_fn_error: bool = field(default=False)
|
|
93
|
+
|
|
94
|
+
def __post_init__(self):
|
|
95
|
+
if self.model and self.agent_id:
|
|
96
|
+
raise RunException("Only one for model or agent_id should be set")
|
|
97
|
+
self._exit_stack = AsyncExitStack()
|
|
98
|
+
|
|
99
|
+
async def __aenter__(self):
|
|
100
|
+
return self
|
|
101
|
+
|
|
102
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
103
|
+
await self._exit_stack.aclose()
|
|
104
|
+
for mcp_client in self._mcp_clients:
|
|
105
|
+
await mcp_client.aclose()
|
|
106
|
+
|
|
107
|
+
def register_func(self, func: Callable):
|
|
108
|
+
"""Add a function to the context."""
|
|
109
|
+
if not inspect.isfunction(func):
|
|
110
|
+
raise RunException(
|
|
111
|
+
"Only object of type function can be registered at the moment."
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
if inspect.iscoroutinefunction(func):
|
|
115
|
+
self._callable_tools[func.__name__] = RunCoroutine(
|
|
116
|
+
name=func.__name__,
|
|
117
|
+
awaitable=func,
|
|
118
|
+
tool=create_tool_call(func),
|
|
119
|
+
)
|
|
120
|
+
else:
|
|
121
|
+
self._callable_tools[func.__name__] = RunFunction(
|
|
122
|
+
name=func.__name__,
|
|
123
|
+
callable=func,
|
|
124
|
+
tool=create_tool_call(func),
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
@wraps(func)
|
|
128
|
+
def wrapper(*args, **kwargs):
|
|
129
|
+
logger.info(f"Executing {func.__name__}")
|
|
130
|
+
return func(*args, **kwargs)
|
|
131
|
+
|
|
132
|
+
return wrapper
|
|
133
|
+
|
|
134
|
+
async def register_mcp_clients(self, mcp_clients: list[MCPClientProtocol]) -> None:
|
|
135
|
+
"""Registering multiple MCP clients at the same time in the same asyncio.Task."""
|
|
136
|
+
for mcp_client in mcp_clients:
|
|
137
|
+
await self.register_mcp_client(mcp_client)
|
|
138
|
+
|
|
139
|
+
async def register_mcp_client(self, mcp_client: MCPClientProtocol) -> None:
|
|
140
|
+
"""Add a MCP client to the context."""
|
|
141
|
+
await mcp_client.initialize(exit_stack=self._exit_stack)
|
|
142
|
+
tools = await mcp_client.get_tools()
|
|
143
|
+
for tool in tools:
|
|
144
|
+
logger.info(
|
|
145
|
+
f"Adding tool {tool.function.name} from {mcp_client._name or 'mcp client'}"
|
|
146
|
+
)
|
|
147
|
+
self._callable_tools[tool.function.name] = RunMCPTool(
|
|
148
|
+
name=tool.function.name,
|
|
149
|
+
tool=tool,
|
|
150
|
+
mcp_client=mcp_client,
|
|
151
|
+
)
|
|
152
|
+
self._mcp_clients.append(mcp_client)
|
|
153
|
+
|
|
154
|
+
async def execute_function_calls(
|
|
155
|
+
self, function_calls: list[FunctionCallEntry]
|
|
156
|
+
) -> list[FunctionResultEntry]:
|
|
157
|
+
"""Execute function calls and create function results from them."""
|
|
158
|
+
if not all(
|
|
159
|
+
function_call.name in self._callable_tools
|
|
160
|
+
for function_call in function_calls
|
|
161
|
+
):
|
|
162
|
+
logger.warning("Can't execute all functions, stopping run here")
|
|
163
|
+
return []
|
|
164
|
+
function_result_tasks = []
|
|
165
|
+
for function_call in function_calls:
|
|
166
|
+
function_result_tasks.append(
|
|
167
|
+
asyncio.create_task(
|
|
168
|
+
create_function_result(
|
|
169
|
+
function_call=function_call,
|
|
170
|
+
run_tool=self._callable_tools[function_call.name],
|
|
171
|
+
continue_on_fn_error=self.continue_on_fn_error,
|
|
172
|
+
)
|
|
173
|
+
)
|
|
174
|
+
)
|
|
175
|
+
await asyncio.gather(*function_result_tasks)
|
|
176
|
+
return [task.result() for task in function_result_tasks]
|
|
177
|
+
|
|
178
|
+
def get_tools(self) -> list[FunctionTool]:
|
|
179
|
+
"""Get the tools that are part of the context."""
|
|
180
|
+
callable_tools = [
|
|
181
|
+
run_functions.tool for run_functions in self._callable_tools.values()
|
|
182
|
+
]
|
|
183
|
+
return callable_tools
|
|
184
|
+
|
|
185
|
+
async def prepare_agent_request(self, beta_client: "Beta") -> AgentRequestKwargs:
|
|
186
|
+
"""Prepare an agent request with the functions added to the context.
|
|
187
|
+
|
|
188
|
+
Update the agent definition before making the request.
|
|
189
|
+
"""
|
|
190
|
+
if self.agent_id is None:
|
|
191
|
+
raise RunException(
|
|
192
|
+
"Can't prepare an agent request, if no agent_id is provided"
|
|
193
|
+
)
|
|
194
|
+
agent = await beta_client.agents.get_async(agent_id=self.agent_id)
|
|
195
|
+
agent_tools = agent.tools or []
|
|
196
|
+
updated_tools = []
|
|
197
|
+
for i in range(len(agent_tools)):
|
|
198
|
+
tool = agent_tools[i]
|
|
199
|
+
if tool.type != "function":
|
|
200
|
+
updated_tools.append(tool)
|
|
201
|
+
elif tool.function.name in self._callable_tools:
|
|
202
|
+
# function already exists in the agent, don't add it again
|
|
203
|
+
continue
|
|
204
|
+
else:
|
|
205
|
+
updated_tools.append(tool)
|
|
206
|
+
updated_tools += self.get_tools()
|
|
207
|
+
completion_args = (
|
|
208
|
+
CompletionArgs(response_format=self.response_format)
|
|
209
|
+
if self.output_format
|
|
210
|
+
else None
|
|
211
|
+
)
|
|
212
|
+
beta_client.agents.update(
|
|
213
|
+
agent_id=self.agent_id, tools=updated_tools, completion_args=completion_args
|
|
214
|
+
)
|
|
215
|
+
return AgentRequestKwargs(agent_id=self.agent_id)
|
|
216
|
+
|
|
217
|
+
async def prepare_model_request(
|
|
218
|
+
self,
|
|
219
|
+
tools: OptionalNullable[Union[list[Tools], list[ToolsTypedDict]]] = UNSET,
|
|
220
|
+
completion_args: OptionalNullable[
|
|
221
|
+
Union[CompletionArgs, CompletionArgsTypedDict]
|
|
222
|
+
] = UNSET,
|
|
223
|
+
instructions: OptionalNullable[str] = None,
|
|
224
|
+
) -> ModelRequestKwargs:
|
|
225
|
+
if self.model is None:
|
|
226
|
+
raise RunException("Can't prepare a model request, if no model is provided")
|
|
227
|
+
if not completion_args and self.output_format:
|
|
228
|
+
completion_args = CompletionArgs(response_format=self.response_format)
|
|
229
|
+
elif isinstance(completion_args, CompletionArgs) and self.output_format:
|
|
230
|
+
completion_args.response_format = self.response_format
|
|
231
|
+
elif isinstance(completion_args, dict) and self.output_format:
|
|
232
|
+
completion_args["response_format"] = typing.cast(
|
|
233
|
+
ResponseFormatTypedDict, self.response_format.model_dump()
|
|
234
|
+
)
|
|
235
|
+
request_tools = []
|
|
236
|
+
if isinstance(tools, list):
|
|
237
|
+
for tool in tools:
|
|
238
|
+
request_tools.append(typing.cast(Tools, tool))
|
|
239
|
+
for tool in self.get_tools():
|
|
240
|
+
request_tools.append(tool)
|
|
241
|
+
return ModelRequestKwargs(
|
|
242
|
+
model=self.model,
|
|
243
|
+
tools=request_tools,
|
|
244
|
+
instructions=instructions,
|
|
245
|
+
completion_args=completion_args,
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
@property
|
|
249
|
+
def response_format(self) -> ResponseFormat:
|
|
250
|
+
if not self.output_format:
|
|
251
|
+
raise RunException("No response format exist for the current RunContext.")
|
|
252
|
+
return response_format_from_pydantic_model(self.output_format)
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
async def _validate_run(
|
|
256
|
+
*,
|
|
257
|
+
beta_client: "Beta",
|
|
258
|
+
run_ctx: RunContext,
|
|
259
|
+
inputs: Union[ConversationInputs, ConversationInputsTypedDict],
|
|
260
|
+
instructions: OptionalNullable[str] = UNSET,
|
|
261
|
+
tools: OptionalNullable[Union[list[Tools], list[ToolsTypedDict]]] = UNSET,
|
|
262
|
+
completion_args: OptionalNullable[
|
|
263
|
+
Union[CompletionArgs, CompletionArgsTypedDict]
|
|
264
|
+
] = UNSET,
|
|
265
|
+
) -> tuple[
|
|
266
|
+
Union[AgentRequestKwargs, ModelRequestKwargs], RunResult, list[InputEntries]
|
|
267
|
+
]:
|
|
268
|
+
input_entries: list[InputEntries] = []
|
|
269
|
+
if isinstance(inputs, str):
|
|
270
|
+
input_entries.append(MessageInputEntry(role="user", content=inputs))
|
|
271
|
+
else:
|
|
272
|
+
for input in inputs:
|
|
273
|
+
if isinstance(input, dict):
|
|
274
|
+
input_entries.append(
|
|
275
|
+
pydantic.TypeAdapter(InputEntries).validate_python(input)
|
|
276
|
+
)
|
|
277
|
+
run_result = RunResult(
|
|
278
|
+
input_entries=input_entries,
|
|
279
|
+
output_model=run_ctx.output_format,
|
|
280
|
+
conversation_id=run_ctx.conversation_id,
|
|
281
|
+
)
|
|
282
|
+
req: Union[AgentRequestKwargs, ModelRequestKwargs]
|
|
283
|
+
if run_ctx.agent_id:
|
|
284
|
+
if tools or completion_args:
|
|
285
|
+
raise RunException("Can't set tools or completion_args when using an agent")
|
|
286
|
+
req = await run_ctx.prepare_agent_request(beta_client=beta_client)
|
|
287
|
+
elif run_ctx.model:
|
|
288
|
+
req = await run_ctx.prepare_model_request(
|
|
289
|
+
instructions=instructions,
|
|
290
|
+
tools=tools,
|
|
291
|
+
completion_args=completion_args,
|
|
292
|
+
)
|
|
293
|
+
else:
|
|
294
|
+
raise RunException("Either agent_id or model must be set in the run context")
|
|
295
|
+
return req, run_result, input_entries
|
|
@@ -0,0 +1,212 @@
|
|
|
1
|
+
import datetime
|
|
2
|
+
import json
|
|
3
|
+
import typing
|
|
4
|
+
from typing import Union, Annotated, Optional, Literal
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from pydantic import Discriminator, Tag, BaseModel
|
|
7
|
+
|
|
8
|
+
from mistralai.extra.utils.response_format import pydantic_model_from_json
|
|
9
|
+
from mistralai.models import (
|
|
10
|
+
FunctionResultEntry,
|
|
11
|
+
FunctionCallEntry,
|
|
12
|
+
MessageOutputEntry,
|
|
13
|
+
AgentHandoffEntry,
|
|
14
|
+
ToolExecutionEntry,
|
|
15
|
+
MessageInputEntry,
|
|
16
|
+
AgentHandoffDoneEvent,
|
|
17
|
+
AgentHandoffStartedEvent,
|
|
18
|
+
ResponseDoneEvent,
|
|
19
|
+
ResponseErrorEvent,
|
|
20
|
+
ResponseStartedEvent,
|
|
21
|
+
FunctionCallEvent,
|
|
22
|
+
MessageOutputEvent,
|
|
23
|
+
ToolExecutionDoneEvent,
|
|
24
|
+
ToolExecutionStartedEvent,
|
|
25
|
+
ConversationEventsData,
|
|
26
|
+
MessageOutputEventContent,
|
|
27
|
+
MessageOutputEntryContent,
|
|
28
|
+
TextChunk,
|
|
29
|
+
MessageOutputContentChunks,
|
|
30
|
+
SSETypes,
|
|
31
|
+
InputEntries,
|
|
32
|
+
ToolFileChunk,
|
|
33
|
+
ToolReferenceChunk,
|
|
34
|
+
FunctionCallEntryArguments,
|
|
35
|
+
)
|
|
36
|
+
from mistralai.utils import get_discriminator
|
|
37
|
+
|
|
38
|
+
RunOutputEntries = typing.Union[
|
|
39
|
+
MessageOutputEntry,
|
|
40
|
+
FunctionCallEntry,
|
|
41
|
+
FunctionResultEntry,
|
|
42
|
+
AgentHandoffEntry,
|
|
43
|
+
ToolExecutionEntry,
|
|
44
|
+
]
|
|
45
|
+
|
|
46
|
+
RunEntries = typing.Union[RunOutputEntries, MessageInputEntry]
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def as_text(entry: RunOutputEntries) -> str:
|
|
50
|
+
"""Keep only the messages and turn content into textual representation."""
|
|
51
|
+
text = ""
|
|
52
|
+
if isinstance(entry, MessageOutputEntry):
|
|
53
|
+
if isinstance(entry.content, str):
|
|
54
|
+
text += entry.content
|
|
55
|
+
else:
|
|
56
|
+
for chunk in entry.content:
|
|
57
|
+
if isinstance(chunk, TextChunk):
|
|
58
|
+
text += chunk.text
|
|
59
|
+
elif isinstance(chunk, ToolFileChunk):
|
|
60
|
+
text += f"<File id={chunk.file_id} name={chunk.file_name}>"
|
|
61
|
+
elif isinstance(chunk, ToolReferenceChunk):
|
|
62
|
+
text += f"<Reference title={chunk.title}>"
|
|
63
|
+
return text
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def reconstitute_message_content(
|
|
67
|
+
chunks: list[MessageOutputEventContent],
|
|
68
|
+
) -> MessageOutputEntryContent:
|
|
69
|
+
"""Given a list of MessageOutputEventContent, recreate a normalised MessageOutputEntryContent."""
|
|
70
|
+
if all(isinstance(chunk, str) for chunk in chunks):
|
|
71
|
+
return "".join(typing.cast(list[str], chunks))
|
|
72
|
+
content: list[MessageOutputContentChunks] = []
|
|
73
|
+
for chunk in chunks:
|
|
74
|
+
if isinstance(chunk, str):
|
|
75
|
+
chunk = TextChunk(text=chunk)
|
|
76
|
+
if isinstance(chunk, TextChunk):
|
|
77
|
+
if len(content) and isinstance(content[-1], TextChunk):
|
|
78
|
+
content[-1].text += chunk.text
|
|
79
|
+
else:
|
|
80
|
+
content.append(chunk)
|
|
81
|
+
else:
|
|
82
|
+
content.append(chunk)
|
|
83
|
+
return content
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def reconstitute_function_call_args(chunks: list[str]) -> FunctionCallEntryArguments:
|
|
87
|
+
"""Recreates function call arguments from stream"""
|
|
88
|
+
return typing.cast(FunctionCallEntryArguments, "".join(chunks))
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def reconstitue_entries(
|
|
92
|
+
received_event_tracker: dict[int, list[ConversationEventsData]],
|
|
93
|
+
) -> list[RunOutputEntries]:
|
|
94
|
+
"""Given a list of events, recreate the corresponding entries."""
|
|
95
|
+
run_entries: list[RunOutputEntries] = []
|
|
96
|
+
for idx, events in sorted(received_event_tracker.items(), key=lambda x: x[0]):
|
|
97
|
+
first_event = events[0]
|
|
98
|
+
if isinstance(first_event, MessageOutputEvent):
|
|
99
|
+
message_events = typing.cast(list[MessageOutputEvent], events)
|
|
100
|
+
run_entries.append(
|
|
101
|
+
MessageOutputEntry(
|
|
102
|
+
content=reconstitute_message_content(
|
|
103
|
+
chunks=[
|
|
104
|
+
message_event.content for message_event in message_events
|
|
105
|
+
]
|
|
106
|
+
),
|
|
107
|
+
created_at=first_event.created_at,
|
|
108
|
+
id=first_event.id,
|
|
109
|
+
agent_id=first_event.agent_id,
|
|
110
|
+
model=first_event.model,
|
|
111
|
+
role=first_event.role,
|
|
112
|
+
)
|
|
113
|
+
)
|
|
114
|
+
elif isinstance(first_event, FunctionCallEvent):
|
|
115
|
+
function_call_events = typing.cast(list[FunctionCallEvent], events)
|
|
116
|
+
run_entries.append(
|
|
117
|
+
FunctionCallEntry(
|
|
118
|
+
name=first_event.name,
|
|
119
|
+
arguments=reconstitute_function_call_args(
|
|
120
|
+
chunks=[
|
|
121
|
+
function_call_event.arguments
|
|
122
|
+
for function_call_event in function_call_events
|
|
123
|
+
]
|
|
124
|
+
),
|
|
125
|
+
created_at=first_event.created_at,
|
|
126
|
+
id=first_event.id,
|
|
127
|
+
tool_call_id=first_event.tool_call_id,
|
|
128
|
+
)
|
|
129
|
+
)
|
|
130
|
+
return run_entries
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
@dataclass
|
|
134
|
+
class RunFiles:
|
|
135
|
+
id: str
|
|
136
|
+
name: str
|
|
137
|
+
content: bytes
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
@dataclass
|
|
141
|
+
class RunResult:
|
|
142
|
+
input_entries: list[InputEntries]
|
|
143
|
+
conversation_id: Optional[str] = field(default=None)
|
|
144
|
+
output_entries: list[RunOutputEntries] = field(default_factory=list)
|
|
145
|
+
files: dict[str, RunFiles] = field(default_factory=dict)
|
|
146
|
+
output_model: Optional[type[BaseModel]] = field(default=None)
|
|
147
|
+
|
|
148
|
+
def get_file(self, file_id: str) -> Optional[RunFiles]:
|
|
149
|
+
return self.files.get(file_id)
|
|
150
|
+
|
|
151
|
+
@property
|
|
152
|
+
def entries(self) -> list[RunEntries]:
|
|
153
|
+
return [*self.input_entries, *self.output_entries]
|
|
154
|
+
|
|
155
|
+
@property
|
|
156
|
+
def output_as_text(self) -> str:
|
|
157
|
+
if not self.output_entries:
|
|
158
|
+
raise ValueError("No output entries were started.")
|
|
159
|
+
return "\n".join(
|
|
160
|
+
as_text(entry)
|
|
161
|
+
for entry in self.output_entries
|
|
162
|
+
if entry.type == "message.output"
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
@property
|
|
166
|
+
def output_as_model(self) -> BaseModel:
|
|
167
|
+
if self.output_model is None:
|
|
168
|
+
raise ValueError("No output format was not set.")
|
|
169
|
+
return pydantic_model_from_json(
|
|
170
|
+
json.loads(self.output_as_text), self.output_model
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
class FunctionResultEvent(BaseModel):
|
|
175
|
+
id: Optional[str] = None
|
|
176
|
+
|
|
177
|
+
type: Optional[Literal["function.result"]] = "function.result"
|
|
178
|
+
|
|
179
|
+
result: str
|
|
180
|
+
|
|
181
|
+
tool_call_id: str
|
|
182
|
+
|
|
183
|
+
created_at: Optional[datetime.datetime] = datetime.datetime.now(
|
|
184
|
+
tz=datetime.timezone.utc
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
output_index: Optional[int] = 0
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
RunResultEventsType = typing.Union[SSETypes, Literal["function.result"]]
|
|
191
|
+
|
|
192
|
+
RunResultEventsData = typing.Annotated[
|
|
193
|
+
Union[
|
|
194
|
+
Annotated[AgentHandoffDoneEvent, Tag("agent.handoff.done")],
|
|
195
|
+
Annotated[AgentHandoffStartedEvent, Tag("agent.handoff.started")],
|
|
196
|
+
Annotated[ResponseDoneEvent, Tag("conversation.response.done")],
|
|
197
|
+
Annotated[ResponseErrorEvent, Tag("conversation.response.error")],
|
|
198
|
+
Annotated[ResponseStartedEvent, Tag("conversation.response.started")],
|
|
199
|
+
Annotated[FunctionCallEvent, Tag("function.call.delta")],
|
|
200
|
+
Annotated[MessageOutputEvent, Tag("message.output.delta")],
|
|
201
|
+
Annotated[ToolExecutionDoneEvent, Tag("tool.execution.done")],
|
|
202
|
+
Annotated[ToolExecutionStartedEvent, Tag("tool.execution.started")],
|
|
203
|
+
Annotated[FunctionResultEvent, Tag("function.result")],
|
|
204
|
+
],
|
|
205
|
+
Discriminator(lambda m: get_discriminator(m, "type", "type")),
|
|
206
|
+
]
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
class RunResultEvents(BaseModel):
|
|
210
|
+
event: RunResultEventsType
|
|
211
|
+
|
|
212
|
+
data: RunResultEventsData
|