quraite 0.0.2__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quraite/__init__.py +3 -3
- quraite/adapters/__init__.py +134 -134
- quraite/adapters/agno_adapter.py +159 -159
- quraite/adapters/base.py +123 -123
- quraite/adapters/bedrock_agents_adapter.py +343 -343
- quraite/adapters/flowise_adapter.py +275 -275
- quraite/adapters/google_adk_adapter.py +209 -209
- quraite/adapters/http_adapter.py +239 -239
- quraite/adapters/langflow_adapter.py +192 -192
- quraite/adapters/langgraph_adapter.py +304 -304
- quraite/adapters/langgraph_server_adapter.py +252 -252
- quraite/adapters/n8n_adapter.py +220 -220
- quraite/adapters/openai_agents_adapter.py +269 -269
- quraite/adapters/pydantic_ai_adapter.py +312 -312
- quraite/adapters/smolagents_adapter.py +152 -152
- quraite/logger.py +61 -64
- quraite/schema/message.py +91 -54
- quraite/schema/response.py +16 -16
- quraite/serve/__init__.py +1 -1
- quraite/serve/cloudflared.py +210 -210
- quraite/serve/local_agent.py +360 -360
- quraite/tracing/__init__.py +24 -24
- quraite/tracing/constants.py +16 -16
- quraite/tracing/span_exporter.py +115 -115
- quraite/tracing/span_processor.py +49 -49
- quraite/tracing/tool_extractors.py +290 -290
- quraite/tracing/trace.py +564 -494
- quraite/tracing/types.py +179 -179
- quraite/tracing/utils.py +170 -170
- quraite/utils/json_utils.py +269 -269
- {quraite-0.0.2.dist-info → quraite-0.1.0.dist-info}/METADATA +9 -9
- quraite-0.1.0.dist-info/RECORD +35 -0
- {quraite-0.0.2.dist-info → quraite-0.1.0.dist-info}/WHEEL +1 -1
- quraite/traces/traces_adk_openinference.json +0 -379
- quraite/traces/traces_agno_multi_agent.json +0 -669
- quraite/traces/traces_agno_openinference.json +0 -321
- quraite/traces/traces_crewai_openinference.json +0 -155
- quraite/traces/traces_langgraph_openinference.json +0 -349
- quraite/traces/traces_langgraph_openinference_multi_agent.json +0 -2705
- quraite/traces/traces_langgraph_traceloop.json +0 -510
- quraite/traces/traces_openai_agents_multi_agent_1.json +0 -402
- quraite/traces/traces_openai_agents_openinference.json +0 -341
- quraite/traces/traces_pydantic_openinference.json +0 -286
- quraite/traces/traces_pydantic_openinference_multi_agent_1.json +0 -399
- quraite/traces/traces_pydantic_openinference_multi_agent_2.json +0 -398
- quraite/traces/traces_smol_agents_openinference.json +0 -397
- quraite/traces/traces_smol_agents_tool_calling_openinference.json +0 -704
- quraite-0.0.2.dist-info/RECORD +0 -49
|
@@ -1,343 +1,343 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Bedrock Agents Adapter
|
|
3
|
-
https://docs.aws.amazon.com/bedrock/latest/userguide/trace-events.html
|
|
4
|
-
"""
|
|
5
|
-
|
|
6
|
-
import asyncio
|
|
7
|
-
import json
|
|
8
|
-
import os
|
|
9
|
-
import uuid
|
|
10
|
-
from typing import Any, Dict, List, Optional, Union
|
|
11
|
-
|
|
12
|
-
import boto3
|
|
13
|
-
from botocore.exceptions import ClientError
|
|
14
|
-
|
|
15
|
-
from quraite.adapters.base import BaseAdapter
|
|
16
|
-
from quraite.logger import get_logger
|
|
17
|
-
from quraite.schema.message import (
|
|
18
|
-
AgentMessage,
|
|
19
|
-
AssistantMessage,
|
|
20
|
-
MessageContentText,
|
|
21
|
-
ToolCall,
|
|
22
|
-
ToolMessage,
|
|
23
|
-
)
|
|
24
|
-
from quraite.schema.response import AgentInvocationResponse
|
|
25
|
-
|
|
26
|
-
logger = get_logger(__name__)
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
class BedrockAgentsAdapter(BaseAdapter):
|
|
30
|
-
"""
|
|
31
|
-
Bedrock Agents adapter wrapper that converts AWS Bedrock agent
|
|
32
|
-
to a standardized callable interface (invoke) and converts the output to List[AgentMessage].
|
|
33
|
-
|
|
34
|
-
This class wraps any Bedrock Agent and provides:
|
|
35
|
-
- Synchronous invocation via invoke()
|
|
36
|
-
- Automatic conversion to List[AgentMessage] format
|
|
37
|
-
"""
|
|
38
|
-
|
|
39
|
-
def __init__(
|
|
40
|
-
self,
|
|
41
|
-
aws_access_key_id: Optional[str] = None,
|
|
42
|
-
aws_secret_access_key: Optional[str] = None,
|
|
43
|
-
aws_session_token: Optional[str] = None,
|
|
44
|
-
agent_id: Optional[str] = None,
|
|
45
|
-
agent_alias_id: Optional[str] = None,
|
|
46
|
-
region_name: Optional[str] = None,
|
|
47
|
-
agent_name: str = "Bedrock Agent",
|
|
48
|
-
):
|
|
49
|
-
"""
|
|
50
|
-
Initialize with Bedrock agent configuration
|
|
51
|
-
|
|
52
|
-
Args:
|
|
53
|
-
aws_access_key_id: AWS access key ID (defaults to AWS_ACCESS_KEY_ID env var)
|
|
54
|
-
aws_secret_access_key: AWS secret access key (defaults to AWS_SECRET_ACCESS_KEY env var)
|
|
55
|
-
aws_session_token: AWS session token (defaults to AWS_SESSION_TOKEN env var)
|
|
56
|
-
agent_id: Bedrock agent ID (defaults to BEDROCK_AGENT_ID env var)
|
|
57
|
-
agent_alias_id: Bedrock agent alias ID (defaults to BEDROCK_AGENT_ALIAS_ID env var)
|
|
58
|
-
region_name: AWS region (defaults to AWS_REGION env var)
|
|
59
|
-
agent_name: Name of the agent for trajectory metadata
|
|
60
|
-
"""
|
|
61
|
-
logger.debug(
|
|
62
|
-
"Initializing BedrockAgentsAdapter (agent_name=%s, region=%s)",
|
|
63
|
-
agent_name,
|
|
64
|
-
region_name or os.getenv("AWS_REGION"),
|
|
65
|
-
)
|
|
66
|
-
self.agent_id = agent_id or os.getenv("BEDROCK_AGENT_ID")
|
|
67
|
-
self.agent_alias_id = agent_alias_id or os.getenv("BEDROCK_AGENT_ALIAS_ID")
|
|
68
|
-
self.region_name = region_name or os.getenv("AWS_REGION")
|
|
69
|
-
self.agent_name = agent_name
|
|
70
|
-
|
|
71
|
-
if not all([self.agent_id, self.agent_alias_id, self.region_name]):
|
|
72
|
-
raise ValueError(
|
|
73
|
-
"Missing required configuration. Please provide agent_id, agent_alias_id, "
|
|
74
|
-
"and region_name either as parameters or environment variables."
|
|
75
|
-
)
|
|
76
|
-
|
|
77
|
-
# Initialize Bedrock client
|
|
78
|
-
self.bedrock_client = boto3.client(
|
|
79
|
-
region_name=self.region_name,
|
|
80
|
-
aws_access_key_id=aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID"),
|
|
81
|
-
aws_secret_access_key=aws_secret_access_key
|
|
82
|
-
or os.getenv("AWS_SECRET_ACCESS_KEY"),
|
|
83
|
-
aws_session_token=aws_session_token or os.getenv("AWS_SESSION_TOKEN"),
|
|
84
|
-
service_name="bedrock-agent-runtime",
|
|
85
|
-
)
|
|
86
|
-
logger.info(
|
|
87
|
-
"BedrockAgentsAdapter initialized (agent_id=%s, alias=%s, region=%s)",
|
|
88
|
-
self.agent_id,
|
|
89
|
-
self.agent_alias_id,
|
|
90
|
-
self.region_name,
|
|
91
|
-
)
|
|
92
|
-
|
|
93
|
-
def _convert_bedrock_traces_to_messages(
|
|
94
|
-
self,
|
|
95
|
-
traces: List[Dict[str, Any]],
|
|
96
|
-
) -> List[AgentMessage]:
|
|
97
|
-
logger.debug("Converting %d Bedrock trace events to messages", len(traces))
|
|
98
|
-
if not traces:
|
|
99
|
-
return []
|
|
100
|
-
|
|
101
|
-
# TODO: Handle agents with only knowledge base
|
|
102
|
-
# It has a modelInvocationInput with KNOWLEDGE_BASE_RESPONSE_GENERATION that
|
|
103
|
-
# has a system prompt. Discuss and decide how to handle this.
|
|
104
|
-
|
|
105
|
-
messages = []
|
|
106
|
-
for trace in traces:
|
|
107
|
-
# TODO: handle other trace types as well - https://docs.aws.amazon.com/bedrock/latest/userguide/trace-events.html#trace-understand
|
|
108
|
-
orchestration_trace = trace.get("trace", {}).get("orchestrationTrace", {})
|
|
109
|
-
if not orchestration_trace:
|
|
110
|
-
continue
|
|
111
|
-
|
|
112
|
-
if "modelInvocationOutput" in orchestration_trace:
|
|
113
|
-
model_invocation_output = orchestration_trace["modelInvocationOutput"]
|
|
114
|
-
|
|
115
|
-
raw_response_content = model_invocation_output.get(
|
|
116
|
-
"rawResponse", {}
|
|
117
|
-
).get("content", "")
|
|
118
|
-
if not raw_response_content:
|
|
119
|
-
continue
|
|
120
|
-
|
|
121
|
-
try:
|
|
122
|
-
parsed_content = json.loads(raw_response_content)
|
|
123
|
-
contents = parsed_content.get("content", [])
|
|
124
|
-
except (json.JSONDecodeError, KeyError, ValueError) as e:
|
|
125
|
-
logger.exception("Error parsing Bedrock raw response content")
|
|
126
|
-
|
|
127
|
-
if not contents:
|
|
128
|
-
continue
|
|
129
|
-
|
|
130
|
-
text_content = []
|
|
131
|
-
tool_calls = []
|
|
132
|
-
for content in contents:
|
|
133
|
-
if content.get("type") == "text":
|
|
134
|
-
text_content.append(
|
|
135
|
-
MessageContentText(
|
|
136
|
-
type="text", text=content.get("text", "")
|
|
137
|
-
)
|
|
138
|
-
)
|
|
139
|
-
# TODO: Revsist this later. Ideally should use this but the tool_call_id does not come in the invocationInput
|
|
140
|
-
# so for now using trace_id as the tool_call_id
|
|
141
|
-
# elif content.get("type") == "tool_use":
|
|
142
|
-
# tool_calls.append(
|
|
143
|
-
# ToolCall(
|
|
144
|
-
# id=content.get("id", ""),
|
|
145
|
-
# name=content.get("name", ""),
|
|
146
|
-
# arguments=content.get("input", {}),
|
|
147
|
-
# )
|
|
148
|
-
# )
|
|
149
|
-
|
|
150
|
-
messages.append(
|
|
151
|
-
AssistantMessage(
|
|
152
|
-
content=text_content if text_content else None,
|
|
153
|
-
tool_calls=tool_calls if tool_calls else None,
|
|
154
|
-
)
|
|
155
|
-
)
|
|
156
|
-
elif "invocationInput" in orchestration_trace:
|
|
157
|
-
invocation_input = orchestration_trace["invocationInput"]
|
|
158
|
-
if invocation_input.get("invocationType") == "KNOWLEDGE_BASE":
|
|
159
|
-
kb_input = invocation_input.get("knowledgeBaseLookupInput", {})
|
|
160
|
-
tool_call = ToolCall(
|
|
161
|
-
id=invocation_input.get("traceId", ""),
|
|
162
|
-
name="knowledge_base_lookup",
|
|
163
|
-
arguments={
|
|
164
|
-
"text": kb_input.get("text", ""),
|
|
165
|
-
"knowledgeBaseId": kb_input.get("knowledgeBaseId", ""),
|
|
166
|
-
},
|
|
167
|
-
)
|
|
168
|
-
|
|
169
|
-
if messages and isinstance(messages[-1], AssistantMessage):
|
|
170
|
-
if messages[-1].tool_calls:
|
|
171
|
-
messages[-1].tool_calls.append(tool_call)
|
|
172
|
-
else:
|
|
173
|
-
messages[-1].tool_calls = [tool_call]
|
|
174
|
-
else:
|
|
175
|
-
messages.append(AssistantMessage(tool_calls=[tool_call]))
|
|
176
|
-
|
|
177
|
-
elif invocation_input.get("invocationType") == "ACTION_GROUP":
|
|
178
|
-
action_group_input = invocation_input.get(
|
|
179
|
-
"actionGroupInvocationInput", {}
|
|
180
|
-
)
|
|
181
|
-
tool_call = ToolCall(
|
|
182
|
-
id=invocation_input.get("traceId", ""),
|
|
183
|
-
name=f"{action_group_input.get('actionGroupName', '')}/{action_group_input.get('function', '')}",
|
|
184
|
-
arguments={
|
|
185
|
-
p["name"]: p["value"]
|
|
186
|
-
for p in action_group_input.get("parameters", [])
|
|
187
|
-
},
|
|
188
|
-
)
|
|
189
|
-
messages.append(AssistantMessage(tool_calls=[tool_call]))
|
|
190
|
-
elif "observation" in orchestration_trace:
|
|
191
|
-
observation = orchestration_trace["observation"]
|
|
192
|
-
if observation.get("type") == "KNOWLEDGE_BASE":
|
|
193
|
-
kb_output = observation.get("knowledgeBaseLookupOutput", {})
|
|
194
|
-
tool_result = json.dumps(kb_output.get("retrievedReferences", []))
|
|
195
|
-
messages.append(
|
|
196
|
-
ToolMessage(
|
|
197
|
-
tool_name="knowledge_base_lookup",
|
|
198
|
-
tool_call_id=observation.get("traceId", ""),
|
|
199
|
-
content=[
|
|
200
|
-
MessageContentText(type="text", text=str(tool_result))
|
|
201
|
-
],
|
|
202
|
-
)
|
|
203
|
-
)
|
|
204
|
-
elif observation.get("type") == "ACTION_GROUP":
|
|
205
|
-
action_group_output = observation.get(
|
|
206
|
-
"actionGroupInvocationOutput", {}
|
|
207
|
-
)
|
|
208
|
-
|
|
209
|
-
try:
|
|
210
|
-
tool_result = json.loads(action_group_output.get("text", ""))
|
|
211
|
-
except (json.JSONDecodeError, KeyError, ValueError):
|
|
212
|
-
tool_result = action_group_output.get("text", "")
|
|
213
|
-
messages.append(
|
|
214
|
-
ToolMessage(
|
|
215
|
-
tool_name="action_group_invocation",
|
|
216
|
-
tool_call_id=observation.get("traceId", ""),
|
|
217
|
-
content=[
|
|
218
|
-
MessageContentText(type="text", text=str(tool_result))
|
|
219
|
-
],
|
|
220
|
-
)
|
|
221
|
-
)
|
|
222
|
-
|
|
223
|
-
logger.info("Converted Bedrock traces into %d messages", len(messages))
|
|
224
|
-
return messages
|
|
225
|
-
|
|
226
|
-
def _prepare_input(self, input_data: List[AgentMessage]) -> str:
|
|
227
|
-
"""Extract user message from List[AgentMessage]."""
|
|
228
|
-
logger.debug("Preparing Bedrock input from %d messages", len(input_data))
|
|
229
|
-
last_user_message = input_data[-1]
|
|
230
|
-
if last_user_message.role != "user":
|
|
231
|
-
logger.error("Last message is not from user")
|
|
232
|
-
return ""
|
|
233
|
-
# Check if content list is not empty and has text
|
|
234
|
-
if not last_user_message.content:
|
|
235
|
-
logger.error("User message has no content")
|
|
236
|
-
raise ValueError("User message has no content")
|
|
237
|
-
# Find the first text content item
|
|
238
|
-
for content_item in last_user_message.content:
|
|
239
|
-
if content_item.type == "text" and content_item.text:
|
|
240
|
-
logger.debug(
|
|
241
|
-
"Prepared Bedrock input (text_length=%d)", len(content_item.text)
|
|
242
|
-
)
|
|
243
|
-
return content_item.text
|
|
244
|
-
raise ValueError("No text content found in user message")
|
|
245
|
-
|
|
246
|
-
def _run_agent(self, session_id: str, prompt: str) -> List[Dict]:
|
|
247
|
-
"""
|
|
248
|
-
Run the Bedrock agent and collect response and traces.
|
|
249
|
-
|
|
250
|
-
Args:
|
|
251
|
-
session_id: Unique session identifier
|
|
252
|
-
prompt: Input prompt for the agent
|
|
253
|
-
|
|
254
|
-
Returns:
|
|
255
|
-
List of traces
|
|
256
|
-
"""
|
|
257
|
-
try:
|
|
258
|
-
agent_answer = ""
|
|
259
|
-
logger.debug(
|
|
260
|
-
"Invoking Bedrock agent (session_id=%s, prompt_length=%d)",
|
|
261
|
-
session_id,
|
|
262
|
-
len(prompt),
|
|
263
|
-
)
|
|
264
|
-
response = self.bedrock_client.invoke_agent(
|
|
265
|
-
agentId=self.agent_id,
|
|
266
|
-
agentAliasId=self.agent_alias_id,
|
|
267
|
-
sessionId=session_id,
|
|
268
|
-
inputText=prompt,
|
|
269
|
-
enableTrace=True,
|
|
270
|
-
)
|
|
271
|
-
|
|
272
|
-
stream = response["completion"]
|
|
273
|
-
traces = []
|
|
274
|
-
|
|
275
|
-
for event in stream:
|
|
276
|
-
if "chunk" in event:
|
|
277
|
-
data = event["chunk"]["bytes"]
|
|
278
|
-
agent_answer = data.decode("utf8")
|
|
279
|
-
event_trace = event.get("trace")
|
|
280
|
-
|
|
281
|
-
if event_trace:
|
|
282
|
-
traces.append(event_trace)
|
|
283
|
-
|
|
284
|
-
except ClientError:
|
|
285
|
-
logger.exception("Error invoking Bedrock agent via Bedrock runtime")
|
|
286
|
-
return "", []
|
|
287
|
-
logger.info(
|
|
288
|
-
"Bedrock agent invocation succeeded (session_id=%s, trace_events=%d)",
|
|
289
|
-
session_id,
|
|
290
|
-
len(traces),
|
|
291
|
-
)
|
|
292
|
-
return agent_answer, traces
|
|
293
|
-
|
|
294
|
-
async def ainvoke(
|
|
295
|
-
self,
|
|
296
|
-
input: List[AgentMessage],
|
|
297
|
-
session_id: Union[str, None],
|
|
298
|
-
) -> AgentInvocationResponse:
|
|
299
|
-
"""Asynchronous invocation method - invokes the Bedrock agent and converts to List[AgentMessage].
|
|
300
|
-
|
|
301
|
-
Args:
|
|
302
|
-
input: List of AgentMessage objects
|
|
303
|
-
session_id: Unique session identifier
|
|
304
|
-
|
|
305
|
-
Returns:
|
|
306
|
-
AgentInvocationResponse - response containing agent trace, trajectory, and final response.
|
|
307
|
-
"""
|
|
308
|
-
logger.info(
|
|
309
|
-
"Bedrock ainvoke called (session_id=%s, input_messages=%d)",
|
|
310
|
-
session_id,
|
|
311
|
-
len(input),
|
|
312
|
-
)
|
|
313
|
-
agent_input = self._prepare_input(input)
|
|
314
|
-
session_id = session_id or str(uuid.uuid4())
|
|
315
|
-
|
|
316
|
-
try:
|
|
317
|
-
# Run the synchronous _run_agent in a thread pool to avoid blocking
|
|
318
|
-
_, traces = await asyncio.to_thread(
|
|
319
|
-
self._run_agent, session_id, agent_input
|
|
320
|
-
)
|
|
321
|
-
logger.debug(
|
|
322
|
-
"Bedrock agent run returned %d trace events for session_id=%s",
|
|
323
|
-
len(traces),
|
|
324
|
-
session_id,
|
|
325
|
-
)
|
|
326
|
-
except (ClientError, ValueError, KeyError, json.JSONDecodeError):
|
|
327
|
-
logger.exception("Error invoking Bedrock agent")
|
|
328
|
-
return AgentInvocationResponse()
|
|
329
|
-
|
|
330
|
-
try:
|
|
331
|
-
agent_trajectory = self._convert_bedrock_traces_to_messages(traces)
|
|
332
|
-
logger.info(
|
|
333
|
-
"Bedrock agent produced %d trajectory messages",
|
|
334
|
-
len(agent_trajectory),
|
|
335
|
-
)
|
|
336
|
-
|
|
337
|
-
return AgentInvocationResponse(
|
|
338
|
-
agent_trajectory=agent_trajectory,
|
|
339
|
-
)
|
|
340
|
-
|
|
341
|
-
except (ClientError, ValueError, KeyError, json.JSONDecodeError):
|
|
342
|
-
logger.exception("Error converting Bedrock traces to messages")
|
|
343
|
-
return AgentInvocationResponse()
|
|
1
|
+
"""
|
|
2
|
+
Bedrock Agents Adapter
|
|
3
|
+
https://docs.aws.amazon.com/bedrock/latest/userguide/trace-events.html
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import asyncio
|
|
7
|
+
import json
|
|
8
|
+
import os
|
|
9
|
+
import uuid
|
|
10
|
+
from typing import Any, Dict, List, Optional, Union
|
|
11
|
+
|
|
12
|
+
import boto3
|
|
13
|
+
from botocore.exceptions import ClientError
|
|
14
|
+
|
|
15
|
+
from quraite.adapters.base import BaseAdapter
|
|
16
|
+
from quraite.logger import get_logger
|
|
17
|
+
from quraite.schema.message import (
|
|
18
|
+
AgentMessage,
|
|
19
|
+
AssistantMessage,
|
|
20
|
+
MessageContentText,
|
|
21
|
+
ToolCall,
|
|
22
|
+
ToolMessage,
|
|
23
|
+
)
|
|
24
|
+
from quraite.schema.response import AgentInvocationResponse
|
|
25
|
+
|
|
26
|
+
logger = get_logger(__name__)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class BedrockAgentsAdapter(BaseAdapter):
|
|
30
|
+
"""
|
|
31
|
+
Bedrock Agents adapter wrapper that converts AWS Bedrock agent
|
|
32
|
+
to a standardized callable interface (invoke) and converts the output to List[AgentMessage].
|
|
33
|
+
|
|
34
|
+
This class wraps any Bedrock Agent and provides:
|
|
35
|
+
- Synchronous invocation via invoke()
|
|
36
|
+
- Automatic conversion to List[AgentMessage] format
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
def __init__(
|
|
40
|
+
self,
|
|
41
|
+
aws_access_key_id: Optional[str] = None,
|
|
42
|
+
aws_secret_access_key: Optional[str] = None,
|
|
43
|
+
aws_session_token: Optional[str] = None,
|
|
44
|
+
agent_id: Optional[str] = None,
|
|
45
|
+
agent_alias_id: Optional[str] = None,
|
|
46
|
+
region_name: Optional[str] = None,
|
|
47
|
+
agent_name: str = "Bedrock Agent",
|
|
48
|
+
):
|
|
49
|
+
"""
|
|
50
|
+
Initialize with Bedrock agent configuration
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
aws_access_key_id: AWS access key ID (defaults to AWS_ACCESS_KEY_ID env var)
|
|
54
|
+
aws_secret_access_key: AWS secret access key (defaults to AWS_SECRET_ACCESS_KEY env var)
|
|
55
|
+
aws_session_token: AWS session token (defaults to AWS_SESSION_TOKEN env var)
|
|
56
|
+
agent_id: Bedrock agent ID (defaults to BEDROCK_AGENT_ID env var)
|
|
57
|
+
agent_alias_id: Bedrock agent alias ID (defaults to BEDROCK_AGENT_ALIAS_ID env var)
|
|
58
|
+
region_name: AWS region (defaults to AWS_REGION env var)
|
|
59
|
+
agent_name: Name of the agent for trajectory metadata
|
|
60
|
+
"""
|
|
61
|
+
logger.debug(
|
|
62
|
+
"Initializing BedrockAgentsAdapter (agent_name=%s, region=%s)",
|
|
63
|
+
agent_name,
|
|
64
|
+
region_name or os.getenv("AWS_REGION"),
|
|
65
|
+
)
|
|
66
|
+
self.agent_id = agent_id or os.getenv("BEDROCK_AGENT_ID")
|
|
67
|
+
self.agent_alias_id = agent_alias_id or os.getenv("BEDROCK_AGENT_ALIAS_ID")
|
|
68
|
+
self.region_name = region_name or os.getenv("AWS_REGION")
|
|
69
|
+
self.agent_name = agent_name
|
|
70
|
+
|
|
71
|
+
if not all([self.agent_id, self.agent_alias_id, self.region_name]):
|
|
72
|
+
raise ValueError(
|
|
73
|
+
"Missing required configuration. Please provide agent_id, agent_alias_id, "
|
|
74
|
+
"and region_name either as parameters or environment variables."
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
# Initialize Bedrock client
|
|
78
|
+
self.bedrock_client = boto3.client(
|
|
79
|
+
region_name=self.region_name,
|
|
80
|
+
aws_access_key_id=aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID"),
|
|
81
|
+
aws_secret_access_key=aws_secret_access_key
|
|
82
|
+
or os.getenv("AWS_SECRET_ACCESS_KEY"),
|
|
83
|
+
aws_session_token=aws_session_token or os.getenv("AWS_SESSION_TOKEN"),
|
|
84
|
+
service_name="bedrock-agent-runtime",
|
|
85
|
+
)
|
|
86
|
+
logger.info(
|
|
87
|
+
"BedrockAgentsAdapter initialized (agent_id=%s, alias=%s, region=%s)",
|
|
88
|
+
self.agent_id,
|
|
89
|
+
self.agent_alias_id,
|
|
90
|
+
self.region_name,
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
def _convert_bedrock_traces_to_messages(
|
|
94
|
+
self,
|
|
95
|
+
traces: List[Dict[str, Any]],
|
|
96
|
+
) -> List[AgentMessage]:
|
|
97
|
+
logger.debug("Converting %d Bedrock trace events to messages", len(traces))
|
|
98
|
+
if not traces:
|
|
99
|
+
return []
|
|
100
|
+
|
|
101
|
+
# TODO: Handle agents with only knowledge base
|
|
102
|
+
# It has a modelInvocationInput with KNOWLEDGE_BASE_RESPONSE_GENERATION that
|
|
103
|
+
# has a system prompt. Discuss and decide how to handle this.
|
|
104
|
+
|
|
105
|
+
messages = []
|
|
106
|
+
for trace in traces:
|
|
107
|
+
# TODO: handle other trace types as well - https://docs.aws.amazon.com/bedrock/latest/userguide/trace-events.html#trace-understand
|
|
108
|
+
orchestration_trace = trace.get("trace", {}).get("orchestrationTrace", {})
|
|
109
|
+
if not orchestration_trace:
|
|
110
|
+
continue
|
|
111
|
+
|
|
112
|
+
if "modelInvocationOutput" in orchestration_trace:
|
|
113
|
+
model_invocation_output = orchestration_trace["modelInvocationOutput"]
|
|
114
|
+
|
|
115
|
+
raw_response_content = model_invocation_output.get(
|
|
116
|
+
"rawResponse", {}
|
|
117
|
+
).get("content", "")
|
|
118
|
+
if not raw_response_content:
|
|
119
|
+
continue
|
|
120
|
+
|
|
121
|
+
try:
|
|
122
|
+
parsed_content = json.loads(raw_response_content)
|
|
123
|
+
contents = parsed_content.get("content", [])
|
|
124
|
+
except (json.JSONDecodeError, KeyError, ValueError) as e:
|
|
125
|
+
logger.exception("Error parsing Bedrock raw response content")
|
|
126
|
+
|
|
127
|
+
if not contents:
|
|
128
|
+
continue
|
|
129
|
+
|
|
130
|
+
text_content = []
|
|
131
|
+
tool_calls = []
|
|
132
|
+
for content in contents:
|
|
133
|
+
if content.get("type") == "text":
|
|
134
|
+
text_content.append(
|
|
135
|
+
MessageContentText(
|
|
136
|
+
type="text", text=content.get("text", "")
|
|
137
|
+
)
|
|
138
|
+
)
|
|
139
|
+
# TODO: Revsist this later. Ideally should use this but the tool_call_id does not come in the invocationInput
|
|
140
|
+
# so for now using trace_id as the tool_call_id
|
|
141
|
+
# elif content.get("type") == "tool_use":
|
|
142
|
+
# tool_calls.append(
|
|
143
|
+
# ToolCall(
|
|
144
|
+
# id=content.get("id", ""),
|
|
145
|
+
# name=content.get("name", ""),
|
|
146
|
+
# arguments=content.get("input", {}),
|
|
147
|
+
# )
|
|
148
|
+
# )
|
|
149
|
+
|
|
150
|
+
messages.append(
|
|
151
|
+
AssistantMessage(
|
|
152
|
+
content=text_content if text_content else None,
|
|
153
|
+
tool_calls=tool_calls if tool_calls else None,
|
|
154
|
+
)
|
|
155
|
+
)
|
|
156
|
+
elif "invocationInput" in orchestration_trace:
|
|
157
|
+
invocation_input = orchestration_trace["invocationInput"]
|
|
158
|
+
if invocation_input.get("invocationType") == "KNOWLEDGE_BASE":
|
|
159
|
+
kb_input = invocation_input.get("knowledgeBaseLookupInput", {})
|
|
160
|
+
tool_call = ToolCall(
|
|
161
|
+
id=invocation_input.get("traceId", ""),
|
|
162
|
+
name="knowledge_base_lookup",
|
|
163
|
+
arguments={
|
|
164
|
+
"text": kb_input.get("text", ""),
|
|
165
|
+
"knowledgeBaseId": kb_input.get("knowledgeBaseId", ""),
|
|
166
|
+
},
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
if messages and isinstance(messages[-1], AssistantMessage):
|
|
170
|
+
if messages[-1].tool_calls:
|
|
171
|
+
messages[-1].tool_calls.append(tool_call)
|
|
172
|
+
else:
|
|
173
|
+
messages[-1].tool_calls = [tool_call]
|
|
174
|
+
else:
|
|
175
|
+
messages.append(AssistantMessage(tool_calls=[tool_call]))
|
|
176
|
+
|
|
177
|
+
elif invocation_input.get("invocationType") == "ACTION_GROUP":
|
|
178
|
+
action_group_input = invocation_input.get(
|
|
179
|
+
"actionGroupInvocationInput", {}
|
|
180
|
+
)
|
|
181
|
+
tool_call = ToolCall(
|
|
182
|
+
id=invocation_input.get("traceId", ""),
|
|
183
|
+
name=f"{action_group_input.get('actionGroupName', '')}/{action_group_input.get('function', '')}",
|
|
184
|
+
arguments={
|
|
185
|
+
p["name"]: p["value"]
|
|
186
|
+
for p in action_group_input.get("parameters", [])
|
|
187
|
+
},
|
|
188
|
+
)
|
|
189
|
+
messages.append(AssistantMessage(tool_calls=[tool_call]))
|
|
190
|
+
elif "observation" in orchestration_trace:
|
|
191
|
+
observation = orchestration_trace["observation"]
|
|
192
|
+
if observation.get("type") == "KNOWLEDGE_BASE":
|
|
193
|
+
kb_output = observation.get("knowledgeBaseLookupOutput", {})
|
|
194
|
+
tool_result = json.dumps(kb_output.get("retrievedReferences", []))
|
|
195
|
+
messages.append(
|
|
196
|
+
ToolMessage(
|
|
197
|
+
tool_name="knowledge_base_lookup",
|
|
198
|
+
tool_call_id=observation.get("traceId", ""),
|
|
199
|
+
content=[
|
|
200
|
+
MessageContentText(type="text", text=str(tool_result))
|
|
201
|
+
],
|
|
202
|
+
)
|
|
203
|
+
)
|
|
204
|
+
elif observation.get("type") == "ACTION_GROUP":
|
|
205
|
+
action_group_output = observation.get(
|
|
206
|
+
"actionGroupInvocationOutput", {}
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
try:
|
|
210
|
+
tool_result = json.loads(action_group_output.get("text", ""))
|
|
211
|
+
except (json.JSONDecodeError, KeyError, ValueError):
|
|
212
|
+
tool_result = action_group_output.get("text", "")
|
|
213
|
+
messages.append(
|
|
214
|
+
ToolMessage(
|
|
215
|
+
tool_name="action_group_invocation",
|
|
216
|
+
tool_call_id=observation.get("traceId", ""),
|
|
217
|
+
content=[
|
|
218
|
+
MessageContentText(type="text", text=str(tool_result))
|
|
219
|
+
],
|
|
220
|
+
)
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
logger.info("Converted Bedrock traces into %d messages", len(messages))
|
|
224
|
+
return messages
|
|
225
|
+
|
|
226
|
+
def _prepare_input(self, input_data: List[AgentMessage]) -> str:
|
|
227
|
+
"""Extract user message from List[AgentMessage]."""
|
|
228
|
+
logger.debug("Preparing Bedrock input from %d messages", len(input_data))
|
|
229
|
+
last_user_message = input_data[-1]
|
|
230
|
+
if last_user_message.role != "user":
|
|
231
|
+
logger.error("Last message is not from user")
|
|
232
|
+
return ""
|
|
233
|
+
# Check if content list is not empty and has text
|
|
234
|
+
if not last_user_message.content:
|
|
235
|
+
logger.error("User message has no content")
|
|
236
|
+
raise ValueError("User message has no content")
|
|
237
|
+
# Find the first text content item
|
|
238
|
+
for content_item in last_user_message.content:
|
|
239
|
+
if content_item.type == "text" and content_item.text:
|
|
240
|
+
logger.debug(
|
|
241
|
+
"Prepared Bedrock input (text_length=%d)", len(content_item.text)
|
|
242
|
+
)
|
|
243
|
+
return content_item.text
|
|
244
|
+
raise ValueError("No text content found in user message")
|
|
245
|
+
|
|
246
|
+
def _run_agent(self, session_id: str, prompt: str) -> List[Dict]:
|
|
247
|
+
"""
|
|
248
|
+
Run the Bedrock agent and collect response and traces.
|
|
249
|
+
|
|
250
|
+
Args:
|
|
251
|
+
session_id: Unique session identifier
|
|
252
|
+
prompt: Input prompt for the agent
|
|
253
|
+
|
|
254
|
+
Returns:
|
|
255
|
+
List of traces
|
|
256
|
+
"""
|
|
257
|
+
try:
|
|
258
|
+
agent_answer = ""
|
|
259
|
+
logger.debug(
|
|
260
|
+
"Invoking Bedrock agent (session_id=%s, prompt_length=%d)",
|
|
261
|
+
session_id,
|
|
262
|
+
len(prompt),
|
|
263
|
+
)
|
|
264
|
+
response = self.bedrock_client.invoke_agent(
|
|
265
|
+
agentId=self.agent_id,
|
|
266
|
+
agentAliasId=self.agent_alias_id,
|
|
267
|
+
sessionId=session_id,
|
|
268
|
+
inputText=prompt,
|
|
269
|
+
enableTrace=True,
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
stream = response["completion"]
|
|
273
|
+
traces = []
|
|
274
|
+
|
|
275
|
+
for event in stream:
|
|
276
|
+
if "chunk" in event:
|
|
277
|
+
data = event["chunk"]["bytes"]
|
|
278
|
+
agent_answer = data.decode("utf8")
|
|
279
|
+
event_trace = event.get("trace")
|
|
280
|
+
|
|
281
|
+
if event_trace:
|
|
282
|
+
traces.append(event_trace)
|
|
283
|
+
|
|
284
|
+
except ClientError:
|
|
285
|
+
logger.exception("Error invoking Bedrock agent via Bedrock runtime")
|
|
286
|
+
return "", []
|
|
287
|
+
logger.info(
|
|
288
|
+
"Bedrock agent invocation succeeded (session_id=%s, trace_events=%d)",
|
|
289
|
+
session_id,
|
|
290
|
+
len(traces),
|
|
291
|
+
)
|
|
292
|
+
return agent_answer, traces
|
|
293
|
+
|
|
294
|
+
async def ainvoke(
|
|
295
|
+
self,
|
|
296
|
+
input: List[AgentMessage],
|
|
297
|
+
session_id: Union[str, None],
|
|
298
|
+
) -> AgentInvocationResponse:
|
|
299
|
+
"""Asynchronous invocation method - invokes the Bedrock agent and converts to List[AgentMessage].
|
|
300
|
+
|
|
301
|
+
Args:
|
|
302
|
+
input: List of AgentMessage objects
|
|
303
|
+
session_id: Unique session identifier
|
|
304
|
+
|
|
305
|
+
Returns:
|
|
306
|
+
AgentInvocationResponse - response containing agent trace, trajectory, and final response.
|
|
307
|
+
"""
|
|
308
|
+
logger.info(
|
|
309
|
+
"Bedrock ainvoke called (session_id=%s, input_messages=%d)",
|
|
310
|
+
session_id,
|
|
311
|
+
len(input),
|
|
312
|
+
)
|
|
313
|
+
agent_input = self._prepare_input(input)
|
|
314
|
+
session_id = session_id or str(uuid.uuid4())
|
|
315
|
+
|
|
316
|
+
try:
|
|
317
|
+
# Run the synchronous _run_agent in a thread pool to avoid blocking
|
|
318
|
+
_, traces = await asyncio.to_thread(
|
|
319
|
+
self._run_agent, session_id, agent_input
|
|
320
|
+
)
|
|
321
|
+
logger.debug(
|
|
322
|
+
"Bedrock agent run returned %d trace events for session_id=%s",
|
|
323
|
+
len(traces),
|
|
324
|
+
session_id,
|
|
325
|
+
)
|
|
326
|
+
except (ClientError, ValueError, KeyError, json.JSONDecodeError):
|
|
327
|
+
logger.exception("Error invoking Bedrock agent")
|
|
328
|
+
return AgentInvocationResponse()
|
|
329
|
+
|
|
330
|
+
try:
|
|
331
|
+
agent_trajectory = self._convert_bedrock_traces_to_messages(traces)
|
|
332
|
+
logger.info(
|
|
333
|
+
"Bedrock agent produced %d trajectory messages",
|
|
334
|
+
len(agent_trajectory),
|
|
335
|
+
)
|
|
336
|
+
|
|
337
|
+
return AgentInvocationResponse(
|
|
338
|
+
agent_trajectory=agent_trajectory,
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
except (ClientError, ValueError, KeyError, json.JSONDecodeError):
|
|
342
|
+
logger.exception("Error converting Bedrock traces to messages")
|
|
343
|
+
return AgentInvocationResponse()
|