agentrun-sdk 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of agentrun-sdk might be problematic. Click here for more details.

Files changed (115) hide show
  1. agentrun_operation_sdk/cli/__init__.py +1 -0
  2. agentrun_operation_sdk/cli/cli.py +19 -0
  3. agentrun_operation_sdk/cli/common.py +21 -0
  4. agentrun_operation_sdk/cli/runtime/__init__.py +1 -0
  5. agentrun_operation_sdk/cli/runtime/commands.py +203 -0
  6. agentrun_operation_sdk/client/client.py +75 -0
  7. agentrun_operation_sdk/operations/runtime/__init__.py +8 -0
  8. agentrun_operation_sdk/operations/runtime/configure.py +101 -0
  9. agentrun_operation_sdk/operations/runtime/launch.py +82 -0
  10. agentrun_operation_sdk/operations/runtime/models.py +31 -0
  11. agentrun_operation_sdk/services/runtime.py +152 -0
  12. agentrun_operation_sdk/utils/logging_config.py +72 -0
  13. agentrun_operation_sdk/utils/runtime/config.py +94 -0
  14. agentrun_operation_sdk/utils/runtime/container.py +280 -0
  15. agentrun_operation_sdk/utils/runtime/entrypoint.py +203 -0
  16. agentrun_operation_sdk/utils/runtime/schema.py +56 -0
  17. agentrun_sdk/__init__.py +7 -0
  18. agentrun_sdk/agent/__init__.py +25 -0
  19. agentrun_sdk/agent/agent.py +696 -0
  20. agentrun_sdk/agent/agent_result.py +46 -0
  21. agentrun_sdk/agent/conversation_manager/__init__.py +26 -0
  22. agentrun_sdk/agent/conversation_manager/conversation_manager.py +88 -0
  23. agentrun_sdk/agent/conversation_manager/null_conversation_manager.py +46 -0
  24. agentrun_sdk/agent/conversation_manager/sliding_window_conversation_manager.py +179 -0
  25. agentrun_sdk/agent/conversation_manager/summarizing_conversation_manager.py +252 -0
  26. agentrun_sdk/agent/state.py +97 -0
  27. agentrun_sdk/event_loop/__init__.py +9 -0
  28. agentrun_sdk/event_loop/event_loop.py +499 -0
  29. agentrun_sdk/event_loop/streaming.py +319 -0
  30. agentrun_sdk/experimental/__init__.py +4 -0
  31. agentrun_sdk/experimental/hooks/__init__.py +15 -0
  32. agentrun_sdk/experimental/hooks/events.py +123 -0
  33. agentrun_sdk/handlers/__init__.py +10 -0
  34. agentrun_sdk/handlers/callback_handler.py +70 -0
  35. agentrun_sdk/hooks/__init__.py +49 -0
  36. agentrun_sdk/hooks/events.py +80 -0
  37. agentrun_sdk/hooks/registry.py +247 -0
  38. agentrun_sdk/models/__init__.py +10 -0
  39. agentrun_sdk/models/anthropic.py +432 -0
  40. agentrun_sdk/models/bedrock.py +649 -0
  41. agentrun_sdk/models/litellm.py +225 -0
  42. agentrun_sdk/models/llamaapi.py +438 -0
  43. agentrun_sdk/models/mistral.py +539 -0
  44. agentrun_sdk/models/model.py +95 -0
  45. agentrun_sdk/models/ollama.py +357 -0
  46. agentrun_sdk/models/openai.py +436 -0
  47. agentrun_sdk/models/sagemaker.py +598 -0
  48. agentrun_sdk/models/writer.py +449 -0
  49. agentrun_sdk/multiagent/__init__.py +22 -0
  50. agentrun_sdk/multiagent/a2a/__init__.py +15 -0
  51. agentrun_sdk/multiagent/a2a/executor.py +148 -0
  52. agentrun_sdk/multiagent/a2a/server.py +252 -0
  53. agentrun_sdk/multiagent/base.py +92 -0
  54. agentrun_sdk/multiagent/graph.py +555 -0
  55. agentrun_sdk/multiagent/swarm.py +656 -0
  56. agentrun_sdk/py.typed +1 -0
  57. agentrun_sdk/session/__init__.py +18 -0
  58. agentrun_sdk/session/file_session_manager.py +216 -0
  59. agentrun_sdk/session/repository_session_manager.py +152 -0
  60. agentrun_sdk/session/s3_session_manager.py +272 -0
  61. agentrun_sdk/session/session_manager.py +73 -0
  62. agentrun_sdk/session/session_repository.py +51 -0
  63. agentrun_sdk/telemetry/__init__.py +21 -0
  64. agentrun_sdk/telemetry/config.py +194 -0
  65. agentrun_sdk/telemetry/metrics.py +476 -0
  66. agentrun_sdk/telemetry/metrics_constants.py +15 -0
  67. agentrun_sdk/telemetry/tracer.py +563 -0
  68. agentrun_sdk/tools/__init__.py +17 -0
  69. agentrun_sdk/tools/decorator.py +569 -0
  70. agentrun_sdk/tools/executor.py +137 -0
  71. agentrun_sdk/tools/loader.py +152 -0
  72. agentrun_sdk/tools/mcp/__init__.py +13 -0
  73. agentrun_sdk/tools/mcp/mcp_agent_tool.py +99 -0
  74. agentrun_sdk/tools/mcp/mcp_client.py +423 -0
  75. agentrun_sdk/tools/mcp/mcp_instrumentation.py +322 -0
  76. agentrun_sdk/tools/mcp/mcp_types.py +63 -0
  77. agentrun_sdk/tools/registry.py +607 -0
  78. agentrun_sdk/tools/structured_output.py +421 -0
  79. agentrun_sdk/tools/tools.py +217 -0
  80. agentrun_sdk/tools/watcher.py +136 -0
  81. agentrun_sdk/types/__init__.py +5 -0
  82. agentrun_sdk/types/collections.py +23 -0
  83. agentrun_sdk/types/content.py +188 -0
  84. agentrun_sdk/types/event_loop.py +48 -0
  85. agentrun_sdk/types/exceptions.py +81 -0
  86. agentrun_sdk/types/guardrails.py +254 -0
  87. agentrun_sdk/types/media.py +89 -0
  88. agentrun_sdk/types/session.py +152 -0
  89. agentrun_sdk/types/streaming.py +201 -0
  90. agentrun_sdk/types/tools.py +258 -0
  91. agentrun_sdk/types/traces.py +5 -0
  92. agentrun_sdk-0.1.2.dist-info/METADATA +51 -0
  93. agentrun_sdk-0.1.2.dist-info/RECORD +115 -0
  94. agentrun_sdk-0.1.2.dist-info/WHEEL +5 -0
  95. agentrun_sdk-0.1.2.dist-info/entry_points.txt +2 -0
  96. agentrun_sdk-0.1.2.dist-info/top_level.txt +3 -0
  97. agentrun_wrapper/__init__.py +11 -0
  98. agentrun_wrapper/_utils/__init__.py +6 -0
  99. agentrun_wrapper/_utils/endpoints.py +16 -0
  100. agentrun_wrapper/identity/__init__.py +5 -0
  101. agentrun_wrapper/identity/auth.py +211 -0
  102. agentrun_wrapper/memory/__init__.py +6 -0
  103. agentrun_wrapper/memory/client.py +1697 -0
  104. agentrun_wrapper/memory/constants.py +103 -0
  105. agentrun_wrapper/memory/controlplane.py +626 -0
  106. agentrun_wrapper/py.typed +1 -0
  107. agentrun_wrapper/runtime/__init__.py +13 -0
  108. agentrun_wrapper/runtime/app.py +473 -0
  109. agentrun_wrapper/runtime/context.py +34 -0
  110. agentrun_wrapper/runtime/models.py +25 -0
  111. agentrun_wrapper/services/__init__.py +1 -0
  112. agentrun_wrapper/services/identity.py +192 -0
  113. agentrun_wrapper/tools/__init__.py +6 -0
  114. agentrun_wrapper/tools/browser_client.py +325 -0
  115. agentrun_wrapper/tools/code_interpreter_client.py +186 -0
@@ -0,0 +1,449 @@
1
+ """Writer model provider.
2
+
3
+ - Docs: https://dev.writer.com/home/introduction
4
+ """
5
+
6
+ import base64
7
+ import json
8
+ import logging
9
+ import mimetypes
10
+ from typing import Any, AsyncGenerator, Dict, List, Optional, Type, TypedDict, TypeVar, Union, cast
11
+
12
+ import writerai
13
+ from pydantic import BaseModel
14
+ from typing_extensions import Unpack, override
15
+
16
+ from ..types.content import ContentBlock, Messages
17
+ from ..types.exceptions import ModelThrottledException
18
+ from ..types.streaming import StreamEvent
19
+ from ..types.tools import ToolResult, ToolSpec, ToolUse
20
+ from .model import Model
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+ T = TypeVar("T", bound=BaseModel)
25
+
26
+
27
+ class WriterModel(Model):
28
+ """Writer API model provider implementation."""
29
+
30
+ class WriterConfig(TypedDict, total=False):
31
+ """Configuration options for Writer API.
32
+
33
+ Attributes:
34
+ model_id: Model name to use (e.g. palmyra-x5, palmyra-x4, etc.).
35
+ max_tokens: Maximum number of tokens to generate.
36
+ stop: Default stop sequences.
37
+ stream_options: Additional options for streaming.
38
+ temperature: What sampling temperature to use.
39
+ top_p: Threshold for 'nucleus sampling'
40
+ """
41
+
42
+ model_id: str
43
+ max_tokens: Optional[int]
44
+ stop: Optional[Union[str, List[str]]]
45
+ stream_options: Dict[str, Any]
46
+ temperature: Optional[float]
47
+ top_p: Optional[float]
48
+
49
+ def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config: Unpack[WriterConfig]):
50
+ """Initialize provider instance.
51
+
52
+ Args:
53
+ client_args: Arguments for the Writer client (e.g., api_key, base_url, timeout, etc.).
54
+ **model_config: Configuration options for the Writer model.
55
+ """
56
+ self.config = WriterModel.WriterConfig(**model_config)
57
+
58
+ logger.debug("config=<%s> | initializing", self.config)
59
+
60
+ client_args = client_args or {}
61
+ self.client = writerai.AsyncClient(**client_args)
62
+
63
+ @override
64
+ def update_config(self, **model_config: Unpack[WriterConfig]) -> None: # type: ignore[override]
65
+ """Update the Writer Model configuration with the provided arguments.
66
+
67
+ Args:
68
+ **model_config: Configuration overrides.
69
+ """
70
+ self.config.update(model_config)
71
+
72
+ @override
73
+ def get_config(self) -> WriterConfig:
74
+ """Get the Writer model configuration.
75
+
76
+ Returns:
77
+ The Writer model configuration.
78
+ """
79
+ return self.config
80
+
81
+ def _format_request_message_contents_vision(self, contents: list[ContentBlock]) -> list[dict[str, Any]]:
82
+ def _format_content_vision(content: ContentBlock) -> dict[str, Any]:
83
+ """Format a Writer content block for Palmyra V5 request.
84
+
85
+ - NOTE: "reasoningContent", "document" and "video" are not supported currently.
86
+
87
+ Args:
88
+ content: Message content.
89
+
90
+ Returns:
91
+ Writer formatted content block for models, which support vision content format.
92
+
93
+ Raises:
94
+ TypeError: If the content block type cannot be converted to a Writer-compatible format.
95
+ """
96
+ if "text" in content:
97
+ return {"text": content["text"], "type": "text"}
98
+
99
+ if "image" in content:
100
+ mime_type = mimetypes.types_map.get(f".{content['image']['format']}", "application/octet-stream")
101
+ image_data = base64.b64encode(content["image"]["source"]["bytes"]).decode("utf-8")
102
+
103
+ return {
104
+ "image_url": {
105
+ "url": f"data:{mime_type};base64,{image_data}",
106
+ },
107
+ "type": "image_url",
108
+ }
109
+
110
+ raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type")
111
+
112
+ return [
113
+ _format_content_vision(content)
114
+ for content in contents
115
+ if not any(block_type in content for block_type in ["toolResult", "toolUse"])
116
+ ]
117
+
118
+ def _format_request_message_contents(self, contents: list[ContentBlock]) -> str:
119
+ def _format_content(content: ContentBlock) -> str:
120
+ """Format a Writer content block for Palmyra models (except V5) request.
121
+
122
+ - NOTE: "reasoningContent", "document", "video" and "image" are not supported currently.
123
+
124
+ Args:
125
+ content: Message content.
126
+
127
+ Returns:
128
+ Writer formatted content block.
129
+
130
+ Raises:
131
+ TypeError: If the content block type cannot be converted to a Writer-compatible format.
132
+ """
133
+ if "text" in content:
134
+ return content["text"]
135
+
136
+ raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type")
137
+
138
+ content_blocks = list(
139
+ filter(
140
+ lambda content: content.get("text")
141
+ and not any(block_type in content for block_type in ["toolResult", "toolUse"]),
142
+ contents,
143
+ )
144
+ )
145
+
146
+ if len(content_blocks) > 1:
147
+ raise ValueError(
148
+ f"Model with name {self.get_config().get('model_id', 'N/A')} doesn't support multiple contents"
149
+ )
150
+ elif len(content_blocks) == 1:
151
+ return _format_content(content_blocks[0])
152
+ else:
153
+ return ""
154
+
155
+ def _format_request_message_tool_call(self, tool_use: ToolUse) -> dict[str, Any]:
156
+ """Format a Writer tool call.
157
+
158
+ Args:
159
+ tool_use: Tool use requested by the model.
160
+
161
+ Returns:
162
+ Writer formatted tool call.
163
+ """
164
+ return {
165
+ "function": {
166
+ "arguments": json.dumps(tool_use["input"]),
167
+ "name": tool_use["name"],
168
+ },
169
+ "id": tool_use["toolUseId"],
170
+ "type": "function",
171
+ }
172
+
173
+ def _format_request_tool_message(self, tool_result: ToolResult) -> dict[str, Any]:
174
+ """Format a Writer tool message.
175
+
176
+ Args:
177
+ tool_result: Tool result collected from a tool execution.
178
+
179
+ Returns:
180
+ Writer formatted tool message.
181
+ """
182
+ contents = cast(
183
+ list[ContentBlock],
184
+ [
185
+ {"text": json.dumps(content["json"])} if "json" in content else content
186
+ for content in tool_result["content"]
187
+ ],
188
+ )
189
+
190
+ if self.get_config().get("model_id", "") == "palmyra-x5":
191
+ formatted_contents = self._format_request_message_contents_vision(contents)
192
+ else:
193
+ formatted_contents = self._format_request_message_contents(contents) # type: ignore [assignment]
194
+
195
+ return {
196
+ "role": "tool",
197
+ "tool_call_id": tool_result["toolUseId"],
198
+ "content": formatted_contents,
199
+ }
200
+
201
+ def _format_request_messages(self, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]:
202
+ """Format a Writer compatible messages array.
203
+
204
+ Args:
205
+ messages: List of message objects to be processed by the model.
206
+ system_prompt: System prompt to provide context to the model.
207
+
208
+ Returns:
209
+ Writer compatible messages array.
210
+ """
211
+ formatted_messages: list[dict[str, Any]]
212
+ formatted_messages = [{"role": "system", "content": system_prompt}] if system_prompt else []
213
+
214
+ for message in messages:
215
+ contents = message["content"]
216
+
217
+ # Only palmyra V5 support multiple content. Other models support only '{"content": "text_content"}'
218
+ if self.get_config().get("model_id", "") == "palmyra-x5":
219
+ formatted_contents: str | list[dict[str, Any]] = self._format_request_message_contents_vision(contents)
220
+ else:
221
+ formatted_contents = self._format_request_message_contents(contents)
222
+
223
+ formatted_tool_calls = [
224
+ self._format_request_message_tool_call(content["toolUse"])
225
+ for content in contents
226
+ if "toolUse" in content
227
+ ]
228
+ formatted_tool_messages = [
229
+ self._format_request_tool_message(content["toolResult"])
230
+ for content in contents
231
+ if "toolResult" in content
232
+ ]
233
+
234
+ formatted_message = {
235
+ "role": message["role"],
236
+ "content": formatted_contents if len(formatted_contents) > 0 else "",
237
+ **({"tool_calls": formatted_tool_calls} if formatted_tool_calls else {}),
238
+ }
239
+ formatted_messages.append(formatted_message)
240
+ formatted_messages.extend(formatted_tool_messages)
241
+
242
+ return [message for message in formatted_messages if message["content"] or "tool_calls" in message]
243
+
244
+ def format_request(
245
+ self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None
246
+ ) -> Any:
247
+ """Format a streaming request to the underlying model.
248
+
249
+ Args:
250
+ messages: List of message objects to be processed by the model.
251
+ tool_specs: List of tool specifications to make available to the model.
252
+ system_prompt: System prompt to provide context to the model.
253
+
254
+ Returns:
255
+ The formatted request.
256
+ """
257
+ request = {
258
+ **{k: v for k, v in self.config.items()},
259
+ "messages": self._format_request_messages(messages, system_prompt),
260
+ "stream": True,
261
+ }
262
+ try:
263
+ request["model"] = request.pop(
264
+ "model_id"
265
+ ) # To be consisted with other models WriterConfig use 'model_id' arg, but Writer API wait for 'model' arg
266
+ except KeyError as e:
267
+ raise KeyError("Please specify a model ID. Use 'model_id' keyword argument.") from e
268
+
269
+ # Writer don't support empty tools attribute
270
+ if tool_specs:
271
+ request["tools"] = [
272
+ {
273
+ "type": "function",
274
+ "function": {
275
+ "name": tool_spec["name"],
276
+ "description": tool_spec["description"],
277
+ "parameters": tool_spec["inputSchema"]["json"],
278
+ },
279
+ }
280
+ for tool_spec in tool_specs
281
+ ]
282
+
283
+ return request
284
+
285
+ def format_chunk(self, event: Any) -> StreamEvent:
286
+ """Format the model response events into standardized message chunks.
287
+
288
+ Args:
289
+ event: A response event from the model.
290
+
291
+ Returns:
292
+ The formatted chunk.
293
+ """
294
+ match event.get("chunk_type", ""):
295
+ case "message_start":
296
+ return {"messageStart": {"role": "assistant"}}
297
+
298
+ case "content_block_start":
299
+ if event["data_type"] == "text":
300
+ return {"contentBlockStart": {"start": {}}}
301
+
302
+ return {
303
+ "contentBlockStart": {
304
+ "start": {
305
+ "toolUse": {
306
+ "name": event["data"].function.name,
307
+ "toolUseId": event["data"].id,
308
+ }
309
+ }
310
+ }
311
+ }
312
+
313
+ case "content_block_delta":
314
+ if event["data_type"] == "text":
315
+ return {"contentBlockDelta": {"delta": {"text": event["data"]}}}
316
+
317
+ return {"contentBlockDelta": {"delta": {"toolUse": {"input": event["data"].function.arguments}}}}
318
+
319
+ case "content_block_stop":
320
+ return {"contentBlockStop": {}}
321
+
322
+ case "message_stop":
323
+ match event["data"]:
324
+ case "tool_calls":
325
+ return {"messageStop": {"stopReason": "tool_use"}}
326
+ case "length":
327
+ return {"messageStop": {"stopReason": "max_tokens"}}
328
+ case _:
329
+ return {"messageStop": {"stopReason": "end_turn"}}
330
+
331
+ case "metadata":
332
+ return {
333
+ "metadata": {
334
+ "usage": {
335
+ "inputTokens": event["data"].prompt_tokens if event["data"] else 0,
336
+ "outputTokens": event["data"].completion_tokens if event["data"] else 0,
337
+ "totalTokens": event["data"].total_tokens if event["data"] else 0,
338
+ }, # If 'stream_options' param is unset, empty metadata will be provided.
339
+ # To avoid errors replacing expected fields with default zero value
340
+ "metrics": {
341
+ "latencyMs": 0, # All palmyra models don't provide 'latency' metadata
342
+ },
343
+ },
344
+ }
345
+
346
+ case _:
347
+ raise RuntimeError(f"chunk_type=<{event['chunk_type']} | unknown type")
348
+
349
+ @override
350
+ async def stream(
351
+ self,
352
+ messages: Messages,
353
+ tool_specs: Optional[list[ToolSpec]] = None,
354
+ system_prompt: Optional[str] = None,
355
+ **kwargs: Any,
356
+ ) -> AsyncGenerator[StreamEvent, None]:
357
+ """Stream conversation with the Writer model.
358
+
359
+ Args:
360
+ messages: List of message objects to be processed by the model.
361
+ tool_specs: List of tool specifications to make available to the model.
362
+ system_prompt: System prompt to provide context to the model.
363
+ **kwargs: Additional keyword arguments for future extensibility.
364
+
365
+ Yields:
366
+ Formatted message chunks from the model.
367
+
368
+ Raises:
369
+ ModelThrottledException: When the model service is throttling requests from the client.
370
+ """
371
+ logger.debug("formatting request")
372
+ request = self.format_request(messages, tool_specs, system_prompt)
373
+ logger.debug("request=<%s>", request)
374
+
375
+ logger.debug("invoking model")
376
+ try:
377
+ response = await self.client.chat.chat(**request)
378
+ except writerai.RateLimitError as e:
379
+ raise ModelThrottledException(str(e)) from e
380
+
381
+ yield self.format_chunk({"chunk_type": "message_start"})
382
+ yield self.format_chunk({"chunk_type": "content_block_start", "data_type": "text"})
383
+
384
+ tool_calls: dict[int, list[Any]] = {}
385
+
386
+ async for chunk in response:
387
+ if not getattr(chunk, "choices", None):
388
+ continue
389
+ choice = chunk.choices[0]
390
+
391
+ if choice.delta.content:
392
+ yield self.format_chunk(
393
+ {"chunk_type": "content_block_delta", "data_type": "text", "data": choice.delta.content}
394
+ )
395
+
396
+ for tool_call in choice.delta.tool_calls or []:
397
+ tool_calls.setdefault(tool_call.index, []).append(tool_call)
398
+
399
+ if choice.finish_reason:
400
+ break
401
+
402
+ yield self.format_chunk({"chunk_type": "content_block_stop", "data_type": "text"})
403
+
404
+ for tool_deltas in tool_calls.values():
405
+ tool_start, tool_deltas = tool_deltas[0], tool_deltas[1:]
406
+ yield self.format_chunk({"chunk_type": "content_block_start", "data_type": "tool", "data": tool_start})
407
+
408
+ for tool_delta in tool_deltas:
409
+ yield self.format_chunk({"chunk_type": "content_block_delta", "data_type": "tool", "data": tool_delta})
410
+
411
+ yield self.format_chunk({"chunk_type": "content_block_stop", "data_type": "tool"})
412
+
413
+ yield self.format_chunk({"chunk_type": "message_stop", "data": choice.finish_reason})
414
+
415
+ # Iterating until the end to fetch metadata chunk
416
+ async for chunk in response:
417
+ _ = chunk
418
+
419
+ yield self.format_chunk({"chunk_type": "metadata", "data": chunk.usage})
420
+
421
+ logger.debug("finished streaming response from model")
422
+
423
+ @override
424
+ async def structured_output(
425
+ self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any
426
+ ) -> AsyncGenerator[dict[str, Union[T, Any]], None]:
427
+ """Get structured output from the model.
428
+
429
+ Args:
430
+ output_model: The output model to use for the agent.
431
+ prompt: The prompt messages to use for the agent.
432
+ system_prompt: System prompt to provide context to the model.
433
+ **kwargs: Additional keyword arguments for future extensibility.
434
+ """
435
+ formatted_request = self.format_request(messages=prompt, tool_specs=None, system_prompt=system_prompt)
436
+ formatted_request["response_format"] = {
437
+ "type": "json_schema",
438
+ "json_schema": {"schema": output_model.model_json_schema()},
439
+ }
440
+ formatted_request["stream"] = False
441
+ formatted_request.pop("stream_options", None)
442
+
443
+ response = await self.client.chat.chat(**formatted_request)
444
+
445
+ try:
446
+ content = response.choices[0].message.content.strip()
447
+ yield {"output": output_model.model_validate_json(content)}
448
+ except Exception as e:
449
+ raise ValueError(f"Failed to parse or load content into model: {e}") from e
@@ -0,0 +1,22 @@
1
+ """Multiagent capabilities for Strands Agents.
2
+
3
+ This module provides support for multiagent systems, including agent-to-agent (A2A)
4
+ communication protocols and coordination mechanisms.
5
+
6
+ Submodules:
7
+ a2a: Implementation of the Agent-to-Agent (A2A) protocol, which enables
8
+ standardized communication between agents.
9
+ """
10
+
11
+ from .base import MultiAgentBase, MultiAgentResult
12
+ from .graph import GraphBuilder, GraphResult
13
+ from .swarm import Swarm, SwarmResult
14
+
15
+ __all__ = [
16
+ "GraphBuilder",
17
+ "GraphResult",
18
+ "MultiAgentBase",
19
+ "MultiAgentResult",
20
+ "Swarm",
21
+ "SwarmResult",
22
+ ]
@@ -0,0 +1,15 @@
1
+ """Agent-to-Agent (A2A) communication protocol implementation for Strands Agents.
2
+
3
+ This module provides classes and utilities for enabling Strands Agents to communicate
4
+ with other agents using the Agent-to-Agent (A2A) protocol.
5
+
6
+ Docs: https://google-a2a.github.io/A2A/latest/
7
+
8
+ Classes:
9
+ A2AAgent: A wrapper that adapts a Strands Agent to be A2A-compatible.
10
+ """
11
+
12
+ from .executor import StrandsA2AExecutor
13
+ from .server import A2AServer
14
+
15
+ __all__ = ["A2AServer", "StrandsA2AExecutor"]
@@ -0,0 +1,148 @@
1
+ """Strands Agent executor for the A2A protocol.
2
+
3
+ This module provides the StrandsA2AExecutor class, which adapts a Strands Agent
4
+ to be used as an executor in the A2A protocol. It handles the execution of agent
5
+ requests and the conversion of Strands Agent streamed responses to A2A events.
6
+
7
+ The A2A AgentExecutor ensures clients receive responses for synchronous and
8
+ streamed requests to the A2AServer.
9
+ """
10
+
11
+ import logging
12
+ from typing import Any
13
+
14
+ from a2a.server.agent_execution import AgentExecutor, RequestContext
15
+ from a2a.server.events import EventQueue
16
+ from a2a.server.tasks import TaskUpdater
17
+ from a2a.types import InternalError, Part, TaskState, TextPart, UnsupportedOperationError
18
+ from a2a.utils import new_agent_text_message, new_task
19
+ from a2a.utils.errors import ServerError
20
+
21
+ from ...agent.agent import Agent as SAAgent
22
+ from ...agent.agent import AgentResult as SAAgentResult
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ class StrandsA2AExecutor(AgentExecutor):
28
+ """Executor that adapts a Strands Agent to the A2A protocol.
29
+
30
+ This executor uses streaming mode to handle the execution of agent requests
31
+ and converts Strands Agent responses to A2A protocol events.
32
+ """
33
+
34
+ def __init__(self, agent: SAAgent):
35
+ """Initialize a StrandsA2AExecutor.
36
+
37
+ Args:
38
+ agent: The Strands Agent instance to adapt to the A2A protocol.
39
+ """
40
+ self.agent = agent
41
+
42
+ async def execute(
43
+ self,
44
+ context: RequestContext,
45
+ event_queue: EventQueue,
46
+ ) -> None:
47
+ """Execute a request using the Strands Agent and send the response as A2A events.
48
+
49
+ This method executes the user's input using the Strands Agent in streaming mode
50
+ and converts the agent's response to A2A events.
51
+
52
+ Args:
53
+ context: The A2A request context, containing the user's input and task metadata.
54
+ event_queue: The A2A event queue used to send response events back to the client.
55
+
56
+ Raises:
57
+ ServerError: If an error occurs during agent execution
58
+ """
59
+ task = context.current_task
60
+ if not task:
61
+ task = new_task(context.message) # type: ignore
62
+ await event_queue.enqueue_event(task)
63
+
64
+ updater = TaskUpdater(event_queue, task.id, task.context_id)
65
+
66
+ try:
67
+ await self._execute_streaming(context, updater)
68
+ except Exception as e:
69
+ raise ServerError(error=InternalError()) from e
70
+
71
+ async def _execute_streaming(self, context: RequestContext, updater: TaskUpdater) -> None:
72
+ """Execute request in streaming mode.
73
+
74
+ Streams the agent's response in real-time, sending incremental updates
75
+ as they become available from the agent.
76
+
77
+ Args:
78
+ context: The A2A request context, containing the user's input and other metadata.
79
+ updater: The task updater for managing task state and sending updates.
80
+ """
81
+ logger.info("Executing request in streaming mode")
82
+ user_input = context.get_user_input()
83
+ try:
84
+ async for event in self.agent.stream_async(user_input):
85
+ await self._handle_streaming_event(event, updater)
86
+ except Exception:
87
+ logger.exception("Error in streaming execution")
88
+ raise
89
+
90
+ async def _handle_streaming_event(self, event: dict[str, Any], updater: TaskUpdater) -> None:
91
+ """Handle a single streaming event from the Strands Agent.
92
+
93
+ Processes streaming events from the agent, converting data chunks to A2A
94
+ task updates and handling the final result when streaming is complete.
95
+
96
+ Args:
97
+ event: The streaming event from the agent, containing either 'data' for
98
+ incremental content or 'result' for the final response.
99
+ updater: The task updater for managing task state and sending updates.
100
+ """
101
+ logger.debug("Streaming event: %s", event)
102
+ if "data" in event:
103
+ if text_content := event["data"]:
104
+ await updater.update_status(
105
+ TaskState.working,
106
+ new_agent_text_message(
107
+ text_content,
108
+ updater.context_id,
109
+ updater.task_id,
110
+ ),
111
+ )
112
+ elif "result" in event:
113
+ await self._handle_agent_result(event["result"], updater)
114
+
115
+ async def _handle_agent_result(self, result: SAAgentResult | None, updater: TaskUpdater) -> None:
116
+ """Handle the final result from the Strands Agent.
117
+
118
+ Processes the agent's final result, extracts text content from the response,
119
+ and adds it as an artifact to the task before marking the task as complete.
120
+
121
+ Args:
122
+ result: The agent result object containing the final response, or None if no result.
123
+ updater: The task updater for managing task state and adding the final artifact.
124
+ """
125
+ if final_content := str(result):
126
+ await updater.add_artifact(
127
+ [Part(root=TextPart(text=final_content))],
128
+ name="agent_response",
129
+ )
130
+ await updater.complete()
131
+
132
+ async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None:
133
+ """Cancel an ongoing execution.
134
+
135
+ This method is called when a request cancellation is requested. Currently,
136
+ cancellation is not supported by the Strands Agent executor, so this method
137
+ always raises an UnsupportedOperationError.
138
+
139
+ Args:
140
+ context: The A2A request context.
141
+ event_queue: The A2A event queue.
142
+
143
+ Raises:
144
+ ServerError: Always raised with an UnsupportedOperationError, as cancellation
145
+ is not currently supported.
146
+ """
147
+ logger.warning("Cancellation requested but not supported")
148
+ raise ServerError(error=UnsupportedOperationError())