datarobot-genai 0.2.31__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (125) hide show
  1. datarobot_genai/__init__.py +19 -0
  2. datarobot_genai/core/__init__.py +0 -0
  3. datarobot_genai/core/agents/__init__.py +43 -0
  4. datarobot_genai/core/agents/base.py +195 -0
  5. datarobot_genai/core/chat/__init__.py +19 -0
  6. datarobot_genai/core/chat/auth.py +146 -0
  7. datarobot_genai/core/chat/client.py +178 -0
  8. datarobot_genai/core/chat/responses.py +297 -0
  9. datarobot_genai/core/cli/__init__.py +18 -0
  10. datarobot_genai/core/cli/agent_environment.py +47 -0
  11. datarobot_genai/core/cli/agent_kernel.py +211 -0
  12. datarobot_genai/core/custom_model.py +141 -0
  13. datarobot_genai/core/mcp/__init__.py +0 -0
  14. datarobot_genai/core/mcp/common.py +218 -0
  15. datarobot_genai/core/telemetry_agent.py +126 -0
  16. datarobot_genai/core/utils/__init__.py +3 -0
  17. datarobot_genai/core/utils/auth.py +234 -0
  18. datarobot_genai/core/utils/urls.py +64 -0
  19. datarobot_genai/crewai/__init__.py +24 -0
  20. datarobot_genai/crewai/agent.py +42 -0
  21. datarobot_genai/crewai/base.py +159 -0
  22. datarobot_genai/crewai/events.py +117 -0
  23. datarobot_genai/crewai/mcp.py +59 -0
  24. datarobot_genai/drmcp/__init__.py +78 -0
  25. datarobot_genai/drmcp/core/__init__.py +13 -0
  26. datarobot_genai/drmcp/core/auth.py +165 -0
  27. datarobot_genai/drmcp/core/clients.py +180 -0
  28. datarobot_genai/drmcp/core/config.py +364 -0
  29. datarobot_genai/drmcp/core/config_utils.py +174 -0
  30. datarobot_genai/drmcp/core/constants.py +18 -0
  31. datarobot_genai/drmcp/core/credentials.py +190 -0
  32. datarobot_genai/drmcp/core/dr_mcp_server.py +350 -0
  33. datarobot_genai/drmcp/core/dr_mcp_server_logo.py +136 -0
  34. datarobot_genai/drmcp/core/dynamic_prompts/__init__.py +13 -0
  35. datarobot_genai/drmcp/core/dynamic_prompts/controllers.py +130 -0
  36. datarobot_genai/drmcp/core/dynamic_prompts/dr_lib.py +70 -0
  37. datarobot_genai/drmcp/core/dynamic_prompts/register.py +205 -0
  38. datarobot_genai/drmcp/core/dynamic_prompts/utils.py +33 -0
  39. datarobot_genai/drmcp/core/dynamic_tools/__init__.py +14 -0
  40. datarobot_genai/drmcp/core/dynamic_tools/deployment/__init__.py +0 -0
  41. datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/__init__.py +14 -0
  42. datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/base.py +72 -0
  43. datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/default.py +82 -0
  44. datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/drum.py +238 -0
  45. datarobot_genai/drmcp/core/dynamic_tools/deployment/config.py +228 -0
  46. datarobot_genai/drmcp/core/dynamic_tools/deployment/controllers.py +63 -0
  47. datarobot_genai/drmcp/core/dynamic_tools/deployment/metadata.py +162 -0
  48. datarobot_genai/drmcp/core/dynamic_tools/deployment/register.py +87 -0
  49. datarobot_genai/drmcp/core/dynamic_tools/deployment/schemas/drum_agentic_fallback_schema.json +36 -0
  50. datarobot_genai/drmcp/core/dynamic_tools/deployment/schemas/drum_prediction_fallback_schema.json +10 -0
  51. datarobot_genai/drmcp/core/dynamic_tools/register.py +254 -0
  52. datarobot_genai/drmcp/core/dynamic_tools/schema.py +532 -0
  53. datarobot_genai/drmcp/core/exceptions.py +25 -0
  54. datarobot_genai/drmcp/core/logging.py +98 -0
  55. datarobot_genai/drmcp/core/mcp_instance.py +515 -0
  56. datarobot_genai/drmcp/core/memory_management/__init__.py +13 -0
  57. datarobot_genai/drmcp/core/memory_management/manager.py +820 -0
  58. datarobot_genai/drmcp/core/memory_management/memory_tools.py +201 -0
  59. datarobot_genai/drmcp/core/routes.py +439 -0
  60. datarobot_genai/drmcp/core/routes_utils.py +30 -0
  61. datarobot_genai/drmcp/core/server_life_cycle.py +107 -0
  62. datarobot_genai/drmcp/core/telemetry.py +424 -0
  63. datarobot_genai/drmcp/core/tool_config.py +111 -0
  64. datarobot_genai/drmcp/core/tool_filter.py +117 -0
  65. datarobot_genai/drmcp/core/utils.py +138 -0
  66. datarobot_genai/drmcp/server.py +19 -0
  67. datarobot_genai/drmcp/test_utils/__init__.py +13 -0
  68. datarobot_genai/drmcp/test_utils/clients/__init__.py +0 -0
  69. datarobot_genai/drmcp/test_utils/clients/anthropic.py +68 -0
  70. datarobot_genai/drmcp/test_utils/clients/base.py +300 -0
  71. datarobot_genai/drmcp/test_utils/clients/dr_gateway.py +58 -0
  72. datarobot_genai/drmcp/test_utils/clients/openai.py +68 -0
  73. datarobot_genai/drmcp/test_utils/elicitation_test_tool.py +89 -0
  74. datarobot_genai/drmcp/test_utils/integration_mcp_server.py +109 -0
  75. datarobot_genai/drmcp/test_utils/mcp_utils_ete.py +133 -0
  76. datarobot_genai/drmcp/test_utils/mcp_utils_integration.py +107 -0
  77. datarobot_genai/drmcp/test_utils/test_interactive.py +205 -0
  78. datarobot_genai/drmcp/test_utils/tool_base_ete.py +220 -0
  79. datarobot_genai/drmcp/test_utils/utils.py +91 -0
  80. datarobot_genai/drmcp/tools/__init__.py +14 -0
  81. datarobot_genai/drmcp/tools/clients/__init__.py +14 -0
  82. datarobot_genai/drmcp/tools/clients/atlassian.py +188 -0
  83. datarobot_genai/drmcp/tools/clients/confluence.py +584 -0
  84. datarobot_genai/drmcp/tools/clients/gdrive.py +832 -0
  85. datarobot_genai/drmcp/tools/clients/jira.py +334 -0
  86. datarobot_genai/drmcp/tools/clients/microsoft_graph.py +479 -0
  87. datarobot_genai/drmcp/tools/clients/s3.py +28 -0
  88. datarobot_genai/drmcp/tools/confluence/__init__.py +14 -0
  89. datarobot_genai/drmcp/tools/confluence/tools.py +321 -0
  90. datarobot_genai/drmcp/tools/gdrive/__init__.py +0 -0
  91. datarobot_genai/drmcp/tools/gdrive/tools.py +347 -0
  92. datarobot_genai/drmcp/tools/jira/__init__.py +14 -0
  93. datarobot_genai/drmcp/tools/jira/tools.py +243 -0
  94. datarobot_genai/drmcp/tools/microsoft_graph/__init__.py +13 -0
  95. datarobot_genai/drmcp/tools/microsoft_graph/tools.py +198 -0
  96. datarobot_genai/drmcp/tools/predictive/__init__.py +27 -0
  97. datarobot_genai/drmcp/tools/predictive/data.py +133 -0
  98. datarobot_genai/drmcp/tools/predictive/deployment.py +91 -0
  99. datarobot_genai/drmcp/tools/predictive/deployment_info.py +392 -0
  100. datarobot_genai/drmcp/tools/predictive/model.py +148 -0
  101. datarobot_genai/drmcp/tools/predictive/predict.py +254 -0
  102. datarobot_genai/drmcp/tools/predictive/predict_realtime.py +307 -0
  103. datarobot_genai/drmcp/tools/predictive/project.py +90 -0
  104. datarobot_genai/drmcp/tools/predictive/training.py +661 -0
  105. datarobot_genai/langgraph/__init__.py +0 -0
  106. datarobot_genai/langgraph/agent.py +341 -0
  107. datarobot_genai/langgraph/mcp.py +73 -0
  108. datarobot_genai/llama_index/__init__.py +16 -0
  109. datarobot_genai/llama_index/agent.py +50 -0
  110. datarobot_genai/llama_index/base.py +299 -0
  111. datarobot_genai/llama_index/mcp.py +79 -0
  112. datarobot_genai/nat/__init__.py +0 -0
  113. datarobot_genai/nat/agent.py +275 -0
  114. datarobot_genai/nat/datarobot_auth_provider.py +110 -0
  115. datarobot_genai/nat/datarobot_llm_clients.py +318 -0
  116. datarobot_genai/nat/datarobot_llm_providers.py +130 -0
  117. datarobot_genai/nat/datarobot_mcp_client.py +266 -0
  118. datarobot_genai/nat/helpers.py +87 -0
  119. datarobot_genai/py.typed +0 -0
  120. datarobot_genai-0.2.31.dist-info/METADATA +145 -0
  121. datarobot_genai-0.2.31.dist-info/RECORD +125 -0
  122. datarobot_genai-0.2.31.dist-info/WHEEL +4 -0
  123. datarobot_genai-0.2.31.dist-info/entry_points.txt +5 -0
  124. datarobot_genai-0.2.31.dist-info/licenses/AUTHORS +2 -0
  125. datarobot_genai-0.2.31.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,299 @@
1
+ # Copyright 2025 DataRobot, Inc. and its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Base class for LlamaIndex-based agents.
17
+
18
+ Provides a standard ``invoke`` that runs an AgentWorkflow, collects events,
19
+ and converts them into pipeline interactions. Subclasses provide the workflow
20
+ and response extraction logic.
21
+ """
22
+
23
+ from __future__ import annotations
24
+
25
+ import abc
26
+ import inspect
27
+ from collections.abc import AsyncGenerator
28
+ from typing import Any
29
+
30
+ from llama_index.core.tools import BaseTool
31
+ from openai.types.chat import CompletionCreateParams
32
+ from ragas import MultiTurnSample
33
+
34
+ from datarobot_genai.core.agents.base import BaseAgent
35
+ from datarobot_genai.core.agents.base import InvokeReturn
36
+ from datarobot_genai.core.agents.base import UsageMetrics
37
+ from datarobot_genai.core.agents.base import default_usage_metrics
38
+ from datarobot_genai.core.agents.base import extract_user_prompt_content
39
+ from datarobot_genai.core.agents.base import is_streaming
40
+
41
+ from .agent import create_pipeline_interactions_from_events
42
+ from .mcp import load_mcp_tools
43
+
44
+
45
+ class LlamaIndexAgent(BaseAgent[BaseTool], abc.ABC):
46
+ """Abstract base agent for LlamaIndex workflows."""
47
+
48
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
49
+ super().__init__(*args, **kwargs)
50
+ self._mcp_tools: list[Any] = []
51
+
52
+ def set_mcp_tools(self, tools: list[Any]) -> None:
53
+ """Set MCP tools for this agent."""
54
+ self._mcp_tools = tools
55
+
56
+ @property
57
+ def mcp_tools(self) -> list[Any]:
58
+ """Return the list of MCP tools available to this agent.
59
+
60
+ Subclasses can use this to wire tools into LlamaIndex agents during
61
+ workflow construction inside ``build_workflow``.
62
+ """
63
+ return self._mcp_tools
64
+
65
+ @abc.abstractmethod
66
+ def build_workflow(self) -> Any:
67
+ """Return an AgentWorkflow instance ready to run."""
68
+ raise NotImplementedError
69
+
70
+ @abc.abstractmethod
71
+ def extract_response_text(self, result_state: Any, events: list[Any]) -> str:
72
+ """Extract final response text from workflow state and/or events."""
73
+ raise NotImplementedError
74
+
75
+ def make_input_message(self, completion_create_params: CompletionCreateParams) -> str:
76
+ """Create an input string for the workflow from the user prompt."""
77
+ user_prompt_content = extract_user_prompt_content(completion_create_params)
78
+ return str(user_prompt_content)
79
+
80
+ async def invoke(self, completion_create_params: CompletionCreateParams) -> InvokeReturn:
81
+ """Run the LlamaIndex workflow with the provided completion parameters."""
82
+ input_message = self.make_input_message(completion_create_params)
83
+
84
+ # Load MCP tools (if configured) asynchronously before building workflow
85
+ mcp_tools = await load_mcp_tools(
86
+ authorization_context=self._authorization_context,
87
+ forwarded_headers=self.forwarded_headers,
88
+ )
89
+ self.set_mcp_tools(mcp_tools)
90
+
91
+ # Preserve prior template startup print for CLI parity
92
+ try:
93
+ print(
94
+ "Running agent with user prompt:",
95
+ extract_user_prompt_content(completion_create_params),
96
+ flush=True,
97
+ )
98
+ except Exception:
99
+ # Printing is best-effort; proceed regardless
100
+ pass
101
+
102
+ workflow = self.build_workflow()
103
+ handler = workflow.run(user_msg=input_message)
104
+
105
+ usage_metrics: UsageMetrics = default_usage_metrics()
106
+
107
+ # Streaming parity with LangGraph: yield incremental deltas during event processing
108
+ if is_streaming(completion_create_params):
109
+
110
+ async def _gen() -> AsyncGenerator[tuple[str, MultiTurnSample | None, UsageMetrics]]:
111
+ events: list[Any] = []
112
+ current_agent_name: str | None = None
113
+ async for event in handler.stream_events():
114
+ events.append(event)
115
+ # Best-effort extraction of incremental text from LlamaIndex events
116
+ delta: str | None = None
117
+ # Agent switch banner if available on event
118
+ try:
119
+ if hasattr(event, "current_agent_name"):
120
+ new_agent = getattr(event, "current_agent_name")
121
+ if (
122
+ isinstance(new_agent, str)
123
+ and new_agent
124
+ and new_agent != current_agent_name
125
+ ):
126
+ current_agent_name = new_agent
127
+ # Print banner for agent switch (do not emit as streamed content)
128
+ print("\n" + "=" * 50, flush=True)
129
+ print(f"🤖 Agent: {current_agent_name}", flush=True)
130
+ print("=" * 50 + "\n", flush=True)
131
+ except Exception:
132
+ pass
133
+
134
+ try:
135
+ if hasattr(event, "delta") and isinstance(getattr(event, "delta"), str):
136
+ delta = getattr(event, "delta")
137
+ # Some event types may carry incremental text under "text" or similar
138
+ elif hasattr(event, "text") and isinstance(getattr(event, "text"), str):
139
+ delta = getattr(event, "text")
140
+ except Exception:
141
+ # Ignore malformed events and continue
142
+ delta = None
143
+
144
+ if delta:
145
+ # Yield token/content delta with current (accumulated) usage metrics
146
+ yield delta, None, usage_metrics
147
+
148
+ # Best-effort debug/event messages printed to CLI (do not stream as content)
149
+ try:
150
+ event_type = type(event).__name__
151
+ if event_type == "AgentInput" and hasattr(event, "input"):
152
+ print("📥 Input:", getattr(event, "input"), flush=True)
153
+ elif event_type == "AgentOutput":
154
+ # Output content
155
+ resp = getattr(event, "response", None)
156
+ if (
157
+ resp is not None
158
+ and hasattr(resp, "content")
159
+ and getattr(resp, "content")
160
+ ):
161
+ print("📤 Output:", getattr(resp, "content"), flush=True)
162
+ # Planned tool calls
163
+ tcalls = getattr(event, "tool_calls", None)
164
+ if isinstance(tcalls, list) and tcalls:
165
+ names = []
166
+ for c in tcalls:
167
+ try:
168
+ nm = getattr(c, "tool_name", None) or (
169
+ c.get("tool_name") if isinstance(c, dict) else None
170
+ )
171
+ if nm:
172
+ names.append(str(nm))
173
+ except Exception:
174
+ pass
175
+ if names:
176
+ print("🛠️ Planning to use tools:", names, flush=True)
177
+ elif event_type == "ToolCallResult":
178
+ tname = getattr(event, "tool_name", None)
179
+ tkwargs = getattr(event, "tool_kwargs", None)
180
+ tout = getattr(event, "tool_output", None)
181
+ print(f"🔧 Tool Result ({tname}):", flush=True)
182
+ print(f" Arguments: {tkwargs}", flush=True)
183
+ print(f" Output: {tout}", flush=True)
184
+ elif event_type == "ToolCall":
185
+ tname = getattr(event, "tool_name", None)
186
+ tkwargs = getattr(event, "tool_kwargs", None)
187
+ print(f"🔨 Calling Tool: {tname}", flush=True)
188
+ print(f" With arguments: {tkwargs}", flush=True)
189
+ except Exception:
190
+ # Ignore best-effort debug rendering errors
191
+ pass
192
+
193
+ # After streaming completes, build final interactions and finish chunk
194
+ # Extract state from workflow context (supports sync/async get or attribute)
195
+ state = None
196
+ ctx = getattr(handler, "ctx", None)
197
+ try:
198
+ if ctx is not None:
199
+ get = getattr(ctx, "get", None)
200
+ if callable(get):
201
+ result = get("state")
202
+ state = await result if inspect.isawaitable(result) else result
203
+ elif hasattr(ctx, "state"):
204
+ state = getattr(ctx, "state")
205
+ except (AttributeError, TypeError):
206
+ state = None
207
+
208
+ # Run subclass-defined response extraction (not streamed) for completeness
209
+ _ = self.extract_response_text(state, events)
210
+
211
+ pipeline_interactions = create_pipeline_interactions_from_events(events)
212
+ # Final empty chunk indicates end of stream, carrying interactions and usage
213
+ yield "", pipeline_interactions, usage_metrics
214
+
215
+ return _gen()
216
+
217
+ # Non-streaming path: run to completion, emit debug prints, then return final response
218
+ events: list[Any] = []
219
+ current_agent_name: str | None = None
220
+ async for event in handler.stream_events():
221
+ events.append(event)
222
+
223
+ # Replicate prior template CLI prints for non-streaming mode
224
+ try:
225
+ if hasattr(event, "current_agent_name"):
226
+ new_agent = getattr(event, "current_agent_name")
227
+ if isinstance(new_agent, str) and new_agent and new_agent != current_agent_name:
228
+ current_agent_name = new_agent
229
+ print(f"\n{'=' * 50}", flush=True)
230
+ print(f"🤖 Agent: {current_agent_name}", flush=True)
231
+ print(f"{'=' * 50}\n", flush=True)
232
+ except Exception:
233
+ pass
234
+
235
+ try:
236
+ if hasattr(event, "delta") and isinstance(getattr(event, "delta"), str):
237
+ print(getattr(event, "delta"), end="", flush=True)
238
+ elif hasattr(event, "text") and isinstance(getattr(event, "text"), str):
239
+ print(getattr(event, "text"), end="", flush=True)
240
+ else:
241
+ event_type = type(event).__name__
242
+ if event_type == "AgentInput" and hasattr(event, "input"):
243
+ print("📥 Input:", getattr(event, "input"), flush=True)
244
+ elif event_type == "AgentOutput":
245
+ resp = getattr(event, "response", None)
246
+ if (
247
+ resp is not None
248
+ and hasattr(resp, "content")
249
+ and getattr(resp, "content")
250
+ ):
251
+ print("📤 Output:", getattr(resp, "content"), flush=True)
252
+ tcalls = getattr(event, "tool_calls", None)
253
+ if isinstance(tcalls, list) and tcalls:
254
+ names: list[str] = []
255
+ for c in tcalls:
256
+ try:
257
+ nm = getattr(c, "tool_name", None) or (
258
+ c.get("tool_name") if isinstance(c, dict) else None
259
+ )
260
+ if nm:
261
+ names.append(str(nm))
262
+ except Exception:
263
+ pass
264
+ if names:
265
+ print("🛠️ Planning to use tools:", names, flush=True)
266
+ elif event_type == "ToolCallResult":
267
+ tname = getattr(event, "tool_name", None)
268
+ tkwargs = getattr(event, "tool_kwargs", None)
269
+ tout = getattr(event, "tool_output", None)
270
+ print(f"🔧 Tool Result ({tname}):", flush=True)
271
+ print(f" Arguments: {tkwargs}", flush=True)
272
+ print(f" Output: {tout}", flush=True)
273
+ elif event_type == "ToolCall":
274
+ tname = getattr(event, "tool_name", None)
275
+ tkwargs = getattr(event, "tool_kwargs", None)
276
+ print(f"🔨 Calling Tool: {tname}", flush=True)
277
+ print(f" With arguments: {tkwargs}", flush=True)
278
+ except Exception:
279
+ # Best-effort debug printing; continue on errors
280
+ pass
281
+
282
+ # Extract state from workflow context (supports sync/async get or attribute)
283
+ state = None
284
+ ctx = getattr(handler, "ctx", None)
285
+ try:
286
+ if ctx is not None:
287
+ get = getattr(ctx, "get", None)
288
+ if callable(get):
289
+ result = get("state")
290
+ state = await result if inspect.isawaitable(result) else result
291
+ elif hasattr(ctx, "state"):
292
+ state = getattr(ctx, "state")
293
+ except (AttributeError, TypeError):
294
+ state = None
295
+ response_text = self.extract_response_text(state, events)
296
+
297
+ pipeline_interactions = create_pipeline_interactions_from_events(events)
298
+
299
+ return response_text, pipeline_interactions, usage_metrics
@@ -0,0 +1,79 @@
1
+ # Copyright 2025 DataRobot, Inc. and its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ MCP integration for LlamaIndex using llama-index-tools-mcp.
17
+
18
+ This module provides MCP server connection management for LlamaIndex agents.
19
+ Unlike CrewAI which uses a context manager, LlamaIndex uses async calls to
20
+ fetch tools from MCP servers.
21
+ """
22
+
23
+ from typing import Any
24
+
25
+ from llama_index.tools.mcp import BasicMCPClient
26
+ from llama_index.tools.mcp import aget_tools_from_mcp_url
27
+
28
+ from datarobot_genai.core.mcp.common import MCPConfig
29
+
30
+
31
+ async def load_mcp_tools(
32
+ authorization_context: dict[str, Any] | None = None,
33
+ forwarded_headers: dict[str, str] | None = None,
34
+ ) -> list[Any]:
35
+ """
36
+ Asynchronously load MCP tools for LlamaIndex.
37
+
38
+ Args:
39
+ authorization_context: Optional authorization context for MCP connections
40
+ forwarded_headers: Optional forwarded headers, e.g. x-datarobot-api-key for MCP auth
41
+
42
+ Returns
43
+ -------
44
+ List of MCP tools, or empty list if no MCP configuration is present.
45
+ """
46
+ config = MCPConfig(
47
+ authorization_context=authorization_context,
48
+ forwarded_headers=forwarded_headers,
49
+ )
50
+ server_params = config.server_config
51
+
52
+ if not server_params:
53
+ print("No MCP server configured, using empty tools list", flush=True)
54
+ return []
55
+
56
+ url = server_params["url"]
57
+ headers = server_params.get("headers", {})
58
+
59
+ try:
60
+ print(f"Connecting to MCP server: {url}", flush=True)
61
+ # Create BasicMCPClient with headers to pass authentication
62
+ client = BasicMCPClient(command_or_url=url, headers=headers)
63
+ tools = await aget_tools_from_mcp_url(
64
+ command_or_url=url,
65
+ client=client,
66
+ )
67
+ # Ensure list
68
+ tools_list = list(tools) if tools is not None else []
69
+ print(
70
+ f"Successfully connected to MCP server, got {len(tools_list)} tools",
71
+ flush=True,
72
+ )
73
+ return tools_list
74
+ except Exception as e:
75
+ print(
76
+ f"Warning: Failed to connect to MCP server {url}: {e}",
77
+ flush=True,
78
+ )
79
+ return []
File without changes
@@ -0,0 +1,275 @@
1
+ # Copyright 2025 DataRobot, Inc. and its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import asyncio
15
+ import logging
16
+ from collections.abc import AsyncGenerator
17
+ from typing import Any
18
+
19
+ from nat.builder.context import Context
20
+ from nat.data_models.api_server import ChatRequest
21
+ from nat.data_models.api_server import ChatResponse
22
+ from nat.data_models.intermediate_step import IntermediateStep
23
+ from nat.data_models.intermediate_step import IntermediateStepType
24
+ from nat.utils.type_utils import StrPath
25
+ from openai.types.chat import CompletionCreateParams
26
+ from ragas import MultiTurnSample
27
+ from ragas.messages import AIMessage
28
+ from ragas.messages import HumanMessage
29
+ from ragas.messages import ToolMessage
30
+
31
+ from datarobot_genai.core.agents.base import BaseAgent
32
+ from datarobot_genai.core.agents.base import InvokeReturn
33
+ from datarobot_genai.core.agents.base import UsageMetrics
34
+ from datarobot_genai.core.agents.base import extract_user_prompt_content
35
+ from datarobot_genai.core.agents.base import is_streaming
36
+ from datarobot_genai.core.mcp.common import MCPConfig
37
+ from datarobot_genai.nat.helpers import load_workflow
38
+
39
+ logger = logging.getLogger(__name__)
40
+
41
+
42
+ def convert_to_ragas_messages(
43
+ steps: list[IntermediateStep],
44
+ ) -> list[HumanMessage | AIMessage | ToolMessage]:
45
+ def _to_ragas(step: IntermediateStep) -> HumanMessage | AIMessage | ToolMessage:
46
+ if step.event_type == IntermediateStepType.LLM_START:
47
+ return HumanMessage(content=_parse(step.data.input))
48
+ elif step.event_type == IntermediateStepType.LLM_END:
49
+ return AIMessage(content=_parse(step.data.output))
50
+ else:
51
+ raise ValueError(f"Unknown event type {step.event_type}")
52
+
53
+ def _include_step(step: IntermediateStep) -> bool:
54
+ return step.event_type in {
55
+ IntermediateStepType.LLM_END,
56
+ IntermediateStepType.LLM_START,
57
+ }
58
+
59
+ def _parse(messages: Any) -> str:
60
+ if isinstance(messages, list):
61
+ last_message = messages[-1]
62
+ else:
63
+ last_message = messages
64
+
65
+ if isinstance(last_message, dict):
66
+ content = last_message.get("content") or last_message
67
+ elif hasattr(last_message, "content"):
68
+ content = getattr(last_message, "content") or last_message
69
+ else:
70
+ content = last_message
71
+ return str(content)
72
+
73
+ return [_to_ragas(step) for step in steps if _include_step(step)]
74
+
75
+
76
+ def create_pipeline_interactions_from_steps(
77
+ steps: list[IntermediateStep],
78
+ ) -> MultiTurnSample | None:
79
+ if not steps:
80
+ return None
81
+ ragas_trace = convert_to_ragas_messages(steps)
82
+ return MultiTurnSample(user_input=ragas_trace)
83
+
84
+
85
+ def pull_intermediate_structured() -> asyncio.Future[list[IntermediateStep]]:
86
+ """
87
+ Subscribe to the runner's event stream using callbacks.
88
+ Intermediate steps are collected and, when complete, the future is set
89
+ with the list of dumped intermediate steps.
90
+ """
91
+ future: asyncio.Future[list[IntermediateStep]] = asyncio.Future()
92
+ intermediate_steps = [] # We'll store the dumped steps here.
93
+ context = Context.get()
94
+
95
+ def on_next_cb(item: IntermediateStep) -> None:
96
+ # Append each new intermediate step to the list.
97
+ intermediate_steps.append(item)
98
+
99
+ def on_error_cb(exc: Exception) -> None:
100
+ logger.error("Hit on_error: %s", exc)
101
+ if not future.done():
102
+ future.set_exception(exc)
103
+
104
+ def on_complete_cb() -> None:
105
+ logger.debug("Completed reading intermediate steps")
106
+ if not future.done():
107
+ future.set_result(intermediate_steps)
108
+
109
+ # Subscribe with our callbacks.
110
+ context.intermediate_step_manager.subscribe(
111
+ on_next=on_next_cb, on_error=on_error_cb, on_complete=on_complete_cb
112
+ )
113
+
114
+ return future
115
+
116
+
117
+ class NatAgent(BaseAgent[None]):
118
+ def __init__(
119
+ self,
120
+ *,
121
+ workflow_path: StrPath,
122
+ api_key: str | None = None,
123
+ api_base: str | None = None,
124
+ model: str | None = None,
125
+ verbose: bool | str | None = True,
126
+ timeout: int | None = 90,
127
+ authorization_context: dict[str, Any] | None = None,
128
+ forwarded_headers: dict[str, str] | None = None,
129
+ **kwargs: Any,
130
+ ) -> None:
131
+ super().__init__(
132
+ api_key=api_key,
133
+ api_base=api_base,
134
+ model=model,
135
+ verbose=verbose,
136
+ timeout=timeout,
137
+ authorization_context=authorization_context,
138
+ forwarded_headers=forwarded_headers,
139
+ **kwargs,
140
+ )
141
+ self.workflow_path = workflow_path
142
+
143
+ def make_chat_request(self, completion_create_params: CompletionCreateParams) -> ChatRequest:
144
+ user_prompt_content = str(extract_user_prompt_content(completion_create_params))
145
+ return ChatRequest.from_string(user_prompt_content)
146
+
147
+ async def invoke(self, completion_create_params: CompletionCreateParams) -> InvokeReturn:
148
+ """Run the agent with the provided completion parameters.
149
+
150
+ [THIS METHOD IS REQUIRED FOR THE AGENT TO WORK WITH DRUM SERVER]
151
+
152
+ Args:
153
+ completion_create_params: The completion request parameters including input topic
154
+ and settings.
155
+
156
+ Returns
157
+ -------
158
+ For streaming requests, returns a generator yielding tuples of (response_text,
159
+ pipeline_interactions, usage_metrics).
160
+ For non-streaming requests, returns a single tuple of (response_text,
161
+ pipeline_interactions, usage_metrics).
162
+
163
+ """
164
+ # Retrieve the starting chat request from the CompletionCreateParams
165
+ chat_request = self.make_chat_request(completion_create_params)
166
+
167
+ # Print commands may need flush=True to ensure they are displayed in real-time.
168
+ print("Running agent with user prompt:", chat_request.messages[0].content, flush=True)
169
+
170
+ mcp_config = MCPConfig(
171
+ authorization_context=self.authorization_context,
172
+ forwarded_headers=self.forwarded_headers,
173
+ )
174
+ server_config = mcp_config.server_config
175
+ headers = server_config["headers"] if server_config else None
176
+
177
+ if is_streaming(completion_create_params):
178
+
179
+ async def stream_generator() -> AsyncGenerator[
180
+ tuple[str, MultiTurnSample | None, UsageMetrics], None
181
+ ]:
182
+ default_usage_metrics: UsageMetrics = {
183
+ "completion_tokens": 0,
184
+ "prompt_tokens": 0,
185
+ "total_tokens": 0,
186
+ }
187
+ async with load_workflow(self.workflow_path, headers=headers) as workflow:
188
+ async with workflow.run(chat_request) as runner:
189
+ intermediate_future = pull_intermediate_structured()
190
+ async for result in runner.result_stream():
191
+ if isinstance(result, ChatResponse):
192
+ result_text = result.choices[0].message.content
193
+ else:
194
+ result_text = str(result)
195
+
196
+ yield (
197
+ result_text,
198
+ None,
199
+ default_usage_metrics,
200
+ )
201
+
202
+ steps = await intermediate_future
203
+ llm_end_steps = [
204
+ step
205
+ for step in steps
206
+ if step.event_type == IntermediateStepType.LLM_END
207
+ ]
208
+ usage_metrics: UsageMetrics = {
209
+ "completion_tokens": 0,
210
+ "prompt_tokens": 0,
211
+ "total_tokens": 0,
212
+ }
213
+ for step in llm_end_steps:
214
+ if step.usage_info:
215
+ token_usage = step.usage_info.token_usage
216
+ usage_metrics["total_tokens"] += token_usage.total_tokens
217
+ usage_metrics["prompt_tokens"] += token_usage.prompt_tokens
218
+ usage_metrics["completion_tokens"] += token_usage.completion_tokens
219
+
220
+ pipeline_interactions = create_pipeline_interactions_from_steps(steps)
221
+ yield "", pipeline_interactions, usage_metrics
222
+
223
+ return stream_generator()
224
+
225
+ # Create and invoke the NAT (Nemo Agent Toolkit) Agentic Workflow with the inputs
226
+ result, steps = await self.run_nat_workflow(self.workflow_path, chat_request, headers)
227
+
228
+ llm_end_steps = [step for step in steps if step.event_type == IntermediateStepType.LLM_END]
229
+ usage_metrics: UsageMetrics = {
230
+ "completion_tokens": 0,
231
+ "prompt_tokens": 0,
232
+ "total_tokens": 0,
233
+ }
234
+ for step in llm_end_steps:
235
+ if step.usage_info:
236
+ token_usage = step.usage_info.token_usage
237
+ usage_metrics["total_tokens"] += token_usage.total_tokens
238
+ usage_metrics["prompt_tokens"] += token_usage.prompt_tokens
239
+ usage_metrics["completion_tokens"] += token_usage.completion_tokens
240
+
241
+ if isinstance(result, ChatResponse):
242
+ result_text = result.choices[0].message.content
243
+ else:
244
+ result_text = str(result)
245
+ pipeline_interactions = create_pipeline_interactions_from_steps(steps)
246
+
247
+ return result_text, pipeline_interactions, usage_metrics
248
+
249
+ async def run_nat_workflow(
250
+ self, workflow_path: StrPath, chat_request: ChatRequest, headers: dict[str, str] | None
251
+ ) -> tuple[ChatResponse | str, list[IntermediateStep]]:
252
+ """Run the NAT workflow with the provided config file and input string.
253
+
254
+ Args:
255
+ workflow_path: Path to the NAT workflow configuration file
256
+ input_str: Input string to process through the workflow
257
+
258
+ Returns
259
+ -------
260
+ ChatResponse | str: The result from the NAT workflow
261
+ list[IntermediateStep]: The list of intermediate steps
262
+ """
263
+ async with load_workflow(workflow_path, headers=headers) as workflow:
264
+ async with workflow.run(chat_request) as runner:
265
+ intermediate_future = pull_intermediate_structured()
266
+ runner_outputs = await runner.result()
267
+ steps = await intermediate_future
268
+
269
+ line = f"{'-' * 50}"
270
+ prefix = f"{line}\nWorkflow Result:\n"
271
+ suffix = f"\n{line}"
272
+
273
+ print(f"{prefix}{runner_outputs}{suffix}")
274
+
275
+ return runner_outputs, steps