datarobot-genai 0.2.39__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. datarobot_genai/core/agents/__init__.py +1 -1
  2. datarobot_genai/core/agents/base.py +5 -2
  3. datarobot_genai/core/chat/responses.py +6 -1
  4. datarobot_genai/core/utils/auth.py +188 -31
  5. datarobot_genai/crewai/__init__.py +1 -4
  6. datarobot_genai/crewai/agent.py +150 -17
  7. datarobot_genai/crewai/events.py +11 -4
  8. datarobot_genai/drmcp/__init__.py +4 -2
  9. datarobot_genai/drmcp/core/config.py +21 -1
  10. datarobot_genai/drmcp/core/mcp_instance.py +5 -49
  11. datarobot_genai/drmcp/core/routes.py +108 -13
  12. datarobot_genai/drmcp/core/tool_config.py +16 -0
  13. datarobot_genai/drmcp/core/utils.py +110 -0
  14. datarobot_genai/drmcp/test_utils/tool_base_ete.py +41 -26
  15. datarobot_genai/drmcp/tools/clients/gdrive.py +2 -0
  16. datarobot_genai/drmcp/tools/clients/microsoft_graph.py +96 -0
  17. datarobot_genai/drmcp/tools/clients/perplexity.py +173 -0
  18. datarobot_genai/drmcp/tools/clients/tavily.py +199 -0
  19. datarobot_genai/drmcp/tools/confluence/tools.py +0 -5
  20. datarobot_genai/drmcp/tools/gdrive/tools.py +12 -59
  21. datarobot_genai/drmcp/tools/jira/tools.py +4 -8
  22. datarobot_genai/drmcp/tools/microsoft_graph/tools.py +135 -19
  23. datarobot_genai/drmcp/tools/perplexity/__init__.py +0 -0
  24. datarobot_genai/drmcp/tools/perplexity/tools.py +117 -0
  25. datarobot_genai/drmcp/tools/predictive/data.py +1 -9
  26. datarobot_genai/drmcp/tools/predictive/deployment.py +0 -8
  27. datarobot_genai/drmcp/tools/predictive/deployment_info.py +0 -19
  28. datarobot_genai/drmcp/tools/predictive/model.py +0 -21
  29. datarobot_genai/drmcp/tools/predictive/predict_realtime.py +3 -0
  30. datarobot_genai/drmcp/tools/predictive/project.py +3 -19
  31. datarobot_genai/drmcp/tools/predictive/training.py +1 -19
  32. datarobot_genai/drmcp/tools/tavily/__init__.py +13 -0
  33. datarobot_genai/drmcp/tools/tavily/tools.py +141 -0
  34. datarobot_genai/langgraph/agent.py +10 -2
  35. datarobot_genai/llama_index/__init__.py +1 -1
  36. datarobot_genai/llama_index/agent.py +284 -5
  37. datarobot_genai/nat/agent.py +17 -6
  38. {datarobot_genai-0.2.39.dist-info → datarobot_genai-0.3.1.dist-info}/METADATA +3 -1
  39. {datarobot_genai-0.2.39.dist-info → datarobot_genai-0.3.1.dist-info}/RECORD +43 -40
  40. datarobot_genai/crewai/base.py +0 -159
  41. datarobot_genai/drmcp/core/tool_filter.py +0 -117
  42. datarobot_genai/llama_index/base.py +0 -299
  43. {datarobot_genai-0.2.39.dist-info → datarobot_genai-0.3.1.dist-info}/WHEEL +0 -0
  44. {datarobot_genai-0.2.39.dist-info → datarobot_genai-0.3.1.dist-info}/entry_points.txt +0 -0
  45. {datarobot_genai-0.2.39.dist-info → datarobot_genai-0.3.1.dist-info}/licenses/AUTHORS +0 -0
  46. {datarobot_genai-0.2.39.dist-info → datarobot_genai-0.3.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,141 @@
1
+ # Copyright 2025 DataRobot, Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Tavily MCP tools for web search."""
16
+
17
+ import logging
18
+ from typing import Annotated
19
+ from typing import Literal
20
+
21
+ from fastmcp.tools.tool import ToolResult
22
+
23
+ from datarobot_genai.drmcp.core.mcp_instance import dr_mcp_tool
24
+ from datarobot_genai.drmcp.tools.clients.tavily import CHUNKS_PER_SOURCE_DEFAULT
25
+ from datarobot_genai.drmcp.tools.clients.tavily import MAX_CHUNKS_PER_SOURCE
26
+ from datarobot_genai.drmcp.tools.clients.tavily import MAX_RESULTS
27
+ from datarobot_genai.drmcp.tools.clients.tavily import MAX_RESULTS_DEFAULT
28
+ from datarobot_genai.drmcp.tools.clients.tavily import TavilyClient
29
+ from datarobot_genai.drmcp.tools.clients.tavily import TavilyImage
30
+ from datarobot_genai.drmcp.tools.clients.tavily import TavilySearchResult
31
+ from datarobot_genai.drmcp.tools.clients.tavily import get_tavily_access_token
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+
36
+ @dr_mcp_tool(tags={"tavily", "search", "web", "websearch"})
37
+ async def tavily_search(
38
+ *,
39
+ query: Annotated[str, "The search query to execute."],
40
+ topic: Annotated[
41
+ Literal["general", "news", "finance"],
42
+ "The category of search. Use 'general' for broad web search, "
43
+ "'news' for recent news articles, or 'finance' for financial information.",
44
+ ] = "general",
45
+ search_depth: Annotated[
46
+ Literal["basic", "advanced"],
47
+ "The depth of search. 'basic' is faster and cheaper, "
48
+ "'advanced' provides more comprehensive results.",
49
+ ] = "basic",
50
+ max_results: Annotated[
51
+ int,
52
+ f"Maximum number of search results to return (1-{MAX_RESULTS}).",
53
+ ] = MAX_RESULTS_DEFAULT,
54
+ time_range: Annotated[
55
+ Literal["day", "week", "month", "year"] | None,
56
+ "Filter results by time range. Use 'day' for last 24 hours, "
57
+ "'week' for last 7 days, 'month' for last 30 days, or 'year' for last year.",
58
+ ] = None,
59
+ include_images: Annotated[
60
+ bool,
61
+ "Whether to include related images in the search results.",
62
+ ] = False,
63
+ include_image_descriptions: Annotated[
64
+ bool,
65
+ "Whether to include AI-generated descriptions for images. "
66
+ "Only applicable when include_images is True.",
67
+ ] = False,
68
+ chunks_per_source: Annotated[
69
+ int,
70
+ f"Maximum number of content snippets to return per source URL (1-{MAX_CHUNKS_PER_SOURCE}).",
71
+ ] = CHUNKS_PER_SOURCE_DEFAULT,
72
+ include_answer: Annotated[
73
+ bool,
74
+ "Whether to include an AI-generated answer summarizing the search results.",
75
+ ] = False,
76
+ ) -> ToolResult:
77
+ """
78
+ Perform a real-time web search using Tavily API.
79
+
80
+ Tavily is optimized for AI agents and provides clean, relevant search results
81
+ suitable for LLM consumption. Use this tool to search the web for current
82
+ information, news, financial data, or general knowledge.
83
+
84
+ Usage:
85
+ - Basic search: tavily_search(query="Python best practices 2026")
86
+ - News search: tavily_search(query="AI regulations", topic="news", time_range="week")
87
+ - Financial search: tavily_search(query="AAPL stock analysis", topic="finance")
88
+ - Comprehensive search: tavily_search(
89
+ query="climate change solutions",
90
+ search_depth="advanced",
91
+ max_results=10,
92
+ include_answer=True
93
+ )
94
+
95
+ Note:
96
+ - Advanced search depth consumes more API credits but provides better results
97
+ - Time range filtering is useful for finding recent information
98
+ """
99
+ api_key = await get_tavily_access_token()
100
+
101
+ async with TavilyClient(api_key) as client:
102
+ response = await client.search(
103
+ query=query,
104
+ topic=topic,
105
+ search_depth=search_depth,
106
+ max_results=max_results,
107
+ time_range=time_range,
108
+ include_images=include_images,
109
+ include_image_descriptions=include_image_descriptions,
110
+ chunks_per_source=chunks_per_source,
111
+ include_answer=include_answer,
112
+ )
113
+
114
+ results = [TavilySearchResult.from_tavily_sdk(r) for r in response.get("results", [])]
115
+
116
+ images: list[TavilyImage] | None = None
117
+ if include_images and response.get("images"):
118
+ images = [TavilyImage.from_tavily_sdk(img) for img in response.get("images", [])]
119
+
120
+ result_count = len(results)
121
+ answer = response.get("answer")
122
+ response_time = response.get("response_time", 0.0)
123
+
124
+ structured_content: dict = {
125
+ "query": response.get("query", query),
126
+ "results": [r.as_flat_dict() for r in results],
127
+ "resultCount": result_count,
128
+ "responseTime": response_time,
129
+ }
130
+
131
+ if answer:
132
+ structured_content["answer"] = answer
133
+
134
+ if images:
135
+ structured_content["images"] = [
136
+ {"url": img.url, "description": img.description} for img in images
137
+ ]
138
+
139
+ return ToolResult(
140
+ structured_content=structured_content,
141
+ )
@@ -11,9 +11,12 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+ from __future__ import annotations
15
+
14
16
  import abc
15
17
  import logging
16
18
  from collections.abc import AsyncGenerator
19
+ from typing import TYPE_CHECKING
17
20
  from typing import Any
18
21
  from typing import cast
19
22
 
@@ -34,8 +37,6 @@ from langgraph.graph import MessagesState
34
37
  from langgraph.graph import StateGraph
35
38
  from langgraph.types import Command
36
39
  from openai.types.chat import CompletionCreateParams
37
- from ragas import MultiTurnSample
38
- from ragas.integrations.langgraph import convert_to_ragas_messages
39
40
 
40
41
  from datarobot_genai.core.agents.base import BaseAgent
41
42
  from datarobot_genai.core.agents.base import InvokeReturn
@@ -44,6 +45,9 @@ from datarobot_genai.core.agents.base import extract_user_prompt_content
44
45
  from datarobot_genai.core.agents.base import is_streaming
45
46
  from datarobot_genai.langgraph.mcp import mcp_tools_context
46
47
 
48
+ if TYPE_CHECKING:
49
+ from ragas import MultiTurnSample
50
+
47
51
  logger = logging.getLogger(__name__)
48
52
 
49
53
 
@@ -337,5 +341,9 @@ class LangGraphAgent(BaseAgent[BaseTool], abc.ABC):
337
341
  if v is not None:
338
342
  messages.extend(v.get("messages", []))
339
343
  messages = [m for m in messages if not isinstance(m, ToolMessage)]
344
+ # Lazy import to reduce memory overhead when ragas is not used
345
+ from ragas import MultiTurnSample
346
+ from ragas.integrations.langgraph import convert_to_ragas_messages
347
+
340
348
  ragas_trace = convert_to_ragas_messages(messages)
341
349
  return MultiTurnSample(user_input=ragas_trace)
@@ -3,8 +3,8 @@
3
3
  from datarobot_genai.core.mcp.common import MCPConfig
4
4
 
5
5
  from .agent import DataRobotLiteLLM
6
+ from .agent import LlamaIndexAgent
6
7
  from .agent import create_pipeline_interactions_from_events
7
- from .base import LlamaIndexAgent
8
8
  from .mcp import load_mcp_tools
9
9
 
10
10
  __all__ = [
@@ -11,17 +11,32 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+ from __future__ import annotations
14
15
 
16
+ import abc
17
+ import inspect
18
+ from collections.abc import AsyncGenerator
19
+ from typing import TYPE_CHECKING
20
+ from typing import Any
15
21
  from typing import cast
16
22
 
17
23
  from llama_index.core.base.llms.types import LLMMetadata
24
+ from llama_index.core.tools import BaseTool
18
25
  from llama_index.core.workflow import Event
19
26
  from llama_index.llms.litellm import LiteLLM
20
- from ragas import MultiTurnSample
21
- from ragas.integrations.llama_index import convert_to_ragas_messages
22
- from ragas.messages import AIMessage
23
- from ragas.messages import HumanMessage
24
- from ragas.messages import ToolMessage
27
+ from openai.types.chat import CompletionCreateParams
28
+
29
+ from datarobot_genai.core.agents.base import BaseAgent
30
+ from datarobot_genai.core.agents.base import InvokeReturn
31
+ from datarobot_genai.core.agents.base import UsageMetrics
32
+ from datarobot_genai.core.agents.base import default_usage_metrics
33
+ from datarobot_genai.core.agents.base import extract_user_prompt_content
34
+ from datarobot_genai.core.agents.base import is_streaming
35
+
36
+ from .mcp import load_mcp_tools
37
+
38
+ if TYPE_CHECKING:
39
+ from ragas import MultiTurnSample
25
40
 
26
41
 
27
42
  class DataRobotLiteLLM(LiteLLM):
@@ -44,7 +59,271 @@ def create_pipeline_interactions_from_events(
44
59
  ) -> MultiTurnSample | None:
45
60
  if not events:
46
61
  return None
62
+ # Lazy import to reduce memory overhead when ragas is not used
63
+ from ragas import MultiTurnSample
64
+ from ragas.integrations.llama_index import convert_to_ragas_messages
65
+ from ragas.messages import AIMessage
66
+ from ragas.messages import HumanMessage
67
+ from ragas.messages import ToolMessage
68
+
47
69
  # convert_to_ragas_messages expects a list[Event]
48
70
  ragas_trace = convert_to_ragas_messages(list(events))
49
71
  ragas_messages = cast(list[HumanMessage | AIMessage | ToolMessage], ragas_trace)
50
72
  return MultiTurnSample(user_input=ragas_messages)
73
+
74
+
75
+ class LlamaIndexAgent(BaseAgent[BaseTool], abc.ABC):
76
+ """Abstract base agent for LlamaIndex workflows."""
77
+
78
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
79
+ super().__init__(*args, **kwargs)
80
+ self._mcp_tools: list[Any] = []
81
+
82
+ def set_mcp_tools(self, tools: list[Any]) -> None:
83
+ """Set MCP tools for this agent."""
84
+ self._mcp_tools = tools
85
+
86
+ @property
87
+ def mcp_tools(self) -> list[Any]:
88
+ """Return the list of MCP tools available to this agent.
89
+
90
+ Subclasses can use this to wire tools into LlamaIndex agents during
91
+ workflow construction inside ``build_workflow``.
92
+ """
93
+ return self._mcp_tools
94
+
95
+ @abc.abstractmethod
96
+ def build_workflow(self) -> Any:
97
+ """Return an AgentWorkflow instance ready to run."""
98
+ raise NotImplementedError
99
+
100
+ @abc.abstractmethod
101
+ def extract_response_text(self, result_state: Any, events: list[Any]) -> str:
102
+ """Extract final response text from workflow state and/or events."""
103
+ raise NotImplementedError
104
+
105
+ def make_input_message(self, completion_create_params: CompletionCreateParams) -> str:
106
+ """Create an input string for the workflow from the user prompt."""
107
+ user_prompt_content = extract_user_prompt_content(completion_create_params)
108
+ return str(user_prompt_content)
109
+
110
+ async def invoke(self, completion_create_params: CompletionCreateParams) -> InvokeReturn:
111
+ """Run the LlamaIndex workflow with the provided completion parameters."""
112
+ input_message = self.make_input_message(completion_create_params)
113
+
114
+ # Load MCP tools (if configured) asynchronously before building workflow
115
+ mcp_tools = await load_mcp_tools(
116
+ authorization_context=self._authorization_context,
117
+ forwarded_headers=self.forwarded_headers,
118
+ )
119
+ self.set_mcp_tools(mcp_tools)
120
+
121
+ # Preserve prior template startup print for CLI parity
122
+ try:
123
+ print(
124
+ "Running agent with user prompt:",
125
+ extract_user_prompt_content(completion_create_params),
126
+ flush=True,
127
+ )
128
+ except Exception:
129
+ # Printing is best-effort; proceed regardless
130
+ pass
131
+
132
+ workflow = self.build_workflow()
133
+ handler = workflow.run(user_msg=input_message)
134
+
135
+ usage_metrics: UsageMetrics = default_usage_metrics()
136
+
137
+ # Streaming parity with LangGraph: yield incremental deltas during event processing
138
+ if is_streaming(completion_create_params):
139
+
140
+ async def _gen() -> AsyncGenerator[tuple[str, MultiTurnSample | None, UsageMetrics]]:
141
+ events: list[Any] = []
142
+ current_agent_name: str | None = None
143
+ async for event in handler.stream_events():
144
+ events.append(event)
145
+ # Best-effort extraction of incremental text from LlamaIndex events
146
+ delta: str | None = None
147
+ # Agent switch banner if available on event
148
+ try:
149
+ if hasattr(event, "current_agent_name"):
150
+ new_agent = getattr(event, "current_agent_name")
151
+ if (
152
+ isinstance(new_agent, str)
153
+ and new_agent
154
+ and new_agent != current_agent_name
155
+ ):
156
+ current_agent_name = new_agent
157
+ # Print banner for agent switch (do not emit as streamed content)
158
+ print("\n" + "=" * 50, flush=True)
159
+ print(f"🤖 Agent: {current_agent_name}", flush=True)
160
+ print("=" * 50 + "\n", flush=True)
161
+ except Exception:
162
+ pass
163
+
164
+ try:
165
+ if hasattr(event, "delta") and isinstance(getattr(event, "delta"), str):
166
+ delta = getattr(event, "delta")
167
+ # Some event types may carry incremental text under "text" or similar
168
+ elif hasattr(event, "text") and isinstance(getattr(event, "text"), str):
169
+ delta = getattr(event, "text")
170
+ except Exception:
171
+ # Ignore malformed events and continue
172
+ delta = None
173
+
174
+ if delta:
175
+ # Yield token/content delta with current (accumulated) usage metrics
176
+ yield delta, None, usage_metrics
177
+
178
+ # Best-effort debug/event messages printed to CLI (do not stream as content)
179
+ try:
180
+ event_type = type(event).__name__
181
+ if event_type == "AgentInput" and hasattr(event, "input"):
182
+ print("📥 Input:", getattr(event, "input"), flush=True)
183
+ elif event_type == "AgentOutput":
184
+ # Output content
185
+ resp = getattr(event, "response", None)
186
+ if (
187
+ resp is not None
188
+ and hasattr(resp, "content")
189
+ and getattr(resp, "content")
190
+ ):
191
+ print("📤 Output:", getattr(resp, "content"), flush=True)
192
+ # Planned tool calls
193
+ tcalls = getattr(event, "tool_calls", None)
194
+ if isinstance(tcalls, list) and tcalls:
195
+ names = []
196
+ for c in tcalls:
197
+ try:
198
+ nm = getattr(c, "tool_name", None) or (
199
+ c.get("tool_name") if isinstance(c, dict) else None
200
+ )
201
+ if nm:
202
+ names.append(str(nm))
203
+ except Exception:
204
+ pass
205
+ if names:
206
+ print("🛠️ Planning to use tools:", names, flush=True)
207
+ elif event_type == "ToolCallResult":
208
+ tname = getattr(event, "tool_name", None)
209
+ tkwargs = getattr(event, "tool_kwargs", None)
210
+ tout = getattr(event, "tool_output", None)
211
+ print(f"🔧 Tool Result ({tname}):", flush=True)
212
+ print(f" Arguments: {tkwargs}", flush=True)
213
+ print(f" Output: {tout}", flush=True)
214
+ elif event_type == "ToolCall":
215
+ tname = getattr(event, "tool_name", None)
216
+ tkwargs = getattr(event, "tool_kwargs", None)
217
+ print(f"🔨 Calling Tool: {tname}", flush=True)
218
+ print(f" With arguments: {tkwargs}", flush=True)
219
+ except Exception:
220
+ # Ignore best-effort debug rendering errors
221
+ pass
222
+
223
+ # After streaming completes, build final interactions and finish chunk
224
+ # Extract state from workflow context (supports sync/async get or attribute)
225
+ state = None
226
+ ctx = getattr(handler, "ctx", None)
227
+ try:
228
+ if ctx is not None:
229
+ get = getattr(ctx, "get", None)
230
+ if callable(get):
231
+ result = get("state")
232
+ state = await result if inspect.isawaitable(result) else result
233
+ elif hasattr(ctx, "state"):
234
+ state = getattr(ctx, "state")
235
+ except (AttributeError, TypeError):
236
+ state = None
237
+
238
+ # Run subclass-defined response extraction (not streamed) for completeness
239
+ _ = self.extract_response_text(state, events)
240
+
241
+ pipeline_interactions = create_pipeline_interactions_from_events(events)
242
+ # Final empty chunk indicates end of stream, carrying interactions and usage
243
+ yield "", pipeline_interactions, usage_metrics
244
+
245
+ return _gen()
246
+
247
+ # Non-streaming path: run to completion, emit debug prints, then return final response
248
+ events: list[Any] = []
249
+ current_agent_name: str | None = None
250
+ async for event in handler.stream_events():
251
+ events.append(event)
252
+
253
+ # Replicate prior template CLI prints for non-streaming mode
254
+ try:
255
+ if hasattr(event, "current_agent_name"):
256
+ new_agent = getattr(event, "current_agent_name")
257
+ if isinstance(new_agent, str) and new_agent and new_agent != current_agent_name:
258
+ current_agent_name = new_agent
259
+ print(f"\n{'=' * 50}", flush=True)
260
+ print(f"🤖 Agent: {current_agent_name}", flush=True)
261
+ print(f"{'=' * 50}\n", flush=True)
262
+ except Exception:
263
+ pass
264
+
265
+ try:
266
+ if hasattr(event, "delta") and isinstance(getattr(event, "delta"), str):
267
+ print(getattr(event, "delta"), end="", flush=True)
268
+ elif hasattr(event, "text") and isinstance(getattr(event, "text"), str):
269
+ print(getattr(event, "text"), end="", flush=True)
270
+ else:
271
+ event_type = type(event).__name__
272
+ if event_type == "AgentInput" and hasattr(event, "input"):
273
+ print("📥 Input:", getattr(event, "input"), flush=True)
274
+ elif event_type == "AgentOutput":
275
+ resp = getattr(event, "response", None)
276
+ if (
277
+ resp is not None
278
+ and hasattr(resp, "content")
279
+ and getattr(resp, "content")
280
+ ):
281
+ print("📤 Output:", getattr(resp, "content"), flush=True)
282
+ tcalls = getattr(event, "tool_calls", None)
283
+ if isinstance(tcalls, list) and tcalls:
284
+ names: list[str] = []
285
+ for c in tcalls:
286
+ try:
287
+ nm = getattr(c, "tool_name", None) or (
288
+ c.get("tool_name") if isinstance(c, dict) else None
289
+ )
290
+ if nm:
291
+ names.append(str(nm))
292
+ except Exception:
293
+ pass
294
+ if names:
295
+ print("🛠️ Planning to use tools:", names, flush=True)
296
+ elif event_type == "ToolCallResult":
297
+ tname = getattr(event, "tool_name", None)
298
+ tkwargs = getattr(event, "tool_kwargs", None)
299
+ tout = getattr(event, "tool_output", None)
300
+ print(f"🔧 Tool Result ({tname}):", flush=True)
301
+ print(f" Arguments: {tkwargs}", flush=True)
302
+ print(f" Output: {tout}", flush=True)
303
+ elif event_type == "ToolCall":
304
+ tname = getattr(event, "tool_name", None)
305
+ tkwargs = getattr(event, "tool_kwargs", None)
306
+ print(f"🔨 Calling Tool: {tname}", flush=True)
307
+ print(f" With arguments: {tkwargs}", flush=True)
308
+ except Exception:
309
+ # Best-effort debug printing; continue on errors
310
+ pass
311
+
312
+ # Extract state from workflow context (supports sync/async get or attribute)
313
+ state = None
314
+ ctx = getattr(handler, "ctx", None)
315
+ try:
316
+ if ctx is not None:
317
+ get = getattr(ctx, "get", None)
318
+ if callable(get):
319
+ result = get("state")
320
+ state = await result if inspect.isawaitable(result) else result
321
+ elif hasattr(ctx, "state"):
322
+ state = getattr(ctx, "state")
323
+ except (AttributeError, TypeError):
324
+ state = None
325
+ response_text = self.extract_response_text(state, events)
326
+
327
+ pipeline_interactions = create_pipeline_interactions_from_events(events)
328
+
329
+ return response_text, pipeline_interactions, usage_metrics
@@ -11,9 +11,12 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+ from __future__ import annotations
15
+
14
16
  import asyncio
15
17
  import logging
16
18
  from collections.abc import AsyncGenerator
19
+ from typing import TYPE_CHECKING
17
20
  from typing import Any
18
21
 
19
22
  from nat.builder.context import Context
@@ -23,10 +26,6 @@ from nat.data_models.intermediate_step import IntermediateStep
23
26
  from nat.data_models.intermediate_step import IntermediateStepType
24
27
  from nat.utils.type_utils import StrPath
25
28
  from openai.types.chat import CompletionCreateParams
26
- from ragas import MultiTurnSample
27
- from ragas.messages import AIMessage
28
- from ragas.messages import HumanMessage
29
- from ragas.messages import ToolMessage
30
29
 
31
30
  from datarobot_genai.core.agents.base import BaseAgent
32
31
  from datarobot_genai.core.agents.base import InvokeReturn
@@ -36,13 +35,22 @@ from datarobot_genai.core.agents.base import is_streaming
36
35
  from datarobot_genai.core.mcp.common import MCPConfig
37
36
  from datarobot_genai.nat.helpers import load_workflow
38
37
 
38
+ if TYPE_CHECKING:
39
+ from ragas import MultiTurnSample
40
+ from ragas.messages import AIMessage
41
+ from ragas.messages import HumanMessage
42
+
39
43
  logger = logging.getLogger(__name__)
40
44
 
41
45
 
42
46
  def convert_to_ragas_messages(
43
47
  steps: list[IntermediateStep],
44
- ) -> list[HumanMessage | AIMessage | ToolMessage]:
45
- def _to_ragas(step: IntermediateStep) -> HumanMessage | AIMessage | ToolMessage:
48
+ ) -> list[HumanMessage | AIMessage]:
49
+ # Lazy import to reduce memory overhead when ragas is not used
50
+ from ragas.messages import AIMessage
51
+ from ragas.messages import HumanMessage
52
+
53
+ def _to_ragas(step: IntermediateStep) -> HumanMessage | AIMessage:
46
54
  if step.event_type == IntermediateStepType.LLM_START:
47
55
  return HumanMessage(content=_parse(step.data.input))
48
56
  elif step.event_type == IntermediateStepType.LLM_END:
@@ -78,6 +86,9 @@ def create_pipeline_interactions_from_steps(
78
86
  ) -> MultiTurnSample | None:
79
87
  if not steps:
80
88
  return None
89
+ # Lazy import to reduce memory overhead when ragas is not used
90
+ from ragas import MultiTurnSample
91
+
81
92
  ragas_trace = convert_to_ragas_messages(steps)
82
93
  return MultiTurnSample(user_input=ragas_trace)
83
94
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: datarobot-genai
3
- Version: 0.2.39
3
+ Version: 0.3.1
4
4
  Summary: Generic helpers for GenAI
5
5
  Project-URL: Homepage, https://github.com/datarobot-oss/datarobot-genai
6
6
  Author: DataRobot, Inc.
@@ -43,9 +43,11 @@ Requires-Dist: opentelemetry-api<2.0.0,>=1.22.0; extra == 'drmcp'
43
43
  Requires-Dist: opentelemetry-exporter-otlp-proto-http<2.0.0,>=1.22.0; extra == 'drmcp'
44
44
  Requires-Dist: opentelemetry-exporter-otlp<2.0.0,>=1.22.0; extra == 'drmcp'
45
45
  Requires-Dist: opentelemetry-sdk<2.0.0,>=1.22.0; extra == 'drmcp'
46
+ Requires-Dist: perplexityai<1.0,>=0.27; extra == 'drmcp'
46
47
  Requires-Dist: pydantic-settings<3.0.0,>=2.1.0; extra == 'drmcp'
47
48
  Requires-Dist: pydantic<3.0.0,>=2.6.1; extra == 'drmcp'
48
49
  Requires-Dist: python-dotenv<2.0.0,>=1.1.0; extra == 'drmcp'
50
+ Requires-Dist: tavily-python<1.0.0,>=0.7.20; extra == 'drmcp'
49
51
  Provides-Extra: langgraph
50
52
  Requires-Dist: langchain-mcp-adapters<0.2.0,>=0.1.12; extra == 'langgraph'
51
53
  Requires-Dist: langgraph-prebuilt<0.7.0,>=0.2.3; extra == 'langgraph'