agentrun-sdk 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of agentrun-sdk might be problematic. Click here for more details.

Files changed (115) hide show
  1. agentrun_operation_sdk/cli/__init__.py +1 -0
  2. agentrun_operation_sdk/cli/cli.py +19 -0
  3. agentrun_operation_sdk/cli/common.py +21 -0
  4. agentrun_operation_sdk/cli/runtime/__init__.py +1 -0
  5. agentrun_operation_sdk/cli/runtime/commands.py +203 -0
  6. agentrun_operation_sdk/client/client.py +75 -0
  7. agentrun_operation_sdk/operations/runtime/__init__.py +8 -0
  8. agentrun_operation_sdk/operations/runtime/configure.py +101 -0
  9. agentrun_operation_sdk/operations/runtime/launch.py +82 -0
  10. agentrun_operation_sdk/operations/runtime/models.py +31 -0
  11. agentrun_operation_sdk/services/runtime.py +152 -0
  12. agentrun_operation_sdk/utils/logging_config.py +72 -0
  13. agentrun_operation_sdk/utils/runtime/config.py +94 -0
  14. agentrun_operation_sdk/utils/runtime/container.py +280 -0
  15. agentrun_operation_sdk/utils/runtime/entrypoint.py +203 -0
  16. agentrun_operation_sdk/utils/runtime/schema.py +56 -0
  17. agentrun_sdk/__init__.py +7 -0
  18. agentrun_sdk/agent/__init__.py +25 -0
  19. agentrun_sdk/agent/agent.py +696 -0
  20. agentrun_sdk/agent/agent_result.py +46 -0
  21. agentrun_sdk/agent/conversation_manager/__init__.py +26 -0
  22. agentrun_sdk/agent/conversation_manager/conversation_manager.py +88 -0
  23. agentrun_sdk/agent/conversation_manager/null_conversation_manager.py +46 -0
  24. agentrun_sdk/agent/conversation_manager/sliding_window_conversation_manager.py +179 -0
  25. agentrun_sdk/agent/conversation_manager/summarizing_conversation_manager.py +252 -0
  26. agentrun_sdk/agent/state.py +97 -0
  27. agentrun_sdk/event_loop/__init__.py +9 -0
  28. agentrun_sdk/event_loop/event_loop.py +499 -0
  29. agentrun_sdk/event_loop/streaming.py +319 -0
  30. agentrun_sdk/experimental/__init__.py +4 -0
  31. agentrun_sdk/experimental/hooks/__init__.py +15 -0
  32. agentrun_sdk/experimental/hooks/events.py +123 -0
  33. agentrun_sdk/handlers/__init__.py +10 -0
  34. agentrun_sdk/handlers/callback_handler.py +70 -0
  35. agentrun_sdk/hooks/__init__.py +49 -0
  36. agentrun_sdk/hooks/events.py +80 -0
  37. agentrun_sdk/hooks/registry.py +247 -0
  38. agentrun_sdk/models/__init__.py +10 -0
  39. agentrun_sdk/models/anthropic.py +432 -0
  40. agentrun_sdk/models/bedrock.py +649 -0
  41. agentrun_sdk/models/litellm.py +225 -0
  42. agentrun_sdk/models/llamaapi.py +438 -0
  43. agentrun_sdk/models/mistral.py +539 -0
  44. agentrun_sdk/models/model.py +95 -0
  45. agentrun_sdk/models/ollama.py +357 -0
  46. agentrun_sdk/models/openai.py +436 -0
  47. agentrun_sdk/models/sagemaker.py +598 -0
  48. agentrun_sdk/models/writer.py +449 -0
  49. agentrun_sdk/multiagent/__init__.py +22 -0
  50. agentrun_sdk/multiagent/a2a/__init__.py +15 -0
  51. agentrun_sdk/multiagent/a2a/executor.py +148 -0
  52. agentrun_sdk/multiagent/a2a/server.py +252 -0
  53. agentrun_sdk/multiagent/base.py +92 -0
  54. agentrun_sdk/multiagent/graph.py +555 -0
  55. agentrun_sdk/multiagent/swarm.py +656 -0
  56. agentrun_sdk/py.typed +1 -0
  57. agentrun_sdk/session/__init__.py +18 -0
  58. agentrun_sdk/session/file_session_manager.py +216 -0
  59. agentrun_sdk/session/repository_session_manager.py +152 -0
  60. agentrun_sdk/session/s3_session_manager.py +272 -0
  61. agentrun_sdk/session/session_manager.py +73 -0
  62. agentrun_sdk/session/session_repository.py +51 -0
  63. agentrun_sdk/telemetry/__init__.py +21 -0
  64. agentrun_sdk/telemetry/config.py +194 -0
  65. agentrun_sdk/telemetry/metrics.py +476 -0
  66. agentrun_sdk/telemetry/metrics_constants.py +15 -0
  67. agentrun_sdk/telemetry/tracer.py +563 -0
  68. agentrun_sdk/tools/__init__.py +17 -0
  69. agentrun_sdk/tools/decorator.py +569 -0
  70. agentrun_sdk/tools/executor.py +137 -0
  71. agentrun_sdk/tools/loader.py +152 -0
  72. agentrun_sdk/tools/mcp/__init__.py +13 -0
  73. agentrun_sdk/tools/mcp/mcp_agent_tool.py +99 -0
  74. agentrun_sdk/tools/mcp/mcp_client.py +423 -0
  75. agentrun_sdk/tools/mcp/mcp_instrumentation.py +322 -0
  76. agentrun_sdk/tools/mcp/mcp_types.py +63 -0
  77. agentrun_sdk/tools/registry.py +607 -0
  78. agentrun_sdk/tools/structured_output.py +421 -0
  79. agentrun_sdk/tools/tools.py +217 -0
  80. agentrun_sdk/tools/watcher.py +136 -0
  81. agentrun_sdk/types/__init__.py +5 -0
  82. agentrun_sdk/types/collections.py +23 -0
  83. agentrun_sdk/types/content.py +188 -0
  84. agentrun_sdk/types/event_loop.py +48 -0
  85. agentrun_sdk/types/exceptions.py +81 -0
  86. agentrun_sdk/types/guardrails.py +254 -0
  87. agentrun_sdk/types/media.py +89 -0
  88. agentrun_sdk/types/session.py +152 -0
  89. agentrun_sdk/types/streaming.py +201 -0
  90. agentrun_sdk/types/tools.py +258 -0
  91. agentrun_sdk/types/traces.py +5 -0
  92. agentrun_sdk-0.1.2.dist-info/METADATA +51 -0
  93. agentrun_sdk-0.1.2.dist-info/RECORD +115 -0
  94. agentrun_sdk-0.1.2.dist-info/WHEEL +5 -0
  95. agentrun_sdk-0.1.2.dist-info/entry_points.txt +2 -0
  96. agentrun_sdk-0.1.2.dist-info/top_level.txt +3 -0
  97. agentrun_wrapper/__init__.py +11 -0
  98. agentrun_wrapper/_utils/__init__.py +6 -0
  99. agentrun_wrapper/_utils/endpoints.py +16 -0
  100. agentrun_wrapper/identity/__init__.py +5 -0
  101. agentrun_wrapper/identity/auth.py +211 -0
  102. agentrun_wrapper/memory/__init__.py +6 -0
  103. agentrun_wrapper/memory/client.py +1697 -0
  104. agentrun_wrapper/memory/constants.py +103 -0
  105. agentrun_wrapper/memory/controlplane.py +626 -0
  106. agentrun_wrapper/py.typed +1 -0
  107. agentrun_wrapper/runtime/__init__.py +13 -0
  108. agentrun_wrapper/runtime/app.py +473 -0
  109. agentrun_wrapper/runtime/context.py +34 -0
  110. agentrun_wrapper/runtime/models.py +25 -0
  111. agentrun_wrapper/services/__init__.py +1 -0
  112. agentrun_wrapper/services/identity.py +192 -0
  113. agentrun_wrapper/tools/__init__.py +6 -0
  114. agentrun_wrapper/tools/browser_client.py +325 -0
  115. agentrun_wrapper/tools/code_interpreter_client.py +186 -0
@@ -0,0 +1,598 @@
1
+ """Amazon SageMaker model provider."""
2
+
3
+ import json
4
+ import logging
5
+ import os
6
+ from dataclasses import dataclass
7
+ from typing import Any, AsyncGenerator, Literal, Optional, Type, TypedDict, TypeVar, Union, cast
8
+
9
+ import boto3
10
+ from botocore.config import Config as BotocoreConfig
11
+ from mypy_boto3_sagemaker_runtime import SageMakerRuntimeClient
12
+ from pydantic import BaseModel
13
+ from typing_extensions import Unpack, override
14
+
15
+ from ..types.content import ContentBlock, Messages
16
+ from ..types.streaming import StreamEvent
17
+ from ..types.tools import ToolResult, ToolSpec
18
+ from .openai import OpenAIModel
19
+
20
+ T = TypeVar("T", bound=BaseModel)
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ @dataclass
26
+ class UsageMetadata:
27
+ """Usage metadata for the model.
28
+
29
+ Attributes:
30
+ total_tokens: Total number of tokens used in the request
31
+ completion_tokens: Number of tokens used in the completion
32
+ prompt_tokens: Number of tokens used in the prompt
33
+ prompt_tokens_details: Additional information about the prompt tokens (optional)
34
+ """
35
+
36
+ total_tokens: int
37
+ completion_tokens: int
38
+ prompt_tokens: int
39
+ prompt_tokens_details: Optional[int] = 0
40
+
41
+
42
+ @dataclass
43
+ class FunctionCall:
44
+ """Function call for the model.
45
+
46
+ Attributes:
47
+ name: Name of the function to call
48
+ arguments: Arguments to pass to the function
49
+ """
50
+
51
+ name: Union[str, dict[Any, Any]]
52
+ arguments: Union[str, dict[Any, Any]]
53
+
54
+ def __init__(self, **kwargs: dict[str, str]):
55
+ """Initialize function call.
56
+
57
+ Args:
58
+ **kwargs: Keyword arguments for the function call.
59
+ """
60
+ self.name = kwargs.get("name", "")
61
+ self.arguments = kwargs.get("arguments", "")
62
+
63
+
64
+ @dataclass
65
+ class ToolCall:
66
+ """Tool call for the model object.
67
+
68
+ Attributes:
69
+ id: Tool call ID
70
+ type: Tool call type
71
+ function: Tool call function
72
+ """
73
+
74
+ id: str
75
+ type: Literal["function"]
76
+ function: FunctionCall
77
+
78
+ def __init__(self, **kwargs: dict):
79
+ """Initialize tool call object.
80
+
81
+ Args:
82
+ **kwargs: Keyword arguments for the tool call.
83
+ """
84
+ self.id = str(kwargs.get("id", ""))
85
+ self.type = "function"
86
+ self.function = FunctionCall(**kwargs.get("function", {"name": "", "arguments": ""}))
87
+
88
+
89
+ class SageMakerAIModel(OpenAIModel):
90
+ """Amazon SageMaker model provider implementation."""
91
+
92
+ client: SageMakerRuntimeClient # type: ignore[assignment]
93
+
94
+ class SageMakerAIPayloadSchema(TypedDict, total=False):
95
+ """Payload schema for the Amazon SageMaker AI model.
96
+
97
+ Attributes:
98
+ max_tokens: Maximum number of tokens to generate in the completion
99
+ stream: Whether to stream the response
100
+ temperature: Sampling temperature to use for the model (optional)
101
+ top_p: Nucleus sampling parameter (optional)
102
+ top_k: Top-k sampling parameter (optional)
103
+ stop: List of stop sequences to use for the model (optional)
104
+ tool_results_as_user_messages: Convert tool result to user messages (optional)
105
+ additional_args: Additional request parameters, as supported by https://bit.ly/djl-lmi-request-schema
106
+ """
107
+
108
+ max_tokens: int
109
+ stream: bool
110
+ temperature: Optional[float]
111
+ top_p: Optional[float]
112
+ top_k: Optional[int]
113
+ stop: Optional[list[str]]
114
+ tool_results_as_user_messages: Optional[bool]
115
+ additional_args: Optional[dict[str, Any]]
116
+
117
+ class SageMakerAIEndpointConfig(TypedDict, total=False):
118
+ """Configuration options for SageMaker models.
119
+
120
+ Attributes:
121
+ endpoint_name: The name of the SageMaker endpoint to invoke
122
+ inference_component_name: The name of the inference component to use
123
+
124
+ additional_args: Other request parameters, as supported by https://bit.ly/sagemaker-invoke-endpoint-params
125
+ """
126
+
127
+ endpoint_name: str
128
+ region_name: str
129
+ inference_component_name: Union[str, None]
130
+ target_model: Union[Optional[str], None]
131
+ target_variant: Union[Optional[str], None]
132
+ additional_args: Optional[dict[str, Any]]
133
+
134
+ def __init__(
135
+ self,
136
+ endpoint_config: SageMakerAIEndpointConfig,
137
+ payload_config: SageMakerAIPayloadSchema,
138
+ boto_session: Optional[boto3.Session] = None,
139
+ boto_client_config: Optional[BotocoreConfig] = None,
140
+ ):
141
+ """Initialize provider instance.
142
+
143
+ Args:
144
+ endpoint_config: Endpoint configuration for SageMaker.
145
+ payload_config: Payload configuration for the model.
146
+ boto_session: Boto Session to use when calling the SageMaker Runtime.
147
+ boto_client_config: Configuration to use when creating the SageMaker-Runtime Boto Client.
148
+ """
149
+ payload_config.setdefault("stream", True)
150
+ payload_config.setdefault("tool_results_as_user_messages", False)
151
+ self.endpoint_config = dict(endpoint_config)
152
+ self.payload_config = dict(payload_config)
153
+ logger.debug(
154
+ "endpoint_config=<%s> payload_config=<%s> | initializing", self.endpoint_config, self.payload_config
155
+ )
156
+
157
+ region = self.endpoint_config.get("region_name") or os.getenv("AWS_REGION") or "us-west-2"
158
+ session = boto_session or boto3.Session(region_name=str(region))
159
+
160
+ # Add strands-agents to the request user agent
161
+ if boto_client_config:
162
+ existing_user_agent = getattr(boto_client_config, "user_agent_extra", None)
163
+
164
+ # Append 'strands-agents' to existing user_agent_extra or set it if not present
165
+ new_user_agent = f"{existing_user_agent} strands-agents" if existing_user_agent else "strands-agents"
166
+
167
+ client_config = boto_client_config.merge(BotocoreConfig(user_agent_extra=new_user_agent))
168
+ else:
169
+ client_config = BotocoreConfig(user_agent_extra="strands-agents")
170
+
171
+ self.client = session.client(
172
+ service_name="sagemaker-runtime",
173
+ config=client_config,
174
+ )
175
+
176
+ @override
177
+ def update_config(self, **endpoint_config: Unpack[SageMakerAIEndpointConfig]) -> None: # type: ignore[override]
178
+ """Update the Amazon SageMaker model configuration with the provided arguments.
179
+
180
+ Args:
181
+ **endpoint_config: Configuration overrides.
182
+ """
183
+ self.endpoint_config.update(endpoint_config)
184
+
185
+ @override
186
+ def get_config(self) -> "SageMakerAIModel.SageMakerAIEndpointConfig": # type: ignore[override]
187
+ """Get the Amazon SageMaker model configuration.
188
+
189
+ Returns:
190
+ The Amazon SageMaker model configuration.
191
+ """
192
+ return cast(SageMakerAIModel.SageMakerAIEndpointConfig, self.endpoint_config)
193
+
194
+ @override
195
+ def format_request(
196
+ self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None
197
+ ) -> dict[str, Any]:
198
+ """Format an Amazon SageMaker chat streaming request.
199
+
200
+ Args:
201
+ messages: List of message objects to be processed by the model.
202
+ tool_specs: List of tool specifications to make available to the model.
203
+ system_prompt: System prompt to provide context to the model.
204
+
205
+ Returns:
206
+ An Amazon SageMaker chat streaming request.
207
+ """
208
+ formatted_messages = self.format_request_messages(messages, system_prompt)
209
+
210
+ payload = {
211
+ "messages": formatted_messages,
212
+ "tools": [
213
+ {
214
+ "type": "function",
215
+ "function": {
216
+ "name": tool_spec["name"],
217
+ "description": tool_spec["description"],
218
+ "parameters": tool_spec["inputSchema"]["json"],
219
+ },
220
+ }
221
+ for tool_spec in tool_specs or []
222
+ ],
223
+ # Add payload configuration parameters
224
+ **{
225
+ k: v
226
+ for k, v in self.payload_config.items()
227
+ if k not in ["additional_args", "tool_results_as_user_messages"]
228
+ },
229
+ }
230
+
231
+ # Remove tools and tool_choice if tools = []
232
+ if not payload["tools"]:
233
+ payload.pop("tools")
234
+ payload.pop("tool_choice", None)
235
+ else:
236
+ # Ensure the model can use tools when available
237
+ payload["tool_choice"] = "auto"
238
+
239
+ for message in payload["messages"]: # type: ignore
240
+ # Assistant message must have either content or tool_calls, but not both
241
+ if message.get("role", "") == "assistant" and message.get("tool_calls", []) != []:
242
+ message.pop("content", None)
243
+ if message.get("role") == "tool" and self.payload_config.get("tool_results_as_user_messages", False):
244
+ # Convert tool message to user message
245
+ tool_call_id = message.get("tool_call_id", "ABCDEF")
246
+ content = message.get("content", "")
247
+ message = {"role": "user", "content": f"Tool call ID '{tool_call_id}' returned: {content}"}
248
+ # Cannot have both reasoning_text and text - if "text", content becomes an array of content["text"]
249
+ for c in message.get("content", []):
250
+ if "text" in c:
251
+ message["content"] = [c]
252
+ break
253
+ # Cast message content to string for TGI compatibility
254
+ # message["content"] = str(message.get("content", ""))
255
+
256
+ logger.info("payload=<%s>", json.dumps(payload, indent=2))
257
+ # Format the request according to the SageMaker Runtime API requirements
258
+ request = {
259
+ "EndpointName": self.endpoint_config["endpoint_name"],
260
+ "Body": json.dumps(payload),
261
+ "ContentType": "application/json",
262
+ "Accept": "application/json",
263
+ }
264
+
265
+ # Add optional SageMaker parameters if provided
266
+ if self.endpoint_config.get("inference_component_name"):
267
+ request["InferenceComponentName"] = self.endpoint_config["inference_component_name"]
268
+ if self.endpoint_config.get("target_model"):
269
+ request["TargetModel"] = self.endpoint_config["target_model"]
270
+ if self.endpoint_config.get("target_variant"):
271
+ request["TargetVariant"] = self.endpoint_config["target_variant"]
272
+
273
+ # Add additional args if provided
274
+ if self.endpoint_config.get("additional_args"):
275
+ request.update(self.endpoint_config["additional_args"].__dict__)
276
+
277
+ return request
278
+
279
+ @override
280
+ async def stream(
281
+ self,
282
+ messages: Messages,
283
+ tool_specs: Optional[list[ToolSpec]] = None,
284
+ system_prompt: Optional[str] = None,
285
+ **kwargs: Any,
286
+ ) -> AsyncGenerator[StreamEvent, None]:
287
+ """Stream conversation with the SageMaker model.
288
+
289
+ Args:
290
+ messages: List of message objects to be processed by the model.
291
+ tool_specs: List of tool specifications to make available to the model.
292
+ system_prompt: System prompt to provide context to the model.
293
+ **kwargs: Additional keyword arguments for future extensibility.
294
+
295
+ Yields:
296
+ Formatted message chunks from the model.
297
+ """
298
+ logger.debug("formatting request")
299
+ request = self.format_request(messages, tool_specs, system_prompt)
300
+ logger.debug("formatted request=<%s>", request)
301
+
302
+ logger.debug("invoking model")
303
+ try:
304
+ if self.payload_config.get("stream", True):
305
+ response = self.client.invoke_endpoint_with_response_stream(**request)
306
+
307
+ # Message start
308
+ yield self.format_chunk({"chunk_type": "message_start"})
309
+
310
+ # Parse the content
311
+ finish_reason = ""
312
+ partial_content = ""
313
+ tool_calls: dict[int, list[Any]] = {}
314
+ has_text_content = False
315
+ text_content_started = False
316
+ reasoning_content_started = False
317
+
318
+ for event in response["Body"]:
319
+ chunk = event["PayloadPart"]["Bytes"].decode("utf-8")
320
+ partial_content += chunk[6:] if chunk.startswith("data: ") else chunk # TGI fix
321
+ logger.info("chunk=<%s>", partial_content)
322
+ try:
323
+ content = json.loads(partial_content)
324
+ partial_content = ""
325
+ choice = content["choices"][0]
326
+ logger.info("choice=<%s>", json.dumps(choice, indent=2))
327
+
328
+ # Handle text content
329
+ if choice["delta"].get("content", None):
330
+ if not text_content_started:
331
+ yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"})
332
+ text_content_started = True
333
+ has_text_content = True
334
+ yield self.format_chunk(
335
+ {
336
+ "chunk_type": "content_delta",
337
+ "data_type": "text",
338
+ "data": choice["delta"]["content"],
339
+ }
340
+ )
341
+
342
+ # Handle reasoning content
343
+ if choice["delta"].get("reasoning_content", None):
344
+ if not reasoning_content_started:
345
+ yield self.format_chunk(
346
+ {"chunk_type": "content_start", "data_type": "reasoning_content"}
347
+ )
348
+ reasoning_content_started = True
349
+ yield self.format_chunk(
350
+ {
351
+ "chunk_type": "content_delta",
352
+ "data_type": "reasoning_content",
353
+ "data": choice["delta"]["reasoning_content"],
354
+ }
355
+ )
356
+
357
+ # Handle tool calls
358
+ generated_tool_calls = choice["delta"].get("tool_calls", [])
359
+ if not isinstance(generated_tool_calls, list):
360
+ generated_tool_calls = [generated_tool_calls]
361
+ for tool_call in generated_tool_calls:
362
+ tool_calls.setdefault(tool_call["index"], []).append(tool_call)
363
+
364
+ if choice["finish_reason"] is not None:
365
+ finish_reason = choice["finish_reason"]
366
+ break
367
+
368
+ if choice.get("usage", None):
369
+ yield self.format_chunk(
370
+ {"chunk_type": "metadata", "data": UsageMetadata(**choice["usage"])}
371
+ )
372
+
373
+ except json.JSONDecodeError:
374
+ # Continue accumulating content until we have valid JSON
375
+ continue
376
+
377
+ # Close reasoning content if it was started
378
+ if reasoning_content_started:
379
+ yield self.format_chunk({"chunk_type": "content_stop", "data_type": "reasoning_content"})
380
+
381
+ # Close text content if it was started
382
+ if text_content_started:
383
+ yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"})
384
+
385
+ # Handle tool calling
386
+ logger.info("tool_calls=<%s>", json.dumps(tool_calls, indent=2))
387
+ for tool_deltas in tool_calls.values():
388
+ if not tool_deltas[0]["function"].get("name", None):
389
+ raise Exception("The model did not provide a tool name.")
390
+ yield self.format_chunk(
391
+ {"chunk_type": "content_start", "data_type": "tool", "data": ToolCall(**tool_deltas[0])}
392
+ )
393
+ for tool_delta in tool_deltas:
394
+ yield self.format_chunk(
395
+ {"chunk_type": "content_delta", "data_type": "tool", "data": ToolCall(**tool_delta)}
396
+ )
397
+ yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"})
398
+
399
+ # If no content was generated at all, ensure we have empty text content
400
+ if not has_text_content and not tool_calls:
401
+ yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"})
402
+ yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"})
403
+
404
+ # Message close
405
+ yield self.format_chunk({"chunk_type": "message_stop", "data": finish_reason})
406
+
407
+ else:
408
+ # Not all SageMaker AI models support streaming!
409
+ response = self.client.invoke_endpoint(**request) # type: ignore[assignment]
410
+ final_response_json = json.loads(response["Body"].read().decode("utf-8")) # type: ignore[attr-defined]
411
+ logger.info("response=<%s>", json.dumps(final_response_json, indent=2))
412
+
413
+ # Obtain the key elements from the response
414
+ message = final_response_json["choices"][0]["message"]
415
+ message_stop_reason = final_response_json["choices"][0]["finish_reason"]
416
+
417
+ # Message start
418
+ yield self.format_chunk({"chunk_type": "message_start"})
419
+
420
+ # Handle text
421
+ if message.get("content", ""):
422
+ yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"})
423
+ yield self.format_chunk(
424
+ {"chunk_type": "content_delta", "data_type": "text", "data": message["content"]}
425
+ )
426
+ yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"})
427
+
428
+ # Handle reasoning content
429
+ if message.get("reasoning_content", None):
430
+ yield self.format_chunk({"chunk_type": "content_start", "data_type": "reasoning_content"})
431
+ yield self.format_chunk(
432
+ {
433
+ "chunk_type": "content_delta",
434
+ "data_type": "reasoning_content",
435
+ "data": message["reasoning_content"],
436
+ }
437
+ )
438
+ yield self.format_chunk({"chunk_type": "content_stop", "data_type": "reasoning_content"})
439
+
440
+ # Handle the tool calling, if any
441
+ if message.get("tool_calls", None) or message_stop_reason == "tool_calls":
442
+ if not isinstance(message["tool_calls"], list):
443
+ message["tool_calls"] = [message["tool_calls"]]
444
+ for tool_call in message["tool_calls"]:
445
+ # if arguments of tool_call is not str, cast it
446
+ if not isinstance(tool_call["function"]["arguments"], str):
447
+ tool_call["function"]["arguments"] = json.dumps(tool_call["function"]["arguments"])
448
+ yield self.format_chunk(
449
+ {"chunk_type": "content_start", "data_type": "tool", "data": ToolCall(**tool_call)}
450
+ )
451
+ yield self.format_chunk(
452
+ {"chunk_type": "content_delta", "data_type": "tool", "data": ToolCall(**tool_call)}
453
+ )
454
+ yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"})
455
+ message_stop_reason = "tool_calls"
456
+
457
+ # Message close
458
+ yield self.format_chunk({"chunk_type": "message_stop", "data": message_stop_reason})
459
+ # Handle usage metadata
460
+ if final_response_json.get("usage", None):
461
+ yield self.format_chunk(
462
+ {"chunk_type": "metadata", "data": UsageMetadata(**final_response_json.get("usage", None))}
463
+ )
464
+ except (
465
+ self.client.exceptions.InternalFailure,
466
+ self.client.exceptions.ServiceUnavailable,
467
+ self.client.exceptions.ValidationError,
468
+ self.client.exceptions.ModelError,
469
+ self.client.exceptions.InternalDependencyException,
470
+ self.client.exceptions.ModelNotReadyException,
471
+ ) as e:
472
+ logger.error("SageMaker error: %s", str(e))
473
+ raise e
474
+
475
+ logger.debug("finished streaming response from model")
476
+
477
+ @override
478
+ @classmethod
479
+ def format_request_tool_message(cls, tool_result: ToolResult) -> dict[str, Any]:
480
+ """Format a SageMaker compatible tool message.
481
+
482
+ Args:
483
+ tool_result: Tool result collected from a tool execution.
484
+
485
+ Returns:
486
+ SageMaker compatible tool message with content as a string.
487
+ """
488
+ # Convert content blocks to a simple string for SageMaker compatibility
489
+ content_parts = []
490
+ for content in tool_result["content"]:
491
+ if "json" in content:
492
+ content_parts.append(json.dumps(content["json"]))
493
+ elif "text" in content:
494
+ content_parts.append(content["text"])
495
+ else:
496
+ # Handle other content types by converting to string
497
+ content_parts.append(str(content))
498
+
499
+ content_string = " ".join(content_parts)
500
+
501
+ return {
502
+ "role": "tool",
503
+ "tool_call_id": tool_result["toolUseId"],
504
+ "content": content_string, # String instead of list
505
+ }
506
+
507
+ @override
508
+ @classmethod
509
+ def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any]:
510
+ """Format a content block.
511
+
512
+ Args:
513
+ content: Message content.
514
+
515
+ Returns:
516
+ Formatted content block.
517
+
518
+ Raises:
519
+ TypeError: If the content block type cannot be converted to a SageMaker-compatible format.
520
+ """
521
+ # if "text" in content and not isinstance(content["text"], str):
522
+ # return {"type": "text", "text": str(content["text"])}
523
+
524
+ if "reasoningContent" in content and content["reasoningContent"]:
525
+ return {
526
+ "signature": content["reasoningContent"].get("reasoningText", {}).get("signature", ""),
527
+ "thinking": content["reasoningContent"].get("reasoningText", {}).get("text", ""),
528
+ "type": "thinking",
529
+ }
530
+ elif not content.get("reasoningContent", None):
531
+ content.pop("reasoningContent", None)
532
+
533
+ if "video" in content:
534
+ return {
535
+ "type": "video_url",
536
+ "video_url": {
537
+ "detail": "auto",
538
+ "url": content["video"]["source"]["bytes"],
539
+ },
540
+ }
541
+
542
+ return super().format_request_message_content(content)
543
+
544
+ @override
545
+ async def structured_output(
546
+ self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any
547
+ ) -> AsyncGenerator[dict[str, Union[T, Any]], None]:
548
+ """Get structured output from the model.
549
+
550
+ Args:
551
+ output_model: The output model to use for the agent.
552
+ prompt: The prompt messages to use for the agent.
553
+ system_prompt: System prompt to provide context to the model.
554
+ **kwargs: Additional keyword arguments for future extensibility.
555
+
556
+ Yields:
557
+ Model events with the last being the structured output.
558
+ """
559
+ # Format the request for structured output
560
+ request = self.format_request(prompt, system_prompt=system_prompt)
561
+
562
+ # Parse the payload to add response format
563
+ payload = json.loads(request["Body"])
564
+ payload["response_format"] = {
565
+ "type": "json_schema",
566
+ "json_schema": {"name": output_model.__name__, "schema": output_model.model_json_schema(), "strict": True},
567
+ }
568
+ request["Body"] = json.dumps(payload)
569
+
570
+ try:
571
+ # Use non-streaming mode for structured output
572
+ response = self.client.invoke_endpoint(**request)
573
+ final_response_json = json.loads(response["Body"].read().decode("utf-8"))
574
+
575
+ # Extract the structured content
576
+ message = final_response_json["choices"][0]["message"]
577
+
578
+ if message.get("content"):
579
+ try:
580
+ # Parse the JSON content and create the output model instance
581
+ content_data = json.loads(message["content"])
582
+ parsed_output = output_model(**content_data)
583
+ yield {"output": parsed_output}
584
+ except (json.JSONDecodeError, TypeError, ValueError) as e:
585
+ raise ValueError(f"Failed to parse structured output: {e}") from e
586
+ else:
587
+ raise ValueError("No content found in SageMaker response")
588
+
589
+ except (
590
+ self.client.exceptions.InternalFailure,
591
+ self.client.exceptions.ServiceUnavailable,
592
+ self.client.exceptions.ValidationError,
593
+ self.client.exceptions.ModelError,
594
+ self.client.exceptions.InternalDependencyException,
595
+ self.client.exceptions.ModelNotReadyException,
596
+ ) as e:
597
+ logger.error("SageMaker structured output error: %s", str(e))
598
+ raise ValueError(f"SageMaker structured output error: {str(e)}") from e