agentrun-sdk 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of agentrun-sdk might be problematic. Click here for more details.
- agentrun_operation_sdk/cli/__init__.py +1 -0
- agentrun_operation_sdk/cli/cli.py +19 -0
- agentrun_operation_sdk/cli/common.py +21 -0
- agentrun_operation_sdk/cli/runtime/__init__.py +1 -0
- agentrun_operation_sdk/cli/runtime/commands.py +203 -0
- agentrun_operation_sdk/client/client.py +75 -0
- agentrun_operation_sdk/operations/runtime/__init__.py +8 -0
- agentrun_operation_sdk/operations/runtime/configure.py +101 -0
- agentrun_operation_sdk/operations/runtime/launch.py +82 -0
- agentrun_operation_sdk/operations/runtime/models.py +31 -0
- agentrun_operation_sdk/services/runtime.py +152 -0
- agentrun_operation_sdk/utils/logging_config.py +72 -0
- agentrun_operation_sdk/utils/runtime/config.py +94 -0
- agentrun_operation_sdk/utils/runtime/container.py +280 -0
- agentrun_operation_sdk/utils/runtime/entrypoint.py +203 -0
- agentrun_operation_sdk/utils/runtime/schema.py +56 -0
- agentrun_sdk/__init__.py +7 -0
- agentrun_sdk/agent/__init__.py +25 -0
- agentrun_sdk/agent/agent.py +696 -0
- agentrun_sdk/agent/agent_result.py +46 -0
- agentrun_sdk/agent/conversation_manager/__init__.py +26 -0
- agentrun_sdk/agent/conversation_manager/conversation_manager.py +88 -0
- agentrun_sdk/agent/conversation_manager/null_conversation_manager.py +46 -0
- agentrun_sdk/agent/conversation_manager/sliding_window_conversation_manager.py +179 -0
- agentrun_sdk/agent/conversation_manager/summarizing_conversation_manager.py +252 -0
- agentrun_sdk/agent/state.py +97 -0
- agentrun_sdk/event_loop/__init__.py +9 -0
- agentrun_sdk/event_loop/event_loop.py +499 -0
- agentrun_sdk/event_loop/streaming.py +319 -0
- agentrun_sdk/experimental/__init__.py +4 -0
- agentrun_sdk/experimental/hooks/__init__.py +15 -0
- agentrun_sdk/experimental/hooks/events.py +123 -0
- agentrun_sdk/handlers/__init__.py +10 -0
- agentrun_sdk/handlers/callback_handler.py +70 -0
- agentrun_sdk/hooks/__init__.py +49 -0
- agentrun_sdk/hooks/events.py +80 -0
- agentrun_sdk/hooks/registry.py +247 -0
- agentrun_sdk/models/__init__.py +10 -0
- agentrun_sdk/models/anthropic.py +432 -0
- agentrun_sdk/models/bedrock.py +649 -0
- agentrun_sdk/models/litellm.py +225 -0
- agentrun_sdk/models/llamaapi.py +438 -0
- agentrun_sdk/models/mistral.py +539 -0
- agentrun_sdk/models/model.py +95 -0
- agentrun_sdk/models/ollama.py +357 -0
- agentrun_sdk/models/openai.py +436 -0
- agentrun_sdk/models/sagemaker.py +598 -0
- agentrun_sdk/models/writer.py +449 -0
- agentrun_sdk/multiagent/__init__.py +22 -0
- agentrun_sdk/multiagent/a2a/__init__.py +15 -0
- agentrun_sdk/multiagent/a2a/executor.py +148 -0
- agentrun_sdk/multiagent/a2a/server.py +252 -0
- agentrun_sdk/multiagent/base.py +92 -0
- agentrun_sdk/multiagent/graph.py +555 -0
- agentrun_sdk/multiagent/swarm.py +656 -0
- agentrun_sdk/py.typed +1 -0
- agentrun_sdk/session/__init__.py +18 -0
- agentrun_sdk/session/file_session_manager.py +216 -0
- agentrun_sdk/session/repository_session_manager.py +152 -0
- agentrun_sdk/session/s3_session_manager.py +272 -0
- agentrun_sdk/session/session_manager.py +73 -0
- agentrun_sdk/session/session_repository.py +51 -0
- agentrun_sdk/telemetry/__init__.py +21 -0
- agentrun_sdk/telemetry/config.py +194 -0
- agentrun_sdk/telemetry/metrics.py +476 -0
- agentrun_sdk/telemetry/metrics_constants.py +15 -0
- agentrun_sdk/telemetry/tracer.py +563 -0
- agentrun_sdk/tools/__init__.py +17 -0
- agentrun_sdk/tools/decorator.py +569 -0
- agentrun_sdk/tools/executor.py +137 -0
- agentrun_sdk/tools/loader.py +152 -0
- agentrun_sdk/tools/mcp/__init__.py +13 -0
- agentrun_sdk/tools/mcp/mcp_agent_tool.py +99 -0
- agentrun_sdk/tools/mcp/mcp_client.py +423 -0
- agentrun_sdk/tools/mcp/mcp_instrumentation.py +322 -0
- agentrun_sdk/tools/mcp/mcp_types.py +63 -0
- agentrun_sdk/tools/registry.py +607 -0
- agentrun_sdk/tools/structured_output.py +421 -0
- agentrun_sdk/tools/tools.py +217 -0
- agentrun_sdk/tools/watcher.py +136 -0
- agentrun_sdk/types/__init__.py +5 -0
- agentrun_sdk/types/collections.py +23 -0
- agentrun_sdk/types/content.py +188 -0
- agentrun_sdk/types/event_loop.py +48 -0
- agentrun_sdk/types/exceptions.py +81 -0
- agentrun_sdk/types/guardrails.py +254 -0
- agentrun_sdk/types/media.py +89 -0
- agentrun_sdk/types/session.py +152 -0
- agentrun_sdk/types/streaming.py +201 -0
- agentrun_sdk/types/tools.py +258 -0
- agentrun_sdk/types/traces.py +5 -0
- agentrun_sdk-0.1.2.dist-info/METADATA +51 -0
- agentrun_sdk-0.1.2.dist-info/RECORD +115 -0
- agentrun_sdk-0.1.2.dist-info/WHEEL +5 -0
- agentrun_sdk-0.1.2.dist-info/entry_points.txt +2 -0
- agentrun_sdk-0.1.2.dist-info/top_level.txt +3 -0
- agentrun_wrapper/__init__.py +11 -0
- agentrun_wrapper/_utils/__init__.py +6 -0
- agentrun_wrapper/_utils/endpoints.py +16 -0
- agentrun_wrapper/identity/__init__.py +5 -0
- agentrun_wrapper/identity/auth.py +211 -0
- agentrun_wrapper/memory/__init__.py +6 -0
- agentrun_wrapper/memory/client.py +1697 -0
- agentrun_wrapper/memory/constants.py +103 -0
- agentrun_wrapper/memory/controlplane.py +626 -0
- agentrun_wrapper/py.typed +1 -0
- agentrun_wrapper/runtime/__init__.py +13 -0
- agentrun_wrapper/runtime/app.py +473 -0
- agentrun_wrapper/runtime/context.py +34 -0
- agentrun_wrapper/runtime/models.py +25 -0
- agentrun_wrapper/services/__init__.py +1 -0
- agentrun_wrapper/services/identity.py +192 -0
- agentrun_wrapper/tools/__init__.py +6 -0
- agentrun_wrapper/tools/browser_client.py +325 -0
- agentrun_wrapper/tools/code_interpreter_client.py +186 -0
|
@@ -0,0 +1,225 @@
|
|
|
1
|
+
"""LiteLLM model provider.
|
|
2
|
+
|
|
3
|
+
- Docs: https://docs.litellm.ai/
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import json
|
|
7
|
+
import logging
|
|
8
|
+
from typing import Any, AsyncGenerator, Optional, Type, TypedDict, TypeVar, Union, cast
|
|
9
|
+
|
|
10
|
+
import litellm
|
|
11
|
+
from litellm.utils import supports_response_schema
|
|
12
|
+
from pydantic import BaseModel
|
|
13
|
+
from typing_extensions import Unpack, override
|
|
14
|
+
|
|
15
|
+
from ..types.content import ContentBlock, Messages
|
|
16
|
+
from ..types.streaming import StreamEvent
|
|
17
|
+
from ..types.tools import ToolSpec
|
|
18
|
+
from .openai import OpenAIModel
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
T = TypeVar("T", bound=BaseModel)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class LiteLLMModel(OpenAIModel):
|
|
26
|
+
"""LiteLLM model provider implementation."""
|
|
27
|
+
|
|
28
|
+
class LiteLLMConfig(TypedDict, total=False):
|
|
29
|
+
"""Configuration options for LiteLLM models.
|
|
30
|
+
|
|
31
|
+
Attributes:
|
|
32
|
+
model_id: Model ID (e.g., "openai/gpt-4o", "anthropic/claude-3-sonnet").
|
|
33
|
+
For a complete list of supported models, see https://docs.litellm.ai/docs/providers.
|
|
34
|
+
params: Model parameters (e.g., max_tokens).
|
|
35
|
+
For a complete list of supported parameters, see
|
|
36
|
+
https://docs.litellm.ai/docs/completion/input#input-params-1.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
model_id: str
|
|
40
|
+
params: Optional[dict[str, Any]]
|
|
41
|
+
|
|
42
|
+
def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config: Unpack[LiteLLMConfig]) -> None:
|
|
43
|
+
"""Initialize provider instance.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
client_args: Arguments for the LiteLLM client.
|
|
47
|
+
For a complete list of supported arguments, see
|
|
48
|
+
https://github.com/BerriAI/litellm/blob/main/litellm/main.py.
|
|
49
|
+
**model_config: Configuration options for the LiteLLM model.
|
|
50
|
+
"""
|
|
51
|
+
self.client_args = client_args or {}
|
|
52
|
+
self.config = dict(model_config)
|
|
53
|
+
|
|
54
|
+
logger.debug("config=<%s> | initializing", self.config)
|
|
55
|
+
|
|
56
|
+
@override
|
|
57
|
+
def update_config(self, **model_config: Unpack[LiteLLMConfig]) -> None: # type: ignore[override]
|
|
58
|
+
"""Update the LiteLLM model configuration with the provided arguments.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
**model_config: Configuration overrides.
|
|
62
|
+
"""
|
|
63
|
+
self.config.update(model_config)
|
|
64
|
+
|
|
65
|
+
@override
|
|
66
|
+
def get_config(self) -> LiteLLMConfig:
|
|
67
|
+
"""Get the LiteLLM model configuration.
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
The LiteLLM model configuration.
|
|
71
|
+
"""
|
|
72
|
+
return cast(LiteLLMModel.LiteLLMConfig, self.config)
|
|
73
|
+
|
|
74
|
+
@override
|
|
75
|
+
@classmethod
|
|
76
|
+
def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any]:
|
|
77
|
+
"""Format a LiteLLM content block.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
content: Message content.
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
LiteLLM formatted content block.
|
|
84
|
+
|
|
85
|
+
Raises:
|
|
86
|
+
TypeError: If the content block type cannot be converted to a LiteLLM-compatible format.
|
|
87
|
+
"""
|
|
88
|
+
if "reasoningContent" in content:
|
|
89
|
+
return {
|
|
90
|
+
"signature": content["reasoningContent"]["reasoningText"]["signature"],
|
|
91
|
+
"thinking": content["reasoningContent"]["reasoningText"]["text"],
|
|
92
|
+
"type": "thinking",
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
if "video" in content:
|
|
96
|
+
return {
|
|
97
|
+
"type": "video_url",
|
|
98
|
+
"video_url": {
|
|
99
|
+
"detail": "auto",
|
|
100
|
+
"url": content["video"]["source"]["bytes"],
|
|
101
|
+
},
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
return super().format_request_message_content(content)
|
|
105
|
+
|
|
106
|
+
@override
|
|
107
|
+
async def stream(
|
|
108
|
+
self,
|
|
109
|
+
messages: Messages,
|
|
110
|
+
tool_specs: Optional[list[ToolSpec]] = None,
|
|
111
|
+
system_prompt: Optional[str] = None,
|
|
112
|
+
**kwargs: Any,
|
|
113
|
+
) -> AsyncGenerator[StreamEvent, None]:
|
|
114
|
+
"""Stream conversation with the LiteLLM model.
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
messages: List of message objects to be processed by the model.
|
|
118
|
+
tool_specs: List of tool specifications to make available to the model.
|
|
119
|
+
system_prompt: System prompt to provide context to the model.
|
|
120
|
+
**kwargs: Additional keyword arguments for future extensibility.
|
|
121
|
+
|
|
122
|
+
Yields:
|
|
123
|
+
Formatted message chunks from the model.
|
|
124
|
+
"""
|
|
125
|
+
logger.debug("formatting request")
|
|
126
|
+
request = self.format_request(messages, tool_specs, system_prompt)
|
|
127
|
+
logger.debug("request=<%s>", request)
|
|
128
|
+
|
|
129
|
+
logger.debug("invoking model")
|
|
130
|
+
response = await litellm.acompletion(**self.client_args, **request)
|
|
131
|
+
|
|
132
|
+
logger.debug("got response from model")
|
|
133
|
+
yield self.format_chunk({"chunk_type": "message_start"})
|
|
134
|
+
yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"})
|
|
135
|
+
|
|
136
|
+
tool_calls: dict[int, list[Any]] = {}
|
|
137
|
+
|
|
138
|
+
async for event in response:
|
|
139
|
+
# Defensive: skip events with empty or missing choices
|
|
140
|
+
if not getattr(event, "choices", None):
|
|
141
|
+
continue
|
|
142
|
+
choice = event.choices[0]
|
|
143
|
+
|
|
144
|
+
if choice.delta.content:
|
|
145
|
+
yield self.format_chunk(
|
|
146
|
+
{"chunk_type": "content_delta", "data_type": "text", "data": choice.delta.content}
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
if hasattr(choice.delta, "reasoning_content") and choice.delta.reasoning_content:
|
|
150
|
+
yield self.format_chunk(
|
|
151
|
+
{
|
|
152
|
+
"chunk_type": "content_delta",
|
|
153
|
+
"data_type": "reasoning_content",
|
|
154
|
+
"data": choice.delta.reasoning_content,
|
|
155
|
+
}
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
for tool_call in choice.delta.tool_calls or []:
|
|
159
|
+
tool_calls.setdefault(tool_call.index, []).append(tool_call)
|
|
160
|
+
|
|
161
|
+
if choice.finish_reason:
|
|
162
|
+
break
|
|
163
|
+
|
|
164
|
+
yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"})
|
|
165
|
+
|
|
166
|
+
for tool_deltas in tool_calls.values():
|
|
167
|
+
yield self.format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]})
|
|
168
|
+
|
|
169
|
+
for tool_delta in tool_deltas:
|
|
170
|
+
yield self.format_chunk({"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta})
|
|
171
|
+
|
|
172
|
+
yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"})
|
|
173
|
+
|
|
174
|
+
yield self.format_chunk({"chunk_type": "message_stop", "data": choice.finish_reason})
|
|
175
|
+
|
|
176
|
+
# Skip remaining events as we don't have use for anything except the final usage payload
|
|
177
|
+
async for event in response:
|
|
178
|
+
_ = event
|
|
179
|
+
|
|
180
|
+
if event.usage:
|
|
181
|
+
yield self.format_chunk({"chunk_type": "metadata", "data": event.usage})
|
|
182
|
+
|
|
183
|
+
logger.debug("finished streaming response from model")
|
|
184
|
+
|
|
185
|
+
@override
|
|
186
|
+
async def structured_output(
|
|
187
|
+
self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any
|
|
188
|
+
) -> AsyncGenerator[dict[str, Union[T, Any]], None]:
|
|
189
|
+
"""Get structured output from the model.
|
|
190
|
+
|
|
191
|
+
Args:
|
|
192
|
+
output_model: The output model to use for the agent.
|
|
193
|
+
prompt: The prompt messages to use for the agent.
|
|
194
|
+
system_prompt: System prompt to provide context to the model.
|
|
195
|
+
**kwargs: Additional keyword arguments for future extensibility.
|
|
196
|
+
|
|
197
|
+
Yields:
|
|
198
|
+
Model events with the last being the structured output.
|
|
199
|
+
"""
|
|
200
|
+
response = await litellm.acompletion(
|
|
201
|
+
**self.client_args,
|
|
202
|
+
model=self.get_config()["model_id"],
|
|
203
|
+
messages=self.format_request(prompt, system_prompt=system_prompt)["messages"],
|
|
204
|
+
response_format=output_model,
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
if not supports_response_schema(self.get_config()["model_id"]):
|
|
208
|
+
raise ValueError("Model does not support response_format")
|
|
209
|
+
if len(response.choices) > 1:
|
|
210
|
+
raise ValueError("Multiple choices found in the response.")
|
|
211
|
+
|
|
212
|
+
# Find the first choice with tool_calls
|
|
213
|
+
for choice in response.choices:
|
|
214
|
+
if choice.finish_reason == "tool_calls":
|
|
215
|
+
try:
|
|
216
|
+
# Parse the tool call content as JSON
|
|
217
|
+
tool_call_data = json.loads(choice.message.content)
|
|
218
|
+
# Instantiate the output model with the parsed data
|
|
219
|
+
yield {"output": output_model(**tool_call_data)}
|
|
220
|
+
return
|
|
221
|
+
except (json.JSONDecodeError, TypeError, ValueError) as e:
|
|
222
|
+
raise ValueError(f"Failed to parse or load content into model: {e}") from e
|
|
223
|
+
|
|
224
|
+
# If no tool_calls found, raise an error
|
|
225
|
+
raise ValueError("No tool_calls found in response")
|
|
@@ -0,0 +1,438 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates
|
|
2
|
+
"""Llama API model provider.
|
|
3
|
+
|
|
4
|
+
- Docs: https://llama.developer.meta.com/
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import base64
|
|
8
|
+
import json
|
|
9
|
+
import logging
|
|
10
|
+
import mimetypes
|
|
11
|
+
from typing import Any, AsyncGenerator, Optional, Type, TypeVar, Union, cast
|
|
12
|
+
|
|
13
|
+
import llama_api_client
|
|
14
|
+
from llama_api_client import LlamaAPIClient
|
|
15
|
+
from pydantic import BaseModel
|
|
16
|
+
from typing_extensions import TypedDict, Unpack, override
|
|
17
|
+
|
|
18
|
+
from ..types.content import ContentBlock, Messages
|
|
19
|
+
from ..types.exceptions import ModelThrottledException
|
|
20
|
+
from ..types.streaming import StreamEvent, Usage
|
|
21
|
+
from ..types.tools import ToolResult, ToolSpec, ToolUse
|
|
22
|
+
from .model import Model
|
|
23
|
+
|
|
24
|
+
logger = logging.getLogger(__name__)
|
|
25
|
+
|
|
26
|
+
T = TypeVar("T", bound=BaseModel)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class LlamaAPIModel(Model):
|
|
30
|
+
"""Llama API model provider implementation."""
|
|
31
|
+
|
|
32
|
+
class LlamaConfig(TypedDict, total=False):
|
|
33
|
+
"""Configuration options for Llama API models.
|
|
34
|
+
|
|
35
|
+
Attributes:
|
|
36
|
+
model_id: Model ID (e.g., "Llama-4-Maverick-17B-128E-Instruct-FP8").
|
|
37
|
+
repetition_penalty: Repetition penalty.
|
|
38
|
+
temperature: Temperature.
|
|
39
|
+
top_p: Top-p.
|
|
40
|
+
max_completion_tokens: Maximum completion tokens.
|
|
41
|
+
top_k: Top-k.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
model_id: str
|
|
45
|
+
repetition_penalty: Optional[float]
|
|
46
|
+
temperature: Optional[float]
|
|
47
|
+
top_p: Optional[float]
|
|
48
|
+
max_completion_tokens: Optional[int]
|
|
49
|
+
top_k: Optional[int]
|
|
50
|
+
|
|
51
|
+
def __init__(
|
|
52
|
+
self,
|
|
53
|
+
*,
|
|
54
|
+
client_args: Optional[dict[str, Any]] = None,
|
|
55
|
+
**model_config: Unpack[LlamaConfig],
|
|
56
|
+
) -> None:
|
|
57
|
+
"""Initialize provider instance.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
client_args: Arguments for the Llama API client.
|
|
61
|
+
**model_config: Configuration options for the Llama API model.
|
|
62
|
+
"""
|
|
63
|
+
self.config = LlamaAPIModel.LlamaConfig(**model_config)
|
|
64
|
+
logger.debug("config=<%s> | initializing", self.config)
|
|
65
|
+
|
|
66
|
+
if not client_args:
|
|
67
|
+
self.client = LlamaAPIClient()
|
|
68
|
+
else:
|
|
69
|
+
self.client = LlamaAPIClient(**client_args)
|
|
70
|
+
|
|
71
|
+
@override
|
|
72
|
+
def update_config(self, **model_config: Unpack[LlamaConfig]) -> None: # type: ignore
|
|
73
|
+
"""Update the Llama API Model configuration with the provided arguments.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
**model_config: Configuration overrides.
|
|
77
|
+
"""
|
|
78
|
+
self.config.update(model_config)
|
|
79
|
+
|
|
80
|
+
@override
|
|
81
|
+
def get_config(self) -> LlamaConfig:
|
|
82
|
+
"""Get the Llama API model configuration.
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
The Llama API model configuration.
|
|
86
|
+
"""
|
|
87
|
+
return self.config
|
|
88
|
+
|
|
89
|
+
def _format_request_message_content(self, content: ContentBlock) -> dict[str, Any]:
|
|
90
|
+
"""Format a LlamaAPI content block.
|
|
91
|
+
|
|
92
|
+
- NOTE: "reasoningContent" and "video" are not supported currently.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
content: Message content.
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
LllamaAPI formatted content block.
|
|
99
|
+
|
|
100
|
+
Raises:
|
|
101
|
+
TypeError: If the content block type cannot be converted to a LlamaAPI-compatible format.
|
|
102
|
+
"""
|
|
103
|
+
if "image" in content:
|
|
104
|
+
mime_type = mimetypes.types_map.get(f".{content['image']['format']}", "application/octet-stream")
|
|
105
|
+
image_data = base64.b64encode(content["image"]["source"]["bytes"]).decode("utf-8")
|
|
106
|
+
|
|
107
|
+
return {
|
|
108
|
+
"image_url": {
|
|
109
|
+
"url": f"data:{mime_type};base64,{image_data}",
|
|
110
|
+
},
|
|
111
|
+
"type": "image_url",
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
if "text" in content:
|
|
115
|
+
return {"text": content["text"], "type": "text"}
|
|
116
|
+
|
|
117
|
+
raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type")
|
|
118
|
+
|
|
119
|
+
def _format_request_message_tool_call(self, tool_use: ToolUse) -> dict[str, Any]:
|
|
120
|
+
"""Format a Llama API tool call.
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
tool_use: Tool use requested by the model.
|
|
124
|
+
|
|
125
|
+
Returns:
|
|
126
|
+
Llama API formatted tool call.
|
|
127
|
+
"""
|
|
128
|
+
return {
|
|
129
|
+
"function": {
|
|
130
|
+
"arguments": json.dumps(tool_use["input"]),
|
|
131
|
+
"name": tool_use["name"],
|
|
132
|
+
},
|
|
133
|
+
"id": tool_use["toolUseId"],
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
def _format_request_tool_message(self, tool_result: ToolResult) -> dict[str, Any]:
|
|
137
|
+
"""Format a Llama API tool message.
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
tool_result: Tool result collected from a tool execution.
|
|
141
|
+
|
|
142
|
+
Returns:
|
|
143
|
+
Llama API formatted tool message.
|
|
144
|
+
"""
|
|
145
|
+
contents = cast(
|
|
146
|
+
list[ContentBlock],
|
|
147
|
+
[
|
|
148
|
+
{"text": json.dumps(content["json"])} if "json" in content else content
|
|
149
|
+
for content in tool_result["content"]
|
|
150
|
+
],
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
return {
|
|
154
|
+
"role": "tool",
|
|
155
|
+
"tool_call_id": tool_result["toolUseId"],
|
|
156
|
+
"content": [self._format_request_message_content(content) for content in contents],
|
|
157
|
+
}
|
|
158
|
+
|
|
159
|
+
def _format_request_messages(self, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]:
|
|
160
|
+
"""Format a LlamaAPI compatible messages array.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
messages: List of message objects to be processed by the model.
|
|
164
|
+
system_prompt: System prompt to provide context to the model.
|
|
165
|
+
|
|
166
|
+
Returns:
|
|
167
|
+
An LlamaAPI compatible messages array.
|
|
168
|
+
"""
|
|
169
|
+
formatted_messages: list[dict[str, Any]]
|
|
170
|
+
formatted_messages = [{"role": "system", "content": system_prompt}] if system_prompt else []
|
|
171
|
+
|
|
172
|
+
for message in messages:
|
|
173
|
+
contents = message["content"]
|
|
174
|
+
|
|
175
|
+
formatted_contents: list[dict[str, Any]] | dict[str, Any] | str = ""
|
|
176
|
+
formatted_contents = [
|
|
177
|
+
self._format_request_message_content(content)
|
|
178
|
+
for content in contents
|
|
179
|
+
if not any(block_type in content for block_type in ["toolResult", "toolUse"])
|
|
180
|
+
]
|
|
181
|
+
formatted_tool_calls = [
|
|
182
|
+
self._format_request_message_tool_call(content["toolUse"])
|
|
183
|
+
for content in contents
|
|
184
|
+
if "toolUse" in content
|
|
185
|
+
]
|
|
186
|
+
formatted_tool_messages = [
|
|
187
|
+
self._format_request_tool_message(content["toolResult"])
|
|
188
|
+
for content in contents
|
|
189
|
+
if "toolResult" in content
|
|
190
|
+
]
|
|
191
|
+
|
|
192
|
+
if message["role"] == "assistant":
|
|
193
|
+
formatted_contents = formatted_contents[0] if formatted_contents else ""
|
|
194
|
+
|
|
195
|
+
formatted_message = {
|
|
196
|
+
"role": message["role"],
|
|
197
|
+
"content": formatted_contents if len(formatted_contents) > 0 else "",
|
|
198
|
+
**({"tool_calls": formatted_tool_calls} if formatted_tool_calls else {}),
|
|
199
|
+
}
|
|
200
|
+
formatted_messages.append(formatted_message)
|
|
201
|
+
formatted_messages.extend(formatted_tool_messages)
|
|
202
|
+
|
|
203
|
+
return [message for message in formatted_messages if message["content"] or "tool_calls" in message]
|
|
204
|
+
|
|
205
|
+
def format_request(
|
|
206
|
+
self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None
|
|
207
|
+
) -> dict[str, Any]:
|
|
208
|
+
"""Format a Llama API chat streaming request.
|
|
209
|
+
|
|
210
|
+
Args:
|
|
211
|
+
messages: List of message objects to be processed by the model.
|
|
212
|
+
tool_specs: List of tool specifications to make available to the model.
|
|
213
|
+
system_prompt: System prompt to provide context to the model.
|
|
214
|
+
|
|
215
|
+
Returns:
|
|
216
|
+
An Llama API chat streaming request.
|
|
217
|
+
|
|
218
|
+
Raises:
|
|
219
|
+
TypeError: If a message contains a content block type that cannot be converted to a LlamaAPI-compatible
|
|
220
|
+
format.
|
|
221
|
+
"""
|
|
222
|
+
request = {
|
|
223
|
+
"messages": self._format_request_messages(messages, system_prompt),
|
|
224
|
+
"model": self.config["model_id"],
|
|
225
|
+
"stream": True,
|
|
226
|
+
"tools": [
|
|
227
|
+
{
|
|
228
|
+
"type": "function",
|
|
229
|
+
"function": {
|
|
230
|
+
"name": tool_spec["name"],
|
|
231
|
+
"description": tool_spec["description"],
|
|
232
|
+
"parameters": tool_spec["inputSchema"]["json"],
|
|
233
|
+
},
|
|
234
|
+
}
|
|
235
|
+
for tool_spec in tool_specs or []
|
|
236
|
+
],
|
|
237
|
+
}
|
|
238
|
+
if "temperature" in self.config:
|
|
239
|
+
request["temperature"] = self.config["temperature"]
|
|
240
|
+
if "top_p" in self.config:
|
|
241
|
+
request["top_p"] = self.config["top_p"]
|
|
242
|
+
if "repetition_penalty" in self.config:
|
|
243
|
+
request["repetition_penalty"] = self.config["repetition_penalty"]
|
|
244
|
+
if "max_completion_tokens" in self.config:
|
|
245
|
+
request["max_completion_tokens"] = self.config["max_completion_tokens"]
|
|
246
|
+
if "top_k" in self.config:
|
|
247
|
+
request["top_k"] = self.config["top_k"]
|
|
248
|
+
|
|
249
|
+
return request
|
|
250
|
+
|
|
251
|
+
def format_chunk(self, event: dict[str, Any]) -> StreamEvent:
|
|
252
|
+
"""Format the Llama API model response events into standardized message chunks.
|
|
253
|
+
|
|
254
|
+
Args:
|
|
255
|
+
event: A response event from the model.
|
|
256
|
+
|
|
257
|
+
Returns:
|
|
258
|
+
The formatted chunk.
|
|
259
|
+
"""
|
|
260
|
+
match event["chunk_type"]:
|
|
261
|
+
case "message_start":
|
|
262
|
+
return {"messageStart": {"role": "assistant"}}
|
|
263
|
+
|
|
264
|
+
case "content_start":
|
|
265
|
+
if event["data_type"] == "text":
|
|
266
|
+
return {"contentBlockStart": {"start": {}}}
|
|
267
|
+
|
|
268
|
+
return {
|
|
269
|
+
"contentBlockStart": {
|
|
270
|
+
"start": {
|
|
271
|
+
"toolUse": {
|
|
272
|
+
"name": event["data"].function.name,
|
|
273
|
+
"toolUseId": event["data"].id,
|
|
274
|
+
}
|
|
275
|
+
}
|
|
276
|
+
}
|
|
277
|
+
}
|
|
278
|
+
|
|
279
|
+
case "content_delta":
|
|
280
|
+
if event["data_type"] == "text":
|
|
281
|
+
return {"contentBlockDelta": {"delta": {"text": event["data"]}}}
|
|
282
|
+
|
|
283
|
+
return {"contentBlockDelta": {"delta": {"toolUse": {"input": event["data"].function.arguments}}}}
|
|
284
|
+
|
|
285
|
+
case "content_stop":
|
|
286
|
+
return {"contentBlockStop": {}}
|
|
287
|
+
|
|
288
|
+
case "message_stop":
|
|
289
|
+
match event["data"]:
|
|
290
|
+
case "tool_calls":
|
|
291
|
+
return {"messageStop": {"stopReason": "tool_use"}}
|
|
292
|
+
case "length":
|
|
293
|
+
return {"messageStop": {"stopReason": "max_tokens"}}
|
|
294
|
+
case _:
|
|
295
|
+
return {"messageStop": {"stopReason": "end_turn"}}
|
|
296
|
+
|
|
297
|
+
case "metadata":
|
|
298
|
+
usage = {}
|
|
299
|
+
for metrics in event["data"]:
|
|
300
|
+
if metrics.metric == "num_prompt_tokens":
|
|
301
|
+
usage["inputTokens"] = metrics.value
|
|
302
|
+
elif metrics.metric == "num_completion_tokens":
|
|
303
|
+
usage["outputTokens"] = metrics.value
|
|
304
|
+
elif metrics.metric == "num_total_tokens":
|
|
305
|
+
usage["totalTokens"] = metrics.value
|
|
306
|
+
|
|
307
|
+
usage_type = Usage(
|
|
308
|
+
inputTokens=usage["inputTokens"],
|
|
309
|
+
outputTokens=usage["outputTokens"],
|
|
310
|
+
totalTokens=usage["totalTokens"],
|
|
311
|
+
)
|
|
312
|
+
return {
|
|
313
|
+
"metadata": {
|
|
314
|
+
"usage": usage_type,
|
|
315
|
+
"metrics": {
|
|
316
|
+
"latencyMs": 0, # TODO
|
|
317
|
+
},
|
|
318
|
+
},
|
|
319
|
+
}
|
|
320
|
+
|
|
321
|
+
case _:
|
|
322
|
+
raise RuntimeError(f"chunk_type=<{event['chunk_type']} | unknown type")
|
|
323
|
+
|
|
324
|
+
@override
|
|
325
|
+
async def stream(
|
|
326
|
+
self,
|
|
327
|
+
messages: Messages,
|
|
328
|
+
tool_specs: Optional[list[ToolSpec]] = None,
|
|
329
|
+
system_prompt: Optional[str] = None,
|
|
330
|
+
**kwargs: Any,
|
|
331
|
+
) -> AsyncGenerator[StreamEvent, None]:
|
|
332
|
+
"""Stream conversation with the LlamaAPI model.
|
|
333
|
+
|
|
334
|
+
Args:
|
|
335
|
+
messages: List of message objects to be processed by the model.
|
|
336
|
+
tool_specs: List of tool specifications to make available to the model.
|
|
337
|
+
system_prompt: System prompt to provide context to the model.
|
|
338
|
+
**kwargs: Additional keyword arguments for future extensibility.
|
|
339
|
+
|
|
340
|
+
Yields:
|
|
341
|
+
Formatted message chunks from the model.
|
|
342
|
+
|
|
343
|
+
Raises:
|
|
344
|
+
ModelThrottledException: When the model service is throttling requests from the client.
|
|
345
|
+
"""
|
|
346
|
+
logger.debug("formatting request")
|
|
347
|
+
request = self.format_request(messages, tool_specs, system_prompt)
|
|
348
|
+
logger.debug("request=<%s>", request)
|
|
349
|
+
|
|
350
|
+
logger.debug("invoking model")
|
|
351
|
+
try:
|
|
352
|
+
response = self.client.chat.completions.create(**request)
|
|
353
|
+
except llama_api_client.RateLimitError as e:
|
|
354
|
+
raise ModelThrottledException(str(e)) from e
|
|
355
|
+
|
|
356
|
+
logger.debug("got response from model")
|
|
357
|
+
yield self.format_chunk({"chunk_type": "message_start"})
|
|
358
|
+
|
|
359
|
+
stop_reason = None
|
|
360
|
+
tool_calls: dict[Any, list[Any]] = {}
|
|
361
|
+
curr_tool_call_id = None
|
|
362
|
+
|
|
363
|
+
metrics_event = None
|
|
364
|
+
for chunk in response:
|
|
365
|
+
if chunk.event.event_type == "start":
|
|
366
|
+
yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"})
|
|
367
|
+
elif chunk.event.event_type in ["progress", "complete"] and chunk.event.delta.type == "text":
|
|
368
|
+
yield self.format_chunk(
|
|
369
|
+
{"chunk_type": "content_delta", "data_type": "text", "data": chunk.event.delta.text}
|
|
370
|
+
)
|
|
371
|
+
else:
|
|
372
|
+
if chunk.event.delta.type == "tool_call":
|
|
373
|
+
if chunk.event.delta.id:
|
|
374
|
+
curr_tool_call_id = chunk.event.delta.id
|
|
375
|
+
|
|
376
|
+
if curr_tool_call_id not in tool_calls:
|
|
377
|
+
tool_calls[curr_tool_call_id] = []
|
|
378
|
+
tool_calls[curr_tool_call_id].append(chunk.event.delta)
|
|
379
|
+
elif chunk.event.event_type == "metrics":
|
|
380
|
+
metrics_event = chunk.event.metrics
|
|
381
|
+
else:
|
|
382
|
+
yield self.format_chunk(chunk)
|
|
383
|
+
|
|
384
|
+
if stop_reason is None:
|
|
385
|
+
stop_reason = chunk.event.stop_reason
|
|
386
|
+
|
|
387
|
+
# stopped generation
|
|
388
|
+
if stop_reason:
|
|
389
|
+
yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"})
|
|
390
|
+
|
|
391
|
+
for tool_deltas in tool_calls.values():
|
|
392
|
+
tool_start, tool_deltas = tool_deltas[0], tool_deltas[1:]
|
|
393
|
+
yield self.format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": tool_start})
|
|
394
|
+
|
|
395
|
+
for tool_delta in tool_deltas:
|
|
396
|
+
yield self.format_chunk({"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta})
|
|
397
|
+
|
|
398
|
+
yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"})
|
|
399
|
+
|
|
400
|
+
yield self.format_chunk({"chunk_type": "message_stop", "data": stop_reason})
|
|
401
|
+
|
|
402
|
+
# we may have a metrics event here
|
|
403
|
+
if metrics_event:
|
|
404
|
+
yield self.format_chunk({"chunk_type": "metadata", "data": metrics_event})
|
|
405
|
+
|
|
406
|
+
logger.debug("finished streaming response from model")
|
|
407
|
+
|
|
408
|
+
@override
|
|
409
|
+
def structured_output(
|
|
410
|
+
self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any
|
|
411
|
+
) -> AsyncGenerator[dict[str, Union[T, Any]], None]:
|
|
412
|
+
"""Get structured output from the model.
|
|
413
|
+
|
|
414
|
+
Args:
|
|
415
|
+
output_model: The output model to use for the agent.
|
|
416
|
+
prompt: The prompt messages to use for the agent.
|
|
417
|
+
system_prompt: System prompt to provide context to the model.
|
|
418
|
+
**kwargs: Additional keyword arguments for future extensibility.
|
|
419
|
+
|
|
420
|
+
Yields:
|
|
421
|
+
Model events with the last being the structured output.
|
|
422
|
+
|
|
423
|
+
Raises:
|
|
424
|
+
NotImplementedError: Structured output is not currently supported for LlamaAPI models.
|
|
425
|
+
"""
|
|
426
|
+
# response_format: ResponseFormat = {
|
|
427
|
+
# "type": "json_schema",
|
|
428
|
+
# "json_schema": {
|
|
429
|
+
# "name": output_model.__name__,
|
|
430
|
+
# "schema": output_model.model_json_schema(),
|
|
431
|
+
# },
|
|
432
|
+
# }
|
|
433
|
+
# response = self.client.chat.completions.create(
|
|
434
|
+
# model=self.config["model_id"],
|
|
435
|
+
# messages=self.format_request(prompt)["messages"],
|
|
436
|
+
# response_format=response_format,
|
|
437
|
+
# )
|
|
438
|
+
raise NotImplementedError("Strands sdk-python does not implement this in the Llama API Preview.")
|