argus-cloud-optimizer 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- adapters/__init__.py +0 -0
- adapters/aws/__init__.py +0 -0
- adapters/aws/adapter.py +85 -0
- adapters/aws/auth.py +57 -0
- adapters/aws/cloudtrail.py +83 -0
- adapters/aws/cloudwatch.py +732 -0
- adapters/aws/config.py +9 -0
- adapters/aws/cost_explorer.py +116 -0
- adapters/aws/resource_explorer.py +186 -0
- adapters/aws/retry.py +55 -0
- adapters/azure/__init__.py +0 -0
- adapters/azure/activity_log.py +159 -0
- adapters/azure/adapter.py +117 -0
- adapters/azure/cost_management.py +125 -0
- adapters/azure/monitor.py +311 -0
- adapters/azure/resource_graph.py +113 -0
- adapters/azure/retry.py +57 -0
- adapters/base.py +105 -0
- adapters/gcp/__init__.py +0 -0
- adapters/gcp/adapter.py +86 -0
- adapters/gcp/asset_inventory.py +116 -0
- adapters/gcp/billing.py +118 -0
- adapters/gcp/cloud_logging.py +93 -0
- adapters/gcp/cloud_monitoring.py +276 -0
- adapters/gcp/retry.py +46 -0
- ai/__init__.py +0 -0
- ai/anthropic.py +174 -0
- ai/azure_openai.py +241 -0
- ai/base.py +78 -0
- ai/bedrock.py +169 -0
- ai/vertexai.py +234 -0
- argus_cloud_optimizer-0.2.0.dist-info/METADATA +433 -0
- argus_cloud_optimizer-0.2.0.dist-info/RECORD +62 -0
- argus_cloud_optimizer-0.2.0.dist-info/WHEEL +5 -0
- argus_cloud_optimizer-0.2.0.dist-info/entry_points.txt +2 -0
- argus_cloud_optimizer-0.2.0.dist-info/licenses/LICENSE +21 -0
- argus_cloud_optimizer-0.2.0.dist-info/top_level.txt +4 -0
- core/__init__.py +0 -0
- core/__version__.py +1 -0
- core/agent/__init__.py +0 -0
- core/agent/loop.py +390 -0
- core/agent/prompts.py +317 -0
- core/config.py +235 -0
- core/log.py +69 -0
- core/models/__init__.py +0 -0
- core/models/finding.py +76 -0
- core/py.typed +0 -0
- core/reports/__init__.py +0 -0
- core/reports/comparison.py +49 -0
- core/reports/delivery.py +323 -0
- core/reports/export.py +111 -0
- core/reports/generator.py +168 -0
- core/reports/html.py +286 -0
- core/reports/multi_cloud.py +162 -0
- core/secrets.py +145 -0
- core/token_tracker.py +97 -0
- core/validation.py +214 -0
- entrypoints/__init__.py +0 -0
- entrypoints/aws_lambda.py +299 -0
- entrypoints/azure_function.py +257 -0
- entrypoints/cli.py +156 -0
- entrypoints/gcp_cloudrun.py +209 -0
ai/azure_openai.py
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
1
|
+
"""
|
|
2
|
+
AI provider backed by Azure OpenAI (GPT-4o).
|
|
3
|
+
|
|
4
|
+
Authentication uses DefaultAzureCredential — run:
|
|
5
|
+
az login
|
|
6
|
+
|
|
7
|
+
No API key needed when running on Azure Functions with a managed identity.
|
|
8
|
+
|
|
9
|
+
Environment variables:
|
|
10
|
+
AZURE_OPENAI_ENDPOINT Azure OpenAI resource endpoint (required)
|
|
11
|
+
e.g. https://my-resource.openai.azure.com/
|
|
12
|
+
AZURE_OPENAI_DEPLOYMENT Deployment name (default: gpt-4o)
|
|
13
|
+
AZURE_OPENAI_API_VERSION API version (default: 2024-10-21)
|
|
14
|
+
AZURE_OPENAI_API_KEY Optional — use only for local dev without az login
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
import json
|
|
20
|
+
import time
|
|
21
|
+
from typing import Any
|
|
22
|
+
|
|
23
|
+
import openai
|
|
24
|
+
import structlog
|
|
25
|
+
|
|
26
|
+
from ai.base import AIProvider, AIResponse, Message, Tool, ToolCall
|
|
27
|
+
|
|
28
|
+
logger = structlog.get_logger(__name__)
|
|
29
|
+
|
|
30
|
+
MAX_RETRIES = 3
|
|
31
|
+
_BASE_DELAY = 1.0
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class AzureOpenAIProvider(AIProvider):
|
|
35
|
+
"""
|
|
36
|
+
AI provider backed by Azure OpenAI GPT-4o.
|
|
37
|
+
Uses DefaultAzureCredential (managed identity / az login) — no API key needed
|
|
38
|
+
when running on Azure infrastructure.
|
|
39
|
+
|
|
40
|
+
Falls back to AZURE_OPENAI_API_KEY for local dev without az login.
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
DEFAULT_DEPLOYMENT = "gpt-4o"
|
|
44
|
+
DEFAULT_API_VERSION = "2024-10-21"
|
|
45
|
+
DEFAULT_MAX_TOKENS = 4096
|
|
46
|
+
DEFAULT_TEMPERATURE = 0.0
|
|
47
|
+
|
|
48
|
+
def __init__(
|
|
49
|
+
self,
|
|
50
|
+
endpoint: str | None = None,
|
|
51
|
+
deployment: str | None = None,
|
|
52
|
+
api_version: str | None = None,
|
|
53
|
+
max_tokens: int = DEFAULT_MAX_TOKENS,
|
|
54
|
+
api_key: str | None = None,
|
|
55
|
+
temperature: float | None = None,
|
|
56
|
+
) -> None:
|
|
57
|
+
from core.config import get_settings
|
|
58
|
+
|
|
59
|
+
cfg = get_settings().ai
|
|
60
|
+
self._endpoint = endpoint or cfg.azure_openai_endpoint
|
|
61
|
+
if not self._endpoint:
|
|
62
|
+
raise EnvironmentError(
|
|
63
|
+
"AZURE_OPENAI_ENDPOINT is not set. "
|
|
64
|
+
"Set it in .env or pass endpoint= explicitly. "
|
|
65
|
+
"Example: https://my-resource.openai.azure.com/"
|
|
66
|
+
)
|
|
67
|
+
self._deployment = deployment or cfg.resolved_model("azure_openai")
|
|
68
|
+
self._api_version = api_version or cfg.azure_openai_api_version
|
|
69
|
+
self._max_tokens = max_tokens
|
|
70
|
+
self._temperature = temperature if temperature is not None else cfg.temperature
|
|
71
|
+
|
|
72
|
+
resolved_key = api_key or cfg.azure_openai_api_key
|
|
73
|
+
if resolved_key:
|
|
74
|
+
self._client = openai.AzureOpenAI(
|
|
75
|
+
azure_endpoint=self._endpoint,
|
|
76
|
+
api_key=resolved_key,
|
|
77
|
+
api_version=self._api_version,
|
|
78
|
+
)
|
|
79
|
+
self._credential = None
|
|
80
|
+
else:
|
|
81
|
+
# DefaultAzureCredential: works with managed identity, az login, env vars
|
|
82
|
+
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
|
|
83
|
+
|
|
84
|
+
credential = DefaultAzureCredential()
|
|
85
|
+
token_provider = get_bearer_token_provider(
|
|
86
|
+
credential,
|
|
87
|
+
"https://cognitiveservices.azure.com/.default",
|
|
88
|
+
)
|
|
89
|
+
self._client = openai.AzureOpenAI(
|
|
90
|
+
azure_endpoint=self._endpoint,
|
|
91
|
+
azure_ad_token_provider=token_provider,
|
|
92
|
+
api_version=self._api_version,
|
|
93
|
+
)
|
|
94
|
+
self._credential = credential
|
|
95
|
+
|
|
96
|
+
@classmethod
|
|
97
|
+
def from_env(cls) -> "AzureOpenAIProvider":
|
|
98
|
+
return cls()
|
|
99
|
+
|
|
100
|
+
def chat(
|
|
101
|
+
self,
|
|
102
|
+
messages: list[Message],
|
|
103
|
+
tools: list[Tool],
|
|
104
|
+
system_prompt: str | None = None,
|
|
105
|
+
) -> AIResponse:
|
|
106
|
+
openai_messages = self._build_messages(messages, system_prompt)
|
|
107
|
+
openai_tools = [self._to_openai_tool(t) for t in tools] if tools else None
|
|
108
|
+
|
|
109
|
+
kwargs: dict[str, Any] = {
|
|
110
|
+
"model": self._deployment,
|
|
111
|
+
"messages": openai_messages,
|
|
112
|
+
"max_tokens": self._max_tokens,
|
|
113
|
+
"temperature": self._temperature,
|
|
114
|
+
}
|
|
115
|
+
if openai_tools:
|
|
116
|
+
kwargs["tools"] = openai_tools
|
|
117
|
+
kwargs["tool_choice"] = "auto"
|
|
118
|
+
|
|
119
|
+
response = self._call_with_retry(kwargs)
|
|
120
|
+
return self._parse_response(response)
|
|
121
|
+
|
|
122
|
+
# ------------------------------------------------------------------
|
|
123
|
+
# Internal helpers
|
|
124
|
+
# ------------------------------------------------------------------
|
|
125
|
+
|
|
126
|
+
def _call_with_retry(self, kwargs: dict[str, Any]) -> Any:
|
|
127
|
+
delay = _BASE_DELAY
|
|
128
|
+
for attempt in range(MAX_RETRIES):
|
|
129
|
+
try:
|
|
130
|
+
return self._client.chat.completions.create(**kwargs)
|
|
131
|
+
except openai.RateLimitError:
|
|
132
|
+
if attempt < MAX_RETRIES - 1:
|
|
133
|
+
logger.warning(
|
|
134
|
+
"Azure OpenAI rate limited (attempt %d/%d), retrying in %.1fs",
|
|
135
|
+
attempt + 1,
|
|
136
|
+
MAX_RETRIES,
|
|
137
|
+
delay,
|
|
138
|
+
)
|
|
139
|
+
time.sleep(delay)
|
|
140
|
+
delay *= 2
|
|
141
|
+
else:
|
|
142
|
+
raise
|
|
143
|
+
except openai.AuthenticationError as exc:
|
|
144
|
+
raise EnvironmentError(
|
|
145
|
+
"Azure OpenAI authentication failed. "
|
|
146
|
+
"Run 'az login' or set AZURE_OPENAI_API_KEY."
|
|
147
|
+
) from exc
|
|
148
|
+
raise RuntimeError("Unreachable") # pragma: no cover
|
|
149
|
+
|
|
150
|
+
def _build_messages(
|
|
151
|
+
self,
|
|
152
|
+
messages: list[Message],
|
|
153
|
+
system_prompt: str | None,
|
|
154
|
+
) -> list[dict[str, Any]]:
|
|
155
|
+
result: list[dict[str, Any]] = []
|
|
156
|
+
|
|
157
|
+
if system_prompt:
|
|
158
|
+
result.append({"role": "system", "content": system_prompt})
|
|
159
|
+
|
|
160
|
+
for msg in messages:
|
|
161
|
+
if msg.role == "user":
|
|
162
|
+
if msg.tool_results:
|
|
163
|
+
for tr in msg.tool_results:
|
|
164
|
+
result.append(
|
|
165
|
+
{
|
|
166
|
+
"role": "tool",
|
|
167
|
+
"tool_call_id": tr.tool_call_id,
|
|
168
|
+
"content": tr.content,
|
|
169
|
+
}
|
|
170
|
+
)
|
|
171
|
+
else:
|
|
172
|
+
result.append({"role": "user", "content": msg.text or ""})
|
|
173
|
+
|
|
174
|
+
else:
|
|
175
|
+
# assistant
|
|
176
|
+
tool_calls_out = []
|
|
177
|
+
for tc in msg.tool_calls:
|
|
178
|
+
tool_calls_out.append(
|
|
179
|
+
{
|
|
180
|
+
"id": tc.id,
|
|
181
|
+
"type": "function",
|
|
182
|
+
"function": {
|
|
183
|
+
"name": tc.name,
|
|
184
|
+
"arguments": json.dumps(tc.arguments),
|
|
185
|
+
},
|
|
186
|
+
}
|
|
187
|
+
)
|
|
188
|
+
assistant_msg: dict[str, Any] = {
|
|
189
|
+
"role": "assistant",
|
|
190
|
+
"content": msg.text or "",
|
|
191
|
+
}
|
|
192
|
+
if tool_calls_out:
|
|
193
|
+
assistant_msg["tool_calls"] = tool_calls_out
|
|
194
|
+
result.append(assistant_msg)
|
|
195
|
+
|
|
196
|
+
return result
|
|
197
|
+
|
|
198
|
+
def _to_openai_tool(self, tool: Tool) -> dict[str, Any]:
|
|
199
|
+
return {
|
|
200
|
+
"type": "function",
|
|
201
|
+
"function": {
|
|
202
|
+
"name": tool.name,
|
|
203
|
+
"description": tool.description,
|
|
204
|
+
"parameters": tool.input_schema,
|
|
205
|
+
},
|
|
206
|
+
}
|
|
207
|
+
|
|
208
|
+
def _parse_response(self, response: Any) -> AIResponse:
|
|
209
|
+
choice = response.choices[0]
|
|
210
|
+
message = choice.message
|
|
211
|
+
stop_reason = choice.finish_reason # "stop" | "tool_calls" | "length"
|
|
212
|
+
|
|
213
|
+
text: str | None = message.content or None
|
|
214
|
+
tool_calls: list[ToolCall] = []
|
|
215
|
+
|
|
216
|
+
if message.tool_calls:
|
|
217
|
+
for tc in message.tool_calls:
|
|
218
|
+
tool_calls.append(
|
|
219
|
+
ToolCall(
|
|
220
|
+
id=tc.id,
|
|
221
|
+
name=tc.function.name,
|
|
222
|
+
arguments=json.loads(tc.function.arguments),
|
|
223
|
+
)
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
# Normalise to internal vocabulary
|
|
227
|
+
if tool_calls:
|
|
228
|
+
stop_reason = "tool_use"
|
|
229
|
+
elif stop_reason == "stop":
|
|
230
|
+
stop_reason = "end_turn"
|
|
231
|
+
elif stop_reason == "length":
|
|
232
|
+
stop_reason = "max_tokens"
|
|
233
|
+
|
|
234
|
+
usage = getattr(response, "usage", None)
|
|
235
|
+
return AIResponse(
|
|
236
|
+
stop_reason=stop_reason,
|
|
237
|
+
text=text,
|
|
238
|
+
tool_calls=tool_calls,
|
|
239
|
+
input_tokens=getattr(usage, "prompt_tokens", 0) if usage else 0,
|
|
240
|
+
output_tokens=getattr(usage, "completion_tokens", 0) if usage else 0,
|
|
241
|
+
)
|
ai/base.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class ToolCall:
|
|
10
|
+
"""A tool invocation requested by the AI."""
|
|
11
|
+
|
|
12
|
+
id: str
|
|
13
|
+
name: str
|
|
14
|
+
arguments: dict[str, Any]
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class ToolResult:
|
|
19
|
+
"""The result of executing a tool call."""
|
|
20
|
+
|
|
21
|
+
tool_call_id: str
|
|
22
|
+
content: str
|
|
23
|
+
is_error: bool = False
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass
|
|
27
|
+
class Message:
|
|
28
|
+
"""
|
|
29
|
+
A single turn in the agent conversation.
|
|
30
|
+
role: "user" | "assistant"
|
|
31
|
+
Exactly one of text, tool_calls, or tool_results will be populated.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
role: str
|
|
35
|
+
text: str | None = None
|
|
36
|
+
tool_calls: list[ToolCall] = field(default_factory=list)
|
|
37
|
+
tool_results: list[ToolResult] = field(default_factory=list)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@dataclass
|
|
41
|
+
class Tool:
|
|
42
|
+
"""Definition of a tool the AI can call."""
|
|
43
|
+
|
|
44
|
+
name: str
|
|
45
|
+
description: str
|
|
46
|
+
input_schema: dict[str, Any]
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@dataclass
|
|
50
|
+
class AIResponse:
|
|
51
|
+
"""Parsed response from an AI provider."""
|
|
52
|
+
|
|
53
|
+
stop_reason: str # "tool_use" | "end_turn" | "max_tokens"
|
|
54
|
+
text: str | None
|
|
55
|
+
tool_calls: list[ToolCall]
|
|
56
|
+
input_tokens: int = 0
|
|
57
|
+
output_tokens: int = 0
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class AIProvider(ABC):
|
|
61
|
+
"""
|
|
62
|
+
Abstract AI provider. One implementation per model family.
|
|
63
|
+
The agent loop only ever calls chat() — never raw SDK methods.
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
@abstractmethod
|
|
67
|
+
def chat(
|
|
68
|
+
self,
|
|
69
|
+
messages: list[Message],
|
|
70
|
+
tools: list[Tool],
|
|
71
|
+
system_prompt: str | None = None,
|
|
72
|
+
) -> AIResponse:
|
|
73
|
+
"""
|
|
74
|
+
Send the conversation to the AI and get a response.
|
|
75
|
+
system_prompt is passed separately so each provider can handle it
|
|
76
|
+
in the way their API expects (e.g. Anthropic has a dedicated system param).
|
|
77
|
+
"""
|
|
78
|
+
...
|
ai/bedrock.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import time
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
import boto3
|
|
7
|
+
import structlog
|
|
8
|
+
from botocore.exceptions import ClientError
|
|
9
|
+
|
|
10
|
+
from ai.base import AIProvider, AIResponse, Message, Tool, ToolCall
|
|
11
|
+
|
|
12
|
+
logger = structlog.get_logger(__name__)
|
|
13
|
+
|
|
14
|
+
MAX_RETRIES = 3
|
|
15
|
+
_BASE_DELAY = 1.0 # seconds; doubles each retry
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class BedrockProvider(AIProvider):
|
|
19
|
+
"""
|
|
20
|
+
AI provider backed by Amazon Bedrock Converse API.
|
|
21
|
+
Uses the execution role when running inside Lambda — no API keys needed.
|
|
22
|
+
Configure via BEDROCK_MODEL_ID and BEDROCK_REGION env vars.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
DEFAULT_MODEL = "anthropic.claude-sonnet-4-6"
|
|
26
|
+
DEFAULT_REGION = "us-east-1"
|
|
27
|
+
DEFAULT_MAX_TOKENS = 4096
|
|
28
|
+
DEFAULT_TEMPERATURE = 0.0
|
|
29
|
+
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
model_id: str | None = None,
|
|
33
|
+
region: str | None = None,
|
|
34
|
+
max_tokens: int = DEFAULT_MAX_TOKENS,
|
|
35
|
+
temperature: float | None = None,
|
|
36
|
+
session: Any = None,
|
|
37
|
+
) -> None:
|
|
38
|
+
from core.config import get_settings
|
|
39
|
+
|
|
40
|
+
cfg = get_settings().ai
|
|
41
|
+
self._model_id = model_id or cfg.resolved_model("bedrock")
|
|
42
|
+
resolved_region = region or cfg.bedrock_region
|
|
43
|
+
self._max_tokens = max_tokens
|
|
44
|
+
self._temperature = temperature if temperature is not None else cfg.temperature
|
|
45
|
+
|
|
46
|
+
boto_session = session or boto3.Session(region_name=resolved_region)
|
|
47
|
+
self._client = boto_session.client(
|
|
48
|
+
"bedrock-runtime", region_name=resolved_region
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
def chat(
|
|
52
|
+
self,
|
|
53
|
+
messages: list[Message],
|
|
54
|
+
tools: list[Tool],
|
|
55
|
+
system_prompt: str | None = None,
|
|
56
|
+
) -> AIResponse:
|
|
57
|
+
kwargs: dict[str, Any] = {
|
|
58
|
+
"modelId": self._model_id,
|
|
59
|
+
"messages": [self._to_bedrock_message(m) for m in messages],
|
|
60
|
+
"inferenceConfig": {
|
|
61
|
+
"maxTokens": self._max_tokens,
|
|
62
|
+
"temperature": self._temperature,
|
|
63
|
+
},
|
|
64
|
+
}
|
|
65
|
+
if system_prompt:
|
|
66
|
+
kwargs["system"] = [{"text": system_prompt}]
|
|
67
|
+
if tools:
|
|
68
|
+
kwargs["toolConfig"] = {"tools": [self._to_bedrock_tool(t) for t in tools]}
|
|
69
|
+
|
|
70
|
+
response = self._call_with_retry(kwargs)
|
|
71
|
+
return self._parse_response(response)
|
|
72
|
+
|
|
73
|
+
# ------------------------------------------------------------------
|
|
74
|
+
# Internal helpers
|
|
75
|
+
# ------------------------------------------------------------------
|
|
76
|
+
|
|
77
|
+
def _call_with_retry(self, kwargs: dict[str, Any]) -> Any:
|
|
78
|
+
delay = _BASE_DELAY
|
|
79
|
+
for attempt in range(MAX_RETRIES):
|
|
80
|
+
try:
|
|
81
|
+
return self._client.converse(**kwargs)
|
|
82
|
+
except ClientError as exc:
|
|
83
|
+
code = exc.response["Error"]["Code"]
|
|
84
|
+
if code == "ThrottlingException" and attempt < MAX_RETRIES - 1:
|
|
85
|
+
logger.warning(
|
|
86
|
+
"Bedrock throttled (attempt %d/%d), retrying in %.1fs",
|
|
87
|
+
attempt + 1,
|
|
88
|
+
MAX_RETRIES,
|
|
89
|
+
delay,
|
|
90
|
+
)
|
|
91
|
+
time.sleep(delay)
|
|
92
|
+
delay *= 2
|
|
93
|
+
else:
|
|
94
|
+
raise
|
|
95
|
+
raise RuntimeError("Unreachable") # pragma: no cover
|
|
96
|
+
|
|
97
|
+
def _to_bedrock_message(self, msg: Message) -> dict[str, Any]:
|
|
98
|
+
if msg.role == "user":
|
|
99
|
+
if msg.tool_results:
|
|
100
|
+
return {
|
|
101
|
+
"role": "user",
|
|
102
|
+
"content": [
|
|
103
|
+
{
|
|
104
|
+
"toolResult": {
|
|
105
|
+
"toolUseId": tr.tool_call_id,
|
|
106
|
+
"content": [{"text": tr.content}],
|
|
107
|
+
**({"status": "error"} if tr.is_error else {}),
|
|
108
|
+
}
|
|
109
|
+
}
|
|
110
|
+
for tr in msg.tool_results
|
|
111
|
+
],
|
|
112
|
+
}
|
|
113
|
+
return {"role": "user", "content": [{"text": msg.text or ""}]}
|
|
114
|
+
|
|
115
|
+
# assistant
|
|
116
|
+
content: list[dict[str, Any]] = []
|
|
117
|
+
if msg.text:
|
|
118
|
+
content.append({"text": msg.text})
|
|
119
|
+
for tc in msg.tool_calls:
|
|
120
|
+
content.append(
|
|
121
|
+
{
|
|
122
|
+
"toolUse": {
|
|
123
|
+
"toolUseId": tc.id,
|
|
124
|
+
"name": tc.name,
|
|
125
|
+
"input": tc.arguments,
|
|
126
|
+
}
|
|
127
|
+
}
|
|
128
|
+
)
|
|
129
|
+
return {"role": "assistant", "content": content}
|
|
130
|
+
|
|
131
|
+
def _to_bedrock_tool(self, tool: Tool) -> dict[str, Any]:
|
|
132
|
+
return {
|
|
133
|
+
"toolSpec": {
|
|
134
|
+
"name": tool.name,
|
|
135
|
+
"description": tool.description,
|
|
136
|
+
"inputSchema": {"json": tool.input_schema},
|
|
137
|
+
}
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
def _parse_response(self, response: dict[str, Any]) -> AIResponse:
|
|
141
|
+
content_blocks: list[dict[str, Any]] = (
|
|
142
|
+
response.get("output", {}).get("message", {}).get("content", [])
|
|
143
|
+
)
|
|
144
|
+
stop_reason: str = response.get("stopReason", "end_turn")
|
|
145
|
+
|
|
146
|
+
tool_calls: list[ToolCall] = []
|
|
147
|
+
text: str | None = None
|
|
148
|
+
|
|
149
|
+
for block in content_blocks:
|
|
150
|
+
if "toolUse" in block:
|
|
151
|
+
tu = block["toolUse"]
|
|
152
|
+
tool_calls.append(
|
|
153
|
+
ToolCall(
|
|
154
|
+
id=tu["toolUseId"],
|
|
155
|
+
name=tu["name"],
|
|
156
|
+
arguments=dict(tu.get("input", {})),
|
|
157
|
+
)
|
|
158
|
+
)
|
|
159
|
+
elif "text" in block:
|
|
160
|
+
text = block["text"]
|
|
161
|
+
|
|
162
|
+
usage = response.get("usage", {})
|
|
163
|
+
return AIResponse(
|
|
164
|
+
stop_reason=stop_reason,
|
|
165
|
+
text=text,
|
|
166
|
+
tool_calls=tool_calls,
|
|
167
|
+
input_tokens=usage.get("inputTokens", 0),
|
|
168
|
+
output_tokens=usage.get("outputTokens", 0),
|
|
169
|
+
)
|