agentrun-sdk 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of agentrun-sdk might be problematic. Click here for more details.
- agentrun_operation_sdk/cli/__init__.py +1 -0
- agentrun_operation_sdk/cli/cli.py +19 -0
- agentrun_operation_sdk/cli/common.py +21 -0
- agentrun_operation_sdk/cli/runtime/__init__.py +1 -0
- agentrun_operation_sdk/cli/runtime/commands.py +203 -0
- agentrun_operation_sdk/client/client.py +75 -0
- agentrun_operation_sdk/operations/runtime/__init__.py +8 -0
- agentrun_operation_sdk/operations/runtime/configure.py +101 -0
- agentrun_operation_sdk/operations/runtime/launch.py +82 -0
- agentrun_operation_sdk/operations/runtime/models.py +31 -0
- agentrun_operation_sdk/services/runtime.py +152 -0
- agentrun_operation_sdk/utils/logging_config.py +72 -0
- agentrun_operation_sdk/utils/runtime/config.py +94 -0
- agentrun_operation_sdk/utils/runtime/container.py +280 -0
- agentrun_operation_sdk/utils/runtime/entrypoint.py +203 -0
- agentrun_operation_sdk/utils/runtime/schema.py +56 -0
- agentrun_sdk/__init__.py +7 -0
- agentrun_sdk/agent/__init__.py +25 -0
- agentrun_sdk/agent/agent.py +696 -0
- agentrun_sdk/agent/agent_result.py +46 -0
- agentrun_sdk/agent/conversation_manager/__init__.py +26 -0
- agentrun_sdk/agent/conversation_manager/conversation_manager.py +88 -0
- agentrun_sdk/agent/conversation_manager/null_conversation_manager.py +46 -0
- agentrun_sdk/agent/conversation_manager/sliding_window_conversation_manager.py +179 -0
- agentrun_sdk/agent/conversation_manager/summarizing_conversation_manager.py +252 -0
- agentrun_sdk/agent/state.py +97 -0
- agentrun_sdk/event_loop/__init__.py +9 -0
- agentrun_sdk/event_loop/event_loop.py +499 -0
- agentrun_sdk/event_loop/streaming.py +319 -0
- agentrun_sdk/experimental/__init__.py +4 -0
- agentrun_sdk/experimental/hooks/__init__.py +15 -0
- agentrun_sdk/experimental/hooks/events.py +123 -0
- agentrun_sdk/handlers/__init__.py +10 -0
- agentrun_sdk/handlers/callback_handler.py +70 -0
- agentrun_sdk/hooks/__init__.py +49 -0
- agentrun_sdk/hooks/events.py +80 -0
- agentrun_sdk/hooks/registry.py +247 -0
- agentrun_sdk/models/__init__.py +10 -0
- agentrun_sdk/models/anthropic.py +432 -0
- agentrun_sdk/models/bedrock.py +649 -0
- agentrun_sdk/models/litellm.py +225 -0
- agentrun_sdk/models/llamaapi.py +438 -0
- agentrun_sdk/models/mistral.py +539 -0
- agentrun_sdk/models/model.py +95 -0
- agentrun_sdk/models/ollama.py +357 -0
- agentrun_sdk/models/openai.py +436 -0
- agentrun_sdk/models/sagemaker.py +598 -0
- agentrun_sdk/models/writer.py +449 -0
- agentrun_sdk/multiagent/__init__.py +22 -0
- agentrun_sdk/multiagent/a2a/__init__.py +15 -0
- agentrun_sdk/multiagent/a2a/executor.py +148 -0
- agentrun_sdk/multiagent/a2a/server.py +252 -0
- agentrun_sdk/multiagent/base.py +92 -0
- agentrun_sdk/multiagent/graph.py +555 -0
- agentrun_sdk/multiagent/swarm.py +656 -0
- agentrun_sdk/py.typed +1 -0
- agentrun_sdk/session/__init__.py +18 -0
- agentrun_sdk/session/file_session_manager.py +216 -0
- agentrun_sdk/session/repository_session_manager.py +152 -0
- agentrun_sdk/session/s3_session_manager.py +272 -0
- agentrun_sdk/session/session_manager.py +73 -0
- agentrun_sdk/session/session_repository.py +51 -0
- agentrun_sdk/telemetry/__init__.py +21 -0
- agentrun_sdk/telemetry/config.py +194 -0
- agentrun_sdk/telemetry/metrics.py +476 -0
- agentrun_sdk/telemetry/metrics_constants.py +15 -0
- agentrun_sdk/telemetry/tracer.py +563 -0
- agentrun_sdk/tools/__init__.py +17 -0
- agentrun_sdk/tools/decorator.py +569 -0
- agentrun_sdk/tools/executor.py +137 -0
- agentrun_sdk/tools/loader.py +152 -0
- agentrun_sdk/tools/mcp/__init__.py +13 -0
- agentrun_sdk/tools/mcp/mcp_agent_tool.py +99 -0
- agentrun_sdk/tools/mcp/mcp_client.py +423 -0
- agentrun_sdk/tools/mcp/mcp_instrumentation.py +322 -0
- agentrun_sdk/tools/mcp/mcp_types.py +63 -0
- agentrun_sdk/tools/registry.py +607 -0
- agentrun_sdk/tools/structured_output.py +421 -0
- agentrun_sdk/tools/tools.py +217 -0
- agentrun_sdk/tools/watcher.py +136 -0
- agentrun_sdk/types/__init__.py +5 -0
- agentrun_sdk/types/collections.py +23 -0
- agentrun_sdk/types/content.py +188 -0
- agentrun_sdk/types/event_loop.py +48 -0
- agentrun_sdk/types/exceptions.py +81 -0
- agentrun_sdk/types/guardrails.py +254 -0
- agentrun_sdk/types/media.py +89 -0
- agentrun_sdk/types/session.py +152 -0
- agentrun_sdk/types/streaming.py +201 -0
- agentrun_sdk/types/tools.py +258 -0
- agentrun_sdk/types/traces.py +5 -0
- agentrun_sdk-0.1.2.dist-info/METADATA +51 -0
- agentrun_sdk-0.1.2.dist-info/RECORD +115 -0
- agentrun_sdk-0.1.2.dist-info/WHEEL +5 -0
- agentrun_sdk-0.1.2.dist-info/entry_points.txt +2 -0
- agentrun_sdk-0.1.2.dist-info/top_level.txt +3 -0
- agentrun_wrapper/__init__.py +11 -0
- agentrun_wrapper/_utils/__init__.py +6 -0
- agentrun_wrapper/_utils/endpoints.py +16 -0
- agentrun_wrapper/identity/__init__.py +5 -0
- agentrun_wrapper/identity/auth.py +211 -0
- agentrun_wrapper/memory/__init__.py +6 -0
- agentrun_wrapper/memory/client.py +1697 -0
- agentrun_wrapper/memory/constants.py +103 -0
- agentrun_wrapper/memory/controlplane.py +626 -0
- agentrun_wrapper/py.typed +1 -0
- agentrun_wrapper/runtime/__init__.py +13 -0
- agentrun_wrapper/runtime/app.py +473 -0
- agentrun_wrapper/runtime/context.py +34 -0
- agentrun_wrapper/runtime/models.py +25 -0
- agentrun_wrapper/services/__init__.py +1 -0
- agentrun_wrapper/services/identity.py +192 -0
- agentrun_wrapper/tools/__init__.py +6 -0
- agentrun_wrapper/tools/browser_client.py +325 -0
- agentrun_wrapper/tools/code_interpreter_client.py +186 -0
|
@@ -0,0 +1,539 @@
|
|
|
1
|
+
"""Mistral AI model provider.
|
|
2
|
+
|
|
3
|
+
- Docs: https://docs.mistral.ai/
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import base64
|
|
7
|
+
import json
|
|
8
|
+
import logging
|
|
9
|
+
from typing import Any, AsyncGenerator, Iterable, Optional, Type, TypeVar, Union
|
|
10
|
+
|
|
11
|
+
import mistralai
|
|
12
|
+
from pydantic import BaseModel
|
|
13
|
+
from typing_extensions import TypedDict, Unpack, override
|
|
14
|
+
|
|
15
|
+
from ..types.content import ContentBlock, Messages
|
|
16
|
+
from ..types.exceptions import ModelThrottledException
|
|
17
|
+
from ..types.streaming import StopReason, StreamEvent
|
|
18
|
+
from ..types.tools import ToolResult, ToolSpec, ToolUse
|
|
19
|
+
from .model import Model
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
T = TypeVar("T", bound=BaseModel)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class MistralModel(Model):
|
|
27
|
+
"""Mistral API model provider implementation.
|
|
28
|
+
|
|
29
|
+
The implementation handles Mistral-specific features such as:
|
|
30
|
+
|
|
31
|
+
- Chat and text completions
|
|
32
|
+
- Streaming responses
|
|
33
|
+
- Tool/function calling
|
|
34
|
+
- System prompts
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
class MistralConfig(TypedDict, total=False):
|
|
38
|
+
"""Configuration parameters for Mistral models.
|
|
39
|
+
|
|
40
|
+
Attributes:
|
|
41
|
+
model_id: Mistral model ID (e.g., "mistral-large-latest", "mistral-medium-latest").
|
|
42
|
+
max_tokens: Maximum number of tokens to generate in the response.
|
|
43
|
+
temperature: Controls randomness in generation (0.0 to 1.0).
|
|
44
|
+
top_p: Controls diversity via nucleus sampling.
|
|
45
|
+
stream: Whether to enable streaming responses.
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
model_id: str
|
|
49
|
+
max_tokens: Optional[int]
|
|
50
|
+
temperature: Optional[float]
|
|
51
|
+
top_p: Optional[float]
|
|
52
|
+
stream: Optional[bool]
|
|
53
|
+
|
|
54
|
+
def __init__(
|
|
55
|
+
self,
|
|
56
|
+
api_key: Optional[str] = None,
|
|
57
|
+
*,
|
|
58
|
+
client_args: Optional[dict[str, Any]] = None,
|
|
59
|
+
**model_config: Unpack[MistralConfig],
|
|
60
|
+
) -> None:
|
|
61
|
+
"""Initialize provider instance.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
api_key: Mistral API key. If not provided, will use MISTRAL_API_KEY env var.
|
|
65
|
+
client_args: Additional arguments for the Mistral client.
|
|
66
|
+
**model_config: Configuration options for the Mistral model.
|
|
67
|
+
"""
|
|
68
|
+
if "temperature" in model_config and model_config["temperature"] is not None:
|
|
69
|
+
temp = model_config["temperature"]
|
|
70
|
+
if not 0.0 <= temp <= 1.0:
|
|
71
|
+
raise ValueError(f"temperature must be between 0.0 and 1.0, got {temp}")
|
|
72
|
+
# Warn if temperature is above recommended range
|
|
73
|
+
if temp > 0.7:
|
|
74
|
+
logger.warning(
|
|
75
|
+
"temperature=%s is above the recommended range (0.0-0.7). "
|
|
76
|
+
"High values may produce unpredictable results.",
|
|
77
|
+
temp,
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
if "top_p" in model_config and model_config["top_p"] is not None:
|
|
81
|
+
top_p = model_config["top_p"]
|
|
82
|
+
if not 0.0 <= top_p <= 1.0:
|
|
83
|
+
raise ValueError(f"top_p must be between 0.0 and 1.0, got {top_p}")
|
|
84
|
+
|
|
85
|
+
self.config = MistralModel.MistralConfig(**model_config)
|
|
86
|
+
|
|
87
|
+
# Set default stream to True if not specified
|
|
88
|
+
if "stream" not in self.config:
|
|
89
|
+
self.config["stream"] = True
|
|
90
|
+
|
|
91
|
+
logger.debug("config=<%s> | initializing", self.config)
|
|
92
|
+
|
|
93
|
+
self.client_args = client_args or {}
|
|
94
|
+
if api_key:
|
|
95
|
+
self.client_args["api_key"] = api_key
|
|
96
|
+
|
|
97
|
+
@override
|
|
98
|
+
def update_config(self, **model_config: Unpack[MistralConfig]) -> None: # type: ignore
|
|
99
|
+
"""Update the Mistral Model configuration with the provided arguments.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
**model_config: Configuration overrides.
|
|
103
|
+
"""
|
|
104
|
+
self.config.update(model_config)
|
|
105
|
+
|
|
106
|
+
@override
|
|
107
|
+
def get_config(self) -> MistralConfig:
|
|
108
|
+
"""Get the Mistral model configuration.
|
|
109
|
+
|
|
110
|
+
Returns:
|
|
111
|
+
The Mistral model configuration.
|
|
112
|
+
"""
|
|
113
|
+
return self.config
|
|
114
|
+
|
|
115
|
+
def _format_request_message_content(self, content: ContentBlock) -> Union[str, dict[str, Any]]:
|
|
116
|
+
"""Format a Mistral content block.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
content: Message content.
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
Mistral formatted content.
|
|
123
|
+
|
|
124
|
+
Raises:
|
|
125
|
+
TypeError: If the content block type cannot be converted to a Mistral-compatible format.
|
|
126
|
+
"""
|
|
127
|
+
if "text" in content:
|
|
128
|
+
return content["text"]
|
|
129
|
+
|
|
130
|
+
if "image" in content:
|
|
131
|
+
image_data = content["image"]
|
|
132
|
+
|
|
133
|
+
if "source" in image_data:
|
|
134
|
+
image_bytes = image_data["source"]["bytes"]
|
|
135
|
+
base64_data = base64.b64encode(image_bytes).decode("utf-8")
|
|
136
|
+
format_value = image_data.get("format", "jpeg")
|
|
137
|
+
media_type = f"image/{format_value}"
|
|
138
|
+
return {"type": "image_url", "image_url": f"data:{media_type};base64,{base64_data}"}
|
|
139
|
+
|
|
140
|
+
raise TypeError("content_type=<image> | unsupported image format")
|
|
141
|
+
|
|
142
|
+
raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type")
|
|
143
|
+
|
|
144
|
+
def _format_request_message_tool_call(self, tool_use: ToolUse) -> dict[str, Any]:
|
|
145
|
+
"""Format a Mistral tool call.
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
tool_use: Tool use requested by the model.
|
|
149
|
+
|
|
150
|
+
Returns:
|
|
151
|
+
Mistral formatted tool call.
|
|
152
|
+
"""
|
|
153
|
+
return {
|
|
154
|
+
"function": {
|
|
155
|
+
"name": tool_use["name"],
|
|
156
|
+
"arguments": json.dumps(tool_use["input"]),
|
|
157
|
+
},
|
|
158
|
+
"id": tool_use["toolUseId"],
|
|
159
|
+
"type": "function",
|
|
160
|
+
}
|
|
161
|
+
|
|
162
|
+
def _format_request_tool_message(self, tool_result: ToolResult) -> dict[str, Any]:
|
|
163
|
+
"""Format a Mistral tool message.
|
|
164
|
+
|
|
165
|
+
Args:
|
|
166
|
+
tool_result: Tool result collected from a tool execution.
|
|
167
|
+
|
|
168
|
+
Returns:
|
|
169
|
+
Mistral formatted tool message.
|
|
170
|
+
"""
|
|
171
|
+
content_parts: list[str] = []
|
|
172
|
+
for content in tool_result["content"]:
|
|
173
|
+
if "json" in content:
|
|
174
|
+
content_parts.append(json.dumps(content["json"]))
|
|
175
|
+
elif "text" in content:
|
|
176
|
+
content_parts.append(content["text"])
|
|
177
|
+
|
|
178
|
+
return {
|
|
179
|
+
"role": "tool",
|
|
180
|
+
"name": tool_result["toolUseId"].split("_")[0]
|
|
181
|
+
if "_" in tool_result["toolUseId"]
|
|
182
|
+
else tool_result["toolUseId"],
|
|
183
|
+
"content": "\n".join(content_parts),
|
|
184
|
+
"tool_call_id": tool_result["toolUseId"],
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
def _format_request_messages(self, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]:
|
|
188
|
+
"""Format a Mistral compatible messages array.
|
|
189
|
+
|
|
190
|
+
Args:
|
|
191
|
+
messages: List of message objects to be processed by the model.
|
|
192
|
+
system_prompt: System prompt to provide context to the model.
|
|
193
|
+
|
|
194
|
+
Returns:
|
|
195
|
+
A Mistral compatible messages array.
|
|
196
|
+
"""
|
|
197
|
+
formatted_messages: list[dict[str, Any]] = []
|
|
198
|
+
|
|
199
|
+
if system_prompt:
|
|
200
|
+
formatted_messages.append({"role": "system", "content": system_prompt})
|
|
201
|
+
|
|
202
|
+
for message in messages:
|
|
203
|
+
role = message["role"]
|
|
204
|
+
contents = message["content"]
|
|
205
|
+
|
|
206
|
+
text_contents: list[str] = []
|
|
207
|
+
tool_calls: list[dict[str, Any]] = []
|
|
208
|
+
tool_messages: list[dict[str, Any]] = []
|
|
209
|
+
|
|
210
|
+
for content in contents:
|
|
211
|
+
if "text" in content:
|
|
212
|
+
formatted_content = self._format_request_message_content(content)
|
|
213
|
+
if isinstance(formatted_content, str):
|
|
214
|
+
text_contents.append(formatted_content)
|
|
215
|
+
elif "toolUse" in content:
|
|
216
|
+
tool_calls.append(self._format_request_message_tool_call(content["toolUse"]))
|
|
217
|
+
elif "toolResult" in content:
|
|
218
|
+
tool_messages.append(self._format_request_tool_message(content["toolResult"]))
|
|
219
|
+
|
|
220
|
+
if text_contents or tool_calls:
|
|
221
|
+
formatted_message: dict[str, Any] = {
|
|
222
|
+
"role": role,
|
|
223
|
+
"content": " ".join(text_contents) if text_contents else "",
|
|
224
|
+
}
|
|
225
|
+
|
|
226
|
+
if tool_calls:
|
|
227
|
+
formatted_message["tool_calls"] = tool_calls
|
|
228
|
+
|
|
229
|
+
formatted_messages.append(formatted_message)
|
|
230
|
+
|
|
231
|
+
formatted_messages.extend(tool_messages)
|
|
232
|
+
|
|
233
|
+
return formatted_messages
|
|
234
|
+
|
|
235
|
+
def format_request(
|
|
236
|
+
self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None
|
|
237
|
+
) -> dict[str, Any]:
|
|
238
|
+
"""Format a Mistral chat streaming request.
|
|
239
|
+
|
|
240
|
+
Args:
|
|
241
|
+
messages: List of message objects to be processed by the model.
|
|
242
|
+
tool_specs: List of tool specifications to make available to the model.
|
|
243
|
+
system_prompt: System prompt to provide context to the model.
|
|
244
|
+
|
|
245
|
+
Returns:
|
|
246
|
+
A Mistral chat streaming request.
|
|
247
|
+
|
|
248
|
+
Raises:
|
|
249
|
+
TypeError: If a message contains a content block type that cannot be converted to a Mistral-compatible
|
|
250
|
+
format.
|
|
251
|
+
"""
|
|
252
|
+
request: dict[str, Any] = {
|
|
253
|
+
"model": self.config["model_id"],
|
|
254
|
+
"messages": self._format_request_messages(messages, system_prompt),
|
|
255
|
+
}
|
|
256
|
+
|
|
257
|
+
if "max_tokens" in self.config:
|
|
258
|
+
request["max_tokens"] = self.config["max_tokens"]
|
|
259
|
+
if "temperature" in self.config:
|
|
260
|
+
request["temperature"] = self.config["temperature"]
|
|
261
|
+
if "top_p" in self.config:
|
|
262
|
+
request["top_p"] = self.config["top_p"]
|
|
263
|
+
if "stream" in self.config:
|
|
264
|
+
request["stream"] = self.config["stream"]
|
|
265
|
+
|
|
266
|
+
if tool_specs:
|
|
267
|
+
request["tools"] = [
|
|
268
|
+
{
|
|
269
|
+
"type": "function",
|
|
270
|
+
"function": {
|
|
271
|
+
"name": tool_spec["name"],
|
|
272
|
+
"description": tool_spec["description"],
|
|
273
|
+
"parameters": tool_spec["inputSchema"]["json"],
|
|
274
|
+
},
|
|
275
|
+
}
|
|
276
|
+
for tool_spec in tool_specs
|
|
277
|
+
]
|
|
278
|
+
|
|
279
|
+
return request
|
|
280
|
+
|
|
281
|
+
def format_chunk(self, event: dict[str, Any]) -> StreamEvent:
|
|
282
|
+
"""Format the Mistral response events into standardized message chunks.
|
|
283
|
+
|
|
284
|
+
Args:
|
|
285
|
+
event: A response event from the Mistral model.
|
|
286
|
+
|
|
287
|
+
Returns:
|
|
288
|
+
The formatted chunk.
|
|
289
|
+
|
|
290
|
+
Raises:
|
|
291
|
+
RuntimeError: If chunk_type is not recognized.
|
|
292
|
+
"""
|
|
293
|
+
match event["chunk_type"]:
|
|
294
|
+
case "message_start":
|
|
295
|
+
return {"messageStart": {"role": "assistant"}}
|
|
296
|
+
|
|
297
|
+
case "content_start":
|
|
298
|
+
if event["data_type"] == "text":
|
|
299
|
+
return {"contentBlockStart": {"start": {}}}
|
|
300
|
+
|
|
301
|
+
tool_call = event["data"]
|
|
302
|
+
return {
|
|
303
|
+
"contentBlockStart": {
|
|
304
|
+
"start": {
|
|
305
|
+
"toolUse": {
|
|
306
|
+
"name": tool_call.function.name,
|
|
307
|
+
"toolUseId": tool_call.id,
|
|
308
|
+
}
|
|
309
|
+
}
|
|
310
|
+
}
|
|
311
|
+
}
|
|
312
|
+
|
|
313
|
+
case "content_delta":
|
|
314
|
+
if event["data_type"] == "text":
|
|
315
|
+
return {"contentBlockDelta": {"delta": {"text": event["data"]}}}
|
|
316
|
+
|
|
317
|
+
return {"contentBlockDelta": {"delta": {"toolUse": {"input": event["data"]}}}}
|
|
318
|
+
|
|
319
|
+
case "content_stop":
|
|
320
|
+
return {"contentBlockStop": {}}
|
|
321
|
+
|
|
322
|
+
case "message_stop":
|
|
323
|
+
reason: StopReason
|
|
324
|
+
if event["data"] == "tool_calls":
|
|
325
|
+
reason = "tool_use"
|
|
326
|
+
elif event["data"] == "length":
|
|
327
|
+
reason = "max_tokens"
|
|
328
|
+
else:
|
|
329
|
+
reason = "end_turn"
|
|
330
|
+
|
|
331
|
+
return {"messageStop": {"stopReason": reason}}
|
|
332
|
+
|
|
333
|
+
case "metadata":
|
|
334
|
+
usage = event["data"]
|
|
335
|
+
return {
|
|
336
|
+
"metadata": {
|
|
337
|
+
"usage": {
|
|
338
|
+
"inputTokens": usage.prompt_tokens,
|
|
339
|
+
"outputTokens": usage.completion_tokens,
|
|
340
|
+
"totalTokens": usage.total_tokens,
|
|
341
|
+
},
|
|
342
|
+
"metrics": {
|
|
343
|
+
"latencyMs": event.get("latency_ms", 0),
|
|
344
|
+
},
|
|
345
|
+
},
|
|
346
|
+
}
|
|
347
|
+
|
|
348
|
+
case _:
|
|
349
|
+
raise RuntimeError(f"chunk_type=<{event['chunk_type']}> | unknown type")
|
|
350
|
+
|
|
351
|
+
def _handle_non_streaming_response(self, response: Any) -> Iterable[dict[str, Any]]:
|
|
352
|
+
"""Handle non-streaming response from Mistral API.
|
|
353
|
+
|
|
354
|
+
Args:
|
|
355
|
+
response: The non-streaming response from Mistral.
|
|
356
|
+
|
|
357
|
+
Yields:
|
|
358
|
+
Formatted events that match the streaming format.
|
|
359
|
+
"""
|
|
360
|
+
yield {"chunk_type": "message_start"}
|
|
361
|
+
|
|
362
|
+
content_started = False
|
|
363
|
+
|
|
364
|
+
if response.choices and response.choices[0].message:
|
|
365
|
+
message = response.choices[0].message
|
|
366
|
+
|
|
367
|
+
if hasattr(message, "content") and message.content:
|
|
368
|
+
if not content_started:
|
|
369
|
+
yield {"chunk_type": "content_start", "data_type": "text"}
|
|
370
|
+
content_started = True
|
|
371
|
+
|
|
372
|
+
yield {"chunk_type": "content_delta", "data_type": "text", "data": message.content}
|
|
373
|
+
|
|
374
|
+
yield {"chunk_type": "content_stop"}
|
|
375
|
+
|
|
376
|
+
if hasattr(message, "tool_calls") and message.tool_calls:
|
|
377
|
+
for tool_call in message.tool_calls:
|
|
378
|
+
yield {"chunk_type": "content_start", "data_type": "tool", "data": tool_call}
|
|
379
|
+
|
|
380
|
+
if hasattr(tool_call.function, "arguments"):
|
|
381
|
+
yield {"chunk_type": "content_delta", "data_type": "tool", "data": tool_call.function.arguments}
|
|
382
|
+
|
|
383
|
+
yield {"chunk_type": "content_stop"}
|
|
384
|
+
|
|
385
|
+
finish_reason = response.choices[0].finish_reason if response.choices[0].finish_reason else "stop"
|
|
386
|
+
yield {"chunk_type": "message_stop", "data": finish_reason}
|
|
387
|
+
|
|
388
|
+
if hasattr(response, "usage") and response.usage:
|
|
389
|
+
yield {"chunk_type": "metadata", "data": response.usage}
|
|
390
|
+
|
|
391
|
+
@override
|
|
392
|
+
async def stream(
|
|
393
|
+
self,
|
|
394
|
+
messages: Messages,
|
|
395
|
+
tool_specs: Optional[list[ToolSpec]] = None,
|
|
396
|
+
system_prompt: Optional[str] = None,
|
|
397
|
+
**kwargs: Any,
|
|
398
|
+
) -> AsyncGenerator[StreamEvent, None]:
|
|
399
|
+
"""Stream conversation with the Mistral model.
|
|
400
|
+
|
|
401
|
+
Args:
|
|
402
|
+
messages: List of message objects to be processed by the model.
|
|
403
|
+
tool_specs: List of tool specifications to make available to the model.
|
|
404
|
+
system_prompt: System prompt to provide context to the model.
|
|
405
|
+
**kwargs: Additional keyword arguments for future extensibility.
|
|
406
|
+
|
|
407
|
+
Yields:
|
|
408
|
+
Formatted message chunks from the model.
|
|
409
|
+
|
|
410
|
+
Raises:
|
|
411
|
+
ModelThrottledException: When the model service is throttling requests.
|
|
412
|
+
"""
|
|
413
|
+
logger.debug("formatting request")
|
|
414
|
+
request = self.format_request(messages, tool_specs, system_prompt)
|
|
415
|
+
logger.debug("request=<%s>", request)
|
|
416
|
+
|
|
417
|
+
logger.debug("invoking model")
|
|
418
|
+
try:
|
|
419
|
+
logger.debug("got response from model")
|
|
420
|
+
if not self.config.get("stream", True):
|
|
421
|
+
# Use non-streaming API
|
|
422
|
+
async with mistralai.Mistral(**self.client_args) as client:
|
|
423
|
+
response = await client.chat.complete_async(**request)
|
|
424
|
+
for event in self._handle_non_streaming_response(response):
|
|
425
|
+
yield self.format_chunk(event)
|
|
426
|
+
|
|
427
|
+
return
|
|
428
|
+
|
|
429
|
+
# Use the streaming API
|
|
430
|
+
async with mistralai.Mistral(**self.client_args) as client:
|
|
431
|
+
stream_response = await client.chat.stream_async(**request)
|
|
432
|
+
|
|
433
|
+
yield self.format_chunk({"chunk_type": "message_start"})
|
|
434
|
+
|
|
435
|
+
content_started = False
|
|
436
|
+
tool_calls: dict[str, list[Any]] = {}
|
|
437
|
+
accumulated_text = ""
|
|
438
|
+
|
|
439
|
+
async for chunk in stream_response:
|
|
440
|
+
if hasattr(chunk, "data") and hasattr(chunk.data, "choices") and chunk.data.choices:
|
|
441
|
+
choice = chunk.data.choices[0]
|
|
442
|
+
|
|
443
|
+
if hasattr(choice, "delta"):
|
|
444
|
+
delta = choice.delta
|
|
445
|
+
|
|
446
|
+
if hasattr(delta, "content") and delta.content:
|
|
447
|
+
if not content_started:
|
|
448
|
+
yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"})
|
|
449
|
+
content_started = True
|
|
450
|
+
|
|
451
|
+
yield self.format_chunk(
|
|
452
|
+
{"chunk_type": "content_delta", "data_type": "text", "data": delta.content}
|
|
453
|
+
)
|
|
454
|
+
accumulated_text += delta.content
|
|
455
|
+
|
|
456
|
+
if hasattr(delta, "tool_calls") and delta.tool_calls:
|
|
457
|
+
for tool_call in delta.tool_calls:
|
|
458
|
+
tool_id = tool_call.id
|
|
459
|
+
tool_calls.setdefault(tool_id, []).append(tool_call)
|
|
460
|
+
|
|
461
|
+
if hasattr(choice, "finish_reason") and choice.finish_reason:
|
|
462
|
+
if content_started:
|
|
463
|
+
yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"})
|
|
464
|
+
|
|
465
|
+
for tool_deltas in tool_calls.values():
|
|
466
|
+
yield self.format_chunk(
|
|
467
|
+
{"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]}
|
|
468
|
+
)
|
|
469
|
+
|
|
470
|
+
for tool_delta in tool_deltas:
|
|
471
|
+
if hasattr(tool_delta.function, "arguments"):
|
|
472
|
+
yield self.format_chunk(
|
|
473
|
+
{
|
|
474
|
+
"chunk_type": "content_delta",
|
|
475
|
+
"data_type": "tool",
|
|
476
|
+
"data": tool_delta.function.arguments,
|
|
477
|
+
}
|
|
478
|
+
)
|
|
479
|
+
|
|
480
|
+
yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"})
|
|
481
|
+
|
|
482
|
+
yield self.format_chunk({"chunk_type": "message_stop", "data": choice.finish_reason})
|
|
483
|
+
|
|
484
|
+
if hasattr(chunk, "usage"):
|
|
485
|
+
yield self.format_chunk({"chunk_type": "metadata", "data": chunk.usage})
|
|
486
|
+
|
|
487
|
+
except Exception as e:
|
|
488
|
+
if "rate" in str(e).lower() or "429" in str(e):
|
|
489
|
+
raise ModelThrottledException(str(e)) from e
|
|
490
|
+
raise
|
|
491
|
+
|
|
492
|
+
logger.debug("finished streaming response from model")
|
|
493
|
+
|
|
494
|
+
@override
|
|
495
|
+
async def structured_output(
|
|
496
|
+
self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any
|
|
497
|
+
) -> AsyncGenerator[dict[str, Union[T, Any]], None]:
|
|
498
|
+
"""Get structured output from the model.
|
|
499
|
+
|
|
500
|
+
Args:
|
|
501
|
+
output_model: The output model to use for the agent.
|
|
502
|
+
prompt: The prompt messages to use for the agent.
|
|
503
|
+
system_prompt: System prompt to provide context to the model.
|
|
504
|
+
**kwargs: Additional keyword arguments for future extensibility.
|
|
505
|
+
|
|
506
|
+
Returns:
|
|
507
|
+
An instance of the output model with the generated data.
|
|
508
|
+
|
|
509
|
+
Raises:
|
|
510
|
+
ValueError: If the response cannot be parsed into the output model.
|
|
511
|
+
"""
|
|
512
|
+
tool_spec: ToolSpec = {
|
|
513
|
+
"name": f"extract_{output_model.__name__.lower()}",
|
|
514
|
+
"description": f"Extract structured data in the format of {output_model.__name__}",
|
|
515
|
+
"inputSchema": {"json": output_model.model_json_schema()},
|
|
516
|
+
}
|
|
517
|
+
|
|
518
|
+
formatted_request = self.format_request(messages=prompt, tool_specs=[tool_spec], system_prompt=system_prompt)
|
|
519
|
+
|
|
520
|
+
formatted_request["tool_choice"] = "any"
|
|
521
|
+
formatted_request["parallel_tool_calls"] = False
|
|
522
|
+
|
|
523
|
+
async with mistralai.Mistral(**self.client_args) as client:
|
|
524
|
+
response = await client.chat.complete_async(**formatted_request)
|
|
525
|
+
|
|
526
|
+
if response.choices and response.choices[0].message.tool_calls:
|
|
527
|
+
tool_call = response.choices[0].message.tool_calls[0]
|
|
528
|
+
try:
|
|
529
|
+
# Handle both string and dict arguments
|
|
530
|
+
if isinstance(tool_call.function.arguments, str):
|
|
531
|
+
arguments = json.loads(tool_call.function.arguments)
|
|
532
|
+
else:
|
|
533
|
+
arguments = tool_call.function.arguments
|
|
534
|
+
yield {"output": output_model(**arguments)}
|
|
535
|
+
return
|
|
536
|
+
except (json.JSONDecodeError, TypeError, ValueError) as e:
|
|
537
|
+
raise ValueError(f"Failed to parse tool call arguments into model: {e}") from e
|
|
538
|
+
|
|
539
|
+
raise ValueError("No tool calls found in response")
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
"""Abstract base class for Agent model providers."""
|
|
2
|
+
|
|
3
|
+
import abc
|
|
4
|
+
import logging
|
|
5
|
+
from typing import Any, AsyncGenerator, AsyncIterable, Optional, Type, TypeVar, Union
|
|
6
|
+
|
|
7
|
+
from pydantic import BaseModel
|
|
8
|
+
|
|
9
|
+
from ..types.content import Messages
|
|
10
|
+
from ..types.streaming import StreamEvent
|
|
11
|
+
from ..types.tools import ToolSpec
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
T = TypeVar("T", bound=BaseModel)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class Model(abc.ABC):
|
|
19
|
+
"""Abstract base class for Agent model providers.
|
|
20
|
+
|
|
21
|
+
This class defines the interface for all model implementations in the Strands Agents SDK. It provides a
|
|
22
|
+
standardized way to configure and process requests for different AI model providers.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
@abc.abstractmethod
|
|
26
|
+
# pragma: no cover
|
|
27
|
+
def update_config(self, **model_config: Any) -> None:
|
|
28
|
+
"""Update the model configuration with the provided arguments.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
**model_config: Configuration overrides.
|
|
32
|
+
"""
|
|
33
|
+
pass
|
|
34
|
+
|
|
35
|
+
@abc.abstractmethod
|
|
36
|
+
# pragma: no cover
|
|
37
|
+
def get_config(self) -> Any:
|
|
38
|
+
"""Return the model configuration.
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
The model's configuration.
|
|
42
|
+
"""
|
|
43
|
+
pass
|
|
44
|
+
|
|
45
|
+
@abc.abstractmethod
|
|
46
|
+
# pragma: no cover
|
|
47
|
+
def structured_output(
|
|
48
|
+
self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any
|
|
49
|
+
) -> AsyncGenerator[dict[str, Union[T, Any]], None]:
|
|
50
|
+
"""Get structured output from the model.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
output_model: The output model to use for the agent.
|
|
54
|
+
prompt: The prompt messages to use for the agent.
|
|
55
|
+
system_prompt: System prompt to provide context to the model.
|
|
56
|
+
**kwargs: Additional keyword arguments for future extensibility.
|
|
57
|
+
|
|
58
|
+
Yields:
|
|
59
|
+
Model events with the last being the structured output.
|
|
60
|
+
|
|
61
|
+
Raises:
|
|
62
|
+
ValidationException: The response format from the model does not match the output_model
|
|
63
|
+
"""
|
|
64
|
+
pass
|
|
65
|
+
|
|
66
|
+
@abc.abstractmethod
|
|
67
|
+
# pragma: no cover
|
|
68
|
+
def stream(
|
|
69
|
+
self,
|
|
70
|
+
messages: Messages,
|
|
71
|
+
tool_specs: Optional[list[ToolSpec]] = None,
|
|
72
|
+
system_prompt: Optional[str] = None,
|
|
73
|
+
**kwargs: Any,
|
|
74
|
+
) -> AsyncIterable[StreamEvent]:
|
|
75
|
+
"""Stream conversation with the model.
|
|
76
|
+
|
|
77
|
+
This method handles the full lifecycle of conversing with the model:
|
|
78
|
+
|
|
79
|
+
1. Format the messages, tool specs, and configuration into a streaming request
|
|
80
|
+
2. Send the request to the model
|
|
81
|
+
3. Yield the formatted message chunks
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
messages: List of message objects to be processed by the model.
|
|
85
|
+
tool_specs: List of tool specifications to make available to the model.
|
|
86
|
+
system_prompt: System prompt to provide context to the model.
|
|
87
|
+
**kwargs: Additional keyword arguments for future extensibility.
|
|
88
|
+
|
|
89
|
+
Yields:
|
|
90
|
+
Formatted message chunks from the model.
|
|
91
|
+
|
|
92
|
+
Raises:
|
|
93
|
+
ModelThrottledException: When the model service is throttling requests from the client.
|
|
94
|
+
"""
|
|
95
|
+
pass
|