datarobot-genai 0.2.31__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- datarobot_genai/__init__.py +19 -0
- datarobot_genai/core/__init__.py +0 -0
- datarobot_genai/core/agents/__init__.py +43 -0
- datarobot_genai/core/agents/base.py +195 -0
- datarobot_genai/core/chat/__init__.py +19 -0
- datarobot_genai/core/chat/auth.py +146 -0
- datarobot_genai/core/chat/client.py +178 -0
- datarobot_genai/core/chat/responses.py +297 -0
- datarobot_genai/core/cli/__init__.py +18 -0
- datarobot_genai/core/cli/agent_environment.py +47 -0
- datarobot_genai/core/cli/agent_kernel.py +211 -0
- datarobot_genai/core/custom_model.py +141 -0
- datarobot_genai/core/mcp/__init__.py +0 -0
- datarobot_genai/core/mcp/common.py +218 -0
- datarobot_genai/core/telemetry_agent.py +126 -0
- datarobot_genai/core/utils/__init__.py +3 -0
- datarobot_genai/core/utils/auth.py +234 -0
- datarobot_genai/core/utils/urls.py +64 -0
- datarobot_genai/crewai/__init__.py +24 -0
- datarobot_genai/crewai/agent.py +42 -0
- datarobot_genai/crewai/base.py +159 -0
- datarobot_genai/crewai/events.py +117 -0
- datarobot_genai/crewai/mcp.py +59 -0
- datarobot_genai/drmcp/__init__.py +78 -0
- datarobot_genai/drmcp/core/__init__.py +13 -0
- datarobot_genai/drmcp/core/auth.py +165 -0
- datarobot_genai/drmcp/core/clients.py +180 -0
- datarobot_genai/drmcp/core/config.py +364 -0
- datarobot_genai/drmcp/core/config_utils.py +174 -0
- datarobot_genai/drmcp/core/constants.py +18 -0
- datarobot_genai/drmcp/core/credentials.py +190 -0
- datarobot_genai/drmcp/core/dr_mcp_server.py +350 -0
- datarobot_genai/drmcp/core/dr_mcp_server_logo.py +136 -0
- datarobot_genai/drmcp/core/dynamic_prompts/__init__.py +13 -0
- datarobot_genai/drmcp/core/dynamic_prompts/controllers.py +130 -0
- datarobot_genai/drmcp/core/dynamic_prompts/dr_lib.py +70 -0
- datarobot_genai/drmcp/core/dynamic_prompts/register.py +205 -0
- datarobot_genai/drmcp/core/dynamic_prompts/utils.py +33 -0
- datarobot_genai/drmcp/core/dynamic_tools/__init__.py +14 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/__init__.py +0 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/__init__.py +14 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/base.py +72 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/default.py +82 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/drum.py +238 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/config.py +228 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/controllers.py +63 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/metadata.py +162 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/register.py +87 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/schemas/drum_agentic_fallback_schema.json +36 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/schemas/drum_prediction_fallback_schema.json +10 -0
- datarobot_genai/drmcp/core/dynamic_tools/register.py +254 -0
- datarobot_genai/drmcp/core/dynamic_tools/schema.py +532 -0
- datarobot_genai/drmcp/core/exceptions.py +25 -0
- datarobot_genai/drmcp/core/logging.py +98 -0
- datarobot_genai/drmcp/core/mcp_instance.py +515 -0
- datarobot_genai/drmcp/core/memory_management/__init__.py +13 -0
- datarobot_genai/drmcp/core/memory_management/manager.py +820 -0
- datarobot_genai/drmcp/core/memory_management/memory_tools.py +201 -0
- datarobot_genai/drmcp/core/routes.py +439 -0
- datarobot_genai/drmcp/core/routes_utils.py +30 -0
- datarobot_genai/drmcp/core/server_life_cycle.py +107 -0
- datarobot_genai/drmcp/core/telemetry.py +424 -0
- datarobot_genai/drmcp/core/tool_config.py +111 -0
- datarobot_genai/drmcp/core/tool_filter.py +117 -0
- datarobot_genai/drmcp/core/utils.py +138 -0
- datarobot_genai/drmcp/server.py +19 -0
- datarobot_genai/drmcp/test_utils/__init__.py +13 -0
- datarobot_genai/drmcp/test_utils/clients/__init__.py +0 -0
- datarobot_genai/drmcp/test_utils/clients/anthropic.py +68 -0
- datarobot_genai/drmcp/test_utils/clients/base.py +300 -0
- datarobot_genai/drmcp/test_utils/clients/dr_gateway.py +58 -0
- datarobot_genai/drmcp/test_utils/clients/openai.py +68 -0
- datarobot_genai/drmcp/test_utils/elicitation_test_tool.py +89 -0
- datarobot_genai/drmcp/test_utils/integration_mcp_server.py +109 -0
- datarobot_genai/drmcp/test_utils/mcp_utils_ete.py +133 -0
- datarobot_genai/drmcp/test_utils/mcp_utils_integration.py +107 -0
- datarobot_genai/drmcp/test_utils/test_interactive.py +205 -0
- datarobot_genai/drmcp/test_utils/tool_base_ete.py +220 -0
- datarobot_genai/drmcp/test_utils/utils.py +91 -0
- datarobot_genai/drmcp/tools/__init__.py +14 -0
- datarobot_genai/drmcp/tools/clients/__init__.py +14 -0
- datarobot_genai/drmcp/tools/clients/atlassian.py +188 -0
- datarobot_genai/drmcp/tools/clients/confluence.py +584 -0
- datarobot_genai/drmcp/tools/clients/gdrive.py +832 -0
- datarobot_genai/drmcp/tools/clients/jira.py +334 -0
- datarobot_genai/drmcp/tools/clients/microsoft_graph.py +479 -0
- datarobot_genai/drmcp/tools/clients/s3.py +28 -0
- datarobot_genai/drmcp/tools/confluence/__init__.py +14 -0
- datarobot_genai/drmcp/tools/confluence/tools.py +321 -0
- datarobot_genai/drmcp/tools/gdrive/__init__.py +0 -0
- datarobot_genai/drmcp/tools/gdrive/tools.py +347 -0
- datarobot_genai/drmcp/tools/jira/__init__.py +14 -0
- datarobot_genai/drmcp/tools/jira/tools.py +243 -0
- datarobot_genai/drmcp/tools/microsoft_graph/__init__.py +13 -0
- datarobot_genai/drmcp/tools/microsoft_graph/tools.py +198 -0
- datarobot_genai/drmcp/tools/predictive/__init__.py +27 -0
- datarobot_genai/drmcp/tools/predictive/data.py +133 -0
- datarobot_genai/drmcp/tools/predictive/deployment.py +91 -0
- datarobot_genai/drmcp/tools/predictive/deployment_info.py +392 -0
- datarobot_genai/drmcp/tools/predictive/model.py +148 -0
- datarobot_genai/drmcp/tools/predictive/predict.py +254 -0
- datarobot_genai/drmcp/tools/predictive/predict_realtime.py +307 -0
- datarobot_genai/drmcp/tools/predictive/project.py +90 -0
- datarobot_genai/drmcp/tools/predictive/training.py +661 -0
- datarobot_genai/langgraph/__init__.py +0 -0
- datarobot_genai/langgraph/agent.py +341 -0
- datarobot_genai/langgraph/mcp.py +73 -0
- datarobot_genai/llama_index/__init__.py +16 -0
- datarobot_genai/llama_index/agent.py +50 -0
- datarobot_genai/llama_index/base.py +299 -0
- datarobot_genai/llama_index/mcp.py +79 -0
- datarobot_genai/nat/__init__.py +0 -0
- datarobot_genai/nat/agent.py +275 -0
- datarobot_genai/nat/datarobot_auth_provider.py +110 -0
- datarobot_genai/nat/datarobot_llm_clients.py +318 -0
- datarobot_genai/nat/datarobot_llm_providers.py +130 -0
- datarobot_genai/nat/datarobot_mcp_client.py +266 -0
- datarobot_genai/nat/helpers.py +87 -0
- datarobot_genai/py.typed +0 -0
- datarobot_genai-0.2.31.dist-info/METADATA +145 -0
- datarobot_genai-0.2.31.dist-info/RECORD +125 -0
- datarobot_genai-0.2.31.dist-info/WHEEL +4 -0
- datarobot_genai-0.2.31.dist-info/entry_points.txt +5 -0
- datarobot_genai-0.2.31.dist-info/licenses/AUTHORS +2 -0
- datarobot_genai-0.2.31.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,299 @@
|
|
|
1
|
+
# Copyright 2025 DataRobot, Inc. and its affiliates.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""
|
|
16
|
+
Base class for LlamaIndex-based agents.
|
|
17
|
+
|
|
18
|
+
Provides a standard ``invoke`` that runs an AgentWorkflow, collects events,
|
|
19
|
+
and converts them into pipeline interactions. Subclasses provide the workflow
|
|
20
|
+
and response extraction logic.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
from __future__ import annotations
|
|
24
|
+
|
|
25
|
+
import abc
|
|
26
|
+
import inspect
|
|
27
|
+
from collections.abc import AsyncGenerator
|
|
28
|
+
from typing import Any
|
|
29
|
+
|
|
30
|
+
from llama_index.core.tools import BaseTool
|
|
31
|
+
from openai.types.chat import CompletionCreateParams
|
|
32
|
+
from ragas import MultiTurnSample
|
|
33
|
+
|
|
34
|
+
from datarobot_genai.core.agents.base import BaseAgent
|
|
35
|
+
from datarobot_genai.core.agents.base import InvokeReturn
|
|
36
|
+
from datarobot_genai.core.agents.base import UsageMetrics
|
|
37
|
+
from datarobot_genai.core.agents.base import default_usage_metrics
|
|
38
|
+
from datarobot_genai.core.agents.base import extract_user_prompt_content
|
|
39
|
+
from datarobot_genai.core.agents.base import is_streaming
|
|
40
|
+
|
|
41
|
+
from .agent import create_pipeline_interactions_from_events
|
|
42
|
+
from .mcp import load_mcp_tools
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class LlamaIndexAgent(BaseAgent[BaseTool], abc.ABC):
|
|
46
|
+
"""Abstract base agent for LlamaIndex workflows."""
|
|
47
|
+
|
|
48
|
+
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
|
49
|
+
super().__init__(*args, **kwargs)
|
|
50
|
+
self._mcp_tools: list[Any] = []
|
|
51
|
+
|
|
52
|
+
def set_mcp_tools(self, tools: list[Any]) -> None:
|
|
53
|
+
"""Set MCP tools for this agent."""
|
|
54
|
+
self._mcp_tools = tools
|
|
55
|
+
|
|
56
|
+
@property
|
|
57
|
+
def mcp_tools(self) -> list[Any]:
|
|
58
|
+
"""Return the list of MCP tools available to this agent.
|
|
59
|
+
|
|
60
|
+
Subclasses can use this to wire tools into LlamaIndex agents during
|
|
61
|
+
workflow construction inside ``build_workflow``.
|
|
62
|
+
"""
|
|
63
|
+
return self._mcp_tools
|
|
64
|
+
|
|
65
|
+
@abc.abstractmethod
|
|
66
|
+
def build_workflow(self) -> Any:
|
|
67
|
+
"""Return an AgentWorkflow instance ready to run."""
|
|
68
|
+
raise NotImplementedError
|
|
69
|
+
|
|
70
|
+
@abc.abstractmethod
|
|
71
|
+
def extract_response_text(self, result_state: Any, events: list[Any]) -> str:
|
|
72
|
+
"""Extract final response text from workflow state and/or events."""
|
|
73
|
+
raise NotImplementedError
|
|
74
|
+
|
|
75
|
+
def make_input_message(self, completion_create_params: CompletionCreateParams) -> str:
|
|
76
|
+
"""Create an input string for the workflow from the user prompt."""
|
|
77
|
+
user_prompt_content = extract_user_prompt_content(completion_create_params)
|
|
78
|
+
return str(user_prompt_content)
|
|
79
|
+
|
|
80
|
+
async def invoke(self, completion_create_params: CompletionCreateParams) -> InvokeReturn:
|
|
81
|
+
"""Run the LlamaIndex workflow with the provided completion parameters."""
|
|
82
|
+
input_message = self.make_input_message(completion_create_params)
|
|
83
|
+
|
|
84
|
+
# Load MCP tools (if configured) asynchronously before building workflow
|
|
85
|
+
mcp_tools = await load_mcp_tools(
|
|
86
|
+
authorization_context=self._authorization_context,
|
|
87
|
+
forwarded_headers=self.forwarded_headers,
|
|
88
|
+
)
|
|
89
|
+
self.set_mcp_tools(mcp_tools)
|
|
90
|
+
|
|
91
|
+
# Preserve prior template startup print for CLI parity
|
|
92
|
+
try:
|
|
93
|
+
print(
|
|
94
|
+
"Running agent with user prompt:",
|
|
95
|
+
extract_user_prompt_content(completion_create_params),
|
|
96
|
+
flush=True,
|
|
97
|
+
)
|
|
98
|
+
except Exception:
|
|
99
|
+
# Printing is best-effort; proceed regardless
|
|
100
|
+
pass
|
|
101
|
+
|
|
102
|
+
workflow = self.build_workflow()
|
|
103
|
+
handler = workflow.run(user_msg=input_message)
|
|
104
|
+
|
|
105
|
+
usage_metrics: UsageMetrics = default_usage_metrics()
|
|
106
|
+
|
|
107
|
+
# Streaming parity with LangGraph: yield incremental deltas during event processing
|
|
108
|
+
if is_streaming(completion_create_params):
|
|
109
|
+
|
|
110
|
+
async def _gen() -> AsyncGenerator[tuple[str, MultiTurnSample | None, UsageMetrics]]:
|
|
111
|
+
events: list[Any] = []
|
|
112
|
+
current_agent_name: str | None = None
|
|
113
|
+
async for event in handler.stream_events():
|
|
114
|
+
events.append(event)
|
|
115
|
+
# Best-effort extraction of incremental text from LlamaIndex events
|
|
116
|
+
delta: str | None = None
|
|
117
|
+
# Agent switch banner if available on event
|
|
118
|
+
try:
|
|
119
|
+
if hasattr(event, "current_agent_name"):
|
|
120
|
+
new_agent = getattr(event, "current_agent_name")
|
|
121
|
+
if (
|
|
122
|
+
isinstance(new_agent, str)
|
|
123
|
+
and new_agent
|
|
124
|
+
and new_agent != current_agent_name
|
|
125
|
+
):
|
|
126
|
+
current_agent_name = new_agent
|
|
127
|
+
# Print banner for agent switch (do not emit as streamed content)
|
|
128
|
+
print("\n" + "=" * 50, flush=True)
|
|
129
|
+
print(f"🤖 Agent: {current_agent_name}", flush=True)
|
|
130
|
+
print("=" * 50 + "\n", flush=True)
|
|
131
|
+
except Exception:
|
|
132
|
+
pass
|
|
133
|
+
|
|
134
|
+
try:
|
|
135
|
+
if hasattr(event, "delta") and isinstance(getattr(event, "delta"), str):
|
|
136
|
+
delta = getattr(event, "delta")
|
|
137
|
+
# Some event types may carry incremental text under "text" or similar
|
|
138
|
+
elif hasattr(event, "text") and isinstance(getattr(event, "text"), str):
|
|
139
|
+
delta = getattr(event, "text")
|
|
140
|
+
except Exception:
|
|
141
|
+
# Ignore malformed events and continue
|
|
142
|
+
delta = None
|
|
143
|
+
|
|
144
|
+
if delta:
|
|
145
|
+
# Yield token/content delta with current (accumulated) usage metrics
|
|
146
|
+
yield delta, None, usage_metrics
|
|
147
|
+
|
|
148
|
+
# Best-effort debug/event messages printed to CLI (do not stream as content)
|
|
149
|
+
try:
|
|
150
|
+
event_type = type(event).__name__
|
|
151
|
+
if event_type == "AgentInput" and hasattr(event, "input"):
|
|
152
|
+
print("📥 Input:", getattr(event, "input"), flush=True)
|
|
153
|
+
elif event_type == "AgentOutput":
|
|
154
|
+
# Output content
|
|
155
|
+
resp = getattr(event, "response", None)
|
|
156
|
+
if (
|
|
157
|
+
resp is not None
|
|
158
|
+
and hasattr(resp, "content")
|
|
159
|
+
and getattr(resp, "content")
|
|
160
|
+
):
|
|
161
|
+
print("📤 Output:", getattr(resp, "content"), flush=True)
|
|
162
|
+
# Planned tool calls
|
|
163
|
+
tcalls = getattr(event, "tool_calls", None)
|
|
164
|
+
if isinstance(tcalls, list) and tcalls:
|
|
165
|
+
names = []
|
|
166
|
+
for c in tcalls:
|
|
167
|
+
try:
|
|
168
|
+
nm = getattr(c, "tool_name", None) or (
|
|
169
|
+
c.get("tool_name") if isinstance(c, dict) else None
|
|
170
|
+
)
|
|
171
|
+
if nm:
|
|
172
|
+
names.append(str(nm))
|
|
173
|
+
except Exception:
|
|
174
|
+
pass
|
|
175
|
+
if names:
|
|
176
|
+
print("🛠️ Planning to use tools:", names, flush=True)
|
|
177
|
+
elif event_type == "ToolCallResult":
|
|
178
|
+
tname = getattr(event, "tool_name", None)
|
|
179
|
+
tkwargs = getattr(event, "tool_kwargs", None)
|
|
180
|
+
tout = getattr(event, "tool_output", None)
|
|
181
|
+
print(f"🔧 Tool Result ({tname}):", flush=True)
|
|
182
|
+
print(f" Arguments: {tkwargs}", flush=True)
|
|
183
|
+
print(f" Output: {tout}", flush=True)
|
|
184
|
+
elif event_type == "ToolCall":
|
|
185
|
+
tname = getattr(event, "tool_name", None)
|
|
186
|
+
tkwargs = getattr(event, "tool_kwargs", None)
|
|
187
|
+
print(f"🔨 Calling Tool: {tname}", flush=True)
|
|
188
|
+
print(f" With arguments: {tkwargs}", flush=True)
|
|
189
|
+
except Exception:
|
|
190
|
+
# Ignore best-effort debug rendering errors
|
|
191
|
+
pass
|
|
192
|
+
|
|
193
|
+
# After streaming completes, build final interactions and finish chunk
|
|
194
|
+
# Extract state from workflow context (supports sync/async get or attribute)
|
|
195
|
+
state = None
|
|
196
|
+
ctx = getattr(handler, "ctx", None)
|
|
197
|
+
try:
|
|
198
|
+
if ctx is not None:
|
|
199
|
+
get = getattr(ctx, "get", None)
|
|
200
|
+
if callable(get):
|
|
201
|
+
result = get("state")
|
|
202
|
+
state = await result if inspect.isawaitable(result) else result
|
|
203
|
+
elif hasattr(ctx, "state"):
|
|
204
|
+
state = getattr(ctx, "state")
|
|
205
|
+
except (AttributeError, TypeError):
|
|
206
|
+
state = None
|
|
207
|
+
|
|
208
|
+
# Run subclass-defined response extraction (not streamed) for completeness
|
|
209
|
+
_ = self.extract_response_text(state, events)
|
|
210
|
+
|
|
211
|
+
pipeline_interactions = create_pipeline_interactions_from_events(events)
|
|
212
|
+
# Final empty chunk indicates end of stream, carrying interactions and usage
|
|
213
|
+
yield "", pipeline_interactions, usage_metrics
|
|
214
|
+
|
|
215
|
+
return _gen()
|
|
216
|
+
|
|
217
|
+
# Non-streaming path: run to completion, emit debug prints, then return final response
|
|
218
|
+
events: list[Any] = []
|
|
219
|
+
current_agent_name: str | None = None
|
|
220
|
+
async for event in handler.stream_events():
|
|
221
|
+
events.append(event)
|
|
222
|
+
|
|
223
|
+
# Replicate prior template CLI prints for non-streaming mode
|
|
224
|
+
try:
|
|
225
|
+
if hasattr(event, "current_agent_name"):
|
|
226
|
+
new_agent = getattr(event, "current_agent_name")
|
|
227
|
+
if isinstance(new_agent, str) and new_agent and new_agent != current_agent_name:
|
|
228
|
+
current_agent_name = new_agent
|
|
229
|
+
print(f"\n{'=' * 50}", flush=True)
|
|
230
|
+
print(f"🤖 Agent: {current_agent_name}", flush=True)
|
|
231
|
+
print(f"{'=' * 50}\n", flush=True)
|
|
232
|
+
except Exception:
|
|
233
|
+
pass
|
|
234
|
+
|
|
235
|
+
try:
|
|
236
|
+
if hasattr(event, "delta") and isinstance(getattr(event, "delta"), str):
|
|
237
|
+
print(getattr(event, "delta"), end="", flush=True)
|
|
238
|
+
elif hasattr(event, "text") and isinstance(getattr(event, "text"), str):
|
|
239
|
+
print(getattr(event, "text"), end="", flush=True)
|
|
240
|
+
else:
|
|
241
|
+
event_type = type(event).__name__
|
|
242
|
+
if event_type == "AgentInput" and hasattr(event, "input"):
|
|
243
|
+
print("📥 Input:", getattr(event, "input"), flush=True)
|
|
244
|
+
elif event_type == "AgentOutput":
|
|
245
|
+
resp = getattr(event, "response", None)
|
|
246
|
+
if (
|
|
247
|
+
resp is not None
|
|
248
|
+
and hasattr(resp, "content")
|
|
249
|
+
and getattr(resp, "content")
|
|
250
|
+
):
|
|
251
|
+
print("📤 Output:", getattr(resp, "content"), flush=True)
|
|
252
|
+
tcalls = getattr(event, "tool_calls", None)
|
|
253
|
+
if isinstance(tcalls, list) and tcalls:
|
|
254
|
+
names: list[str] = []
|
|
255
|
+
for c in tcalls:
|
|
256
|
+
try:
|
|
257
|
+
nm = getattr(c, "tool_name", None) or (
|
|
258
|
+
c.get("tool_name") if isinstance(c, dict) else None
|
|
259
|
+
)
|
|
260
|
+
if nm:
|
|
261
|
+
names.append(str(nm))
|
|
262
|
+
except Exception:
|
|
263
|
+
pass
|
|
264
|
+
if names:
|
|
265
|
+
print("🛠️ Planning to use tools:", names, flush=True)
|
|
266
|
+
elif event_type == "ToolCallResult":
|
|
267
|
+
tname = getattr(event, "tool_name", None)
|
|
268
|
+
tkwargs = getattr(event, "tool_kwargs", None)
|
|
269
|
+
tout = getattr(event, "tool_output", None)
|
|
270
|
+
print(f"🔧 Tool Result ({tname}):", flush=True)
|
|
271
|
+
print(f" Arguments: {tkwargs}", flush=True)
|
|
272
|
+
print(f" Output: {tout}", flush=True)
|
|
273
|
+
elif event_type == "ToolCall":
|
|
274
|
+
tname = getattr(event, "tool_name", None)
|
|
275
|
+
tkwargs = getattr(event, "tool_kwargs", None)
|
|
276
|
+
print(f"🔨 Calling Tool: {tname}", flush=True)
|
|
277
|
+
print(f" With arguments: {tkwargs}", flush=True)
|
|
278
|
+
except Exception:
|
|
279
|
+
# Best-effort debug printing; continue on errors
|
|
280
|
+
pass
|
|
281
|
+
|
|
282
|
+
# Extract state from workflow context (supports sync/async get or attribute)
|
|
283
|
+
state = None
|
|
284
|
+
ctx = getattr(handler, "ctx", None)
|
|
285
|
+
try:
|
|
286
|
+
if ctx is not None:
|
|
287
|
+
get = getattr(ctx, "get", None)
|
|
288
|
+
if callable(get):
|
|
289
|
+
result = get("state")
|
|
290
|
+
state = await result if inspect.isawaitable(result) else result
|
|
291
|
+
elif hasattr(ctx, "state"):
|
|
292
|
+
state = getattr(ctx, "state")
|
|
293
|
+
except (AttributeError, TypeError):
|
|
294
|
+
state = None
|
|
295
|
+
response_text = self.extract_response_text(state, events)
|
|
296
|
+
|
|
297
|
+
pipeline_interactions = create_pipeline_interactions_from_events(events)
|
|
298
|
+
|
|
299
|
+
return response_text, pipeline_interactions, usage_metrics
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
# Copyright 2025 DataRobot, Inc. and its affiliates.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""
|
|
16
|
+
MCP integration for LlamaIndex using llama-index-tools-mcp.
|
|
17
|
+
|
|
18
|
+
This module provides MCP server connection management for LlamaIndex agents.
|
|
19
|
+
Unlike CrewAI which uses a context manager, LlamaIndex uses async calls to
|
|
20
|
+
fetch tools from MCP servers.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
from typing import Any
|
|
24
|
+
|
|
25
|
+
from llama_index.tools.mcp import BasicMCPClient
|
|
26
|
+
from llama_index.tools.mcp import aget_tools_from_mcp_url
|
|
27
|
+
|
|
28
|
+
from datarobot_genai.core.mcp.common import MCPConfig
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
async def load_mcp_tools(
|
|
32
|
+
authorization_context: dict[str, Any] | None = None,
|
|
33
|
+
forwarded_headers: dict[str, str] | None = None,
|
|
34
|
+
) -> list[Any]:
|
|
35
|
+
"""
|
|
36
|
+
Asynchronously load MCP tools for LlamaIndex.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
authorization_context: Optional authorization context for MCP connections
|
|
40
|
+
forwarded_headers: Optional forwarded headers, e.g. x-datarobot-api-key for MCP auth
|
|
41
|
+
|
|
42
|
+
Returns
|
|
43
|
+
-------
|
|
44
|
+
List of MCP tools, or empty list if no MCP configuration is present.
|
|
45
|
+
"""
|
|
46
|
+
config = MCPConfig(
|
|
47
|
+
authorization_context=authorization_context,
|
|
48
|
+
forwarded_headers=forwarded_headers,
|
|
49
|
+
)
|
|
50
|
+
server_params = config.server_config
|
|
51
|
+
|
|
52
|
+
if not server_params:
|
|
53
|
+
print("No MCP server configured, using empty tools list", flush=True)
|
|
54
|
+
return []
|
|
55
|
+
|
|
56
|
+
url = server_params["url"]
|
|
57
|
+
headers = server_params.get("headers", {})
|
|
58
|
+
|
|
59
|
+
try:
|
|
60
|
+
print(f"Connecting to MCP server: {url}", flush=True)
|
|
61
|
+
# Create BasicMCPClient with headers to pass authentication
|
|
62
|
+
client = BasicMCPClient(command_or_url=url, headers=headers)
|
|
63
|
+
tools = await aget_tools_from_mcp_url(
|
|
64
|
+
command_or_url=url,
|
|
65
|
+
client=client,
|
|
66
|
+
)
|
|
67
|
+
# Ensure list
|
|
68
|
+
tools_list = list(tools) if tools is not None else []
|
|
69
|
+
print(
|
|
70
|
+
f"Successfully connected to MCP server, got {len(tools_list)} tools",
|
|
71
|
+
flush=True,
|
|
72
|
+
)
|
|
73
|
+
return tools_list
|
|
74
|
+
except Exception as e:
|
|
75
|
+
print(
|
|
76
|
+
f"Warning: Failed to connect to MCP server {url}: {e}",
|
|
77
|
+
flush=True,
|
|
78
|
+
)
|
|
79
|
+
return []
|
|
File without changes
|
|
@@ -0,0 +1,275 @@
|
|
|
1
|
+
# Copyright 2025 DataRobot, Inc. and its affiliates.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
import asyncio
|
|
15
|
+
import logging
|
|
16
|
+
from collections.abc import AsyncGenerator
|
|
17
|
+
from typing import Any
|
|
18
|
+
|
|
19
|
+
from nat.builder.context import Context
|
|
20
|
+
from nat.data_models.api_server import ChatRequest
|
|
21
|
+
from nat.data_models.api_server import ChatResponse
|
|
22
|
+
from nat.data_models.intermediate_step import IntermediateStep
|
|
23
|
+
from nat.data_models.intermediate_step import IntermediateStepType
|
|
24
|
+
from nat.utils.type_utils import StrPath
|
|
25
|
+
from openai.types.chat import CompletionCreateParams
|
|
26
|
+
from ragas import MultiTurnSample
|
|
27
|
+
from ragas.messages import AIMessage
|
|
28
|
+
from ragas.messages import HumanMessage
|
|
29
|
+
from ragas.messages import ToolMessage
|
|
30
|
+
|
|
31
|
+
from datarobot_genai.core.agents.base import BaseAgent
|
|
32
|
+
from datarobot_genai.core.agents.base import InvokeReturn
|
|
33
|
+
from datarobot_genai.core.agents.base import UsageMetrics
|
|
34
|
+
from datarobot_genai.core.agents.base import extract_user_prompt_content
|
|
35
|
+
from datarobot_genai.core.agents.base import is_streaming
|
|
36
|
+
from datarobot_genai.core.mcp.common import MCPConfig
|
|
37
|
+
from datarobot_genai.nat.helpers import load_workflow
|
|
38
|
+
|
|
39
|
+
logger = logging.getLogger(__name__)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def convert_to_ragas_messages(
|
|
43
|
+
steps: list[IntermediateStep],
|
|
44
|
+
) -> list[HumanMessage | AIMessage | ToolMessage]:
|
|
45
|
+
def _to_ragas(step: IntermediateStep) -> HumanMessage | AIMessage | ToolMessage:
|
|
46
|
+
if step.event_type == IntermediateStepType.LLM_START:
|
|
47
|
+
return HumanMessage(content=_parse(step.data.input))
|
|
48
|
+
elif step.event_type == IntermediateStepType.LLM_END:
|
|
49
|
+
return AIMessage(content=_parse(step.data.output))
|
|
50
|
+
else:
|
|
51
|
+
raise ValueError(f"Unknown event type {step.event_type}")
|
|
52
|
+
|
|
53
|
+
def _include_step(step: IntermediateStep) -> bool:
|
|
54
|
+
return step.event_type in {
|
|
55
|
+
IntermediateStepType.LLM_END,
|
|
56
|
+
IntermediateStepType.LLM_START,
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
def _parse(messages: Any) -> str:
|
|
60
|
+
if isinstance(messages, list):
|
|
61
|
+
last_message = messages[-1]
|
|
62
|
+
else:
|
|
63
|
+
last_message = messages
|
|
64
|
+
|
|
65
|
+
if isinstance(last_message, dict):
|
|
66
|
+
content = last_message.get("content") or last_message
|
|
67
|
+
elif hasattr(last_message, "content"):
|
|
68
|
+
content = getattr(last_message, "content") or last_message
|
|
69
|
+
else:
|
|
70
|
+
content = last_message
|
|
71
|
+
return str(content)
|
|
72
|
+
|
|
73
|
+
return [_to_ragas(step) for step in steps if _include_step(step)]
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def create_pipeline_interactions_from_steps(
|
|
77
|
+
steps: list[IntermediateStep],
|
|
78
|
+
) -> MultiTurnSample | None:
|
|
79
|
+
if not steps:
|
|
80
|
+
return None
|
|
81
|
+
ragas_trace = convert_to_ragas_messages(steps)
|
|
82
|
+
return MultiTurnSample(user_input=ragas_trace)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def pull_intermediate_structured() -> asyncio.Future[list[IntermediateStep]]:
|
|
86
|
+
"""
|
|
87
|
+
Subscribe to the runner's event stream using callbacks.
|
|
88
|
+
Intermediate steps are collected and, when complete, the future is set
|
|
89
|
+
with the list of dumped intermediate steps.
|
|
90
|
+
"""
|
|
91
|
+
future: asyncio.Future[list[IntermediateStep]] = asyncio.Future()
|
|
92
|
+
intermediate_steps = [] # We'll store the dumped steps here.
|
|
93
|
+
context = Context.get()
|
|
94
|
+
|
|
95
|
+
def on_next_cb(item: IntermediateStep) -> None:
|
|
96
|
+
# Append each new intermediate step to the list.
|
|
97
|
+
intermediate_steps.append(item)
|
|
98
|
+
|
|
99
|
+
def on_error_cb(exc: Exception) -> None:
|
|
100
|
+
logger.error("Hit on_error: %s", exc)
|
|
101
|
+
if not future.done():
|
|
102
|
+
future.set_exception(exc)
|
|
103
|
+
|
|
104
|
+
def on_complete_cb() -> None:
|
|
105
|
+
logger.debug("Completed reading intermediate steps")
|
|
106
|
+
if not future.done():
|
|
107
|
+
future.set_result(intermediate_steps)
|
|
108
|
+
|
|
109
|
+
# Subscribe with our callbacks.
|
|
110
|
+
context.intermediate_step_manager.subscribe(
|
|
111
|
+
on_next=on_next_cb, on_error=on_error_cb, on_complete=on_complete_cb
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
return future
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
class NatAgent(BaseAgent[None]):
|
|
118
|
+
def __init__(
|
|
119
|
+
self,
|
|
120
|
+
*,
|
|
121
|
+
workflow_path: StrPath,
|
|
122
|
+
api_key: str | None = None,
|
|
123
|
+
api_base: str | None = None,
|
|
124
|
+
model: str | None = None,
|
|
125
|
+
verbose: bool | str | None = True,
|
|
126
|
+
timeout: int | None = 90,
|
|
127
|
+
authorization_context: dict[str, Any] | None = None,
|
|
128
|
+
forwarded_headers: dict[str, str] | None = None,
|
|
129
|
+
**kwargs: Any,
|
|
130
|
+
) -> None:
|
|
131
|
+
super().__init__(
|
|
132
|
+
api_key=api_key,
|
|
133
|
+
api_base=api_base,
|
|
134
|
+
model=model,
|
|
135
|
+
verbose=verbose,
|
|
136
|
+
timeout=timeout,
|
|
137
|
+
authorization_context=authorization_context,
|
|
138
|
+
forwarded_headers=forwarded_headers,
|
|
139
|
+
**kwargs,
|
|
140
|
+
)
|
|
141
|
+
self.workflow_path = workflow_path
|
|
142
|
+
|
|
143
|
+
def make_chat_request(self, completion_create_params: CompletionCreateParams) -> ChatRequest:
|
|
144
|
+
user_prompt_content = str(extract_user_prompt_content(completion_create_params))
|
|
145
|
+
return ChatRequest.from_string(user_prompt_content)
|
|
146
|
+
|
|
147
|
+
async def invoke(self, completion_create_params: CompletionCreateParams) -> InvokeReturn:
|
|
148
|
+
"""Run the agent with the provided completion parameters.
|
|
149
|
+
|
|
150
|
+
[THIS METHOD IS REQUIRED FOR THE AGENT TO WORK WITH DRUM SERVER]
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
completion_create_params: The completion request parameters including input topic
|
|
154
|
+
and settings.
|
|
155
|
+
|
|
156
|
+
Returns
|
|
157
|
+
-------
|
|
158
|
+
For streaming requests, returns a generator yielding tuples of (response_text,
|
|
159
|
+
pipeline_interactions, usage_metrics).
|
|
160
|
+
For non-streaming requests, returns a single tuple of (response_text,
|
|
161
|
+
pipeline_interactions, usage_metrics).
|
|
162
|
+
|
|
163
|
+
"""
|
|
164
|
+
# Retrieve the starting chat request from the CompletionCreateParams
|
|
165
|
+
chat_request = self.make_chat_request(completion_create_params)
|
|
166
|
+
|
|
167
|
+
# Print commands may need flush=True to ensure they are displayed in real-time.
|
|
168
|
+
print("Running agent with user prompt:", chat_request.messages[0].content, flush=True)
|
|
169
|
+
|
|
170
|
+
mcp_config = MCPConfig(
|
|
171
|
+
authorization_context=self.authorization_context,
|
|
172
|
+
forwarded_headers=self.forwarded_headers,
|
|
173
|
+
)
|
|
174
|
+
server_config = mcp_config.server_config
|
|
175
|
+
headers = server_config["headers"] if server_config else None
|
|
176
|
+
|
|
177
|
+
if is_streaming(completion_create_params):
|
|
178
|
+
|
|
179
|
+
async def stream_generator() -> AsyncGenerator[
|
|
180
|
+
tuple[str, MultiTurnSample | None, UsageMetrics], None
|
|
181
|
+
]:
|
|
182
|
+
default_usage_metrics: UsageMetrics = {
|
|
183
|
+
"completion_tokens": 0,
|
|
184
|
+
"prompt_tokens": 0,
|
|
185
|
+
"total_tokens": 0,
|
|
186
|
+
}
|
|
187
|
+
async with load_workflow(self.workflow_path, headers=headers) as workflow:
|
|
188
|
+
async with workflow.run(chat_request) as runner:
|
|
189
|
+
intermediate_future = pull_intermediate_structured()
|
|
190
|
+
async for result in runner.result_stream():
|
|
191
|
+
if isinstance(result, ChatResponse):
|
|
192
|
+
result_text = result.choices[0].message.content
|
|
193
|
+
else:
|
|
194
|
+
result_text = str(result)
|
|
195
|
+
|
|
196
|
+
yield (
|
|
197
|
+
result_text,
|
|
198
|
+
None,
|
|
199
|
+
default_usage_metrics,
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
steps = await intermediate_future
|
|
203
|
+
llm_end_steps = [
|
|
204
|
+
step
|
|
205
|
+
for step in steps
|
|
206
|
+
if step.event_type == IntermediateStepType.LLM_END
|
|
207
|
+
]
|
|
208
|
+
usage_metrics: UsageMetrics = {
|
|
209
|
+
"completion_tokens": 0,
|
|
210
|
+
"prompt_tokens": 0,
|
|
211
|
+
"total_tokens": 0,
|
|
212
|
+
}
|
|
213
|
+
for step in llm_end_steps:
|
|
214
|
+
if step.usage_info:
|
|
215
|
+
token_usage = step.usage_info.token_usage
|
|
216
|
+
usage_metrics["total_tokens"] += token_usage.total_tokens
|
|
217
|
+
usage_metrics["prompt_tokens"] += token_usage.prompt_tokens
|
|
218
|
+
usage_metrics["completion_tokens"] += token_usage.completion_tokens
|
|
219
|
+
|
|
220
|
+
pipeline_interactions = create_pipeline_interactions_from_steps(steps)
|
|
221
|
+
yield "", pipeline_interactions, usage_metrics
|
|
222
|
+
|
|
223
|
+
return stream_generator()
|
|
224
|
+
|
|
225
|
+
# Create and invoke the NAT (Nemo Agent Toolkit) Agentic Workflow with the inputs
|
|
226
|
+
result, steps = await self.run_nat_workflow(self.workflow_path, chat_request, headers)
|
|
227
|
+
|
|
228
|
+
llm_end_steps = [step for step in steps if step.event_type == IntermediateStepType.LLM_END]
|
|
229
|
+
usage_metrics: UsageMetrics = {
|
|
230
|
+
"completion_tokens": 0,
|
|
231
|
+
"prompt_tokens": 0,
|
|
232
|
+
"total_tokens": 0,
|
|
233
|
+
}
|
|
234
|
+
for step in llm_end_steps:
|
|
235
|
+
if step.usage_info:
|
|
236
|
+
token_usage = step.usage_info.token_usage
|
|
237
|
+
usage_metrics["total_tokens"] += token_usage.total_tokens
|
|
238
|
+
usage_metrics["prompt_tokens"] += token_usage.prompt_tokens
|
|
239
|
+
usage_metrics["completion_tokens"] += token_usage.completion_tokens
|
|
240
|
+
|
|
241
|
+
if isinstance(result, ChatResponse):
|
|
242
|
+
result_text = result.choices[0].message.content
|
|
243
|
+
else:
|
|
244
|
+
result_text = str(result)
|
|
245
|
+
pipeline_interactions = create_pipeline_interactions_from_steps(steps)
|
|
246
|
+
|
|
247
|
+
return result_text, pipeline_interactions, usage_metrics
|
|
248
|
+
|
|
249
|
+
async def run_nat_workflow(
|
|
250
|
+
self, workflow_path: StrPath, chat_request: ChatRequest, headers: dict[str, str] | None
|
|
251
|
+
) -> tuple[ChatResponse | str, list[IntermediateStep]]:
|
|
252
|
+
"""Run the NAT workflow with the provided config file and input string.
|
|
253
|
+
|
|
254
|
+
Args:
|
|
255
|
+
workflow_path: Path to the NAT workflow configuration file
|
|
256
|
+
input_str: Input string to process through the workflow
|
|
257
|
+
|
|
258
|
+
Returns
|
|
259
|
+
-------
|
|
260
|
+
ChatResponse | str: The result from the NAT workflow
|
|
261
|
+
list[IntermediateStep]: The list of intermediate steps
|
|
262
|
+
"""
|
|
263
|
+
async with load_workflow(workflow_path, headers=headers) as workflow:
|
|
264
|
+
async with workflow.run(chat_request) as runner:
|
|
265
|
+
intermediate_future = pull_intermediate_structured()
|
|
266
|
+
runner_outputs = await runner.result()
|
|
267
|
+
steps = await intermediate_future
|
|
268
|
+
|
|
269
|
+
line = f"{'-' * 50}"
|
|
270
|
+
prefix = f"{line}\nWorkflow Result:\n"
|
|
271
|
+
suffix = f"\n{line}"
|
|
272
|
+
|
|
273
|
+
print(f"{prefix}{runner_outputs}{suffix}")
|
|
274
|
+
|
|
275
|
+
return runner_outputs, steps
|