agentrun-sdk 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of agentrun-sdk might be problematic. Click here for more details.

Files changed (115) hide show
  1. agentrun_operation_sdk/cli/__init__.py +1 -0
  2. agentrun_operation_sdk/cli/cli.py +19 -0
  3. agentrun_operation_sdk/cli/common.py +21 -0
  4. agentrun_operation_sdk/cli/runtime/__init__.py +1 -0
  5. agentrun_operation_sdk/cli/runtime/commands.py +203 -0
  6. agentrun_operation_sdk/client/client.py +75 -0
  7. agentrun_operation_sdk/operations/runtime/__init__.py +8 -0
  8. agentrun_operation_sdk/operations/runtime/configure.py +101 -0
  9. agentrun_operation_sdk/operations/runtime/launch.py +82 -0
  10. agentrun_operation_sdk/operations/runtime/models.py +31 -0
  11. agentrun_operation_sdk/services/runtime.py +152 -0
  12. agentrun_operation_sdk/utils/logging_config.py +72 -0
  13. agentrun_operation_sdk/utils/runtime/config.py +94 -0
  14. agentrun_operation_sdk/utils/runtime/container.py +280 -0
  15. agentrun_operation_sdk/utils/runtime/entrypoint.py +203 -0
  16. agentrun_operation_sdk/utils/runtime/schema.py +56 -0
  17. agentrun_sdk/__init__.py +7 -0
  18. agentrun_sdk/agent/__init__.py +25 -0
  19. agentrun_sdk/agent/agent.py +696 -0
  20. agentrun_sdk/agent/agent_result.py +46 -0
  21. agentrun_sdk/agent/conversation_manager/__init__.py +26 -0
  22. agentrun_sdk/agent/conversation_manager/conversation_manager.py +88 -0
  23. agentrun_sdk/agent/conversation_manager/null_conversation_manager.py +46 -0
  24. agentrun_sdk/agent/conversation_manager/sliding_window_conversation_manager.py +179 -0
  25. agentrun_sdk/agent/conversation_manager/summarizing_conversation_manager.py +252 -0
  26. agentrun_sdk/agent/state.py +97 -0
  27. agentrun_sdk/event_loop/__init__.py +9 -0
  28. agentrun_sdk/event_loop/event_loop.py +499 -0
  29. agentrun_sdk/event_loop/streaming.py +319 -0
  30. agentrun_sdk/experimental/__init__.py +4 -0
  31. agentrun_sdk/experimental/hooks/__init__.py +15 -0
  32. agentrun_sdk/experimental/hooks/events.py +123 -0
  33. agentrun_sdk/handlers/__init__.py +10 -0
  34. agentrun_sdk/handlers/callback_handler.py +70 -0
  35. agentrun_sdk/hooks/__init__.py +49 -0
  36. agentrun_sdk/hooks/events.py +80 -0
  37. agentrun_sdk/hooks/registry.py +247 -0
  38. agentrun_sdk/models/__init__.py +10 -0
  39. agentrun_sdk/models/anthropic.py +432 -0
  40. agentrun_sdk/models/bedrock.py +649 -0
  41. agentrun_sdk/models/litellm.py +225 -0
  42. agentrun_sdk/models/llamaapi.py +438 -0
  43. agentrun_sdk/models/mistral.py +539 -0
  44. agentrun_sdk/models/model.py +95 -0
  45. agentrun_sdk/models/ollama.py +357 -0
  46. agentrun_sdk/models/openai.py +436 -0
  47. agentrun_sdk/models/sagemaker.py +598 -0
  48. agentrun_sdk/models/writer.py +449 -0
  49. agentrun_sdk/multiagent/__init__.py +22 -0
  50. agentrun_sdk/multiagent/a2a/__init__.py +15 -0
  51. agentrun_sdk/multiagent/a2a/executor.py +148 -0
  52. agentrun_sdk/multiagent/a2a/server.py +252 -0
  53. agentrun_sdk/multiagent/base.py +92 -0
  54. agentrun_sdk/multiagent/graph.py +555 -0
  55. agentrun_sdk/multiagent/swarm.py +656 -0
  56. agentrun_sdk/py.typed +1 -0
  57. agentrun_sdk/session/__init__.py +18 -0
  58. agentrun_sdk/session/file_session_manager.py +216 -0
  59. agentrun_sdk/session/repository_session_manager.py +152 -0
  60. agentrun_sdk/session/s3_session_manager.py +272 -0
  61. agentrun_sdk/session/session_manager.py +73 -0
  62. agentrun_sdk/session/session_repository.py +51 -0
  63. agentrun_sdk/telemetry/__init__.py +21 -0
  64. agentrun_sdk/telemetry/config.py +194 -0
  65. agentrun_sdk/telemetry/metrics.py +476 -0
  66. agentrun_sdk/telemetry/metrics_constants.py +15 -0
  67. agentrun_sdk/telemetry/tracer.py +563 -0
  68. agentrun_sdk/tools/__init__.py +17 -0
  69. agentrun_sdk/tools/decorator.py +569 -0
  70. agentrun_sdk/tools/executor.py +137 -0
  71. agentrun_sdk/tools/loader.py +152 -0
  72. agentrun_sdk/tools/mcp/__init__.py +13 -0
  73. agentrun_sdk/tools/mcp/mcp_agent_tool.py +99 -0
  74. agentrun_sdk/tools/mcp/mcp_client.py +423 -0
  75. agentrun_sdk/tools/mcp/mcp_instrumentation.py +322 -0
  76. agentrun_sdk/tools/mcp/mcp_types.py +63 -0
  77. agentrun_sdk/tools/registry.py +607 -0
  78. agentrun_sdk/tools/structured_output.py +421 -0
  79. agentrun_sdk/tools/tools.py +217 -0
  80. agentrun_sdk/tools/watcher.py +136 -0
  81. agentrun_sdk/types/__init__.py +5 -0
  82. agentrun_sdk/types/collections.py +23 -0
  83. agentrun_sdk/types/content.py +188 -0
  84. agentrun_sdk/types/event_loop.py +48 -0
  85. agentrun_sdk/types/exceptions.py +81 -0
  86. agentrun_sdk/types/guardrails.py +254 -0
  87. agentrun_sdk/types/media.py +89 -0
  88. agentrun_sdk/types/session.py +152 -0
  89. agentrun_sdk/types/streaming.py +201 -0
  90. agentrun_sdk/types/tools.py +258 -0
  91. agentrun_sdk/types/traces.py +5 -0
  92. agentrun_sdk-0.1.2.dist-info/METADATA +51 -0
  93. agentrun_sdk-0.1.2.dist-info/RECORD +115 -0
  94. agentrun_sdk-0.1.2.dist-info/WHEEL +5 -0
  95. agentrun_sdk-0.1.2.dist-info/entry_points.txt +2 -0
  96. agentrun_sdk-0.1.2.dist-info/top_level.txt +3 -0
  97. agentrun_wrapper/__init__.py +11 -0
  98. agentrun_wrapper/_utils/__init__.py +6 -0
  99. agentrun_wrapper/_utils/endpoints.py +16 -0
  100. agentrun_wrapper/identity/__init__.py +5 -0
  101. agentrun_wrapper/identity/auth.py +211 -0
  102. agentrun_wrapper/memory/__init__.py +6 -0
  103. agentrun_wrapper/memory/client.py +1697 -0
  104. agentrun_wrapper/memory/constants.py +103 -0
  105. agentrun_wrapper/memory/controlplane.py +626 -0
  106. agentrun_wrapper/py.typed +1 -0
  107. agentrun_wrapper/runtime/__init__.py +13 -0
  108. agentrun_wrapper/runtime/app.py +473 -0
  109. agentrun_wrapper/runtime/context.py +34 -0
  110. agentrun_wrapper/runtime/models.py +25 -0
  111. agentrun_wrapper/services/__init__.py +1 -0
  112. agentrun_wrapper/services/identity.py +192 -0
  113. agentrun_wrapper/tools/__init__.py +6 -0
  114. agentrun_wrapper/tools/browser_client.py +325 -0
  115. agentrun_wrapper/tools/code_interpreter_client.py +186 -0
@@ -0,0 +1,225 @@
1
+ """LiteLLM model provider.
2
+
3
+ - Docs: https://docs.litellm.ai/
4
+ """
5
+
6
+ import json
7
+ import logging
8
+ from typing import Any, AsyncGenerator, Optional, Type, TypedDict, TypeVar, Union, cast
9
+
10
+ import litellm
11
+ from litellm.utils import supports_response_schema
12
+ from pydantic import BaseModel
13
+ from typing_extensions import Unpack, override
14
+
15
+ from ..types.content import ContentBlock, Messages
16
+ from ..types.streaming import StreamEvent
17
+ from ..types.tools import ToolSpec
18
+ from .openai import OpenAIModel
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+ T = TypeVar("T", bound=BaseModel)
23
+
24
+
25
+ class LiteLLMModel(OpenAIModel):
26
+ """LiteLLM model provider implementation."""
27
+
28
+ class LiteLLMConfig(TypedDict, total=False):
29
+ """Configuration options for LiteLLM models.
30
+
31
+ Attributes:
32
+ model_id: Model ID (e.g., "openai/gpt-4o", "anthropic/claude-3-sonnet").
33
+ For a complete list of supported models, see https://docs.litellm.ai/docs/providers.
34
+ params: Model parameters (e.g., max_tokens).
35
+ For a complete list of supported parameters, see
36
+ https://docs.litellm.ai/docs/completion/input#input-params-1.
37
+ """
38
+
39
+ model_id: str
40
+ params: Optional[dict[str, Any]]
41
+
42
+ def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config: Unpack[LiteLLMConfig]) -> None:
43
+ """Initialize provider instance.
44
+
45
+ Args:
46
+ client_args: Arguments for the LiteLLM client.
47
+ For a complete list of supported arguments, see
48
+ https://github.com/BerriAI/litellm/blob/main/litellm/main.py.
49
+ **model_config: Configuration options for the LiteLLM model.
50
+ """
51
+ self.client_args = client_args or {}
52
+ self.config = dict(model_config)
53
+
54
+ logger.debug("config=<%s> | initializing", self.config)
55
+
56
+ @override
57
+ def update_config(self, **model_config: Unpack[LiteLLMConfig]) -> None: # type: ignore[override]
58
+ """Update the LiteLLM model configuration with the provided arguments.
59
+
60
+ Args:
61
+ **model_config: Configuration overrides.
62
+ """
63
+ self.config.update(model_config)
64
+
65
+ @override
66
+ def get_config(self) -> LiteLLMConfig:
67
+ """Get the LiteLLM model configuration.
68
+
69
+ Returns:
70
+ The LiteLLM model configuration.
71
+ """
72
+ return cast(LiteLLMModel.LiteLLMConfig, self.config)
73
+
74
+ @override
75
+ @classmethod
76
+ def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any]:
77
+ """Format a LiteLLM content block.
78
+
79
+ Args:
80
+ content: Message content.
81
+
82
+ Returns:
83
+ LiteLLM formatted content block.
84
+
85
+ Raises:
86
+ TypeError: If the content block type cannot be converted to a LiteLLM-compatible format.
87
+ """
88
+ if "reasoningContent" in content:
89
+ return {
90
+ "signature": content["reasoningContent"]["reasoningText"]["signature"],
91
+ "thinking": content["reasoningContent"]["reasoningText"]["text"],
92
+ "type": "thinking",
93
+ }
94
+
95
+ if "video" in content:
96
+ return {
97
+ "type": "video_url",
98
+ "video_url": {
99
+ "detail": "auto",
100
+ "url": content["video"]["source"]["bytes"],
101
+ },
102
+ }
103
+
104
+ return super().format_request_message_content(content)
105
+
106
+ @override
107
+ async def stream(
108
+ self,
109
+ messages: Messages,
110
+ tool_specs: Optional[list[ToolSpec]] = None,
111
+ system_prompt: Optional[str] = None,
112
+ **kwargs: Any,
113
+ ) -> AsyncGenerator[StreamEvent, None]:
114
+ """Stream conversation with the LiteLLM model.
115
+
116
+ Args:
117
+ messages: List of message objects to be processed by the model.
118
+ tool_specs: List of tool specifications to make available to the model.
119
+ system_prompt: System prompt to provide context to the model.
120
+ **kwargs: Additional keyword arguments for future extensibility.
121
+
122
+ Yields:
123
+ Formatted message chunks from the model.
124
+ """
125
+ logger.debug("formatting request")
126
+ request = self.format_request(messages, tool_specs, system_prompt)
127
+ logger.debug("request=<%s>", request)
128
+
129
+ logger.debug("invoking model")
130
+ response = await litellm.acompletion(**self.client_args, **request)
131
+
132
+ logger.debug("got response from model")
133
+ yield self.format_chunk({"chunk_type": "message_start"})
134
+ yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"})
135
+
136
+ tool_calls: dict[int, list[Any]] = {}
137
+
138
+ async for event in response:
139
+ # Defensive: skip events with empty or missing choices
140
+ if not getattr(event, "choices", None):
141
+ continue
142
+ choice = event.choices[0]
143
+
144
+ if choice.delta.content:
145
+ yield self.format_chunk(
146
+ {"chunk_type": "content_delta", "data_type": "text", "data": choice.delta.content}
147
+ )
148
+
149
+ if hasattr(choice.delta, "reasoning_content") and choice.delta.reasoning_content:
150
+ yield self.format_chunk(
151
+ {
152
+ "chunk_type": "content_delta",
153
+ "data_type": "reasoning_content",
154
+ "data": choice.delta.reasoning_content,
155
+ }
156
+ )
157
+
158
+ for tool_call in choice.delta.tool_calls or []:
159
+ tool_calls.setdefault(tool_call.index, []).append(tool_call)
160
+
161
+ if choice.finish_reason:
162
+ break
163
+
164
+ yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"})
165
+
166
+ for tool_deltas in tool_calls.values():
167
+ yield self.format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]})
168
+
169
+ for tool_delta in tool_deltas:
170
+ yield self.format_chunk({"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta})
171
+
172
+ yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"})
173
+
174
+ yield self.format_chunk({"chunk_type": "message_stop", "data": choice.finish_reason})
175
+
176
+ # Skip remaining events as we don't have use for anything except the final usage payload
177
+ async for event in response:
178
+ _ = event
179
+
180
+ if event.usage:
181
+ yield self.format_chunk({"chunk_type": "metadata", "data": event.usage})
182
+
183
+ logger.debug("finished streaming response from model")
184
+
185
+ @override
186
+ async def structured_output(
187
+ self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any
188
+ ) -> AsyncGenerator[dict[str, Union[T, Any]], None]:
189
+ """Get structured output from the model.
190
+
191
+ Args:
192
+ output_model: The output model to use for the agent.
193
+ prompt: The prompt messages to use for the agent.
194
+ system_prompt: System prompt to provide context to the model.
195
+ **kwargs: Additional keyword arguments for future extensibility.
196
+
197
+ Yields:
198
+ Model events with the last being the structured output.
199
+ """
200
+ response = await litellm.acompletion(
201
+ **self.client_args,
202
+ model=self.get_config()["model_id"],
203
+ messages=self.format_request(prompt, system_prompt=system_prompt)["messages"],
204
+ response_format=output_model,
205
+ )
206
+
207
+ if not supports_response_schema(self.get_config()["model_id"]):
208
+ raise ValueError("Model does not support response_format")
209
+ if len(response.choices) > 1:
210
+ raise ValueError("Multiple choices found in the response.")
211
+
212
+ # Find the first choice with tool_calls
213
+ for choice in response.choices:
214
+ if choice.finish_reason == "tool_calls":
215
+ try:
216
+ # Parse the tool call content as JSON
217
+ tool_call_data = json.loads(choice.message.content)
218
+ # Instantiate the output model with the parsed data
219
+ yield {"output": output_model(**tool_call_data)}
220
+ return
221
+ except (json.JSONDecodeError, TypeError, ValueError) as e:
222
+ raise ValueError(f"Failed to parse or load content into model: {e}") from e
223
+
224
+ # If no tool_calls found, raise an error
225
+ raise ValueError("No tool_calls found in response")
@@ -0,0 +1,438 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates
2
+ """Llama API model provider.
3
+
4
+ - Docs: https://llama.developer.meta.com/
5
+ """
6
+
7
+ import base64
8
+ import json
9
+ import logging
10
+ import mimetypes
11
+ from typing import Any, AsyncGenerator, Optional, Type, TypeVar, Union, cast
12
+
13
+ import llama_api_client
14
+ from llama_api_client import LlamaAPIClient
15
+ from pydantic import BaseModel
16
+ from typing_extensions import TypedDict, Unpack, override
17
+
18
+ from ..types.content import ContentBlock, Messages
19
+ from ..types.exceptions import ModelThrottledException
20
+ from ..types.streaming import StreamEvent, Usage
21
+ from ..types.tools import ToolResult, ToolSpec, ToolUse
22
+ from .model import Model
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+ T = TypeVar("T", bound=BaseModel)
27
+
28
+
29
+ class LlamaAPIModel(Model):
30
+ """Llama API model provider implementation."""
31
+
32
+ class LlamaConfig(TypedDict, total=False):
33
+ """Configuration options for Llama API models.
34
+
35
+ Attributes:
36
+ model_id: Model ID (e.g., "Llama-4-Maverick-17B-128E-Instruct-FP8").
37
+ repetition_penalty: Repetition penalty.
38
+ temperature: Temperature.
39
+ top_p: Top-p.
40
+ max_completion_tokens: Maximum completion tokens.
41
+ top_k: Top-k.
42
+ """
43
+
44
+ model_id: str
45
+ repetition_penalty: Optional[float]
46
+ temperature: Optional[float]
47
+ top_p: Optional[float]
48
+ max_completion_tokens: Optional[int]
49
+ top_k: Optional[int]
50
+
51
+ def __init__(
52
+ self,
53
+ *,
54
+ client_args: Optional[dict[str, Any]] = None,
55
+ **model_config: Unpack[LlamaConfig],
56
+ ) -> None:
57
+ """Initialize provider instance.
58
+
59
+ Args:
60
+ client_args: Arguments for the Llama API client.
61
+ **model_config: Configuration options for the Llama API model.
62
+ """
63
+ self.config = LlamaAPIModel.LlamaConfig(**model_config)
64
+ logger.debug("config=<%s> | initializing", self.config)
65
+
66
+ if not client_args:
67
+ self.client = LlamaAPIClient()
68
+ else:
69
+ self.client = LlamaAPIClient(**client_args)
70
+
71
+ @override
72
+ def update_config(self, **model_config: Unpack[LlamaConfig]) -> None: # type: ignore
73
+ """Update the Llama API Model configuration with the provided arguments.
74
+
75
+ Args:
76
+ **model_config: Configuration overrides.
77
+ """
78
+ self.config.update(model_config)
79
+
80
+ @override
81
+ def get_config(self) -> LlamaConfig:
82
+ """Get the Llama API model configuration.
83
+
84
+ Returns:
85
+ The Llama API model configuration.
86
+ """
87
+ return self.config
88
+
89
+ def _format_request_message_content(self, content: ContentBlock) -> dict[str, Any]:
90
+ """Format a LlamaAPI content block.
91
+
92
+ - NOTE: "reasoningContent" and "video" are not supported currently.
93
+
94
+ Args:
95
+ content: Message content.
96
+
97
+ Returns:
98
+ LllamaAPI formatted content block.
99
+
100
+ Raises:
101
+ TypeError: If the content block type cannot be converted to a LlamaAPI-compatible format.
102
+ """
103
+ if "image" in content:
104
+ mime_type = mimetypes.types_map.get(f".{content['image']['format']}", "application/octet-stream")
105
+ image_data = base64.b64encode(content["image"]["source"]["bytes"]).decode("utf-8")
106
+
107
+ return {
108
+ "image_url": {
109
+ "url": f"data:{mime_type};base64,{image_data}",
110
+ },
111
+ "type": "image_url",
112
+ }
113
+
114
+ if "text" in content:
115
+ return {"text": content["text"], "type": "text"}
116
+
117
+ raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type")
118
+
119
+ def _format_request_message_tool_call(self, tool_use: ToolUse) -> dict[str, Any]:
120
+ """Format a Llama API tool call.
121
+
122
+ Args:
123
+ tool_use: Tool use requested by the model.
124
+
125
+ Returns:
126
+ Llama API formatted tool call.
127
+ """
128
+ return {
129
+ "function": {
130
+ "arguments": json.dumps(tool_use["input"]),
131
+ "name": tool_use["name"],
132
+ },
133
+ "id": tool_use["toolUseId"],
134
+ }
135
+
136
+ def _format_request_tool_message(self, tool_result: ToolResult) -> dict[str, Any]:
137
+ """Format a Llama API tool message.
138
+
139
+ Args:
140
+ tool_result: Tool result collected from a tool execution.
141
+
142
+ Returns:
143
+ Llama API formatted tool message.
144
+ """
145
+ contents = cast(
146
+ list[ContentBlock],
147
+ [
148
+ {"text": json.dumps(content["json"])} if "json" in content else content
149
+ for content in tool_result["content"]
150
+ ],
151
+ )
152
+
153
+ return {
154
+ "role": "tool",
155
+ "tool_call_id": tool_result["toolUseId"],
156
+ "content": [self._format_request_message_content(content) for content in contents],
157
+ }
158
+
159
+ def _format_request_messages(self, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]:
160
+ """Format a LlamaAPI compatible messages array.
161
+
162
+ Args:
163
+ messages: List of message objects to be processed by the model.
164
+ system_prompt: System prompt to provide context to the model.
165
+
166
+ Returns:
167
+ An LlamaAPI compatible messages array.
168
+ """
169
+ formatted_messages: list[dict[str, Any]]
170
+ formatted_messages = [{"role": "system", "content": system_prompt}] if system_prompt else []
171
+
172
+ for message in messages:
173
+ contents = message["content"]
174
+
175
+ formatted_contents: list[dict[str, Any]] | dict[str, Any] | str = ""
176
+ formatted_contents = [
177
+ self._format_request_message_content(content)
178
+ for content in contents
179
+ if not any(block_type in content for block_type in ["toolResult", "toolUse"])
180
+ ]
181
+ formatted_tool_calls = [
182
+ self._format_request_message_tool_call(content["toolUse"])
183
+ for content in contents
184
+ if "toolUse" in content
185
+ ]
186
+ formatted_tool_messages = [
187
+ self._format_request_tool_message(content["toolResult"])
188
+ for content in contents
189
+ if "toolResult" in content
190
+ ]
191
+
192
+ if message["role"] == "assistant":
193
+ formatted_contents = formatted_contents[0] if formatted_contents else ""
194
+
195
+ formatted_message = {
196
+ "role": message["role"],
197
+ "content": formatted_contents if len(formatted_contents) > 0 else "",
198
+ **({"tool_calls": formatted_tool_calls} if formatted_tool_calls else {}),
199
+ }
200
+ formatted_messages.append(formatted_message)
201
+ formatted_messages.extend(formatted_tool_messages)
202
+
203
+ return [message for message in formatted_messages if message["content"] or "tool_calls" in message]
204
+
205
+ def format_request(
206
+ self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None
207
+ ) -> dict[str, Any]:
208
+ """Format a Llama API chat streaming request.
209
+
210
+ Args:
211
+ messages: List of message objects to be processed by the model.
212
+ tool_specs: List of tool specifications to make available to the model.
213
+ system_prompt: System prompt to provide context to the model.
214
+
215
+ Returns:
216
+ An Llama API chat streaming request.
217
+
218
+ Raises:
219
+ TypeError: If a message contains a content block type that cannot be converted to a LlamaAPI-compatible
220
+ format.
221
+ """
222
+ request = {
223
+ "messages": self._format_request_messages(messages, system_prompt),
224
+ "model": self.config["model_id"],
225
+ "stream": True,
226
+ "tools": [
227
+ {
228
+ "type": "function",
229
+ "function": {
230
+ "name": tool_spec["name"],
231
+ "description": tool_spec["description"],
232
+ "parameters": tool_spec["inputSchema"]["json"],
233
+ },
234
+ }
235
+ for tool_spec in tool_specs or []
236
+ ],
237
+ }
238
+ if "temperature" in self.config:
239
+ request["temperature"] = self.config["temperature"]
240
+ if "top_p" in self.config:
241
+ request["top_p"] = self.config["top_p"]
242
+ if "repetition_penalty" in self.config:
243
+ request["repetition_penalty"] = self.config["repetition_penalty"]
244
+ if "max_completion_tokens" in self.config:
245
+ request["max_completion_tokens"] = self.config["max_completion_tokens"]
246
+ if "top_k" in self.config:
247
+ request["top_k"] = self.config["top_k"]
248
+
249
+ return request
250
+
251
+ def format_chunk(self, event: dict[str, Any]) -> StreamEvent:
252
+ """Format the Llama API model response events into standardized message chunks.
253
+
254
+ Args:
255
+ event: A response event from the model.
256
+
257
+ Returns:
258
+ The formatted chunk.
259
+ """
260
+ match event["chunk_type"]:
261
+ case "message_start":
262
+ return {"messageStart": {"role": "assistant"}}
263
+
264
+ case "content_start":
265
+ if event["data_type"] == "text":
266
+ return {"contentBlockStart": {"start": {}}}
267
+
268
+ return {
269
+ "contentBlockStart": {
270
+ "start": {
271
+ "toolUse": {
272
+ "name": event["data"].function.name,
273
+ "toolUseId": event["data"].id,
274
+ }
275
+ }
276
+ }
277
+ }
278
+
279
+ case "content_delta":
280
+ if event["data_type"] == "text":
281
+ return {"contentBlockDelta": {"delta": {"text": event["data"]}}}
282
+
283
+ return {"contentBlockDelta": {"delta": {"toolUse": {"input": event["data"].function.arguments}}}}
284
+
285
+ case "content_stop":
286
+ return {"contentBlockStop": {}}
287
+
288
+ case "message_stop":
289
+ match event["data"]:
290
+ case "tool_calls":
291
+ return {"messageStop": {"stopReason": "tool_use"}}
292
+ case "length":
293
+ return {"messageStop": {"stopReason": "max_tokens"}}
294
+ case _:
295
+ return {"messageStop": {"stopReason": "end_turn"}}
296
+
297
+ case "metadata":
298
+ usage = {}
299
+ for metrics in event["data"]:
300
+ if metrics.metric == "num_prompt_tokens":
301
+ usage["inputTokens"] = metrics.value
302
+ elif metrics.metric == "num_completion_tokens":
303
+ usage["outputTokens"] = metrics.value
304
+ elif metrics.metric == "num_total_tokens":
305
+ usage["totalTokens"] = metrics.value
306
+
307
+ usage_type = Usage(
308
+ inputTokens=usage["inputTokens"],
309
+ outputTokens=usage["outputTokens"],
310
+ totalTokens=usage["totalTokens"],
311
+ )
312
+ return {
313
+ "metadata": {
314
+ "usage": usage_type,
315
+ "metrics": {
316
+ "latencyMs": 0, # TODO
317
+ },
318
+ },
319
+ }
320
+
321
+ case _:
322
+ raise RuntimeError(f"chunk_type=<{event['chunk_type']} | unknown type")
323
+
324
+ @override
325
+ async def stream(
326
+ self,
327
+ messages: Messages,
328
+ tool_specs: Optional[list[ToolSpec]] = None,
329
+ system_prompt: Optional[str] = None,
330
+ **kwargs: Any,
331
+ ) -> AsyncGenerator[StreamEvent, None]:
332
+ """Stream conversation with the LlamaAPI model.
333
+
334
+ Args:
335
+ messages: List of message objects to be processed by the model.
336
+ tool_specs: List of tool specifications to make available to the model.
337
+ system_prompt: System prompt to provide context to the model.
338
+ **kwargs: Additional keyword arguments for future extensibility.
339
+
340
+ Yields:
341
+ Formatted message chunks from the model.
342
+
343
+ Raises:
344
+ ModelThrottledException: When the model service is throttling requests from the client.
345
+ """
346
+ logger.debug("formatting request")
347
+ request = self.format_request(messages, tool_specs, system_prompt)
348
+ logger.debug("request=<%s>", request)
349
+
350
+ logger.debug("invoking model")
351
+ try:
352
+ response = self.client.chat.completions.create(**request)
353
+ except llama_api_client.RateLimitError as e:
354
+ raise ModelThrottledException(str(e)) from e
355
+
356
+ logger.debug("got response from model")
357
+ yield self.format_chunk({"chunk_type": "message_start"})
358
+
359
+ stop_reason = None
360
+ tool_calls: dict[Any, list[Any]] = {}
361
+ curr_tool_call_id = None
362
+
363
+ metrics_event = None
364
+ for chunk in response:
365
+ if chunk.event.event_type == "start":
366
+ yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"})
367
+ elif chunk.event.event_type in ["progress", "complete"] and chunk.event.delta.type == "text":
368
+ yield self.format_chunk(
369
+ {"chunk_type": "content_delta", "data_type": "text", "data": chunk.event.delta.text}
370
+ )
371
+ else:
372
+ if chunk.event.delta.type == "tool_call":
373
+ if chunk.event.delta.id:
374
+ curr_tool_call_id = chunk.event.delta.id
375
+
376
+ if curr_tool_call_id not in tool_calls:
377
+ tool_calls[curr_tool_call_id] = []
378
+ tool_calls[curr_tool_call_id].append(chunk.event.delta)
379
+ elif chunk.event.event_type == "metrics":
380
+ metrics_event = chunk.event.metrics
381
+ else:
382
+ yield self.format_chunk(chunk)
383
+
384
+ if stop_reason is None:
385
+ stop_reason = chunk.event.stop_reason
386
+
387
+ # stopped generation
388
+ if stop_reason:
389
+ yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"})
390
+
391
+ for tool_deltas in tool_calls.values():
392
+ tool_start, tool_deltas = tool_deltas[0], tool_deltas[1:]
393
+ yield self.format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": tool_start})
394
+
395
+ for tool_delta in tool_deltas:
396
+ yield self.format_chunk({"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta})
397
+
398
+ yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"})
399
+
400
+ yield self.format_chunk({"chunk_type": "message_stop", "data": stop_reason})
401
+
402
+ # we may have a metrics event here
403
+ if metrics_event:
404
+ yield self.format_chunk({"chunk_type": "metadata", "data": metrics_event})
405
+
406
+ logger.debug("finished streaming response from model")
407
+
408
+ @override
409
+ def structured_output(
410
+ self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any
411
+ ) -> AsyncGenerator[dict[str, Union[T, Any]], None]:
412
+ """Get structured output from the model.
413
+
414
+ Args:
415
+ output_model: The output model to use for the agent.
416
+ prompt: The prompt messages to use for the agent.
417
+ system_prompt: System prompt to provide context to the model.
418
+ **kwargs: Additional keyword arguments for future extensibility.
419
+
420
+ Yields:
421
+ Model events with the last being the structured output.
422
+
423
+ Raises:
424
+ NotImplementedError: Structured output is not currently supported for LlamaAPI models.
425
+ """
426
+ # response_format: ResponseFormat = {
427
+ # "type": "json_schema",
428
+ # "json_schema": {
429
+ # "name": output_model.__name__,
430
+ # "schema": output_model.model_json_schema(),
431
+ # },
432
+ # }
433
+ # response = self.client.chat.completions.create(
434
+ # model=self.config["model_id"],
435
+ # messages=self.format_request(prompt)["messages"],
436
+ # response_format=response_format,
437
+ # )
438
+ raise NotImplementedError("Strands sdk-python does not implement this in the Llama API Preview.")