datarobot-genai 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- datarobot_genai/__init__.py +19 -0
- datarobot_genai/core/__init__.py +0 -0
- datarobot_genai/core/agents/__init__.py +43 -0
- datarobot_genai/core/agents/base.py +195 -0
- datarobot_genai/core/chat/__init__.py +19 -0
- datarobot_genai/core/chat/auth.py +146 -0
- datarobot_genai/core/chat/client.py +178 -0
- datarobot_genai/core/chat/responses.py +297 -0
- datarobot_genai/core/cli/__init__.py +18 -0
- datarobot_genai/core/cli/agent_environment.py +47 -0
- datarobot_genai/core/cli/agent_kernel.py +211 -0
- datarobot_genai/core/custom_model.py +141 -0
- datarobot_genai/core/mcp/__init__.py +0 -0
- datarobot_genai/core/mcp/common.py +218 -0
- datarobot_genai/core/telemetry_agent.py +126 -0
- datarobot_genai/core/utils/__init__.py +3 -0
- datarobot_genai/core/utils/auth.py +234 -0
- datarobot_genai/core/utils/urls.py +64 -0
- datarobot_genai/crewai/__init__.py +24 -0
- datarobot_genai/crewai/agent.py +42 -0
- datarobot_genai/crewai/base.py +159 -0
- datarobot_genai/crewai/events.py +117 -0
- datarobot_genai/crewai/mcp.py +59 -0
- datarobot_genai/drmcp/__init__.py +78 -0
- datarobot_genai/drmcp/core/__init__.py +13 -0
- datarobot_genai/drmcp/core/auth.py +165 -0
- datarobot_genai/drmcp/core/clients.py +180 -0
- datarobot_genai/drmcp/core/config.py +250 -0
- datarobot_genai/drmcp/core/config_utils.py +174 -0
- datarobot_genai/drmcp/core/constants.py +18 -0
- datarobot_genai/drmcp/core/credentials.py +190 -0
- datarobot_genai/drmcp/core/dr_mcp_server.py +316 -0
- datarobot_genai/drmcp/core/dr_mcp_server_logo.py +136 -0
- datarobot_genai/drmcp/core/dynamic_prompts/__init__.py +13 -0
- datarobot_genai/drmcp/core/dynamic_prompts/controllers.py +130 -0
- datarobot_genai/drmcp/core/dynamic_prompts/dr_lib.py +128 -0
- datarobot_genai/drmcp/core/dynamic_prompts/register.py +206 -0
- datarobot_genai/drmcp/core/dynamic_prompts/utils.py +33 -0
- datarobot_genai/drmcp/core/dynamic_tools/__init__.py +14 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/__init__.py +0 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/__init__.py +14 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/base.py +72 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/default.py +82 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/drum.py +238 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/config.py +228 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/controllers.py +63 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/metadata.py +162 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/register.py +87 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/schemas/drum_agentic_fallback_schema.json +36 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/schemas/drum_prediction_fallback_schema.json +10 -0
- datarobot_genai/drmcp/core/dynamic_tools/register.py +254 -0
- datarobot_genai/drmcp/core/dynamic_tools/schema.py +532 -0
- datarobot_genai/drmcp/core/exceptions.py +25 -0
- datarobot_genai/drmcp/core/logging.py +98 -0
- datarobot_genai/drmcp/core/mcp_instance.py +542 -0
- datarobot_genai/drmcp/core/mcp_server_tools.py +129 -0
- datarobot_genai/drmcp/core/memory_management/__init__.py +13 -0
- datarobot_genai/drmcp/core/memory_management/manager.py +820 -0
- datarobot_genai/drmcp/core/memory_management/memory_tools.py +201 -0
- datarobot_genai/drmcp/core/routes.py +436 -0
- datarobot_genai/drmcp/core/routes_utils.py +30 -0
- datarobot_genai/drmcp/core/server_life_cycle.py +107 -0
- datarobot_genai/drmcp/core/telemetry.py +424 -0
- datarobot_genai/drmcp/core/tool_filter.py +108 -0
- datarobot_genai/drmcp/core/utils.py +131 -0
- datarobot_genai/drmcp/server.py +19 -0
- datarobot_genai/drmcp/test_utils/__init__.py +13 -0
- datarobot_genai/drmcp/test_utils/integration_mcp_server.py +102 -0
- datarobot_genai/drmcp/test_utils/mcp_utils_ete.py +96 -0
- datarobot_genai/drmcp/test_utils/mcp_utils_integration.py +94 -0
- datarobot_genai/drmcp/test_utils/openai_llm_mcp_client.py +234 -0
- datarobot_genai/drmcp/test_utils/tool_base_ete.py +151 -0
- datarobot_genai/drmcp/test_utils/utils.py +91 -0
- datarobot_genai/drmcp/tools/__init__.py +14 -0
- datarobot_genai/drmcp/tools/predictive/__init__.py +27 -0
- datarobot_genai/drmcp/tools/predictive/data.py +97 -0
- datarobot_genai/drmcp/tools/predictive/deployment.py +91 -0
- datarobot_genai/drmcp/tools/predictive/deployment_info.py +392 -0
- datarobot_genai/drmcp/tools/predictive/model.py +148 -0
- datarobot_genai/drmcp/tools/predictive/predict.py +254 -0
- datarobot_genai/drmcp/tools/predictive/predict_realtime.py +307 -0
- datarobot_genai/drmcp/tools/predictive/project.py +72 -0
- datarobot_genai/drmcp/tools/predictive/training.py +651 -0
- datarobot_genai/langgraph/__init__.py +0 -0
- datarobot_genai/langgraph/agent.py +341 -0
- datarobot_genai/langgraph/mcp.py +73 -0
- datarobot_genai/llama_index/__init__.py +16 -0
- datarobot_genai/llama_index/agent.py +50 -0
- datarobot_genai/llama_index/base.py +299 -0
- datarobot_genai/llama_index/mcp.py +79 -0
- datarobot_genai/nat/__init__.py +0 -0
- datarobot_genai/nat/agent.py +258 -0
- datarobot_genai/nat/datarobot_llm_clients.py +249 -0
- datarobot_genai/nat/datarobot_llm_providers.py +130 -0
- datarobot_genai/py.typed +0 -0
- datarobot_genai-0.2.0.dist-info/METADATA +139 -0
- datarobot_genai-0.2.0.dist-info/RECORD +101 -0
- datarobot_genai-0.2.0.dist-info/WHEEL +4 -0
- datarobot_genai-0.2.0.dist-info/entry_points.txt +3 -0
- datarobot_genai-0.2.0.dist-info/licenses/AUTHORS +2 -0
- datarobot_genai-0.2.0.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,341 @@
|
|
|
1
|
+
# Copyright 2025 DataRobot, Inc. and its affiliates.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
import abc
|
|
15
|
+
import logging
|
|
16
|
+
from collections.abc import AsyncGenerator
|
|
17
|
+
from typing import Any
|
|
18
|
+
from typing import cast
|
|
19
|
+
|
|
20
|
+
from ag_ui.core import Event
|
|
21
|
+
from ag_ui.core import EventType
|
|
22
|
+
from ag_ui.core import TextMessageContentEvent
|
|
23
|
+
from ag_ui.core import TextMessageEndEvent
|
|
24
|
+
from ag_ui.core import TextMessageStartEvent
|
|
25
|
+
from ag_ui.core import ToolCallArgsEvent
|
|
26
|
+
from ag_ui.core import ToolCallEndEvent
|
|
27
|
+
from ag_ui.core import ToolCallResultEvent
|
|
28
|
+
from ag_ui.core import ToolCallStartEvent
|
|
29
|
+
from langchain.tools import BaseTool
|
|
30
|
+
from langchain_core.messages import AIMessageChunk
|
|
31
|
+
from langchain_core.messages import ToolMessage
|
|
32
|
+
from langchain_core.prompts import ChatPromptTemplate
|
|
33
|
+
from langgraph.graph import MessagesState
|
|
34
|
+
from langgraph.graph import StateGraph
|
|
35
|
+
from langgraph.types import Command
|
|
36
|
+
from openai.types.chat import CompletionCreateParams
|
|
37
|
+
from ragas import MultiTurnSample
|
|
38
|
+
from ragas.integrations.langgraph import convert_to_ragas_messages
|
|
39
|
+
|
|
40
|
+
from datarobot_genai.core.agents.base import BaseAgent
|
|
41
|
+
from datarobot_genai.core.agents.base import InvokeReturn
|
|
42
|
+
from datarobot_genai.core.agents.base import UsageMetrics
|
|
43
|
+
from datarobot_genai.core.agents.base import extract_user_prompt_content
|
|
44
|
+
from datarobot_genai.core.agents.base import is_streaming
|
|
45
|
+
from datarobot_genai.langgraph.mcp import mcp_tools_context
|
|
46
|
+
|
|
47
|
+
logger = logging.getLogger(__name__)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class LangGraphAgent(BaseAgent[BaseTool], abc.ABC):
|
|
51
|
+
@property
|
|
52
|
+
@abc.abstractmethod
|
|
53
|
+
def workflow(self) -> StateGraph[MessagesState]:
|
|
54
|
+
raise NotImplementedError("Not implemented")
|
|
55
|
+
|
|
56
|
+
@property
|
|
57
|
+
@abc.abstractmethod
|
|
58
|
+
def prompt_template(self) -> ChatPromptTemplate:
|
|
59
|
+
raise NotImplementedError("Not implemented")
|
|
60
|
+
|
|
61
|
+
@property
|
|
62
|
+
def langgraph_config(self) -> dict[str, Any]:
|
|
63
|
+
return {
|
|
64
|
+
"recursion_limit": 150, # Maximum number of steps to take in the graph
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
def convert_input_message(self, completion_create_params: CompletionCreateParams) -> Command:
|
|
68
|
+
user_prompt = extract_user_prompt_content(completion_create_params)
|
|
69
|
+
command = Command( # type: ignore[var-annotated]
|
|
70
|
+
update={
|
|
71
|
+
"messages": self.prompt_template.invoke(user_prompt).to_messages(),
|
|
72
|
+
},
|
|
73
|
+
)
|
|
74
|
+
return command
|
|
75
|
+
|
|
76
|
+
async def invoke(self, completion_create_params: CompletionCreateParams) -> InvokeReturn:
|
|
77
|
+
"""Run the agent with the provided completion parameters.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
completion_create_params: The completion request parameters including input topic and
|
|
81
|
+
settings.
|
|
82
|
+
|
|
83
|
+
Returns
|
|
84
|
+
-------
|
|
85
|
+
For streaming requests, returns a generator yielding tuples of (response_text,
|
|
86
|
+
pipeline_interactions, usage_metrics).
|
|
87
|
+
For non-streaming requests, returns a single tuple of (response_text,
|
|
88
|
+
pipeline_interactions, usage_metrics).
|
|
89
|
+
"""
|
|
90
|
+
# For streaming, we need to keep the context alive until the generator is exhausted
|
|
91
|
+
if completion_create_params.get("stream"):
|
|
92
|
+
|
|
93
|
+
async def wrapped_generator() -> AsyncGenerator[
|
|
94
|
+
tuple[str, Any | None, UsageMetrics], None
|
|
95
|
+
]:
|
|
96
|
+
try:
|
|
97
|
+
async with mcp_tools_context(
|
|
98
|
+
authorization_context=self._authorization_context,
|
|
99
|
+
forwarded_headers=self.forwarded_headers,
|
|
100
|
+
) as mcp_tools:
|
|
101
|
+
self.set_mcp_tools(mcp_tools)
|
|
102
|
+
result = await self._invoke(completion_create_params)
|
|
103
|
+
|
|
104
|
+
# Yield all items from the result generator
|
|
105
|
+
# The context will be closed when this generator is exhausted
|
|
106
|
+
# Cast to async generator since we know stream=True means it's a generator
|
|
107
|
+
result_generator = cast(
|
|
108
|
+
AsyncGenerator[tuple[str, Any | None, UsageMetrics], None], result
|
|
109
|
+
)
|
|
110
|
+
async for item in result_generator:
|
|
111
|
+
yield item
|
|
112
|
+
except RuntimeError as e:
|
|
113
|
+
error_message = str(e).lower()
|
|
114
|
+
if "different task" in error_message and "cancel scope" in error_message:
|
|
115
|
+
# Due to anyio task group constraints when consuming async generators
|
|
116
|
+
# across task boundaries, we cannot always clean up properly.
|
|
117
|
+
# The underlying HTTP client/connection pool should handle resource cleanup
|
|
118
|
+
# via timeouts and connection pooling, but this
|
|
119
|
+
# may lead to delayed resource release.
|
|
120
|
+
logger.debug(
|
|
121
|
+
"MCP context cleanup attempted in different task. "
|
|
122
|
+
"This is a limitation when consuming async generators "
|
|
123
|
+
"across task boundaries."
|
|
124
|
+
)
|
|
125
|
+
else:
|
|
126
|
+
# Re-raise if it's a different RuntimeError
|
|
127
|
+
raise
|
|
128
|
+
|
|
129
|
+
return wrapped_generator()
|
|
130
|
+
else:
|
|
131
|
+
# For non-streaming, use async with directly
|
|
132
|
+
async with mcp_tools_context(
|
|
133
|
+
authorization_context=self._authorization_context,
|
|
134
|
+
forwarded_headers=self.forwarded_headers,
|
|
135
|
+
) as mcp_tools:
|
|
136
|
+
self.set_mcp_tools(mcp_tools)
|
|
137
|
+
result = await self._invoke(completion_create_params)
|
|
138
|
+
|
|
139
|
+
return result
|
|
140
|
+
|
|
141
|
+
async def _invoke(self, completion_create_params: CompletionCreateParams) -> InvokeReturn:
|
|
142
|
+
input_command = self.convert_input_message(completion_create_params)
|
|
143
|
+
logger.info(
|
|
144
|
+
f"Running a langgraph agent with a command: {input_command}",
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
# Create and invoke the Langgraph Agentic Workflow with the inputs
|
|
148
|
+
langgraph_execution_graph = self.workflow.compile()
|
|
149
|
+
|
|
150
|
+
graph_stream = langgraph_execution_graph.astream(
|
|
151
|
+
input=input_command,
|
|
152
|
+
config=self.langgraph_config,
|
|
153
|
+
debug=self.verbose,
|
|
154
|
+
# Streaming updates and messages from all the nodes
|
|
155
|
+
stream_mode=["updates", "messages"],
|
|
156
|
+
subgraphs=True,
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
usage_metrics: UsageMetrics = {
|
|
160
|
+
"completion_tokens": 0,
|
|
161
|
+
"prompt_tokens": 0,
|
|
162
|
+
"total_tokens": 0,
|
|
163
|
+
}
|
|
164
|
+
|
|
165
|
+
# The following code demonstrate both a synchronous and streaming response.
|
|
166
|
+
# You can choose one or the other based on your use case, they function the same.
|
|
167
|
+
# The main difference is returning a generator for streaming or a final response for sync.
|
|
168
|
+
if is_streaming(completion_create_params):
|
|
169
|
+
# Streaming response: yield each message as it is generated
|
|
170
|
+
return self._stream_generator(graph_stream, usage_metrics)
|
|
171
|
+
else:
|
|
172
|
+
# Synchronous response: collect all events and return the final message
|
|
173
|
+
events: list[dict[str, Any]] = [
|
|
174
|
+
event # type: ignore[misc]
|
|
175
|
+
async for _, mode, event in graph_stream
|
|
176
|
+
if mode == "updates"
|
|
177
|
+
]
|
|
178
|
+
|
|
179
|
+
# Accumulate the usage metrics from the updates
|
|
180
|
+
for update in events:
|
|
181
|
+
current_node = next(iter(update))
|
|
182
|
+
node_data = update[current_node]
|
|
183
|
+
current_usage = node_data.get("usage", {}) if node_data is not None else {}
|
|
184
|
+
if current_usage:
|
|
185
|
+
usage_metrics["total_tokens"] += current_usage.get("total_tokens", 0)
|
|
186
|
+
usage_metrics["prompt_tokens"] += current_usage.get("prompt_tokens", 0)
|
|
187
|
+
usage_metrics["completion_tokens"] += current_usage.get("completion_tokens", 0)
|
|
188
|
+
|
|
189
|
+
pipeline_interactions = self.create_pipeline_interactions_from_events(events)
|
|
190
|
+
|
|
191
|
+
# Extract the final event from the graph stream as the synchronous response
|
|
192
|
+
last_event = events[-1]
|
|
193
|
+
node_name = next(iter(last_event))
|
|
194
|
+
node_data = last_event[node_name]
|
|
195
|
+
response_text = (
|
|
196
|
+
str(node_data["messages"][-1].content)
|
|
197
|
+
if node_data is not None and "messages" in node_data
|
|
198
|
+
else ""
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
return response_text, pipeline_interactions, usage_metrics
|
|
202
|
+
|
|
203
|
+
async def _stream_generator(
|
|
204
|
+
self, graph_stream: AsyncGenerator[tuple[Any, str, Any], None], usage_metrics: UsageMetrics
|
|
205
|
+
) -> AsyncGenerator[tuple[str | Event, MultiTurnSample | None, UsageMetrics], None]:
|
|
206
|
+
# Iterate over the graph stream. For message events, yield the content.
|
|
207
|
+
# For update events, accumulate the usage metrics.
|
|
208
|
+
events = []
|
|
209
|
+
current_message_id = None
|
|
210
|
+
tool_call_id = ""
|
|
211
|
+
async for _, mode, event in graph_stream:
|
|
212
|
+
if mode == "messages":
|
|
213
|
+
message_event: tuple[AIMessageChunk | ToolMessage, dict[str, Any]] = event # type: ignore[assignment]
|
|
214
|
+
message = message_event[0]
|
|
215
|
+
if isinstance(message, ToolMessage):
|
|
216
|
+
yield (
|
|
217
|
+
ToolCallEndEvent(
|
|
218
|
+
type=EventType.TOOL_CALL_END, tool_call_id=message.tool_call_id
|
|
219
|
+
),
|
|
220
|
+
None,
|
|
221
|
+
usage_metrics,
|
|
222
|
+
)
|
|
223
|
+
yield (
|
|
224
|
+
ToolCallResultEvent(
|
|
225
|
+
type=EventType.TOOL_CALL_RESULT,
|
|
226
|
+
message_id=message.id,
|
|
227
|
+
tool_call_id=message.tool_call_id,
|
|
228
|
+
content=message.content,
|
|
229
|
+
role="tool",
|
|
230
|
+
),
|
|
231
|
+
None,
|
|
232
|
+
usage_metrics,
|
|
233
|
+
)
|
|
234
|
+
tool_call_id = ""
|
|
235
|
+
elif isinstance(message, AIMessageChunk):
|
|
236
|
+
if message.tool_call_chunks:
|
|
237
|
+
# This is a tool call message
|
|
238
|
+
for tool_call_chunk in message.tool_call_chunks:
|
|
239
|
+
if name := tool_call_chunk.get("name"):
|
|
240
|
+
# Its a tool call start message
|
|
241
|
+
tool_call_id = tool_call_chunk["id"]
|
|
242
|
+
yield (
|
|
243
|
+
ToolCallStartEvent(
|
|
244
|
+
type=EventType.TOOL_CALL_START,
|
|
245
|
+
tool_call_id=tool_call_id,
|
|
246
|
+
tool_call_name=name,
|
|
247
|
+
parent_message_id=message.id,
|
|
248
|
+
),
|
|
249
|
+
None,
|
|
250
|
+
usage_metrics,
|
|
251
|
+
)
|
|
252
|
+
elif args := tool_call_chunk.get("args"):
|
|
253
|
+
# Its a tool call args message
|
|
254
|
+
yield (
|
|
255
|
+
ToolCallArgsEvent(
|
|
256
|
+
type=EventType.TOOL_CALL_ARGS,
|
|
257
|
+
# Its empty when the tool chunk is not a start message
|
|
258
|
+
# So we use the tool call id from a previous start message
|
|
259
|
+
tool_call_id=tool_call_id,
|
|
260
|
+
delta=args,
|
|
261
|
+
),
|
|
262
|
+
None,
|
|
263
|
+
usage_metrics,
|
|
264
|
+
)
|
|
265
|
+
elif message.content:
|
|
266
|
+
# Its a text message
|
|
267
|
+
# Handle the start and end of the text message
|
|
268
|
+
if message.id != current_message_id:
|
|
269
|
+
if current_message_id:
|
|
270
|
+
yield (
|
|
271
|
+
TextMessageEndEvent(
|
|
272
|
+
type=EventType.TEXT_MESSAGE_END,
|
|
273
|
+
message_id=current_message_id,
|
|
274
|
+
),
|
|
275
|
+
None,
|
|
276
|
+
usage_metrics,
|
|
277
|
+
)
|
|
278
|
+
current_message_id = message.id
|
|
279
|
+
yield (
|
|
280
|
+
TextMessageStartEvent(
|
|
281
|
+
type=EventType.TEXT_MESSAGE_START,
|
|
282
|
+
message_id=message.id,
|
|
283
|
+
role="assistant",
|
|
284
|
+
),
|
|
285
|
+
None,
|
|
286
|
+
usage_metrics,
|
|
287
|
+
)
|
|
288
|
+
yield (
|
|
289
|
+
TextMessageContentEvent(
|
|
290
|
+
type=EventType.TEXT_MESSAGE_CONTENT,
|
|
291
|
+
message_id=message.id,
|
|
292
|
+
delta=message.content,
|
|
293
|
+
),
|
|
294
|
+
None,
|
|
295
|
+
usage_metrics,
|
|
296
|
+
)
|
|
297
|
+
else:
|
|
298
|
+
raise ValueError(f"Invalid message event: {message_event}")
|
|
299
|
+
elif mode == "updates":
|
|
300
|
+
update_event: dict[str, Any] = event # type: ignore[assignment]
|
|
301
|
+
events.append(update_event)
|
|
302
|
+
current_node = next(iter(update_event))
|
|
303
|
+
node_data = update_event[current_node]
|
|
304
|
+
current_usage = node_data.get("usage", {}) if node_data is not None else {}
|
|
305
|
+
if current_usage:
|
|
306
|
+
usage_metrics["total_tokens"] += current_usage.get("total_tokens", 0)
|
|
307
|
+
usage_metrics["prompt_tokens"] += current_usage.get("prompt_tokens", 0)
|
|
308
|
+
usage_metrics["completion_tokens"] += current_usage.get("completion_tokens", 0)
|
|
309
|
+
if current_message_id:
|
|
310
|
+
yield (
|
|
311
|
+
TextMessageEndEvent(
|
|
312
|
+
type=EventType.TEXT_MESSAGE_END,
|
|
313
|
+
message_id=current_message_id,
|
|
314
|
+
),
|
|
315
|
+
None,
|
|
316
|
+
usage_metrics,
|
|
317
|
+
)
|
|
318
|
+
current_message_id = None
|
|
319
|
+
|
|
320
|
+
# Create a list of events from the event listener
|
|
321
|
+
pipeline_interactions = self.create_pipeline_interactions_from_events(events)
|
|
322
|
+
|
|
323
|
+
# yield the final response indicating completion
|
|
324
|
+
yield "", pipeline_interactions, usage_metrics
|
|
325
|
+
|
|
326
|
+
@classmethod
|
|
327
|
+
def create_pipeline_interactions_from_events(
|
|
328
|
+
cls,
|
|
329
|
+
events: list[dict[str, Any]] | None,
|
|
330
|
+
) -> MultiTurnSample | None:
|
|
331
|
+
"""Convert a list of LangGraph events into Ragas MultiTurnSample."""
|
|
332
|
+
if not events:
|
|
333
|
+
return None
|
|
334
|
+
messages = []
|
|
335
|
+
for e in events:
|
|
336
|
+
for _, v in e.items():
|
|
337
|
+
if v is not None:
|
|
338
|
+
messages.extend(v.get("messages", []))
|
|
339
|
+
messages = [m for m in messages if not isinstance(m, ToolMessage)]
|
|
340
|
+
ragas_trace = convert_to_ragas_messages(messages)
|
|
341
|
+
return MultiTurnSample(user_input=ragas_trace)
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
# Copyright 2025 DataRobot, Inc. and its affiliates.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from collections.abc import AsyncGenerator
|
|
16
|
+
from contextlib import asynccontextmanager
|
|
17
|
+
from typing import Any
|
|
18
|
+
|
|
19
|
+
from langchain.tools import BaseTool
|
|
20
|
+
from langchain_mcp_adapters.sessions import SSEConnection
|
|
21
|
+
from langchain_mcp_adapters.sessions import StreamableHttpConnection
|
|
22
|
+
from langchain_mcp_adapters.sessions import create_session
|
|
23
|
+
from langchain_mcp_adapters.tools import load_mcp_tools
|
|
24
|
+
|
|
25
|
+
from datarobot_genai.core.mcp.common import MCPConfig
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@asynccontextmanager
|
|
29
|
+
async def mcp_tools_context(
|
|
30
|
+
authorization_context: dict[str, Any] | None = None,
|
|
31
|
+
forwarded_headers: dict[str, str] | None = None,
|
|
32
|
+
) -> AsyncGenerator[list[BaseTool], None]:
|
|
33
|
+
"""Yield a list of LangChain BaseTool instances loaded via MCP.
|
|
34
|
+
|
|
35
|
+
If no configuration or loading fails, yields an empty list without raising.
|
|
36
|
+
|
|
37
|
+
Parameters
|
|
38
|
+
----------
|
|
39
|
+
authorization_context : dict[str, Any] | None
|
|
40
|
+
Authorization context to use for MCP connections
|
|
41
|
+
forwarded_headers : dict[str, str] | None
|
|
42
|
+
Forwarded headers, e.g. x-datarobot-api-key to use for MCP authentication
|
|
43
|
+
"""
|
|
44
|
+
mcp_config = MCPConfig(
|
|
45
|
+
authorization_context=authorization_context,
|
|
46
|
+
forwarded_headers=forwarded_headers,
|
|
47
|
+
)
|
|
48
|
+
server_config = mcp_config.server_config
|
|
49
|
+
|
|
50
|
+
if not server_config:
|
|
51
|
+
print("No MCP server configured, using empty tools list", flush=True)
|
|
52
|
+
yield []
|
|
53
|
+
return
|
|
54
|
+
|
|
55
|
+
url = server_config["url"]
|
|
56
|
+
print(f"Connecting to MCP server: {url}", flush=True)
|
|
57
|
+
|
|
58
|
+
# Pop transport from server_config to avoid passing it twice
|
|
59
|
+
# Use .pop() with default to never error
|
|
60
|
+
transport = server_config.pop("transport", "streamable-http")
|
|
61
|
+
|
|
62
|
+
if transport in ["streamable-http", "streamable_http"]:
|
|
63
|
+
connection = StreamableHttpConnection(transport="streamable_http", **server_config)
|
|
64
|
+
elif transport == "sse":
|
|
65
|
+
connection = SSEConnection(transport="sse", **server_config)
|
|
66
|
+
else:
|
|
67
|
+
raise RuntimeError("Unsupported MCP transport specified.")
|
|
68
|
+
|
|
69
|
+
async with create_session(connection=connection) as session:
|
|
70
|
+
# Use the connection to load available MCP tools
|
|
71
|
+
tools = await load_mcp_tools(session=session)
|
|
72
|
+
print(f"Successfully loaded {len(tools)} MCP tools", flush=True)
|
|
73
|
+
yield tools
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
"""LlamaIndex utilities and helpers."""
|
|
2
|
+
|
|
3
|
+
from datarobot_genai.core.mcp.common import MCPConfig
|
|
4
|
+
|
|
5
|
+
from .agent import DataRobotLiteLLM
|
|
6
|
+
from .agent import create_pipeline_interactions_from_events
|
|
7
|
+
from .base import LlamaIndexAgent
|
|
8
|
+
from .mcp import load_mcp_tools
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"DataRobotLiteLLM",
|
|
12
|
+
"create_pipeline_interactions_from_events",
|
|
13
|
+
"LlamaIndexAgent",
|
|
14
|
+
"load_mcp_tools",
|
|
15
|
+
"MCPConfig",
|
|
16
|
+
]
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
# Copyright 2025 DataRobot, Inc. and its affiliates.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import cast
|
|
16
|
+
|
|
17
|
+
from llama_index.core.base.llms.types import LLMMetadata
|
|
18
|
+
from llama_index.core.workflow import Event
|
|
19
|
+
from llama_index.llms.litellm import LiteLLM
|
|
20
|
+
from ragas import MultiTurnSample
|
|
21
|
+
from ragas.integrations.llama_index import convert_to_ragas_messages
|
|
22
|
+
from ragas.messages import AIMessage
|
|
23
|
+
from ragas.messages import HumanMessage
|
|
24
|
+
from ragas.messages import ToolMessage
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class DataRobotLiteLLM(LiteLLM):
|
|
28
|
+
"""LiteLLM wrapper providing chat/function capability metadata for LlamaIndex."""
|
|
29
|
+
|
|
30
|
+
@property
|
|
31
|
+
def metadata(self) -> LLMMetadata:
|
|
32
|
+
"""Return LLM metadata."""
|
|
33
|
+
return LLMMetadata(
|
|
34
|
+
context_window=128000,
|
|
35
|
+
num_output=self.max_tokens or -1,
|
|
36
|
+
is_chat_model=True,
|
|
37
|
+
is_function_calling_model=True,
|
|
38
|
+
model_name=self.model,
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def create_pipeline_interactions_from_events(
|
|
43
|
+
events: list[Event] | None,
|
|
44
|
+
) -> MultiTurnSample | None:
|
|
45
|
+
if not events:
|
|
46
|
+
return None
|
|
47
|
+
# convert_to_ragas_messages expects a list[Event]
|
|
48
|
+
ragas_trace = convert_to_ragas_messages(list(events))
|
|
49
|
+
ragas_messages = cast(list[HumanMessage | AIMessage | ToolMessage], ragas_trace)
|
|
50
|
+
return MultiTurnSample(user_input=ragas_messages)
|