quraite 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quraite/__init__.py +3 -3
- quraite/adapters/__init__.py +134 -134
- quraite/adapters/agno_adapter.py +157 -159
- quraite/adapters/base.py +123 -123
- quraite/adapters/bedrock_agents_adapter.py +343 -343
- quraite/adapters/flowise_adapter.py +275 -275
- quraite/adapters/google_adk_adapter.py +211 -209
- quraite/adapters/http_adapter.py +255 -239
- quraite/adapters/{langgraph_adapter.py → langchain_adapter.py} +305 -304
- quraite/adapters/{langgraph_server_adapter.py → langchain_server_adapter.py} +252 -252
- quraite/adapters/langflow_adapter.py +192 -192
- quraite/adapters/n8n_adapter.py +220 -220
- quraite/adapters/openai_agents_adapter.py +267 -269
- quraite/adapters/pydantic_ai_adapter.py +307 -312
- quraite/adapters/smolagents_adapter.py +148 -152
- quraite/logger.py +61 -61
- quraite/schema/message.py +91 -91
- quraite/schema/response.py +16 -16
- quraite/serve/__init__.py +1 -1
- quraite/serve/cloudflared.py +210 -210
- quraite/serve/local_agent.py +360 -360
- quraite/traces/traces_adk_openinference.json +379 -0
- quraite/traces/traces_agno_multi_agent.json +669 -0
- quraite/traces/traces_agno_openinference.json +321 -0
- quraite/traces/traces_crewai_openinference.json +155 -0
- quraite/traces/traces_langgraph_openinference.json +349 -0
- quraite/traces/traces_langgraph_openinference_multi_agent.json +2705 -0
- quraite/traces/traces_langgraph_traceloop.json +510 -0
- quraite/traces/traces_openai_agents_multi_agent_1.json +402 -0
- quraite/traces/traces_openai_agents_openinference.json +341 -0
- quraite/traces/traces_pydantic_openinference.json +286 -0
- quraite/traces/traces_pydantic_openinference_multi_agent_1.json +399 -0
- quraite/traces/traces_pydantic_openinference_multi_agent_2.json +398 -0
- quraite/traces/traces_smol_agents_openinference.json +397 -0
- quraite/traces/traces_smol_agents_tool_calling_openinference.json +704 -0
- quraite/tracing/__init__.py +25 -24
- quraite/tracing/constants.py +15 -16
- quraite/tracing/span_exporter.py +101 -115
- quraite/tracing/span_processor.py +47 -49
- quraite/tracing/tool_extractors.py +309 -290
- quraite/tracing/trace.py +564 -564
- quraite/tracing/types.py +179 -179
- quraite/tracing/utils.py +170 -170
- quraite/utils/json_utils.py +269 -269
- quraite-0.1.2.dist-info/METADATA +386 -0
- quraite-0.1.2.dist-info/RECORD +49 -0
- {quraite-0.1.0.dist-info → quraite-0.1.2.dist-info}/WHEEL +1 -1
- quraite-0.1.0.dist-info/METADATA +0 -44
- quraite-0.1.0.dist-info/RECORD +0 -35
|
@@ -1,252 +1,252 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
|
|
3
|
-
from typing import Annotated, Any, List, Optional, Union
|
|
4
|
-
|
|
5
|
-
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
|
6
|
-
from langgraph.pregel.remote import RemoteGraph
|
|
7
|
-
from langgraph_sdk import get_client, get_sync_client
|
|
8
|
-
from pydantic import Discriminator
|
|
9
|
-
|
|
10
|
-
from quraite.adapters.base import BaseAdapter
|
|
11
|
-
from quraite.logger import get_logger
|
|
12
|
-
from quraite.schema.message import AgentMessage, AssistantMessage, MessageContentText
|
|
13
|
-
from quraite.schema.message import SystemMessage as QuraiteSystemMessage
|
|
14
|
-
from quraite.schema.message import ToolCall, ToolMessage, UserMessage
|
|
15
|
-
from quraite.schema.response import AgentInvocationResponse
|
|
16
|
-
|
|
17
|
-
LangchainMessage = Annotated[
|
|
18
|
-
Union[HumanMessage, SystemMessage, AIMessage, ToolMessage],
|
|
19
|
-
Discriminator(discriminator="type"),
|
|
20
|
-
]
|
|
21
|
-
|
|
22
|
-
logger = get_logger(__name__)
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
class
|
|
26
|
-
"""Remote
|
|
27
|
-
|
|
28
|
-
Args:
|
|
29
|
-
base_url: The base URL of the
|
|
30
|
-
assistant_id: The ID of the assistant to invoke
|
|
31
|
-
**kwargs: Additional keyword arguments passed directly to
|
|
32
|
-
langgraph_sdk.get_client() and get_sync_client().
|
|
33
|
-
Common options include:
|
|
34
|
-
- api_key: API key for authentication
|
|
35
|
-
- headers: Additional HTTP headers
|
|
36
|
-
- timeout: Request timeout configuration
|
|
37
|
-
"""
|
|
38
|
-
|
|
39
|
-
def __init__(
|
|
40
|
-
self,
|
|
41
|
-
*,
|
|
42
|
-
base_url: str,
|
|
43
|
-
assistant_id: Optional[str] = None,
|
|
44
|
-
graph_name: Optional[str] = None,
|
|
45
|
-
**kwargs,
|
|
46
|
-
) -> None:
|
|
47
|
-
self.base_url = base_url
|
|
48
|
-
self.assistant_id = assistant_id
|
|
49
|
-
self.graph_name = graph_name
|
|
50
|
-
|
|
51
|
-
logger.debug(
|
|
52
|
-
"Initializing
|
|
53
|
-
base_url,
|
|
54
|
-
assistant_id,
|
|
55
|
-
graph_name,
|
|
56
|
-
)
|
|
57
|
-
try:
|
|
58
|
-
sync_client = get_sync_client(url=self.base_url, **kwargs)
|
|
59
|
-
async_client = get_client(url=self.base_url, **kwargs)
|
|
60
|
-
if self.assistant_id:
|
|
61
|
-
self.remote_graph = RemoteGraph(
|
|
62
|
-
self.assistant_id,
|
|
63
|
-
url=self.base_url,
|
|
64
|
-
sync_client=sync_client,
|
|
65
|
-
client=async_client,
|
|
66
|
-
)
|
|
67
|
-
else:
|
|
68
|
-
self.remote_graph = RemoteGraph(
|
|
69
|
-
self.graph_name,
|
|
70
|
-
url=self.base_url,
|
|
71
|
-
sync_client=sync_client,
|
|
72
|
-
client=async_client,
|
|
73
|
-
)
|
|
74
|
-
except Exception as exc:
|
|
75
|
-
raise RuntimeError(
|
|
76
|
-
f"Failed to initialize
|
|
77
|
-
)
|
|
78
|
-
logger.info(
|
|
79
|
-
"
|
|
80
|
-
self.assistant_id,
|
|
81
|
-
self.graph_name,
|
|
82
|
-
)
|
|
83
|
-
|
|
84
|
-
def _prepare_input(self, input: List[AgentMessage]) -> Any:
|
|
85
|
-
"""
|
|
86
|
-
Prepare input for
|
|
87
|
-
|
|
88
|
-
Args:
|
|
89
|
-
input: List[AgentMessage] containing user_message
|
|
90
|
-
|
|
91
|
-
Returns:
|
|
92
|
-
Dict with messages list containing user_message
|
|
93
|
-
"""
|
|
94
|
-
logger.debug("Preparing
|
|
95
|
-
if not input or input[-1].role != "user":
|
|
96
|
-
logger.error("
|
|
97
|
-
raise ValueError("No user message found in the input")
|
|
98
|
-
|
|
99
|
-
last_user_message = input[-1]
|
|
100
|
-
# Check if content list is not empty and has text
|
|
101
|
-
if not last_user_message.content:
|
|
102
|
-
logger.error("
|
|
103
|
-
raise ValueError("User message has no content")
|
|
104
|
-
|
|
105
|
-
# Find the first text content item
|
|
106
|
-
text_content = None
|
|
107
|
-
for content_item in last_user_message.content:
|
|
108
|
-
if content_item.type == "text" and content_item.text:
|
|
109
|
-
text_content = content_item.text
|
|
110
|
-
break
|
|
111
|
-
|
|
112
|
-
if not text_content:
|
|
113
|
-
logger.error("
|
|
114
|
-
raise ValueError("No text content found in user message")
|
|
115
|
-
|
|
116
|
-
logger.debug(
|
|
117
|
-
"Prepared
|
|
118
|
-
)
|
|
119
|
-
return {"messages": [HumanMessage(content=text_content).model_dump()]}
|
|
120
|
-
|
|
121
|
-
def _convert_langchain_messages_to_quraite_messages(
|
|
122
|
-
self,
|
|
123
|
-
messages: List[dict],
|
|
124
|
-
) -> List[AgentMessage]:
|
|
125
|
-
logger.debug(
|
|
126
|
-
"Converting %d
|
|
127
|
-
)
|
|
128
|
-
converted_messages: List[AgentMessage] = []
|
|
129
|
-
|
|
130
|
-
for msg in messages:
|
|
131
|
-
if msg.get("type") == "system":
|
|
132
|
-
converted_messages.append(
|
|
133
|
-
QuraiteSystemMessage(
|
|
134
|
-
content=[
|
|
135
|
-
MessageContentText(type="text", text=msg.get("content", ""))
|
|
136
|
-
],
|
|
137
|
-
)
|
|
138
|
-
)
|
|
139
|
-
|
|
140
|
-
elif msg.get("type") == "human":
|
|
141
|
-
converted_messages.append(
|
|
142
|
-
UserMessage(
|
|
143
|
-
content=[
|
|
144
|
-
MessageContentText(type="text", text=msg.get("content", ""))
|
|
145
|
-
],
|
|
146
|
-
)
|
|
147
|
-
)
|
|
148
|
-
|
|
149
|
-
elif msg.get("type") == "ai":
|
|
150
|
-
text_content: List[MessageContentText] = []
|
|
151
|
-
tool_calls_list: List[ToolCall] = []
|
|
152
|
-
|
|
153
|
-
# Extract text content - sometimes it's a string, sometimes a list of dicts
|
|
154
|
-
content = msg.get("content")
|
|
155
|
-
if isinstance(content, str) and content:
|
|
156
|
-
text_content.append(MessageContentText(type="text", text=content))
|
|
157
|
-
elif isinstance(content, list):
|
|
158
|
-
for content_item in content:
|
|
159
|
-
if isinstance(content_item, dict):
|
|
160
|
-
if content_item.get("type") == "text" and content_item.get(
|
|
161
|
-
"text"
|
|
162
|
-
):
|
|
163
|
-
text_content.append(
|
|
164
|
-
MessageContentText(
|
|
165
|
-
type="text", text=content_item.get("text")
|
|
166
|
-
)
|
|
167
|
-
)
|
|
168
|
-
|
|
169
|
-
# Extract tool calls if present
|
|
170
|
-
if msg.get("tool_calls"):
|
|
171
|
-
for tool_call in msg.get("tool_calls"):
|
|
172
|
-
if isinstance(tool_call, dict):
|
|
173
|
-
tool_calls_list.append(
|
|
174
|
-
ToolCall(
|
|
175
|
-
id=tool_call.get("id", ""),
|
|
176
|
-
name=tool_call.get("name", ""),
|
|
177
|
-
arguments=tool_call.get("args", {}),
|
|
178
|
-
)
|
|
179
|
-
)
|
|
180
|
-
|
|
181
|
-
converted_messages.append(
|
|
182
|
-
AssistantMessage(
|
|
183
|
-
content=text_content if text_content else None,
|
|
184
|
-
tool_calls=tool_calls_list if tool_calls_list else None,
|
|
185
|
-
)
|
|
186
|
-
)
|
|
187
|
-
|
|
188
|
-
elif msg.get("type") == "tool":
|
|
189
|
-
tool_content = msg.get("content", "")
|
|
190
|
-
converted_messages.append(
|
|
191
|
-
ToolMessage(
|
|
192
|
-
tool_call_id=msg.get("tool_call_id", ""),
|
|
193
|
-
content=[
|
|
194
|
-
MessageContentText(type="text", text=str(tool_content))
|
|
195
|
-
],
|
|
196
|
-
)
|
|
197
|
-
)
|
|
198
|
-
|
|
199
|
-
else:
|
|
200
|
-
# Skip unsupported message types
|
|
201
|
-
continue
|
|
202
|
-
|
|
203
|
-
logger.info(
|
|
204
|
-
"
|
|
205
|
-
len(converted_messages),
|
|
206
|
-
)
|
|
207
|
-
return converted_messages
|
|
208
|
-
|
|
209
|
-
async def ainvoke(
|
|
210
|
-
self,
|
|
211
|
-
input: List[AgentMessage],
|
|
212
|
-
session_id: Annotated[Union[str, None], "Thread ID used by
|
|
213
|
-
) -> AgentInvocationResponse:
|
|
214
|
-
agent_messages = []
|
|
215
|
-
agent_input = self._prepare_input(input)
|
|
216
|
-
if session_id:
|
|
217
|
-
config = {"configurable": {"thread_id": session_id}}
|
|
218
|
-
else:
|
|
219
|
-
config = {}
|
|
220
|
-
|
|
221
|
-
try:
|
|
222
|
-
logger.info("
|
|
223
|
-
async for event in self.remote_graph.astream(agent_input, config=config):
|
|
224
|
-
for _, result in event.items():
|
|
225
|
-
if result.get("messages"):
|
|
226
|
-
logger.debug(
|
|
227
|
-
"
|
|
228
|
-
len(result.get("messages")),
|
|
229
|
-
)
|
|
230
|
-
agent_messages += result.get("messages")
|
|
231
|
-
|
|
232
|
-
except Exception as e:
|
|
233
|
-
logger.exception("Error invoking
|
|
234
|
-
raise RuntimeError(f"Error invoking
|
|
235
|
-
|
|
236
|
-
try:
|
|
237
|
-
# Convert to List[AgentMessage]
|
|
238
|
-
agent_trajectory = self._convert_langchain_messages_to_quraite_messages(
|
|
239
|
-
agent_messages
|
|
240
|
-
)
|
|
241
|
-
logger.info(
|
|
242
|
-
"
|
|
243
|
-
len(agent_trajectory),
|
|
244
|
-
)
|
|
245
|
-
|
|
246
|
-
return AgentInvocationResponse(
|
|
247
|
-
agent_trajectory=agent_trajectory,
|
|
248
|
-
)
|
|
249
|
-
|
|
250
|
-
except ValueError:
|
|
251
|
-
logger.exception("
|
|
252
|
-
return AgentInvocationResponse()
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Annotated, Any, List, Optional, Union
|
|
4
|
+
|
|
5
|
+
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
|
6
|
+
from langgraph.pregel.remote import RemoteGraph
|
|
7
|
+
from langgraph_sdk import get_client, get_sync_client
|
|
8
|
+
from pydantic import Discriminator
|
|
9
|
+
|
|
10
|
+
from quraite.adapters.base import BaseAdapter
|
|
11
|
+
from quraite.logger import get_logger
|
|
12
|
+
from quraite.schema.message import AgentMessage, AssistantMessage, MessageContentText
|
|
13
|
+
from quraite.schema.message import SystemMessage as QuraiteSystemMessage
|
|
14
|
+
from quraite.schema.message import ToolCall, ToolMessage, UserMessage
|
|
15
|
+
from quraite.schema.response import AgentInvocationResponse
|
|
16
|
+
|
|
17
|
+
LangchainMessage = Annotated[
|
|
18
|
+
Union[HumanMessage, SystemMessage, AIMessage, ToolMessage],
|
|
19
|
+
Discriminator(discriminator="type"),
|
|
20
|
+
]
|
|
21
|
+
|
|
22
|
+
logger = get_logger(__name__)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class LangchainServerAdapter(BaseAdapter):
|
|
26
|
+
"""Remote LangChain server adapter based on langgraph-sdk.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
base_url: The base URL of the LangChain server
|
|
30
|
+
assistant_id: The ID of the assistant to invoke
|
|
31
|
+
**kwargs: Additional keyword arguments passed directly to
|
|
32
|
+
langgraph_sdk.get_client() and get_sync_client().
|
|
33
|
+
Common options include:
|
|
34
|
+
- api_key: API key for authentication
|
|
35
|
+
- headers: Additional HTTP headers
|
|
36
|
+
- timeout: Request timeout configuration
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
def __init__(
|
|
40
|
+
self,
|
|
41
|
+
*,
|
|
42
|
+
base_url: str,
|
|
43
|
+
assistant_id: Optional[str] = None,
|
|
44
|
+
graph_name: Optional[str] = None,
|
|
45
|
+
**kwargs,
|
|
46
|
+
) -> None:
|
|
47
|
+
self.base_url = base_url
|
|
48
|
+
self.assistant_id = assistant_id
|
|
49
|
+
self.graph_name = graph_name
|
|
50
|
+
|
|
51
|
+
logger.debug(
|
|
52
|
+
"Initializing LangchainServerAdapter (base_url=%s, assistant_id=%s, graph_name=%s)",
|
|
53
|
+
base_url,
|
|
54
|
+
assistant_id,
|
|
55
|
+
graph_name,
|
|
56
|
+
)
|
|
57
|
+
try:
|
|
58
|
+
sync_client = get_sync_client(url=self.base_url, **kwargs)
|
|
59
|
+
async_client = get_client(url=self.base_url, **kwargs)
|
|
60
|
+
if self.assistant_id:
|
|
61
|
+
self.remote_graph = RemoteGraph(
|
|
62
|
+
self.assistant_id,
|
|
63
|
+
url=self.base_url,
|
|
64
|
+
sync_client=sync_client,
|
|
65
|
+
client=async_client,
|
|
66
|
+
)
|
|
67
|
+
else:
|
|
68
|
+
self.remote_graph = RemoteGraph(
|
|
69
|
+
self.graph_name,
|
|
70
|
+
url=self.base_url,
|
|
71
|
+
sync_client=sync_client,
|
|
72
|
+
client=async_client,
|
|
73
|
+
)
|
|
74
|
+
except Exception as exc:
|
|
75
|
+
raise RuntimeError(
|
|
76
|
+
f"Failed to initialize LangChain RemoteGraph for {self.base_url}: {exc}"
|
|
77
|
+
)
|
|
78
|
+
logger.info(
|
|
79
|
+
"LangchainServerAdapter initialized (assistant_id=%s, graph_name=%s)",
|
|
80
|
+
self.assistant_id,
|
|
81
|
+
self.graph_name,
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
def _prepare_input(self, input: List[AgentMessage]) -> Any:
|
|
85
|
+
"""
|
|
86
|
+
Prepare input for LangChain agent from List[AgentMessage].
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
input: List[AgentMessage] containing user_message
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
Dict with messages list containing user_message
|
|
93
|
+
"""
|
|
94
|
+
logger.debug("Preparing Langchain server input from %d messages", len(input))
|
|
95
|
+
if not input or input[-1].role != "user":
|
|
96
|
+
logger.error("Langchain server input missing user message")
|
|
97
|
+
raise ValueError("No user message found in the input")
|
|
98
|
+
|
|
99
|
+
last_user_message = input[-1]
|
|
100
|
+
# Check if content list is not empty and has text
|
|
101
|
+
if not last_user_message.content:
|
|
102
|
+
logger.error("Langchain server user message missing content")
|
|
103
|
+
raise ValueError("User message has no content")
|
|
104
|
+
|
|
105
|
+
# Find the first text content item
|
|
106
|
+
text_content = None
|
|
107
|
+
for content_item in last_user_message.content:
|
|
108
|
+
if content_item.type == "text" and content_item.text:
|
|
109
|
+
text_content = content_item.text
|
|
110
|
+
break
|
|
111
|
+
|
|
112
|
+
if not text_content:
|
|
113
|
+
logger.error("Langchain server user message missing text content")
|
|
114
|
+
raise ValueError("No text content found in user message")
|
|
115
|
+
|
|
116
|
+
logger.debug(
|
|
117
|
+
"Prepared Langchain server input (text_length=%d)", len(text_content)
|
|
118
|
+
)
|
|
119
|
+
return {"messages": [HumanMessage(content=text_content).model_dump()]}
|
|
120
|
+
|
|
121
|
+
def _convert_langchain_messages_to_quraite_messages(
|
|
122
|
+
self,
|
|
123
|
+
messages: List[dict],
|
|
124
|
+
) -> List[AgentMessage]:
|
|
125
|
+
logger.debug(
|
|
126
|
+
"Converting %d Langchain server messages to quraite format", len(messages)
|
|
127
|
+
)
|
|
128
|
+
converted_messages: List[AgentMessage] = []
|
|
129
|
+
|
|
130
|
+
for msg in messages:
|
|
131
|
+
if msg.get("type") == "system":
|
|
132
|
+
converted_messages.append(
|
|
133
|
+
QuraiteSystemMessage(
|
|
134
|
+
content=[
|
|
135
|
+
MessageContentText(type="text", text=msg.get("content", ""))
|
|
136
|
+
],
|
|
137
|
+
)
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
elif msg.get("type") == "human":
|
|
141
|
+
converted_messages.append(
|
|
142
|
+
UserMessage(
|
|
143
|
+
content=[
|
|
144
|
+
MessageContentText(type="text", text=msg.get("content", ""))
|
|
145
|
+
],
|
|
146
|
+
)
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
elif msg.get("type") == "ai":
|
|
150
|
+
text_content: List[MessageContentText] = []
|
|
151
|
+
tool_calls_list: List[ToolCall] = []
|
|
152
|
+
|
|
153
|
+
# Extract text content - sometimes it's a string, sometimes a list of dicts
|
|
154
|
+
content = msg.get("content")
|
|
155
|
+
if isinstance(content, str) and content:
|
|
156
|
+
text_content.append(MessageContentText(type="text", text=content))
|
|
157
|
+
elif isinstance(content, list):
|
|
158
|
+
for content_item in content:
|
|
159
|
+
if isinstance(content_item, dict):
|
|
160
|
+
if content_item.get("type") == "text" and content_item.get(
|
|
161
|
+
"text"
|
|
162
|
+
):
|
|
163
|
+
text_content.append(
|
|
164
|
+
MessageContentText(
|
|
165
|
+
type="text", text=content_item.get("text")
|
|
166
|
+
)
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
# Extract tool calls if present
|
|
170
|
+
if msg.get("tool_calls"):
|
|
171
|
+
for tool_call in msg.get("tool_calls"):
|
|
172
|
+
if isinstance(tool_call, dict):
|
|
173
|
+
tool_calls_list.append(
|
|
174
|
+
ToolCall(
|
|
175
|
+
id=tool_call.get("id", ""),
|
|
176
|
+
name=tool_call.get("name", ""),
|
|
177
|
+
arguments=tool_call.get("args", {}),
|
|
178
|
+
)
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
converted_messages.append(
|
|
182
|
+
AssistantMessage(
|
|
183
|
+
content=text_content if text_content else None,
|
|
184
|
+
tool_calls=tool_calls_list if tool_calls_list else None,
|
|
185
|
+
)
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
elif msg.get("type") == "tool":
|
|
189
|
+
tool_content = msg.get("content", "")
|
|
190
|
+
converted_messages.append(
|
|
191
|
+
ToolMessage(
|
|
192
|
+
tool_call_id=msg.get("tool_call_id", ""),
|
|
193
|
+
content=[
|
|
194
|
+
MessageContentText(type="text", text=str(tool_content))
|
|
195
|
+
],
|
|
196
|
+
)
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
else:
|
|
200
|
+
# Skip unsupported message types
|
|
201
|
+
continue
|
|
202
|
+
|
|
203
|
+
logger.info(
|
|
204
|
+
"Langchain server message conversion produced %d messages",
|
|
205
|
+
len(converted_messages),
|
|
206
|
+
)
|
|
207
|
+
return converted_messages
|
|
208
|
+
|
|
209
|
+
async def ainvoke(
|
|
210
|
+
self,
|
|
211
|
+
input: List[AgentMessage],
|
|
212
|
+
session_id: Annotated[Union[str, None], "Thread ID used by LangChain API"],
|
|
213
|
+
) -> AgentInvocationResponse:
|
|
214
|
+
agent_messages = []
|
|
215
|
+
agent_input = self._prepare_input(input)
|
|
216
|
+
if session_id:
|
|
217
|
+
config = {"configurable": {"thread_id": session_id}}
|
|
218
|
+
else:
|
|
219
|
+
config = {}
|
|
220
|
+
|
|
221
|
+
try:
|
|
222
|
+
logger.info("Langchain server ainvoke called (session_id=%s)", session_id)
|
|
223
|
+
async for event in self.remote_graph.astream(agent_input, config=config):
|
|
224
|
+
for _, result in event.items():
|
|
225
|
+
if result.get("messages"):
|
|
226
|
+
logger.debug(
|
|
227
|
+
"Langchain server received %d messages from stream chunk",
|
|
228
|
+
len(result.get("messages")),
|
|
229
|
+
)
|
|
230
|
+
agent_messages += result.get("messages")
|
|
231
|
+
|
|
232
|
+
except Exception as e:
|
|
233
|
+
logger.exception("Error invoking Langchain remote graph")
|
|
234
|
+
raise RuntimeError(f"Error invoking LangChain agent: {e}") from e
|
|
235
|
+
|
|
236
|
+
try:
|
|
237
|
+
# Convert to List[AgentMessage]
|
|
238
|
+
agent_trajectory = self._convert_langchain_messages_to_quraite_messages(
|
|
239
|
+
agent_messages
|
|
240
|
+
)
|
|
241
|
+
logger.info(
|
|
242
|
+
"Langchain server ainvoke produced %d trajectory messages",
|
|
243
|
+
len(agent_trajectory),
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
return AgentInvocationResponse(
|
|
247
|
+
agent_trajectory=agent_trajectory,
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
except ValueError:
|
|
251
|
+
logger.exception("Langchain server conversion to AgentMessage failed")
|
|
252
|
+
return AgentInvocationResponse()
|