agentrun-sdk 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of agentrun-sdk might be problematic. Click here for more details.
- agentrun_operation_sdk/cli/__init__.py +1 -0
- agentrun_operation_sdk/cli/cli.py +19 -0
- agentrun_operation_sdk/cli/common.py +21 -0
- agentrun_operation_sdk/cli/runtime/__init__.py +1 -0
- agentrun_operation_sdk/cli/runtime/commands.py +203 -0
- agentrun_operation_sdk/client/client.py +75 -0
- agentrun_operation_sdk/operations/runtime/__init__.py +8 -0
- agentrun_operation_sdk/operations/runtime/configure.py +101 -0
- agentrun_operation_sdk/operations/runtime/launch.py +82 -0
- agentrun_operation_sdk/operations/runtime/models.py +31 -0
- agentrun_operation_sdk/services/runtime.py +152 -0
- agentrun_operation_sdk/utils/logging_config.py +72 -0
- agentrun_operation_sdk/utils/runtime/config.py +94 -0
- agentrun_operation_sdk/utils/runtime/container.py +280 -0
- agentrun_operation_sdk/utils/runtime/entrypoint.py +203 -0
- agentrun_operation_sdk/utils/runtime/schema.py +56 -0
- agentrun_sdk/__init__.py +7 -0
- agentrun_sdk/agent/__init__.py +25 -0
- agentrun_sdk/agent/agent.py +696 -0
- agentrun_sdk/agent/agent_result.py +46 -0
- agentrun_sdk/agent/conversation_manager/__init__.py +26 -0
- agentrun_sdk/agent/conversation_manager/conversation_manager.py +88 -0
- agentrun_sdk/agent/conversation_manager/null_conversation_manager.py +46 -0
- agentrun_sdk/agent/conversation_manager/sliding_window_conversation_manager.py +179 -0
- agentrun_sdk/agent/conversation_manager/summarizing_conversation_manager.py +252 -0
- agentrun_sdk/agent/state.py +97 -0
- agentrun_sdk/event_loop/__init__.py +9 -0
- agentrun_sdk/event_loop/event_loop.py +499 -0
- agentrun_sdk/event_loop/streaming.py +319 -0
- agentrun_sdk/experimental/__init__.py +4 -0
- agentrun_sdk/experimental/hooks/__init__.py +15 -0
- agentrun_sdk/experimental/hooks/events.py +123 -0
- agentrun_sdk/handlers/__init__.py +10 -0
- agentrun_sdk/handlers/callback_handler.py +70 -0
- agentrun_sdk/hooks/__init__.py +49 -0
- agentrun_sdk/hooks/events.py +80 -0
- agentrun_sdk/hooks/registry.py +247 -0
- agentrun_sdk/models/__init__.py +10 -0
- agentrun_sdk/models/anthropic.py +432 -0
- agentrun_sdk/models/bedrock.py +649 -0
- agentrun_sdk/models/litellm.py +225 -0
- agentrun_sdk/models/llamaapi.py +438 -0
- agentrun_sdk/models/mistral.py +539 -0
- agentrun_sdk/models/model.py +95 -0
- agentrun_sdk/models/ollama.py +357 -0
- agentrun_sdk/models/openai.py +436 -0
- agentrun_sdk/models/sagemaker.py +598 -0
- agentrun_sdk/models/writer.py +449 -0
- agentrun_sdk/multiagent/__init__.py +22 -0
- agentrun_sdk/multiagent/a2a/__init__.py +15 -0
- agentrun_sdk/multiagent/a2a/executor.py +148 -0
- agentrun_sdk/multiagent/a2a/server.py +252 -0
- agentrun_sdk/multiagent/base.py +92 -0
- agentrun_sdk/multiagent/graph.py +555 -0
- agentrun_sdk/multiagent/swarm.py +656 -0
- agentrun_sdk/py.typed +1 -0
- agentrun_sdk/session/__init__.py +18 -0
- agentrun_sdk/session/file_session_manager.py +216 -0
- agentrun_sdk/session/repository_session_manager.py +152 -0
- agentrun_sdk/session/s3_session_manager.py +272 -0
- agentrun_sdk/session/session_manager.py +73 -0
- agentrun_sdk/session/session_repository.py +51 -0
- agentrun_sdk/telemetry/__init__.py +21 -0
- agentrun_sdk/telemetry/config.py +194 -0
- agentrun_sdk/telemetry/metrics.py +476 -0
- agentrun_sdk/telemetry/metrics_constants.py +15 -0
- agentrun_sdk/telemetry/tracer.py +563 -0
- agentrun_sdk/tools/__init__.py +17 -0
- agentrun_sdk/tools/decorator.py +569 -0
- agentrun_sdk/tools/executor.py +137 -0
- agentrun_sdk/tools/loader.py +152 -0
- agentrun_sdk/tools/mcp/__init__.py +13 -0
- agentrun_sdk/tools/mcp/mcp_agent_tool.py +99 -0
- agentrun_sdk/tools/mcp/mcp_client.py +423 -0
- agentrun_sdk/tools/mcp/mcp_instrumentation.py +322 -0
- agentrun_sdk/tools/mcp/mcp_types.py +63 -0
- agentrun_sdk/tools/registry.py +607 -0
- agentrun_sdk/tools/structured_output.py +421 -0
- agentrun_sdk/tools/tools.py +217 -0
- agentrun_sdk/tools/watcher.py +136 -0
- agentrun_sdk/types/__init__.py +5 -0
- agentrun_sdk/types/collections.py +23 -0
- agentrun_sdk/types/content.py +188 -0
- agentrun_sdk/types/event_loop.py +48 -0
- agentrun_sdk/types/exceptions.py +81 -0
- agentrun_sdk/types/guardrails.py +254 -0
- agentrun_sdk/types/media.py +89 -0
- agentrun_sdk/types/session.py +152 -0
- agentrun_sdk/types/streaming.py +201 -0
- agentrun_sdk/types/tools.py +258 -0
- agentrun_sdk/types/traces.py +5 -0
- agentrun_sdk-0.1.2.dist-info/METADATA +51 -0
- agentrun_sdk-0.1.2.dist-info/RECORD +115 -0
- agentrun_sdk-0.1.2.dist-info/WHEEL +5 -0
- agentrun_sdk-0.1.2.dist-info/entry_points.txt +2 -0
- agentrun_sdk-0.1.2.dist-info/top_level.txt +3 -0
- agentrun_wrapper/__init__.py +11 -0
- agentrun_wrapper/_utils/__init__.py +6 -0
- agentrun_wrapper/_utils/endpoints.py +16 -0
- agentrun_wrapper/identity/__init__.py +5 -0
- agentrun_wrapper/identity/auth.py +211 -0
- agentrun_wrapper/memory/__init__.py +6 -0
- agentrun_wrapper/memory/client.py +1697 -0
- agentrun_wrapper/memory/constants.py +103 -0
- agentrun_wrapper/memory/controlplane.py +626 -0
- agentrun_wrapper/py.typed +1 -0
- agentrun_wrapper/runtime/__init__.py +13 -0
- agentrun_wrapper/runtime/app.py +473 -0
- agentrun_wrapper/runtime/context.py +34 -0
- agentrun_wrapper/runtime/models.py +25 -0
- agentrun_wrapper/services/__init__.py +1 -0
- agentrun_wrapper/services/identity.py +192 -0
- agentrun_wrapper/tools/__init__.py +6 -0
- agentrun_wrapper/tools/browser_client.py +325 -0
- agentrun_wrapper/tools/code_interpreter_client.py +186 -0
|
@@ -0,0 +1,598 @@
|
|
|
1
|
+
"""Amazon SageMaker model provider."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import logging
|
|
5
|
+
import os
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from typing import Any, AsyncGenerator, Literal, Optional, Type, TypedDict, TypeVar, Union, cast
|
|
8
|
+
|
|
9
|
+
import boto3
|
|
10
|
+
from botocore.config import Config as BotocoreConfig
|
|
11
|
+
from mypy_boto3_sagemaker_runtime import SageMakerRuntimeClient
|
|
12
|
+
from pydantic import BaseModel
|
|
13
|
+
from typing_extensions import Unpack, override
|
|
14
|
+
|
|
15
|
+
from ..types.content import ContentBlock, Messages
|
|
16
|
+
from ..types.streaming import StreamEvent
|
|
17
|
+
from ..types.tools import ToolResult, ToolSpec
|
|
18
|
+
from .openai import OpenAIModel
|
|
19
|
+
|
|
20
|
+
T = TypeVar("T", bound=BaseModel)
|
|
21
|
+
|
|
22
|
+
logger = logging.getLogger(__name__)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class UsageMetadata:
|
|
27
|
+
"""Usage metadata for the model.
|
|
28
|
+
|
|
29
|
+
Attributes:
|
|
30
|
+
total_tokens: Total number of tokens used in the request
|
|
31
|
+
completion_tokens: Number of tokens used in the completion
|
|
32
|
+
prompt_tokens: Number of tokens used in the prompt
|
|
33
|
+
prompt_tokens_details: Additional information about the prompt tokens (optional)
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
total_tokens: int
|
|
37
|
+
completion_tokens: int
|
|
38
|
+
prompt_tokens: int
|
|
39
|
+
prompt_tokens_details: Optional[int] = 0
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@dataclass
|
|
43
|
+
class FunctionCall:
|
|
44
|
+
"""Function call for the model.
|
|
45
|
+
|
|
46
|
+
Attributes:
|
|
47
|
+
name: Name of the function to call
|
|
48
|
+
arguments: Arguments to pass to the function
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
name: Union[str, dict[Any, Any]]
|
|
52
|
+
arguments: Union[str, dict[Any, Any]]
|
|
53
|
+
|
|
54
|
+
def __init__(self, **kwargs: dict[str, str]):
|
|
55
|
+
"""Initialize function call.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
**kwargs: Keyword arguments for the function call.
|
|
59
|
+
"""
|
|
60
|
+
self.name = kwargs.get("name", "")
|
|
61
|
+
self.arguments = kwargs.get("arguments", "")
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@dataclass
|
|
65
|
+
class ToolCall:
|
|
66
|
+
"""Tool call for the model object.
|
|
67
|
+
|
|
68
|
+
Attributes:
|
|
69
|
+
id: Tool call ID
|
|
70
|
+
type: Tool call type
|
|
71
|
+
function: Tool call function
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
id: str
|
|
75
|
+
type: Literal["function"]
|
|
76
|
+
function: FunctionCall
|
|
77
|
+
|
|
78
|
+
def __init__(self, **kwargs: dict):
|
|
79
|
+
"""Initialize tool call object.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
**kwargs: Keyword arguments for the tool call.
|
|
83
|
+
"""
|
|
84
|
+
self.id = str(kwargs.get("id", ""))
|
|
85
|
+
self.type = "function"
|
|
86
|
+
self.function = FunctionCall(**kwargs.get("function", {"name": "", "arguments": ""}))
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class SageMakerAIModel(OpenAIModel):
|
|
90
|
+
"""Amazon SageMaker model provider implementation."""
|
|
91
|
+
|
|
92
|
+
client: SageMakerRuntimeClient # type: ignore[assignment]
|
|
93
|
+
|
|
94
|
+
class SageMakerAIPayloadSchema(TypedDict, total=False):
|
|
95
|
+
"""Payload schema for the Amazon SageMaker AI model.
|
|
96
|
+
|
|
97
|
+
Attributes:
|
|
98
|
+
max_tokens: Maximum number of tokens to generate in the completion
|
|
99
|
+
stream: Whether to stream the response
|
|
100
|
+
temperature: Sampling temperature to use for the model (optional)
|
|
101
|
+
top_p: Nucleus sampling parameter (optional)
|
|
102
|
+
top_k: Top-k sampling parameter (optional)
|
|
103
|
+
stop: List of stop sequences to use for the model (optional)
|
|
104
|
+
tool_results_as_user_messages: Convert tool result to user messages (optional)
|
|
105
|
+
additional_args: Additional request parameters, as supported by https://bit.ly/djl-lmi-request-schema
|
|
106
|
+
"""
|
|
107
|
+
|
|
108
|
+
max_tokens: int
|
|
109
|
+
stream: bool
|
|
110
|
+
temperature: Optional[float]
|
|
111
|
+
top_p: Optional[float]
|
|
112
|
+
top_k: Optional[int]
|
|
113
|
+
stop: Optional[list[str]]
|
|
114
|
+
tool_results_as_user_messages: Optional[bool]
|
|
115
|
+
additional_args: Optional[dict[str, Any]]
|
|
116
|
+
|
|
117
|
+
class SageMakerAIEndpointConfig(TypedDict, total=False):
|
|
118
|
+
"""Configuration options for SageMaker models.
|
|
119
|
+
|
|
120
|
+
Attributes:
|
|
121
|
+
endpoint_name: The name of the SageMaker endpoint to invoke
|
|
122
|
+
inference_component_name: The name of the inference component to use
|
|
123
|
+
|
|
124
|
+
additional_args: Other request parameters, as supported by https://bit.ly/sagemaker-invoke-endpoint-params
|
|
125
|
+
"""
|
|
126
|
+
|
|
127
|
+
endpoint_name: str
|
|
128
|
+
region_name: str
|
|
129
|
+
inference_component_name: Union[str, None]
|
|
130
|
+
target_model: Union[Optional[str], None]
|
|
131
|
+
target_variant: Union[Optional[str], None]
|
|
132
|
+
additional_args: Optional[dict[str, Any]]
|
|
133
|
+
|
|
134
|
+
def __init__(
|
|
135
|
+
self,
|
|
136
|
+
endpoint_config: SageMakerAIEndpointConfig,
|
|
137
|
+
payload_config: SageMakerAIPayloadSchema,
|
|
138
|
+
boto_session: Optional[boto3.Session] = None,
|
|
139
|
+
boto_client_config: Optional[BotocoreConfig] = None,
|
|
140
|
+
):
|
|
141
|
+
"""Initialize provider instance.
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
endpoint_config: Endpoint configuration for SageMaker.
|
|
145
|
+
payload_config: Payload configuration for the model.
|
|
146
|
+
boto_session: Boto Session to use when calling the SageMaker Runtime.
|
|
147
|
+
boto_client_config: Configuration to use when creating the SageMaker-Runtime Boto Client.
|
|
148
|
+
"""
|
|
149
|
+
payload_config.setdefault("stream", True)
|
|
150
|
+
payload_config.setdefault("tool_results_as_user_messages", False)
|
|
151
|
+
self.endpoint_config = dict(endpoint_config)
|
|
152
|
+
self.payload_config = dict(payload_config)
|
|
153
|
+
logger.debug(
|
|
154
|
+
"endpoint_config=<%s> payload_config=<%s> | initializing", self.endpoint_config, self.payload_config
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
region = self.endpoint_config.get("region_name") or os.getenv("AWS_REGION") or "us-west-2"
|
|
158
|
+
session = boto_session or boto3.Session(region_name=str(region))
|
|
159
|
+
|
|
160
|
+
# Add strands-agents to the request user agent
|
|
161
|
+
if boto_client_config:
|
|
162
|
+
existing_user_agent = getattr(boto_client_config, "user_agent_extra", None)
|
|
163
|
+
|
|
164
|
+
# Append 'strands-agents' to existing user_agent_extra or set it if not present
|
|
165
|
+
new_user_agent = f"{existing_user_agent} strands-agents" if existing_user_agent else "strands-agents"
|
|
166
|
+
|
|
167
|
+
client_config = boto_client_config.merge(BotocoreConfig(user_agent_extra=new_user_agent))
|
|
168
|
+
else:
|
|
169
|
+
client_config = BotocoreConfig(user_agent_extra="strands-agents")
|
|
170
|
+
|
|
171
|
+
self.client = session.client(
|
|
172
|
+
service_name="sagemaker-runtime",
|
|
173
|
+
config=client_config,
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
@override
|
|
177
|
+
def update_config(self, **endpoint_config: Unpack[SageMakerAIEndpointConfig]) -> None: # type: ignore[override]
|
|
178
|
+
"""Update the Amazon SageMaker model configuration with the provided arguments.
|
|
179
|
+
|
|
180
|
+
Args:
|
|
181
|
+
**endpoint_config: Configuration overrides.
|
|
182
|
+
"""
|
|
183
|
+
self.endpoint_config.update(endpoint_config)
|
|
184
|
+
|
|
185
|
+
@override
|
|
186
|
+
def get_config(self) -> "SageMakerAIModel.SageMakerAIEndpointConfig": # type: ignore[override]
|
|
187
|
+
"""Get the Amazon SageMaker model configuration.
|
|
188
|
+
|
|
189
|
+
Returns:
|
|
190
|
+
The Amazon SageMaker model configuration.
|
|
191
|
+
"""
|
|
192
|
+
return cast(SageMakerAIModel.SageMakerAIEndpointConfig, self.endpoint_config)
|
|
193
|
+
|
|
194
|
+
@override
|
|
195
|
+
def format_request(
|
|
196
|
+
self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None
|
|
197
|
+
) -> dict[str, Any]:
|
|
198
|
+
"""Format an Amazon SageMaker chat streaming request.
|
|
199
|
+
|
|
200
|
+
Args:
|
|
201
|
+
messages: List of message objects to be processed by the model.
|
|
202
|
+
tool_specs: List of tool specifications to make available to the model.
|
|
203
|
+
system_prompt: System prompt to provide context to the model.
|
|
204
|
+
|
|
205
|
+
Returns:
|
|
206
|
+
An Amazon SageMaker chat streaming request.
|
|
207
|
+
"""
|
|
208
|
+
formatted_messages = self.format_request_messages(messages, system_prompt)
|
|
209
|
+
|
|
210
|
+
payload = {
|
|
211
|
+
"messages": formatted_messages,
|
|
212
|
+
"tools": [
|
|
213
|
+
{
|
|
214
|
+
"type": "function",
|
|
215
|
+
"function": {
|
|
216
|
+
"name": tool_spec["name"],
|
|
217
|
+
"description": tool_spec["description"],
|
|
218
|
+
"parameters": tool_spec["inputSchema"]["json"],
|
|
219
|
+
},
|
|
220
|
+
}
|
|
221
|
+
for tool_spec in tool_specs or []
|
|
222
|
+
],
|
|
223
|
+
# Add payload configuration parameters
|
|
224
|
+
**{
|
|
225
|
+
k: v
|
|
226
|
+
for k, v in self.payload_config.items()
|
|
227
|
+
if k not in ["additional_args", "tool_results_as_user_messages"]
|
|
228
|
+
},
|
|
229
|
+
}
|
|
230
|
+
|
|
231
|
+
# Remove tools and tool_choice if tools = []
|
|
232
|
+
if not payload["tools"]:
|
|
233
|
+
payload.pop("tools")
|
|
234
|
+
payload.pop("tool_choice", None)
|
|
235
|
+
else:
|
|
236
|
+
# Ensure the model can use tools when available
|
|
237
|
+
payload["tool_choice"] = "auto"
|
|
238
|
+
|
|
239
|
+
for message in payload["messages"]: # type: ignore
|
|
240
|
+
# Assistant message must have either content or tool_calls, but not both
|
|
241
|
+
if message.get("role", "") == "assistant" and message.get("tool_calls", []) != []:
|
|
242
|
+
message.pop("content", None)
|
|
243
|
+
if message.get("role") == "tool" and self.payload_config.get("tool_results_as_user_messages", False):
|
|
244
|
+
# Convert tool message to user message
|
|
245
|
+
tool_call_id = message.get("tool_call_id", "ABCDEF")
|
|
246
|
+
content = message.get("content", "")
|
|
247
|
+
message = {"role": "user", "content": f"Tool call ID '{tool_call_id}' returned: {content}"}
|
|
248
|
+
# Cannot have both reasoning_text and text - if "text", content becomes an array of content["text"]
|
|
249
|
+
for c in message.get("content", []):
|
|
250
|
+
if "text" in c:
|
|
251
|
+
message["content"] = [c]
|
|
252
|
+
break
|
|
253
|
+
# Cast message content to string for TGI compatibility
|
|
254
|
+
# message["content"] = str(message.get("content", ""))
|
|
255
|
+
|
|
256
|
+
logger.info("payload=<%s>", json.dumps(payload, indent=2))
|
|
257
|
+
# Format the request according to the SageMaker Runtime API requirements
|
|
258
|
+
request = {
|
|
259
|
+
"EndpointName": self.endpoint_config["endpoint_name"],
|
|
260
|
+
"Body": json.dumps(payload),
|
|
261
|
+
"ContentType": "application/json",
|
|
262
|
+
"Accept": "application/json",
|
|
263
|
+
}
|
|
264
|
+
|
|
265
|
+
# Add optional SageMaker parameters if provided
|
|
266
|
+
if self.endpoint_config.get("inference_component_name"):
|
|
267
|
+
request["InferenceComponentName"] = self.endpoint_config["inference_component_name"]
|
|
268
|
+
if self.endpoint_config.get("target_model"):
|
|
269
|
+
request["TargetModel"] = self.endpoint_config["target_model"]
|
|
270
|
+
if self.endpoint_config.get("target_variant"):
|
|
271
|
+
request["TargetVariant"] = self.endpoint_config["target_variant"]
|
|
272
|
+
|
|
273
|
+
# Add additional args if provided
|
|
274
|
+
if self.endpoint_config.get("additional_args"):
|
|
275
|
+
request.update(self.endpoint_config["additional_args"].__dict__)
|
|
276
|
+
|
|
277
|
+
return request
|
|
278
|
+
|
|
279
|
+
@override
|
|
280
|
+
async def stream(
|
|
281
|
+
self,
|
|
282
|
+
messages: Messages,
|
|
283
|
+
tool_specs: Optional[list[ToolSpec]] = None,
|
|
284
|
+
system_prompt: Optional[str] = None,
|
|
285
|
+
**kwargs: Any,
|
|
286
|
+
) -> AsyncGenerator[StreamEvent, None]:
|
|
287
|
+
"""Stream conversation with the SageMaker model.
|
|
288
|
+
|
|
289
|
+
Args:
|
|
290
|
+
messages: List of message objects to be processed by the model.
|
|
291
|
+
tool_specs: List of tool specifications to make available to the model.
|
|
292
|
+
system_prompt: System prompt to provide context to the model.
|
|
293
|
+
**kwargs: Additional keyword arguments for future extensibility.
|
|
294
|
+
|
|
295
|
+
Yields:
|
|
296
|
+
Formatted message chunks from the model.
|
|
297
|
+
"""
|
|
298
|
+
logger.debug("formatting request")
|
|
299
|
+
request = self.format_request(messages, tool_specs, system_prompt)
|
|
300
|
+
logger.debug("formatted request=<%s>", request)
|
|
301
|
+
|
|
302
|
+
logger.debug("invoking model")
|
|
303
|
+
try:
|
|
304
|
+
if self.payload_config.get("stream", True):
|
|
305
|
+
response = self.client.invoke_endpoint_with_response_stream(**request)
|
|
306
|
+
|
|
307
|
+
# Message start
|
|
308
|
+
yield self.format_chunk({"chunk_type": "message_start"})
|
|
309
|
+
|
|
310
|
+
# Parse the content
|
|
311
|
+
finish_reason = ""
|
|
312
|
+
partial_content = ""
|
|
313
|
+
tool_calls: dict[int, list[Any]] = {}
|
|
314
|
+
has_text_content = False
|
|
315
|
+
text_content_started = False
|
|
316
|
+
reasoning_content_started = False
|
|
317
|
+
|
|
318
|
+
for event in response["Body"]:
|
|
319
|
+
chunk = event["PayloadPart"]["Bytes"].decode("utf-8")
|
|
320
|
+
partial_content += chunk[6:] if chunk.startswith("data: ") else chunk # TGI fix
|
|
321
|
+
logger.info("chunk=<%s>", partial_content)
|
|
322
|
+
try:
|
|
323
|
+
content = json.loads(partial_content)
|
|
324
|
+
partial_content = ""
|
|
325
|
+
choice = content["choices"][0]
|
|
326
|
+
logger.info("choice=<%s>", json.dumps(choice, indent=2))
|
|
327
|
+
|
|
328
|
+
# Handle text content
|
|
329
|
+
if choice["delta"].get("content", None):
|
|
330
|
+
if not text_content_started:
|
|
331
|
+
yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"})
|
|
332
|
+
text_content_started = True
|
|
333
|
+
has_text_content = True
|
|
334
|
+
yield self.format_chunk(
|
|
335
|
+
{
|
|
336
|
+
"chunk_type": "content_delta",
|
|
337
|
+
"data_type": "text",
|
|
338
|
+
"data": choice["delta"]["content"],
|
|
339
|
+
}
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
# Handle reasoning content
|
|
343
|
+
if choice["delta"].get("reasoning_content", None):
|
|
344
|
+
if not reasoning_content_started:
|
|
345
|
+
yield self.format_chunk(
|
|
346
|
+
{"chunk_type": "content_start", "data_type": "reasoning_content"}
|
|
347
|
+
)
|
|
348
|
+
reasoning_content_started = True
|
|
349
|
+
yield self.format_chunk(
|
|
350
|
+
{
|
|
351
|
+
"chunk_type": "content_delta",
|
|
352
|
+
"data_type": "reasoning_content",
|
|
353
|
+
"data": choice["delta"]["reasoning_content"],
|
|
354
|
+
}
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
# Handle tool calls
|
|
358
|
+
generated_tool_calls = choice["delta"].get("tool_calls", [])
|
|
359
|
+
if not isinstance(generated_tool_calls, list):
|
|
360
|
+
generated_tool_calls = [generated_tool_calls]
|
|
361
|
+
for tool_call in generated_tool_calls:
|
|
362
|
+
tool_calls.setdefault(tool_call["index"], []).append(tool_call)
|
|
363
|
+
|
|
364
|
+
if choice["finish_reason"] is not None:
|
|
365
|
+
finish_reason = choice["finish_reason"]
|
|
366
|
+
break
|
|
367
|
+
|
|
368
|
+
if choice.get("usage", None):
|
|
369
|
+
yield self.format_chunk(
|
|
370
|
+
{"chunk_type": "metadata", "data": UsageMetadata(**choice["usage"])}
|
|
371
|
+
)
|
|
372
|
+
|
|
373
|
+
except json.JSONDecodeError:
|
|
374
|
+
# Continue accumulating content until we have valid JSON
|
|
375
|
+
continue
|
|
376
|
+
|
|
377
|
+
# Close reasoning content if it was started
|
|
378
|
+
if reasoning_content_started:
|
|
379
|
+
yield self.format_chunk({"chunk_type": "content_stop", "data_type": "reasoning_content"})
|
|
380
|
+
|
|
381
|
+
# Close text content if it was started
|
|
382
|
+
if text_content_started:
|
|
383
|
+
yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"})
|
|
384
|
+
|
|
385
|
+
# Handle tool calling
|
|
386
|
+
logger.info("tool_calls=<%s>", json.dumps(tool_calls, indent=2))
|
|
387
|
+
for tool_deltas in tool_calls.values():
|
|
388
|
+
if not tool_deltas[0]["function"].get("name", None):
|
|
389
|
+
raise Exception("The model did not provide a tool name.")
|
|
390
|
+
yield self.format_chunk(
|
|
391
|
+
{"chunk_type": "content_start", "data_type": "tool", "data": ToolCall(**tool_deltas[0])}
|
|
392
|
+
)
|
|
393
|
+
for tool_delta in tool_deltas:
|
|
394
|
+
yield self.format_chunk(
|
|
395
|
+
{"chunk_type": "content_delta", "data_type": "tool", "data": ToolCall(**tool_delta)}
|
|
396
|
+
)
|
|
397
|
+
yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"})
|
|
398
|
+
|
|
399
|
+
# If no content was generated at all, ensure we have empty text content
|
|
400
|
+
if not has_text_content and not tool_calls:
|
|
401
|
+
yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"})
|
|
402
|
+
yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"})
|
|
403
|
+
|
|
404
|
+
# Message close
|
|
405
|
+
yield self.format_chunk({"chunk_type": "message_stop", "data": finish_reason})
|
|
406
|
+
|
|
407
|
+
else:
|
|
408
|
+
# Not all SageMaker AI models support streaming!
|
|
409
|
+
response = self.client.invoke_endpoint(**request) # type: ignore[assignment]
|
|
410
|
+
final_response_json = json.loads(response["Body"].read().decode("utf-8")) # type: ignore[attr-defined]
|
|
411
|
+
logger.info("response=<%s>", json.dumps(final_response_json, indent=2))
|
|
412
|
+
|
|
413
|
+
# Obtain the key elements from the response
|
|
414
|
+
message = final_response_json["choices"][0]["message"]
|
|
415
|
+
message_stop_reason = final_response_json["choices"][0]["finish_reason"]
|
|
416
|
+
|
|
417
|
+
# Message start
|
|
418
|
+
yield self.format_chunk({"chunk_type": "message_start"})
|
|
419
|
+
|
|
420
|
+
# Handle text
|
|
421
|
+
if message.get("content", ""):
|
|
422
|
+
yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"})
|
|
423
|
+
yield self.format_chunk(
|
|
424
|
+
{"chunk_type": "content_delta", "data_type": "text", "data": message["content"]}
|
|
425
|
+
)
|
|
426
|
+
yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"})
|
|
427
|
+
|
|
428
|
+
# Handle reasoning content
|
|
429
|
+
if message.get("reasoning_content", None):
|
|
430
|
+
yield self.format_chunk({"chunk_type": "content_start", "data_type": "reasoning_content"})
|
|
431
|
+
yield self.format_chunk(
|
|
432
|
+
{
|
|
433
|
+
"chunk_type": "content_delta",
|
|
434
|
+
"data_type": "reasoning_content",
|
|
435
|
+
"data": message["reasoning_content"],
|
|
436
|
+
}
|
|
437
|
+
)
|
|
438
|
+
yield self.format_chunk({"chunk_type": "content_stop", "data_type": "reasoning_content"})
|
|
439
|
+
|
|
440
|
+
# Handle the tool calling, if any
|
|
441
|
+
if message.get("tool_calls", None) or message_stop_reason == "tool_calls":
|
|
442
|
+
if not isinstance(message["tool_calls"], list):
|
|
443
|
+
message["tool_calls"] = [message["tool_calls"]]
|
|
444
|
+
for tool_call in message["tool_calls"]:
|
|
445
|
+
# if arguments of tool_call is not str, cast it
|
|
446
|
+
if not isinstance(tool_call["function"]["arguments"], str):
|
|
447
|
+
tool_call["function"]["arguments"] = json.dumps(tool_call["function"]["arguments"])
|
|
448
|
+
yield self.format_chunk(
|
|
449
|
+
{"chunk_type": "content_start", "data_type": "tool", "data": ToolCall(**tool_call)}
|
|
450
|
+
)
|
|
451
|
+
yield self.format_chunk(
|
|
452
|
+
{"chunk_type": "content_delta", "data_type": "tool", "data": ToolCall(**tool_call)}
|
|
453
|
+
)
|
|
454
|
+
yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"})
|
|
455
|
+
message_stop_reason = "tool_calls"
|
|
456
|
+
|
|
457
|
+
# Message close
|
|
458
|
+
yield self.format_chunk({"chunk_type": "message_stop", "data": message_stop_reason})
|
|
459
|
+
# Handle usage metadata
|
|
460
|
+
if final_response_json.get("usage", None):
|
|
461
|
+
yield self.format_chunk(
|
|
462
|
+
{"chunk_type": "metadata", "data": UsageMetadata(**final_response_json.get("usage", None))}
|
|
463
|
+
)
|
|
464
|
+
except (
|
|
465
|
+
self.client.exceptions.InternalFailure,
|
|
466
|
+
self.client.exceptions.ServiceUnavailable,
|
|
467
|
+
self.client.exceptions.ValidationError,
|
|
468
|
+
self.client.exceptions.ModelError,
|
|
469
|
+
self.client.exceptions.InternalDependencyException,
|
|
470
|
+
self.client.exceptions.ModelNotReadyException,
|
|
471
|
+
) as e:
|
|
472
|
+
logger.error("SageMaker error: %s", str(e))
|
|
473
|
+
raise e
|
|
474
|
+
|
|
475
|
+
logger.debug("finished streaming response from model")
|
|
476
|
+
|
|
477
|
+
@override
|
|
478
|
+
@classmethod
|
|
479
|
+
def format_request_tool_message(cls, tool_result: ToolResult) -> dict[str, Any]:
|
|
480
|
+
"""Format a SageMaker compatible tool message.
|
|
481
|
+
|
|
482
|
+
Args:
|
|
483
|
+
tool_result: Tool result collected from a tool execution.
|
|
484
|
+
|
|
485
|
+
Returns:
|
|
486
|
+
SageMaker compatible tool message with content as a string.
|
|
487
|
+
"""
|
|
488
|
+
# Convert content blocks to a simple string for SageMaker compatibility
|
|
489
|
+
content_parts = []
|
|
490
|
+
for content in tool_result["content"]:
|
|
491
|
+
if "json" in content:
|
|
492
|
+
content_parts.append(json.dumps(content["json"]))
|
|
493
|
+
elif "text" in content:
|
|
494
|
+
content_parts.append(content["text"])
|
|
495
|
+
else:
|
|
496
|
+
# Handle other content types by converting to string
|
|
497
|
+
content_parts.append(str(content))
|
|
498
|
+
|
|
499
|
+
content_string = " ".join(content_parts)
|
|
500
|
+
|
|
501
|
+
return {
|
|
502
|
+
"role": "tool",
|
|
503
|
+
"tool_call_id": tool_result["toolUseId"],
|
|
504
|
+
"content": content_string, # String instead of list
|
|
505
|
+
}
|
|
506
|
+
|
|
507
|
+
@override
|
|
508
|
+
@classmethod
|
|
509
|
+
def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any]:
|
|
510
|
+
"""Format a content block.
|
|
511
|
+
|
|
512
|
+
Args:
|
|
513
|
+
content: Message content.
|
|
514
|
+
|
|
515
|
+
Returns:
|
|
516
|
+
Formatted content block.
|
|
517
|
+
|
|
518
|
+
Raises:
|
|
519
|
+
TypeError: If the content block type cannot be converted to a SageMaker-compatible format.
|
|
520
|
+
"""
|
|
521
|
+
# if "text" in content and not isinstance(content["text"], str):
|
|
522
|
+
# return {"type": "text", "text": str(content["text"])}
|
|
523
|
+
|
|
524
|
+
if "reasoningContent" in content and content["reasoningContent"]:
|
|
525
|
+
return {
|
|
526
|
+
"signature": content["reasoningContent"].get("reasoningText", {}).get("signature", ""),
|
|
527
|
+
"thinking": content["reasoningContent"].get("reasoningText", {}).get("text", ""),
|
|
528
|
+
"type": "thinking",
|
|
529
|
+
}
|
|
530
|
+
elif not content.get("reasoningContent", None):
|
|
531
|
+
content.pop("reasoningContent", None)
|
|
532
|
+
|
|
533
|
+
if "video" in content:
|
|
534
|
+
return {
|
|
535
|
+
"type": "video_url",
|
|
536
|
+
"video_url": {
|
|
537
|
+
"detail": "auto",
|
|
538
|
+
"url": content["video"]["source"]["bytes"],
|
|
539
|
+
},
|
|
540
|
+
}
|
|
541
|
+
|
|
542
|
+
return super().format_request_message_content(content)
|
|
543
|
+
|
|
544
|
+
@override
|
|
545
|
+
async def structured_output(
|
|
546
|
+
self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any
|
|
547
|
+
) -> AsyncGenerator[dict[str, Union[T, Any]], None]:
|
|
548
|
+
"""Get structured output from the model.
|
|
549
|
+
|
|
550
|
+
Args:
|
|
551
|
+
output_model: The output model to use for the agent.
|
|
552
|
+
prompt: The prompt messages to use for the agent.
|
|
553
|
+
system_prompt: System prompt to provide context to the model.
|
|
554
|
+
**kwargs: Additional keyword arguments for future extensibility.
|
|
555
|
+
|
|
556
|
+
Yields:
|
|
557
|
+
Model events with the last being the structured output.
|
|
558
|
+
"""
|
|
559
|
+
# Format the request for structured output
|
|
560
|
+
request = self.format_request(prompt, system_prompt=system_prompt)
|
|
561
|
+
|
|
562
|
+
# Parse the payload to add response format
|
|
563
|
+
payload = json.loads(request["Body"])
|
|
564
|
+
payload["response_format"] = {
|
|
565
|
+
"type": "json_schema",
|
|
566
|
+
"json_schema": {"name": output_model.__name__, "schema": output_model.model_json_schema(), "strict": True},
|
|
567
|
+
}
|
|
568
|
+
request["Body"] = json.dumps(payload)
|
|
569
|
+
|
|
570
|
+
try:
|
|
571
|
+
# Use non-streaming mode for structured output
|
|
572
|
+
response = self.client.invoke_endpoint(**request)
|
|
573
|
+
final_response_json = json.loads(response["Body"].read().decode("utf-8"))
|
|
574
|
+
|
|
575
|
+
# Extract the structured content
|
|
576
|
+
message = final_response_json["choices"][0]["message"]
|
|
577
|
+
|
|
578
|
+
if message.get("content"):
|
|
579
|
+
try:
|
|
580
|
+
# Parse the JSON content and create the output model instance
|
|
581
|
+
content_data = json.loads(message["content"])
|
|
582
|
+
parsed_output = output_model(**content_data)
|
|
583
|
+
yield {"output": parsed_output}
|
|
584
|
+
except (json.JSONDecodeError, TypeError, ValueError) as e:
|
|
585
|
+
raise ValueError(f"Failed to parse structured output: {e}") from e
|
|
586
|
+
else:
|
|
587
|
+
raise ValueError("No content found in SageMaker response")
|
|
588
|
+
|
|
589
|
+
except (
|
|
590
|
+
self.client.exceptions.InternalFailure,
|
|
591
|
+
self.client.exceptions.ServiceUnavailable,
|
|
592
|
+
self.client.exceptions.ValidationError,
|
|
593
|
+
self.client.exceptions.ModelError,
|
|
594
|
+
self.client.exceptions.InternalDependencyException,
|
|
595
|
+
self.client.exceptions.ModelNotReadyException,
|
|
596
|
+
) as e:
|
|
597
|
+
logger.error("SageMaker structured output error: %s", str(e))
|
|
598
|
+
raise ValueError(f"SageMaker structured output error: {str(e)}") from e
|