agentrun-sdk 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of agentrun-sdk might be problematic. Click here for more details.
- agentrun_operation_sdk/cli/__init__.py +1 -0
- agentrun_operation_sdk/cli/cli.py +19 -0
- agentrun_operation_sdk/cli/common.py +21 -0
- agentrun_operation_sdk/cli/runtime/__init__.py +1 -0
- agentrun_operation_sdk/cli/runtime/commands.py +203 -0
- agentrun_operation_sdk/client/client.py +75 -0
- agentrun_operation_sdk/operations/runtime/__init__.py +8 -0
- agentrun_operation_sdk/operations/runtime/configure.py +101 -0
- agentrun_operation_sdk/operations/runtime/launch.py +82 -0
- agentrun_operation_sdk/operations/runtime/models.py +31 -0
- agentrun_operation_sdk/services/runtime.py +152 -0
- agentrun_operation_sdk/utils/logging_config.py +72 -0
- agentrun_operation_sdk/utils/runtime/config.py +94 -0
- agentrun_operation_sdk/utils/runtime/container.py +280 -0
- agentrun_operation_sdk/utils/runtime/entrypoint.py +203 -0
- agentrun_operation_sdk/utils/runtime/schema.py +56 -0
- agentrun_sdk/__init__.py +7 -0
- agentrun_sdk/agent/__init__.py +25 -0
- agentrun_sdk/agent/agent.py +696 -0
- agentrun_sdk/agent/agent_result.py +46 -0
- agentrun_sdk/agent/conversation_manager/__init__.py +26 -0
- agentrun_sdk/agent/conversation_manager/conversation_manager.py +88 -0
- agentrun_sdk/agent/conversation_manager/null_conversation_manager.py +46 -0
- agentrun_sdk/agent/conversation_manager/sliding_window_conversation_manager.py +179 -0
- agentrun_sdk/agent/conversation_manager/summarizing_conversation_manager.py +252 -0
- agentrun_sdk/agent/state.py +97 -0
- agentrun_sdk/event_loop/__init__.py +9 -0
- agentrun_sdk/event_loop/event_loop.py +499 -0
- agentrun_sdk/event_loop/streaming.py +319 -0
- agentrun_sdk/experimental/__init__.py +4 -0
- agentrun_sdk/experimental/hooks/__init__.py +15 -0
- agentrun_sdk/experimental/hooks/events.py +123 -0
- agentrun_sdk/handlers/__init__.py +10 -0
- agentrun_sdk/handlers/callback_handler.py +70 -0
- agentrun_sdk/hooks/__init__.py +49 -0
- agentrun_sdk/hooks/events.py +80 -0
- agentrun_sdk/hooks/registry.py +247 -0
- agentrun_sdk/models/__init__.py +10 -0
- agentrun_sdk/models/anthropic.py +432 -0
- agentrun_sdk/models/bedrock.py +649 -0
- agentrun_sdk/models/litellm.py +225 -0
- agentrun_sdk/models/llamaapi.py +438 -0
- agentrun_sdk/models/mistral.py +539 -0
- agentrun_sdk/models/model.py +95 -0
- agentrun_sdk/models/ollama.py +357 -0
- agentrun_sdk/models/openai.py +436 -0
- agentrun_sdk/models/sagemaker.py +598 -0
- agentrun_sdk/models/writer.py +449 -0
- agentrun_sdk/multiagent/__init__.py +22 -0
- agentrun_sdk/multiagent/a2a/__init__.py +15 -0
- agentrun_sdk/multiagent/a2a/executor.py +148 -0
- agentrun_sdk/multiagent/a2a/server.py +252 -0
- agentrun_sdk/multiagent/base.py +92 -0
- agentrun_sdk/multiagent/graph.py +555 -0
- agentrun_sdk/multiagent/swarm.py +656 -0
- agentrun_sdk/py.typed +1 -0
- agentrun_sdk/session/__init__.py +18 -0
- agentrun_sdk/session/file_session_manager.py +216 -0
- agentrun_sdk/session/repository_session_manager.py +152 -0
- agentrun_sdk/session/s3_session_manager.py +272 -0
- agentrun_sdk/session/session_manager.py +73 -0
- agentrun_sdk/session/session_repository.py +51 -0
- agentrun_sdk/telemetry/__init__.py +21 -0
- agentrun_sdk/telemetry/config.py +194 -0
- agentrun_sdk/telemetry/metrics.py +476 -0
- agentrun_sdk/telemetry/metrics_constants.py +15 -0
- agentrun_sdk/telemetry/tracer.py +563 -0
- agentrun_sdk/tools/__init__.py +17 -0
- agentrun_sdk/tools/decorator.py +569 -0
- agentrun_sdk/tools/executor.py +137 -0
- agentrun_sdk/tools/loader.py +152 -0
- agentrun_sdk/tools/mcp/__init__.py +13 -0
- agentrun_sdk/tools/mcp/mcp_agent_tool.py +99 -0
- agentrun_sdk/tools/mcp/mcp_client.py +423 -0
- agentrun_sdk/tools/mcp/mcp_instrumentation.py +322 -0
- agentrun_sdk/tools/mcp/mcp_types.py +63 -0
- agentrun_sdk/tools/registry.py +607 -0
- agentrun_sdk/tools/structured_output.py +421 -0
- agentrun_sdk/tools/tools.py +217 -0
- agentrun_sdk/tools/watcher.py +136 -0
- agentrun_sdk/types/__init__.py +5 -0
- agentrun_sdk/types/collections.py +23 -0
- agentrun_sdk/types/content.py +188 -0
- agentrun_sdk/types/event_loop.py +48 -0
- agentrun_sdk/types/exceptions.py +81 -0
- agentrun_sdk/types/guardrails.py +254 -0
- agentrun_sdk/types/media.py +89 -0
- agentrun_sdk/types/session.py +152 -0
- agentrun_sdk/types/streaming.py +201 -0
- agentrun_sdk/types/tools.py +258 -0
- agentrun_sdk/types/traces.py +5 -0
- agentrun_sdk-0.1.2.dist-info/METADATA +51 -0
- agentrun_sdk-0.1.2.dist-info/RECORD +115 -0
- agentrun_sdk-0.1.2.dist-info/WHEEL +5 -0
- agentrun_sdk-0.1.2.dist-info/entry_points.txt +2 -0
- agentrun_sdk-0.1.2.dist-info/top_level.txt +3 -0
- agentrun_wrapper/__init__.py +11 -0
- agentrun_wrapper/_utils/__init__.py +6 -0
- agentrun_wrapper/_utils/endpoints.py +16 -0
- agentrun_wrapper/identity/__init__.py +5 -0
- agentrun_wrapper/identity/auth.py +211 -0
- agentrun_wrapper/memory/__init__.py +6 -0
- agentrun_wrapper/memory/client.py +1697 -0
- agentrun_wrapper/memory/constants.py +103 -0
- agentrun_wrapper/memory/controlplane.py +626 -0
- agentrun_wrapper/py.typed +1 -0
- agentrun_wrapper/runtime/__init__.py +13 -0
- agentrun_wrapper/runtime/app.py +473 -0
- agentrun_wrapper/runtime/context.py +34 -0
- agentrun_wrapper/runtime/models.py +25 -0
- agentrun_wrapper/services/__init__.py +1 -0
- agentrun_wrapper/services/identity.py +192 -0
- agentrun_wrapper/tools/__init__.py +6 -0
- agentrun_wrapper/tools/browser_client.py +325 -0
- agentrun_wrapper/tools/code_interpreter_client.py +186 -0
|
@@ -0,0 +1,449 @@
|
|
|
1
|
+
"""Writer model provider.
|
|
2
|
+
|
|
3
|
+
- Docs: https://dev.writer.com/home/introduction
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import base64
|
|
7
|
+
import json
|
|
8
|
+
import logging
|
|
9
|
+
import mimetypes
|
|
10
|
+
from typing import Any, AsyncGenerator, Dict, List, Optional, Type, TypedDict, TypeVar, Union, cast
|
|
11
|
+
|
|
12
|
+
import writerai
|
|
13
|
+
from pydantic import BaseModel
|
|
14
|
+
from typing_extensions import Unpack, override
|
|
15
|
+
|
|
16
|
+
from ..types.content import ContentBlock, Messages
|
|
17
|
+
from ..types.exceptions import ModelThrottledException
|
|
18
|
+
from ..types.streaming import StreamEvent
|
|
19
|
+
from ..types.tools import ToolResult, ToolSpec, ToolUse
|
|
20
|
+
from .model import Model
|
|
21
|
+
|
|
22
|
+
logger = logging.getLogger(__name__)
|
|
23
|
+
|
|
24
|
+
T = TypeVar("T", bound=BaseModel)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class WriterModel(Model):
|
|
28
|
+
"""Writer API model provider implementation."""
|
|
29
|
+
|
|
30
|
+
class WriterConfig(TypedDict, total=False):
|
|
31
|
+
"""Configuration options for Writer API.
|
|
32
|
+
|
|
33
|
+
Attributes:
|
|
34
|
+
model_id: Model name to use (e.g. palmyra-x5, palmyra-x4, etc.).
|
|
35
|
+
max_tokens: Maximum number of tokens to generate.
|
|
36
|
+
stop: Default stop sequences.
|
|
37
|
+
stream_options: Additional options for streaming.
|
|
38
|
+
temperature: What sampling temperature to use.
|
|
39
|
+
top_p: Threshold for 'nucleus sampling'
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
model_id: str
|
|
43
|
+
max_tokens: Optional[int]
|
|
44
|
+
stop: Optional[Union[str, List[str]]]
|
|
45
|
+
stream_options: Dict[str, Any]
|
|
46
|
+
temperature: Optional[float]
|
|
47
|
+
top_p: Optional[float]
|
|
48
|
+
|
|
49
|
+
def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config: Unpack[WriterConfig]):
|
|
50
|
+
"""Initialize provider instance.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
client_args: Arguments for the Writer client (e.g., api_key, base_url, timeout, etc.).
|
|
54
|
+
**model_config: Configuration options for the Writer model.
|
|
55
|
+
"""
|
|
56
|
+
self.config = WriterModel.WriterConfig(**model_config)
|
|
57
|
+
|
|
58
|
+
logger.debug("config=<%s> | initializing", self.config)
|
|
59
|
+
|
|
60
|
+
client_args = client_args or {}
|
|
61
|
+
self.client = writerai.AsyncClient(**client_args)
|
|
62
|
+
|
|
63
|
+
@override
|
|
64
|
+
def update_config(self, **model_config: Unpack[WriterConfig]) -> None: # type: ignore[override]
|
|
65
|
+
"""Update the Writer Model configuration with the provided arguments.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
**model_config: Configuration overrides.
|
|
69
|
+
"""
|
|
70
|
+
self.config.update(model_config)
|
|
71
|
+
|
|
72
|
+
@override
|
|
73
|
+
def get_config(self) -> WriterConfig:
|
|
74
|
+
"""Get the Writer model configuration.
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
The Writer model configuration.
|
|
78
|
+
"""
|
|
79
|
+
return self.config
|
|
80
|
+
|
|
81
|
+
def _format_request_message_contents_vision(self, contents: list[ContentBlock]) -> list[dict[str, Any]]:
|
|
82
|
+
def _format_content_vision(content: ContentBlock) -> dict[str, Any]:
|
|
83
|
+
"""Format a Writer content block for Palmyra V5 request.
|
|
84
|
+
|
|
85
|
+
- NOTE: "reasoningContent", "document" and "video" are not supported currently.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
content: Message content.
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
Writer formatted content block for models, which support vision content format.
|
|
92
|
+
|
|
93
|
+
Raises:
|
|
94
|
+
TypeError: If the content block type cannot be converted to a Writer-compatible format.
|
|
95
|
+
"""
|
|
96
|
+
if "text" in content:
|
|
97
|
+
return {"text": content["text"], "type": "text"}
|
|
98
|
+
|
|
99
|
+
if "image" in content:
|
|
100
|
+
mime_type = mimetypes.types_map.get(f".{content['image']['format']}", "application/octet-stream")
|
|
101
|
+
image_data = base64.b64encode(content["image"]["source"]["bytes"]).decode("utf-8")
|
|
102
|
+
|
|
103
|
+
return {
|
|
104
|
+
"image_url": {
|
|
105
|
+
"url": f"data:{mime_type};base64,{image_data}",
|
|
106
|
+
},
|
|
107
|
+
"type": "image_url",
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type")
|
|
111
|
+
|
|
112
|
+
return [
|
|
113
|
+
_format_content_vision(content)
|
|
114
|
+
for content in contents
|
|
115
|
+
if not any(block_type in content for block_type in ["toolResult", "toolUse"])
|
|
116
|
+
]
|
|
117
|
+
|
|
118
|
+
def _format_request_message_contents(self, contents: list[ContentBlock]) -> str:
|
|
119
|
+
def _format_content(content: ContentBlock) -> str:
|
|
120
|
+
"""Format a Writer content block for Palmyra models (except V5) request.
|
|
121
|
+
|
|
122
|
+
- NOTE: "reasoningContent", "document", "video" and "image" are not supported currently.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
content: Message content.
|
|
126
|
+
|
|
127
|
+
Returns:
|
|
128
|
+
Writer formatted content block.
|
|
129
|
+
|
|
130
|
+
Raises:
|
|
131
|
+
TypeError: If the content block type cannot be converted to a Writer-compatible format.
|
|
132
|
+
"""
|
|
133
|
+
if "text" in content:
|
|
134
|
+
return content["text"]
|
|
135
|
+
|
|
136
|
+
raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type")
|
|
137
|
+
|
|
138
|
+
content_blocks = list(
|
|
139
|
+
filter(
|
|
140
|
+
lambda content: content.get("text")
|
|
141
|
+
and not any(block_type in content for block_type in ["toolResult", "toolUse"]),
|
|
142
|
+
contents,
|
|
143
|
+
)
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
if len(content_blocks) > 1:
|
|
147
|
+
raise ValueError(
|
|
148
|
+
f"Model with name {self.get_config().get('model_id', 'N/A')} doesn't support multiple contents"
|
|
149
|
+
)
|
|
150
|
+
elif len(content_blocks) == 1:
|
|
151
|
+
return _format_content(content_blocks[0])
|
|
152
|
+
else:
|
|
153
|
+
return ""
|
|
154
|
+
|
|
155
|
+
def _format_request_message_tool_call(self, tool_use: ToolUse) -> dict[str, Any]:
|
|
156
|
+
"""Format a Writer tool call.
|
|
157
|
+
|
|
158
|
+
Args:
|
|
159
|
+
tool_use: Tool use requested by the model.
|
|
160
|
+
|
|
161
|
+
Returns:
|
|
162
|
+
Writer formatted tool call.
|
|
163
|
+
"""
|
|
164
|
+
return {
|
|
165
|
+
"function": {
|
|
166
|
+
"arguments": json.dumps(tool_use["input"]),
|
|
167
|
+
"name": tool_use["name"],
|
|
168
|
+
},
|
|
169
|
+
"id": tool_use["toolUseId"],
|
|
170
|
+
"type": "function",
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
def _format_request_tool_message(self, tool_result: ToolResult) -> dict[str, Any]:
|
|
174
|
+
"""Format a Writer tool message.
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
tool_result: Tool result collected from a tool execution.
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
Writer formatted tool message.
|
|
181
|
+
"""
|
|
182
|
+
contents = cast(
|
|
183
|
+
list[ContentBlock],
|
|
184
|
+
[
|
|
185
|
+
{"text": json.dumps(content["json"])} if "json" in content else content
|
|
186
|
+
for content in tool_result["content"]
|
|
187
|
+
],
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
if self.get_config().get("model_id", "") == "palmyra-x5":
|
|
191
|
+
formatted_contents = self._format_request_message_contents_vision(contents)
|
|
192
|
+
else:
|
|
193
|
+
formatted_contents = self._format_request_message_contents(contents) # type: ignore [assignment]
|
|
194
|
+
|
|
195
|
+
return {
|
|
196
|
+
"role": "tool",
|
|
197
|
+
"tool_call_id": tool_result["toolUseId"],
|
|
198
|
+
"content": formatted_contents,
|
|
199
|
+
}
|
|
200
|
+
|
|
201
|
+
def _format_request_messages(self, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]:
|
|
202
|
+
"""Format a Writer compatible messages array.
|
|
203
|
+
|
|
204
|
+
Args:
|
|
205
|
+
messages: List of message objects to be processed by the model.
|
|
206
|
+
system_prompt: System prompt to provide context to the model.
|
|
207
|
+
|
|
208
|
+
Returns:
|
|
209
|
+
Writer compatible messages array.
|
|
210
|
+
"""
|
|
211
|
+
formatted_messages: list[dict[str, Any]]
|
|
212
|
+
formatted_messages = [{"role": "system", "content": system_prompt}] if system_prompt else []
|
|
213
|
+
|
|
214
|
+
for message in messages:
|
|
215
|
+
contents = message["content"]
|
|
216
|
+
|
|
217
|
+
# Only palmyra V5 support multiple content. Other models support only '{"content": "text_content"}'
|
|
218
|
+
if self.get_config().get("model_id", "") == "palmyra-x5":
|
|
219
|
+
formatted_contents: str | list[dict[str, Any]] = self._format_request_message_contents_vision(contents)
|
|
220
|
+
else:
|
|
221
|
+
formatted_contents = self._format_request_message_contents(contents)
|
|
222
|
+
|
|
223
|
+
formatted_tool_calls = [
|
|
224
|
+
self._format_request_message_tool_call(content["toolUse"])
|
|
225
|
+
for content in contents
|
|
226
|
+
if "toolUse" in content
|
|
227
|
+
]
|
|
228
|
+
formatted_tool_messages = [
|
|
229
|
+
self._format_request_tool_message(content["toolResult"])
|
|
230
|
+
for content in contents
|
|
231
|
+
if "toolResult" in content
|
|
232
|
+
]
|
|
233
|
+
|
|
234
|
+
formatted_message = {
|
|
235
|
+
"role": message["role"],
|
|
236
|
+
"content": formatted_contents if len(formatted_contents) > 0 else "",
|
|
237
|
+
**({"tool_calls": formatted_tool_calls} if formatted_tool_calls else {}),
|
|
238
|
+
}
|
|
239
|
+
formatted_messages.append(formatted_message)
|
|
240
|
+
formatted_messages.extend(formatted_tool_messages)
|
|
241
|
+
|
|
242
|
+
return [message for message in formatted_messages if message["content"] or "tool_calls" in message]
|
|
243
|
+
|
|
244
|
+
def format_request(
|
|
245
|
+
self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None
|
|
246
|
+
) -> Any:
|
|
247
|
+
"""Format a streaming request to the underlying model.
|
|
248
|
+
|
|
249
|
+
Args:
|
|
250
|
+
messages: List of message objects to be processed by the model.
|
|
251
|
+
tool_specs: List of tool specifications to make available to the model.
|
|
252
|
+
system_prompt: System prompt to provide context to the model.
|
|
253
|
+
|
|
254
|
+
Returns:
|
|
255
|
+
The formatted request.
|
|
256
|
+
"""
|
|
257
|
+
request = {
|
|
258
|
+
**{k: v for k, v in self.config.items()},
|
|
259
|
+
"messages": self._format_request_messages(messages, system_prompt),
|
|
260
|
+
"stream": True,
|
|
261
|
+
}
|
|
262
|
+
try:
|
|
263
|
+
request["model"] = request.pop(
|
|
264
|
+
"model_id"
|
|
265
|
+
) # To be consisted with other models WriterConfig use 'model_id' arg, but Writer API wait for 'model' arg
|
|
266
|
+
except KeyError as e:
|
|
267
|
+
raise KeyError("Please specify a model ID. Use 'model_id' keyword argument.") from e
|
|
268
|
+
|
|
269
|
+
# Writer don't support empty tools attribute
|
|
270
|
+
if tool_specs:
|
|
271
|
+
request["tools"] = [
|
|
272
|
+
{
|
|
273
|
+
"type": "function",
|
|
274
|
+
"function": {
|
|
275
|
+
"name": tool_spec["name"],
|
|
276
|
+
"description": tool_spec["description"],
|
|
277
|
+
"parameters": tool_spec["inputSchema"]["json"],
|
|
278
|
+
},
|
|
279
|
+
}
|
|
280
|
+
for tool_spec in tool_specs
|
|
281
|
+
]
|
|
282
|
+
|
|
283
|
+
return request
|
|
284
|
+
|
|
285
|
+
def format_chunk(self, event: Any) -> StreamEvent:
|
|
286
|
+
"""Format the model response events into standardized message chunks.
|
|
287
|
+
|
|
288
|
+
Args:
|
|
289
|
+
event: A response event from the model.
|
|
290
|
+
|
|
291
|
+
Returns:
|
|
292
|
+
The formatted chunk.
|
|
293
|
+
"""
|
|
294
|
+
match event.get("chunk_type", ""):
|
|
295
|
+
case "message_start":
|
|
296
|
+
return {"messageStart": {"role": "assistant"}}
|
|
297
|
+
|
|
298
|
+
case "content_block_start":
|
|
299
|
+
if event["data_type"] == "text":
|
|
300
|
+
return {"contentBlockStart": {"start": {}}}
|
|
301
|
+
|
|
302
|
+
return {
|
|
303
|
+
"contentBlockStart": {
|
|
304
|
+
"start": {
|
|
305
|
+
"toolUse": {
|
|
306
|
+
"name": event["data"].function.name,
|
|
307
|
+
"toolUseId": event["data"].id,
|
|
308
|
+
}
|
|
309
|
+
}
|
|
310
|
+
}
|
|
311
|
+
}
|
|
312
|
+
|
|
313
|
+
case "content_block_delta":
|
|
314
|
+
if event["data_type"] == "text":
|
|
315
|
+
return {"contentBlockDelta": {"delta": {"text": event["data"]}}}
|
|
316
|
+
|
|
317
|
+
return {"contentBlockDelta": {"delta": {"toolUse": {"input": event["data"].function.arguments}}}}
|
|
318
|
+
|
|
319
|
+
case "content_block_stop":
|
|
320
|
+
return {"contentBlockStop": {}}
|
|
321
|
+
|
|
322
|
+
case "message_stop":
|
|
323
|
+
match event["data"]:
|
|
324
|
+
case "tool_calls":
|
|
325
|
+
return {"messageStop": {"stopReason": "tool_use"}}
|
|
326
|
+
case "length":
|
|
327
|
+
return {"messageStop": {"stopReason": "max_tokens"}}
|
|
328
|
+
case _:
|
|
329
|
+
return {"messageStop": {"stopReason": "end_turn"}}
|
|
330
|
+
|
|
331
|
+
case "metadata":
|
|
332
|
+
return {
|
|
333
|
+
"metadata": {
|
|
334
|
+
"usage": {
|
|
335
|
+
"inputTokens": event["data"].prompt_tokens if event["data"] else 0,
|
|
336
|
+
"outputTokens": event["data"].completion_tokens if event["data"] else 0,
|
|
337
|
+
"totalTokens": event["data"].total_tokens if event["data"] else 0,
|
|
338
|
+
}, # If 'stream_options' param is unset, empty metadata will be provided.
|
|
339
|
+
# To avoid errors replacing expected fields with default zero value
|
|
340
|
+
"metrics": {
|
|
341
|
+
"latencyMs": 0, # All palmyra models don't provide 'latency' metadata
|
|
342
|
+
},
|
|
343
|
+
},
|
|
344
|
+
}
|
|
345
|
+
|
|
346
|
+
case _:
|
|
347
|
+
raise RuntimeError(f"chunk_type=<{event['chunk_type']} | unknown type")
|
|
348
|
+
|
|
349
|
+
@override
|
|
350
|
+
async def stream(
|
|
351
|
+
self,
|
|
352
|
+
messages: Messages,
|
|
353
|
+
tool_specs: Optional[list[ToolSpec]] = None,
|
|
354
|
+
system_prompt: Optional[str] = None,
|
|
355
|
+
**kwargs: Any,
|
|
356
|
+
) -> AsyncGenerator[StreamEvent, None]:
|
|
357
|
+
"""Stream conversation with the Writer model.
|
|
358
|
+
|
|
359
|
+
Args:
|
|
360
|
+
messages: List of message objects to be processed by the model.
|
|
361
|
+
tool_specs: List of tool specifications to make available to the model.
|
|
362
|
+
system_prompt: System prompt to provide context to the model.
|
|
363
|
+
**kwargs: Additional keyword arguments for future extensibility.
|
|
364
|
+
|
|
365
|
+
Yields:
|
|
366
|
+
Formatted message chunks from the model.
|
|
367
|
+
|
|
368
|
+
Raises:
|
|
369
|
+
ModelThrottledException: When the model service is throttling requests from the client.
|
|
370
|
+
"""
|
|
371
|
+
logger.debug("formatting request")
|
|
372
|
+
request = self.format_request(messages, tool_specs, system_prompt)
|
|
373
|
+
logger.debug("request=<%s>", request)
|
|
374
|
+
|
|
375
|
+
logger.debug("invoking model")
|
|
376
|
+
try:
|
|
377
|
+
response = await self.client.chat.chat(**request)
|
|
378
|
+
except writerai.RateLimitError as e:
|
|
379
|
+
raise ModelThrottledException(str(e)) from e
|
|
380
|
+
|
|
381
|
+
yield self.format_chunk({"chunk_type": "message_start"})
|
|
382
|
+
yield self.format_chunk({"chunk_type": "content_block_start", "data_type": "text"})
|
|
383
|
+
|
|
384
|
+
tool_calls: dict[int, list[Any]] = {}
|
|
385
|
+
|
|
386
|
+
async for chunk in response:
|
|
387
|
+
if not getattr(chunk, "choices", None):
|
|
388
|
+
continue
|
|
389
|
+
choice = chunk.choices[0]
|
|
390
|
+
|
|
391
|
+
if choice.delta.content:
|
|
392
|
+
yield self.format_chunk(
|
|
393
|
+
{"chunk_type": "content_block_delta", "data_type": "text", "data": choice.delta.content}
|
|
394
|
+
)
|
|
395
|
+
|
|
396
|
+
for tool_call in choice.delta.tool_calls or []:
|
|
397
|
+
tool_calls.setdefault(tool_call.index, []).append(tool_call)
|
|
398
|
+
|
|
399
|
+
if choice.finish_reason:
|
|
400
|
+
break
|
|
401
|
+
|
|
402
|
+
yield self.format_chunk({"chunk_type": "content_block_stop", "data_type": "text"})
|
|
403
|
+
|
|
404
|
+
for tool_deltas in tool_calls.values():
|
|
405
|
+
tool_start, tool_deltas = tool_deltas[0], tool_deltas[1:]
|
|
406
|
+
yield self.format_chunk({"chunk_type": "content_block_start", "data_type": "tool", "data": tool_start})
|
|
407
|
+
|
|
408
|
+
for tool_delta in tool_deltas:
|
|
409
|
+
yield self.format_chunk({"chunk_type": "content_block_delta", "data_type": "tool", "data": tool_delta})
|
|
410
|
+
|
|
411
|
+
yield self.format_chunk({"chunk_type": "content_block_stop", "data_type": "tool"})
|
|
412
|
+
|
|
413
|
+
yield self.format_chunk({"chunk_type": "message_stop", "data": choice.finish_reason})
|
|
414
|
+
|
|
415
|
+
# Iterating until the end to fetch metadata chunk
|
|
416
|
+
async for chunk in response:
|
|
417
|
+
_ = chunk
|
|
418
|
+
|
|
419
|
+
yield self.format_chunk({"chunk_type": "metadata", "data": chunk.usage})
|
|
420
|
+
|
|
421
|
+
logger.debug("finished streaming response from model")
|
|
422
|
+
|
|
423
|
+
@override
|
|
424
|
+
async def structured_output(
|
|
425
|
+
self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any
|
|
426
|
+
) -> AsyncGenerator[dict[str, Union[T, Any]], None]:
|
|
427
|
+
"""Get structured output from the model.
|
|
428
|
+
|
|
429
|
+
Args:
|
|
430
|
+
output_model: The output model to use for the agent.
|
|
431
|
+
prompt: The prompt messages to use for the agent.
|
|
432
|
+
system_prompt: System prompt to provide context to the model.
|
|
433
|
+
**kwargs: Additional keyword arguments for future extensibility.
|
|
434
|
+
"""
|
|
435
|
+
formatted_request = self.format_request(messages=prompt, tool_specs=None, system_prompt=system_prompt)
|
|
436
|
+
formatted_request["response_format"] = {
|
|
437
|
+
"type": "json_schema",
|
|
438
|
+
"json_schema": {"schema": output_model.model_json_schema()},
|
|
439
|
+
}
|
|
440
|
+
formatted_request["stream"] = False
|
|
441
|
+
formatted_request.pop("stream_options", None)
|
|
442
|
+
|
|
443
|
+
response = await self.client.chat.chat(**formatted_request)
|
|
444
|
+
|
|
445
|
+
try:
|
|
446
|
+
content = response.choices[0].message.content.strip()
|
|
447
|
+
yield {"output": output_model.model_validate_json(content)}
|
|
448
|
+
except Exception as e:
|
|
449
|
+
raise ValueError(f"Failed to parse or load content into model: {e}") from e
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
"""Multiagent capabilities for Strands Agents.
|
|
2
|
+
|
|
3
|
+
This module provides support for multiagent systems, including agent-to-agent (A2A)
|
|
4
|
+
communication protocols and coordination mechanisms.
|
|
5
|
+
|
|
6
|
+
Submodules:
|
|
7
|
+
a2a: Implementation of the Agent-to-Agent (A2A) protocol, which enables
|
|
8
|
+
standardized communication between agents.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from .base import MultiAgentBase, MultiAgentResult
|
|
12
|
+
from .graph import GraphBuilder, GraphResult
|
|
13
|
+
from .swarm import Swarm, SwarmResult
|
|
14
|
+
|
|
15
|
+
__all__ = [
|
|
16
|
+
"GraphBuilder",
|
|
17
|
+
"GraphResult",
|
|
18
|
+
"MultiAgentBase",
|
|
19
|
+
"MultiAgentResult",
|
|
20
|
+
"Swarm",
|
|
21
|
+
"SwarmResult",
|
|
22
|
+
]
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""Agent-to-Agent (A2A) communication protocol implementation for Strands Agents.
|
|
2
|
+
|
|
3
|
+
This module provides classes and utilities for enabling Strands Agents to communicate
|
|
4
|
+
with other agents using the Agent-to-Agent (A2A) protocol.
|
|
5
|
+
|
|
6
|
+
Docs: https://google-a2a.github.io/A2A/latest/
|
|
7
|
+
|
|
8
|
+
Classes:
|
|
9
|
+
A2AAgent: A wrapper that adapts a Strands Agent to be A2A-compatible.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from .executor import StrandsA2AExecutor
|
|
13
|
+
from .server import A2AServer
|
|
14
|
+
|
|
15
|
+
__all__ = ["A2AServer", "StrandsA2AExecutor"]
|
|
@@ -0,0 +1,148 @@
|
|
|
1
|
+
"""Strands Agent executor for the A2A protocol.
|
|
2
|
+
|
|
3
|
+
This module provides the StrandsA2AExecutor class, which adapts a Strands Agent
|
|
4
|
+
to be used as an executor in the A2A protocol. It handles the execution of agent
|
|
5
|
+
requests and the conversion of Strands Agent streamed responses to A2A events.
|
|
6
|
+
|
|
7
|
+
The A2A AgentExecutor ensures clients receive responses for synchronous and
|
|
8
|
+
streamed requests to the A2AServer.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import logging
|
|
12
|
+
from typing import Any
|
|
13
|
+
|
|
14
|
+
from a2a.server.agent_execution import AgentExecutor, RequestContext
|
|
15
|
+
from a2a.server.events import EventQueue
|
|
16
|
+
from a2a.server.tasks import TaskUpdater
|
|
17
|
+
from a2a.types import InternalError, Part, TaskState, TextPart, UnsupportedOperationError
|
|
18
|
+
from a2a.utils import new_agent_text_message, new_task
|
|
19
|
+
from a2a.utils.errors import ServerError
|
|
20
|
+
|
|
21
|
+
from ...agent.agent import Agent as SAAgent
|
|
22
|
+
from ...agent.agent import AgentResult as SAAgentResult
|
|
23
|
+
|
|
24
|
+
logger = logging.getLogger(__name__)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class StrandsA2AExecutor(AgentExecutor):
|
|
28
|
+
"""Executor that adapts a Strands Agent to the A2A protocol.
|
|
29
|
+
|
|
30
|
+
This executor uses streaming mode to handle the execution of agent requests
|
|
31
|
+
and converts Strands Agent responses to A2A protocol events.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(self, agent: SAAgent):
|
|
35
|
+
"""Initialize a StrandsA2AExecutor.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
agent: The Strands Agent instance to adapt to the A2A protocol.
|
|
39
|
+
"""
|
|
40
|
+
self.agent = agent
|
|
41
|
+
|
|
42
|
+
async def execute(
|
|
43
|
+
self,
|
|
44
|
+
context: RequestContext,
|
|
45
|
+
event_queue: EventQueue,
|
|
46
|
+
) -> None:
|
|
47
|
+
"""Execute a request using the Strands Agent and send the response as A2A events.
|
|
48
|
+
|
|
49
|
+
This method executes the user's input using the Strands Agent in streaming mode
|
|
50
|
+
and converts the agent's response to A2A events.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
context: The A2A request context, containing the user's input and task metadata.
|
|
54
|
+
event_queue: The A2A event queue used to send response events back to the client.
|
|
55
|
+
|
|
56
|
+
Raises:
|
|
57
|
+
ServerError: If an error occurs during agent execution
|
|
58
|
+
"""
|
|
59
|
+
task = context.current_task
|
|
60
|
+
if not task:
|
|
61
|
+
task = new_task(context.message) # type: ignore
|
|
62
|
+
await event_queue.enqueue_event(task)
|
|
63
|
+
|
|
64
|
+
updater = TaskUpdater(event_queue, task.id, task.context_id)
|
|
65
|
+
|
|
66
|
+
try:
|
|
67
|
+
await self._execute_streaming(context, updater)
|
|
68
|
+
except Exception as e:
|
|
69
|
+
raise ServerError(error=InternalError()) from e
|
|
70
|
+
|
|
71
|
+
async def _execute_streaming(self, context: RequestContext, updater: TaskUpdater) -> None:
|
|
72
|
+
"""Execute request in streaming mode.
|
|
73
|
+
|
|
74
|
+
Streams the agent's response in real-time, sending incremental updates
|
|
75
|
+
as they become available from the agent.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
context: The A2A request context, containing the user's input and other metadata.
|
|
79
|
+
updater: The task updater for managing task state and sending updates.
|
|
80
|
+
"""
|
|
81
|
+
logger.info("Executing request in streaming mode")
|
|
82
|
+
user_input = context.get_user_input()
|
|
83
|
+
try:
|
|
84
|
+
async for event in self.agent.stream_async(user_input):
|
|
85
|
+
await self._handle_streaming_event(event, updater)
|
|
86
|
+
except Exception:
|
|
87
|
+
logger.exception("Error in streaming execution")
|
|
88
|
+
raise
|
|
89
|
+
|
|
90
|
+
async def _handle_streaming_event(self, event: dict[str, Any], updater: TaskUpdater) -> None:
|
|
91
|
+
"""Handle a single streaming event from the Strands Agent.
|
|
92
|
+
|
|
93
|
+
Processes streaming events from the agent, converting data chunks to A2A
|
|
94
|
+
task updates and handling the final result when streaming is complete.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
event: The streaming event from the agent, containing either 'data' for
|
|
98
|
+
incremental content or 'result' for the final response.
|
|
99
|
+
updater: The task updater for managing task state and sending updates.
|
|
100
|
+
"""
|
|
101
|
+
logger.debug("Streaming event: %s", event)
|
|
102
|
+
if "data" in event:
|
|
103
|
+
if text_content := event["data"]:
|
|
104
|
+
await updater.update_status(
|
|
105
|
+
TaskState.working,
|
|
106
|
+
new_agent_text_message(
|
|
107
|
+
text_content,
|
|
108
|
+
updater.context_id,
|
|
109
|
+
updater.task_id,
|
|
110
|
+
),
|
|
111
|
+
)
|
|
112
|
+
elif "result" in event:
|
|
113
|
+
await self._handle_agent_result(event["result"], updater)
|
|
114
|
+
|
|
115
|
+
async def _handle_agent_result(self, result: SAAgentResult | None, updater: TaskUpdater) -> None:
|
|
116
|
+
"""Handle the final result from the Strands Agent.
|
|
117
|
+
|
|
118
|
+
Processes the agent's final result, extracts text content from the response,
|
|
119
|
+
and adds it as an artifact to the task before marking the task as complete.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
result: The agent result object containing the final response, or None if no result.
|
|
123
|
+
updater: The task updater for managing task state and adding the final artifact.
|
|
124
|
+
"""
|
|
125
|
+
if final_content := str(result):
|
|
126
|
+
await updater.add_artifact(
|
|
127
|
+
[Part(root=TextPart(text=final_content))],
|
|
128
|
+
name="agent_response",
|
|
129
|
+
)
|
|
130
|
+
await updater.complete()
|
|
131
|
+
|
|
132
|
+
async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None:
|
|
133
|
+
"""Cancel an ongoing execution.
|
|
134
|
+
|
|
135
|
+
This method is called when a request cancellation is requested. Currently,
|
|
136
|
+
cancellation is not supported by the Strands Agent executor, so this method
|
|
137
|
+
always raises an UnsupportedOperationError.
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
context: The A2A request context.
|
|
141
|
+
event_queue: The A2A event queue.
|
|
142
|
+
|
|
143
|
+
Raises:
|
|
144
|
+
ServerError: Always raised with an UnsupportedOperationError, as cancellation
|
|
145
|
+
is not currently supported.
|
|
146
|
+
"""
|
|
147
|
+
logger.warning("Cancellation requested but not supported")
|
|
148
|
+
raise ServerError(error=UnsupportedOperationError())
|