agentrun-sdk 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of agentrun-sdk might be problematic. Click here for more details.

Files changed (115) hide show
  1. agentrun_operation_sdk/cli/__init__.py +1 -0
  2. agentrun_operation_sdk/cli/cli.py +19 -0
  3. agentrun_operation_sdk/cli/common.py +21 -0
  4. agentrun_operation_sdk/cli/runtime/__init__.py +1 -0
  5. agentrun_operation_sdk/cli/runtime/commands.py +203 -0
  6. agentrun_operation_sdk/client/client.py +75 -0
  7. agentrun_operation_sdk/operations/runtime/__init__.py +8 -0
  8. agentrun_operation_sdk/operations/runtime/configure.py +101 -0
  9. agentrun_operation_sdk/operations/runtime/launch.py +82 -0
  10. agentrun_operation_sdk/operations/runtime/models.py +31 -0
  11. agentrun_operation_sdk/services/runtime.py +152 -0
  12. agentrun_operation_sdk/utils/logging_config.py +72 -0
  13. agentrun_operation_sdk/utils/runtime/config.py +94 -0
  14. agentrun_operation_sdk/utils/runtime/container.py +280 -0
  15. agentrun_operation_sdk/utils/runtime/entrypoint.py +203 -0
  16. agentrun_operation_sdk/utils/runtime/schema.py +56 -0
  17. agentrun_sdk/__init__.py +7 -0
  18. agentrun_sdk/agent/__init__.py +25 -0
  19. agentrun_sdk/agent/agent.py +696 -0
  20. agentrun_sdk/agent/agent_result.py +46 -0
  21. agentrun_sdk/agent/conversation_manager/__init__.py +26 -0
  22. agentrun_sdk/agent/conversation_manager/conversation_manager.py +88 -0
  23. agentrun_sdk/agent/conversation_manager/null_conversation_manager.py +46 -0
  24. agentrun_sdk/agent/conversation_manager/sliding_window_conversation_manager.py +179 -0
  25. agentrun_sdk/agent/conversation_manager/summarizing_conversation_manager.py +252 -0
  26. agentrun_sdk/agent/state.py +97 -0
  27. agentrun_sdk/event_loop/__init__.py +9 -0
  28. agentrun_sdk/event_loop/event_loop.py +499 -0
  29. agentrun_sdk/event_loop/streaming.py +319 -0
  30. agentrun_sdk/experimental/__init__.py +4 -0
  31. agentrun_sdk/experimental/hooks/__init__.py +15 -0
  32. agentrun_sdk/experimental/hooks/events.py +123 -0
  33. agentrun_sdk/handlers/__init__.py +10 -0
  34. agentrun_sdk/handlers/callback_handler.py +70 -0
  35. agentrun_sdk/hooks/__init__.py +49 -0
  36. agentrun_sdk/hooks/events.py +80 -0
  37. agentrun_sdk/hooks/registry.py +247 -0
  38. agentrun_sdk/models/__init__.py +10 -0
  39. agentrun_sdk/models/anthropic.py +432 -0
  40. agentrun_sdk/models/bedrock.py +649 -0
  41. agentrun_sdk/models/litellm.py +225 -0
  42. agentrun_sdk/models/llamaapi.py +438 -0
  43. agentrun_sdk/models/mistral.py +539 -0
  44. agentrun_sdk/models/model.py +95 -0
  45. agentrun_sdk/models/ollama.py +357 -0
  46. agentrun_sdk/models/openai.py +436 -0
  47. agentrun_sdk/models/sagemaker.py +598 -0
  48. agentrun_sdk/models/writer.py +449 -0
  49. agentrun_sdk/multiagent/__init__.py +22 -0
  50. agentrun_sdk/multiagent/a2a/__init__.py +15 -0
  51. agentrun_sdk/multiagent/a2a/executor.py +148 -0
  52. agentrun_sdk/multiagent/a2a/server.py +252 -0
  53. agentrun_sdk/multiagent/base.py +92 -0
  54. agentrun_sdk/multiagent/graph.py +555 -0
  55. agentrun_sdk/multiagent/swarm.py +656 -0
  56. agentrun_sdk/py.typed +1 -0
  57. agentrun_sdk/session/__init__.py +18 -0
  58. agentrun_sdk/session/file_session_manager.py +216 -0
  59. agentrun_sdk/session/repository_session_manager.py +152 -0
  60. agentrun_sdk/session/s3_session_manager.py +272 -0
  61. agentrun_sdk/session/session_manager.py +73 -0
  62. agentrun_sdk/session/session_repository.py +51 -0
  63. agentrun_sdk/telemetry/__init__.py +21 -0
  64. agentrun_sdk/telemetry/config.py +194 -0
  65. agentrun_sdk/telemetry/metrics.py +476 -0
  66. agentrun_sdk/telemetry/metrics_constants.py +15 -0
  67. agentrun_sdk/telemetry/tracer.py +563 -0
  68. agentrun_sdk/tools/__init__.py +17 -0
  69. agentrun_sdk/tools/decorator.py +569 -0
  70. agentrun_sdk/tools/executor.py +137 -0
  71. agentrun_sdk/tools/loader.py +152 -0
  72. agentrun_sdk/tools/mcp/__init__.py +13 -0
  73. agentrun_sdk/tools/mcp/mcp_agent_tool.py +99 -0
  74. agentrun_sdk/tools/mcp/mcp_client.py +423 -0
  75. agentrun_sdk/tools/mcp/mcp_instrumentation.py +322 -0
  76. agentrun_sdk/tools/mcp/mcp_types.py +63 -0
  77. agentrun_sdk/tools/registry.py +607 -0
  78. agentrun_sdk/tools/structured_output.py +421 -0
  79. agentrun_sdk/tools/tools.py +217 -0
  80. agentrun_sdk/tools/watcher.py +136 -0
  81. agentrun_sdk/types/__init__.py +5 -0
  82. agentrun_sdk/types/collections.py +23 -0
  83. agentrun_sdk/types/content.py +188 -0
  84. agentrun_sdk/types/event_loop.py +48 -0
  85. agentrun_sdk/types/exceptions.py +81 -0
  86. agentrun_sdk/types/guardrails.py +254 -0
  87. agentrun_sdk/types/media.py +89 -0
  88. agentrun_sdk/types/session.py +152 -0
  89. agentrun_sdk/types/streaming.py +201 -0
  90. agentrun_sdk/types/tools.py +258 -0
  91. agentrun_sdk/types/traces.py +5 -0
  92. agentrun_sdk-0.1.2.dist-info/METADATA +51 -0
  93. agentrun_sdk-0.1.2.dist-info/RECORD +115 -0
  94. agentrun_sdk-0.1.2.dist-info/WHEEL +5 -0
  95. agentrun_sdk-0.1.2.dist-info/entry_points.txt +2 -0
  96. agentrun_sdk-0.1.2.dist-info/top_level.txt +3 -0
  97. agentrun_wrapper/__init__.py +11 -0
  98. agentrun_wrapper/_utils/__init__.py +6 -0
  99. agentrun_wrapper/_utils/endpoints.py +16 -0
  100. agentrun_wrapper/identity/__init__.py +5 -0
  101. agentrun_wrapper/identity/auth.py +211 -0
  102. agentrun_wrapper/memory/__init__.py +6 -0
  103. agentrun_wrapper/memory/client.py +1697 -0
  104. agentrun_wrapper/memory/constants.py +103 -0
  105. agentrun_wrapper/memory/controlplane.py +626 -0
  106. agentrun_wrapper/py.typed +1 -0
  107. agentrun_wrapper/runtime/__init__.py +13 -0
  108. agentrun_wrapper/runtime/app.py +473 -0
  109. agentrun_wrapper/runtime/context.py +34 -0
  110. agentrun_wrapper/runtime/models.py +25 -0
  111. agentrun_wrapper/services/__init__.py +1 -0
  112. agentrun_wrapper/services/identity.py +192 -0
  113. agentrun_wrapper/tools/__init__.py +6 -0
  114. agentrun_wrapper/tools/browser_client.py +325 -0
  115. agentrun_wrapper/tools/code_interpreter_client.py +186 -0
@@ -0,0 +1,539 @@
1
+ """Mistral AI model provider.
2
+
3
+ - Docs: https://docs.mistral.ai/
4
+ """
5
+
6
+ import base64
7
+ import json
8
+ import logging
9
+ from typing import Any, AsyncGenerator, Iterable, Optional, Type, TypeVar, Union
10
+
11
+ import mistralai
12
+ from pydantic import BaseModel
13
+ from typing_extensions import TypedDict, Unpack, override
14
+
15
+ from ..types.content import ContentBlock, Messages
16
+ from ..types.exceptions import ModelThrottledException
17
+ from ..types.streaming import StopReason, StreamEvent
18
+ from ..types.tools import ToolResult, ToolSpec, ToolUse
19
+ from .model import Model
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+ T = TypeVar("T", bound=BaseModel)
24
+
25
+
26
+ class MistralModel(Model):
27
+ """Mistral API model provider implementation.
28
+
29
+ The implementation handles Mistral-specific features such as:
30
+
31
+ - Chat and text completions
32
+ - Streaming responses
33
+ - Tool/function calling
34
+ - System prompts
35
+ """
36
+
37
+ class MistralConfig(TypedDict, total=False):
38
+ """Configuration parameters for Mistral models.
39
+
40
+ Attributes:
41
+ model_id: Mistral model ID (e.g., "mistral-large-latest", "mistral-medium-latest").
42
+ max_tokens: Maximum number of tokens to generate in the response.
43
+ temperature: Controls randomness in generation (0.0 to 1.0).
44
+ top_p: Controls diversity via nucleus sampling.
45
+ stream: Whether to enable streaming responses.
46
+ """
47
+
48
+ model_id: str
49
+ max_tokens: Optional[int]
50
+ temperature: Optional[float]
51
+ top_p: Optional[float]
52
+ stream: Optional[bool]
53
+
54
+ def __init__(
55
+ self,
56
+ api_key: Optional[str] = None,
57
+ *,
58
+ client_args: Optional[dict[str, Any]] = None,
59
+ **model_config: Unpack[MistralConfig],
60
+ ) -> None:
61
+ """Initialize provider instance.
62
+
63
+ Args:
64
+ api_key: Mistral API key. If not provided, will use MISTRAL_API_KEY env var.
65
+ client_args: Additional arguments for the Mistral client.
66
+ **model_config: Configuration options for the Mistral model.
67
+ """
68
+ if "temperature" in model_config and model_config["temperature"] is not None:
69
+ temp = model_config["temperature"]
70
+ if not 0.0 <= temp <= 1.0:
71
+ raise ValueError(f"temperature must be between 0.0 and 1.0, got {temp}")
72
+ # Warn if temperature is above recommended range
73
+ if temp > 0.7:
74
+ logger.warning(
75
+ "temperature=%s is above the recommended range (0.0-0.7). "
76
+ "High values may produce unpredictable results.",
77
+ temp,
78
+ )
79
+
80
+ if "top_p" in model_config and model_config["top_p"] is not None:
81
+ top_p = model_config["top_p"]
82
+ if not 0.0 <= top_p <= 1.0:
83
+ raise ValueError(f"top_p must be between 0.0 and 1.0, got {top_p}")
84
+
85
+ self.config = MistralModel.MistralConfig(**model_config)
86
+
87
+ # Set default stream to True if not specified
88
+ if "stream" not in self.config:
89
+ self.config["stream"] = True
90
+
91
+ logger.debug("config=<%s> | initializing", self.config)
92
+
93
+ self.client_args = client_args or {}
94
+ if api_key:
95
+ self.client_args["api_key"] = api_key
96
+
97
+ @override
98
+ def update_config(self, **model_config: Unpack[MistralConfig]) -> None: # type: ignore
99
+ """Update the Mistral Model configuration with the provided arguments.
100
+
101
+ Args:
102
+ **model_config: Configuration overrides.
103
+ """
104
+ self.config.update(model_config)
105
+
106
+ @override
107
+ def get_config(self) -> MistralConfig:
108
+ """Get the Mistral model configuration.
109
+
110
+ Returns:
111
+ The Mistral model configuration.
112
+ """
113
+ return self.config
114
+
115
+ def _format_request_message_content(self, content: ContentBlock) -> Union[str, dict[str, Any]]:
116
+ """Format a Mistral content block.
117
+
118
+ Args:
119
+ content: Message content.
120
+
121
+ Returns:
122
+ Mistral formatted content.
123
+
124
+ Raises:
125
+ TypeError: If the content block type cannot be converted to a Mistral-compatible format.
126
+ """
127
+ if "text" in content:
128
+ return content["text"]
129
+
130
+ if "image" in content:
131
+ image_data = content["image"]
132
+
133
+ if "source" in image_data:
134
+ image_bytes = image_data["source"]["bytes"]
135
+ base64_data = base64.b64encode(image_bytes).decode("utf-8")
136
+ format_value = image_data.get("format", "jpeg")
137
+ media_type = f"image/{format_value}"
138
+ return {"type": "image_url", "image_url": f"data:{media_type};base64,{base64_data}"}
139
+
140
+ raise TypeError("content_type=<image> | unsupported image format")
141
+
142
+ raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type")
143
+
144
+ def _format_request_message_tool_call(self, tool_use: ToolUse) -> dict[str, Any]:
145
+ """Format a Mistral tool call.
146
+
147
+ Args:
148
+ tool_use: Tool use requested by the model.
149
+
150
+ Returns:
151
+ Mistral formatted tool call.
152
+ """
153
+ return {
154
+ "function": {
155
+ "name": tool_use["name"],
156
+ "arguments": json.dumps(tool_use["input"]),
157
+ },
158
+ "id": tool_use["toolUseId"],
159
+ "type": "function",
160
+ }
161
+
162
+ def _format_request_tool_message(self, tool_result: ToolResult) -> dict[str, Any]:
163
+ """Format a Mistral tool message.
164
+
165
+ Args:
166
+ tool_result: Tool result collected from a tool execution.
167
+
168
+ Returns:
169
+ Mistral formatted tool message.
170
+ """
171
+ content_parts: list[str] = []
172
+ for content in tool_result["content"]:
173
+ if "json" in content:
174
+ content_parts.append(json.dumps(content["json"]))
175
+ elif "text" in content:
176
+ content_parts.append(content["text"])
177
+
178
+ return {
179
+ "role": "tool",
180
+ "name": tool_result["toolUseId"].split("_")[0]
181
+ if "_" in tool_result["toolUseId"]
182
+ else tool_result["toolUseId"],
183
+ "content": "\n".join(content_parts),
184
+ "tool_call_id": tool_result["toolUseId"],
185
+ }
186
+
187
+ def _format_request_messages(self, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]:
188
+ """Format a Mistral compatible messages array.
189
+
190
+ Args:
191
+ messages: List of message objects to be processed by the model.
192
+ system_prompt: System prompt to provide context to the model.
193
+
194
+ Returns:
195
+ A Mistral compatible messages array.
196
+ """
197
+ formatted_messages: list[dict[str, Any]] = []
198
+
199
+ if system_prompt:
200
+ formatted_messages.append({"role": "system", "content": system_prompt})
201
+
202
+ for message in messages:
203
+ role = message["role"]
204
+ contents = message["content"]
205
+
206
+ text_contents: list[str] = []
207
+ tool_calls: list[dict[str, Any]] = []
208
+ tool_messages: list[dict[str, Any]] = []
209
+
210
+ for content in contents:
211
+ if "text" in content:
212
+ formatted_content = self._format_request_message_content(content)
213
+ if isinstance(formatted_content, str):
214
+ text_contents.append(formatted_content)
215
+ elif "toolUse" in content:
216
+ tool_calls.append(self._format_request_message_tool_call(content["toolUse"]))
217
+ elif "toolResult" in content:
218
+ tool_messages.append(self._format_request_tool_message(content["toolResult"]))
219
+
220
+ if text_contents or tool_calls:
221
+ formatted_message: dict[str, Any] = {
222
+ "role": role,
223
+ "content": " ".join(text_contents) if text_contents else "",
224
+ }
225
+
226
+ if tool_calls:
227
+ formatted_message["tool_calls"] = tool_calls
228
+
229
+ formatted_messages.append(formatted_message)
230
+
231
+ formatted_messages.extend(tool_messages)
232
+
233
+ return formatted_messages
234
+
235
+ def format_request(
236
+ self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None
237
+ ) -> dict[str, Any]:
238
+ """Format a Mistral chat streaming request.
239
+
240
+ Args:
241
+ messages: List of message objects to be processed by the model.
242
+ tool_specs: List of tool specifications to make available to the model.
243
+ system_prompt: System prompt to provide context to the model.
244
+
245
+ Returns:
246
+ A Mistral chat streaming request.
247
+
248
+ Raises:
249
+ TypeError: If a message contains a content block type that cannot be converted to a Mistral-compatible
250
+ format.
251
+ """
252
+ request: dict[str, Any] = {
253
+ "model": self.config["model_id"],
254
+ "messages": self._format_request_messages(messages, system_prompt),
255
+ }
256
+
257
+ if "max_tokens" in self.config:
258
+ request["max_tokens"] = self.config["max_tokens"]
259
+ if "temperature" in self.config:
260
+ request["temperature"] = self.config["temperature"]
261
+ if "top_p" in self.config:
262
+ request["top_p"] = self.config["top_p"]
263
+ if "stream" in self.config:
264
+ request["stream"] = self.config["stream"]
265
+
266
+ if tool_specs:
267
+ request["tools"] = [
268
+ {
269
+ "type": "function",
270
+ "function": {
271
+ "name": tool_spec["name"],
272
+ "description": tool_spec["description"],
273
+ "parameters": tool_spec["inputSchema"]["json"],
274
+ },
275
+ }
276
+ for tool_spec in tool_specs
277
+ ]
278
+
279
+ return request
280
+
281
+ def format_chunk(self, event: dict[str, Any]) -> StreamEvent:
282
+ """Format the Mistral response events into standardized message chunks.
283
+
284
+ Args:
285
+ event: A response event from the Mistral model.
286
+
287
+ Returns:
288
+ The formatted chunk.
289
+
290
+ Raises:
291
+ RuntimeError: If chunk_type is not recognized.
292
+ """
293
+ match event["chunk_type"]:
294
+ case "message_start":
295
+ return {"messageStart": {"role": "assistant"}}
296
+
297
+ case "content_start":
298
+ if event["data_type"] == "text":
299
+ return {"contentBlockStart": {"start": {}}}
300
+
301
+ tool_call = event["data"]
302
+ return {
303
+ "contentBlockStart": {
304
+ "start": {
305
+ "toolUse": {
306
+ "name": tool_call.function.name,
307
+ "toolUseId": tool_call.id,
308
+ }
309
+ }
310
+ }
311
+ }
312
+
313
+ case "content_delta":
314
+ if event["data_type"] == "text":
315
+ return {"contentBlockDelta": {"delta": {"text": event["data"]}}}
316
+
317
+ return {"contentBlockDelta": {"delta": {"toolUse": {"input": event["data"]}}}}
318
+
319
+ case "content_stop":
320
+ return {"contentBlockStop": {}}
321
+
322
+ case "message_stop":
323
+ reason: StopReason
324
+ if event["data"] == "tool_calls":
325
+ reason = "tool_use"
326
+ elif event["data"] == "length":
327
+ reason = "max_tokens"
328
+ else:
329
+ reason = "end_turn"
330
+
331
+ return {"messageStop": {"stopReason": reason}}
332
+
333
+ case "metadata":
334
+ usage = event["data"]
335
+ return {
336
+ "metadata": {
337
+ "usage": {
338
+ "inputTokens": usage.prompt_tokens,
339
+ "outputTokens": usage.completion_tokens,
340
+ "totalTokens": usage.total_tokens,
341
+ },
342
+ "metrics": {
343
+ "latencyMs": event.get("latency_ms", 0),
344
+ },
345
+ },
346
+ }
347
+
348
+ case _:
349
+ raise RuntimeError(f"chunk_type=<{event['chunk_type']}> | unknown type")
350
+
351
+ def _handle_non_streaming_response(self, response: Any) -> Iterable[dict[str, Any]]:
352
+ """Handle non-streaming response from Mistral API.
353
+
354
+ Args:
355
+ response: The non-streaming response from Mistral.
356
+
357
+ Yields:
358
+ Formatted events that match the streaming format.
359
+ """
360
+ yield {"chunk_type": "message_start"}
361
+
362
+ content_started = False
363
+
364
+ if response.choices and response.choices[0].message:
365
+ message = response.choices[0].message
366
+
367
+ if hasattr(message, "content") and message.content:
368
+ if not content_started:
369
+ yield {"chunk_type": "content_start", "data_type": "text"}
370
+ content_started = True
371
+
372
+ yield {"chunk_type": "content_delta", "data_type": "text", "data": message.content}
373
+
374
+ yield {"chunk_type": "content_stop"}
375
+
376
+ if hasattr(message, "tool_calls") and message.tool_calls:
377
+ for tool_call in message.tool_calls:
378
+ yield {"chunk_type": "content_start", "data_type": "tool", "data": tool_call}
379
+
380
+ if hasattr(tool_call.function, "arguments"):
381
+ yield {"chunk_type": "content_delta", "data_type": "tool", "data": tool_call.function.arguments}
382
+
383
+ yield {"chunk_type": "content_stop"}
384
+
385
+ finish_reason = response.choices[0].finish_reason if response.choices[0].finish_reason else "stop"
386
+ yield {"chunk_type": "message_stop", "data": finish_reason}
387
+
388
+ if hasattr(response, "usage") and response.usage:
389
+ yield {"chunk_type": "metadata", "data": response.usage}
390
+
391
+ @override
392
+ async def stream(
393
+ self,
394
+ messages: Messages,
395
+ tool_specs: Optional[list[ToolSpec]] = None,
396
+ system_prompt: Optional[str] = None,
397
+ **kwargs: Any,
398
+ ) -> AsyncGenerator[StreamEvent, None]:
399
+ """Stream conversation with the Mistral model.
400
+
401
+ Args:
402
+ messages: List of message objects to be processed by the model.
403
+ tool_specs: List of tool specifications to make available to the model.
404
+ system_prompt: System prompt to provide context to the model.
405
+ **kwargs: Additional keyword arguments for future extensibility.
406
+
407
+ Yields:
408
+ Formatted message chunks from the model.
409
+
410
+ Raises:
411
+ ModelThrottledException: When the model service is throttling requests.
412
+ """
413
+ logger.debug("formatting request")
414
+ request = self.format_request(messages, tool_specs, system_prompt)
415
+ logger.debug("request=<%s>", request)
416
+
417
+ logger.debug("invoking model")
418
+ try:
419
+ logger.debug("got response from model")
420
+ if not self.config.get("stream", True):
421
+ # Use non-streaming API
422
+ async with mistralai.Mistral(**self.client_args) as client:
423
+ response = await client.chat.complete_async(**request)
424
+ for event in self._handle_non_streaming_response(response):
425
+ yield self.format_chunk(event)
426
+
427
+ return
428
+
429
+ # Use the streaming API
430
+ async with mistralai.Mistral(**self.client_args) as client:
431
+ stream_response = await client.chat.stream_async(**request)
432
+
433
+ yield self.format_chunk({"chunk_type": "message_start"})
434
+
435
+ content_started = False
436
+ tool_calls: dict[str, list[Any]] = {}
437
+ accumulated_text = ""
438
+
439
+ async for chunk in stream_response:
440
+ if hasattr(chunk, "data") and hasattr(chunk.data, "choices") and chunk.data.choices:
441
+ choice = chunk.data.choices[0]
442
+
443
+ if hasattr(choice, "delta"):
444
+ delta = choice.delta
445
+
446
+ if hasattr(delta, "content") and delta.content:
447
+ if not content_started:
448
+ yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"})
449
+ content_started = True
450
+
451
+ yield self.format_chunk(
452
+ {"chunk_type": "content_delta", "data_type": "text", "data": delta.content}
453
+ )
454
+ accumulated_text += delta.content
455
+
456
+ if hasattr(delta, "tool_calls") and delta.tool_calls:
457
+ for tool_call in delta.tool_calls:
458
+ tool_id = tool_call.id
459
+ tool_calls.setdefault(tool_id, []).append(tool_call)
460
+
461
+ if hasattr(choice, "finish_reason") and choice.finish_reason:
462
+ if content_started:
463
+ yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"})
464
+
465
+ for tool_deltas in tool_calls.values():
466
+ yield self.format_chunk(
467
+ {"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]}
468
+ )
469
+
470
+ for tool_delta in tool_deltas:
471
+ if hasattr(tool_delta.function, "arguments"):
472
+ yield self.format_chunk(
473
+ {
474
+ "chunk_type": "content_delta",
475
+ "data_type": "tool",
476
+ "data": tool_delta.function.arguments,
477
+ }
478
+ )
479
+
480
+ yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"})
481
+
482
+ yield self.format_chunk({"chunk_type": "message_stop", "data": choice.finish_reason})
483
+
484
+ if hasattr(chunk, "usage"):
485
+ yield self.format_chunk({"chunk_type": "metadata", "data": chunk.usage})
486
+
487
+ except Exception as e:
488
+ if "rate" in str(e).lower() or "429" in str(e):
489
+ raise ModelThrottledException(str(e)) from e
490
+ raise
491
+
492
+ logger.debug("finished streaming response from model")
493
+
494
+ @override
495
+ async def structured_output(
496
+ self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any
497
+ ) -> AsyncGenerator[dict[str, Union[T, Any]], None]:
498
+ """Get structured output from the model.
499
+
500
+ Args:
501
+ output_model: The output model to use for the agent.
502
+ prompt: The prompt messages to use for the agent.
503
+ system_prompt: System prompt to provide context to the model.
504
+ **kwargs: Additional keyword arguments for future extensibility.
505
+
506
+ Returns:
507
+ An instance of the output model with the generated data.
508
+
509
+ Raises:
510
+ ValueError: If the response cannot be parsed into the output model.
511
+ """
512
+ tool_spec: ToolSpec = {
513
+ "name": f"extract_{output_model.__name__.lower()}",
514
+ "description": f"Extract structured data in the format of {output_model.__name__}",
515
+ "inputSchema": {"json": output_model.model_json_schema()},
516
+ }
517
+
518
+ formatted_request = self.format_request(messages=prompt, tool_specs=[tool_spec], system_prompt=system_prompt)
519
+
520
+ formatted_request["tool_choice"] = "any"
521
+ formatted_request["parallel_tool_calls"] = False
522
+
523
+ async with mistralai.Mistral(**self.client_args) as client:
524
+ response = await client.chat.complete_async(**formatted_request)
525
+
526
+ if response.choices and response.choices[0].message.tool_calls:
527
+ tool_call = response.choices[0].message.tool_calls[0]
528
+ try:
529
+ # Handle both string and dict arguments
530
+ if isinstance(tool_call.function.arguments, str):
531
+ arguments = json.loads(tool_call.function.arguments)
532
+ else:
533
+ arguments = tool_call.function.arguments
534
+ yield {"output": output_model(**arguments)}
535
+ return
536
+ except (json.JSONDecodeError, TypeError, ValueError) as e:
537
+ raise ValueError(f"Failed to parse tool call arguments into model: {e}") from e
538
+
539
+ raise ValueError("No tool calls found in response")
@@ -0,0 +1,95 @@
1
+ """Abstract base class for Agent model providers."""
2
+
3
+ import abc
4
+ import logging
5
+ from typing import Any, AsyncGenerator, AsyncIterable, Optional, Type, TypeVar, Union
6
+
7
+ from pydantic import BaseModel
8
+
9
+ from ..types.content import Messages
10
+ from ..types.streaming import StreamEvent
11
+ from ..types.tools import ToolSpec
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ T = TypeVar("T", bound=BaseModel)
16
+
17
+
18
+ class Model(abc.ABC):
19
+ """Abstract base class for Agent model providers.
20
+
21
+ This class defines the interface for all model implementations in the Strands Agents SDK. It provides a
22
+ standardized way to configure and process requests for different AI model providers.
23
+ """
24
+
25
+ @abc.abstractmethod
26
+ # pragma: no cover
27
+ def update_config(self, **model_config: Any) -> None:
28
+ """Update the model configuration with the provided arguments.
29
+
30
+ Args:
31
+ **model_config: Configuration overrides.
32
+ """
33
+ pass
34
+
35
+ @abc.abstractmethod
36
+ # pragma: no cover
37
+ def get_config(self) -> Any:
38
+ """Return the model configuration.
39
+
40
+ Returns:
41
+ The model's configuration.
42
+ """
43
+ pass
44
+
45
+ @abc.abstractmethod
46
+ # pragma: no cover
47
+ def structured_output(
48
+ self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any
49
+ ) -> AsyncGenerator[dict[str, Union[T, Any]], None]:
50
+ """Get structured output from the model.
51
+
52
+ Args:
53
+ output_model: The output model to use for the agent.
54
+ prompt: The prompt messages to use for the agent.
55
+ system_prompt: System prompt to provide context to the model.
56
+ **kwargs: Additional keyword arguments for future extensibility.
57
+
58
+ Yields:
59
+ Model events with the last being the structured output.
60
+
61
+ Raises:
62
+ ValidationException: The response format from the model does not match the output_model
63
+ """
64
+ pass
65
+
66
+ @abc.abstractmethod
67
+ # pragma: no cover
68
+ def stream(
69
+ self,
70
+ messages: Messages,
71
+ tool_specs: Optional[list[ToolSpec]] = None,
72
+ system_prompt: Optional[str] = None,
73
+ **kwargs: Any,
74
+ ) -> AsyncIterable[StreamEvent]:
75
+ """Stream conversation with the model.
76
+
77
+ This method handles the full lifecycle of conversing with the model:
78
+
79
+ 1. Format the messages, tool specs, and configuration into a streaming request
80
+ 2. Send the request to the model
81
+ 3. Yield the formatted message chunks
82
+
83
+ Args:
84
+ messages: List of message objects to be processed by the model.
85
+ tool_specs: List of tool specifications to make available to the model.
86
+ system_prompt: System prompt to provide context to the model.
87
+ **kwargs: Additional keyword arguments for future extensibility.
88
+
89
+ Yields:
90
+ Formatted message chunks from the model.
91
+
92
+ Raises:
93
+ ModelThrottledException: When the model service is throttling requests from the client.
94
+ """
95
+ pass