datarobot-genai 0.2.37__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- datarobot_genai/core/agents/__init__.py +1 -1
- datarobot_genai/core/agents/base.py +5 -2
- datarobot_genai/core/chat/responses.py +6 -1
- datarobot_genai/core/utils/auth.py +188 -31
- datarobot_genai/crewai/__init__.py +1 -4
- datarobot_genai/crewai/agent.py +150 -17
- datarobot_genai/crewai/events.py +11 -4
- datarobot_genai/drmcp/__init__.py +4 -2
- datarobot_genai/drmcp/core/config.py +21 -1
- datarobot_genai/drmcp/core/mcp_instance.py +5 -49
- datarobot_genai/drmcp/core/routes.py +108 -13
- datarobot_genai/drmcp/core/tool_config.py +16 -0
- datarobot_genai/drmcp/core/utils.py +110 -0
- datarobot_genai/drmcp/test_utils/tool_base_ete.py +41 -26
- datarobot_genai/drmcp/tools/clients/gdrive.py +2 -0
- datarobot_genai/drmcp/tools/clients/microsoft_graph.py +141 -0
- datarobot_genai/drmcp/tools/clients/perplexity.py +173 -0
- datarobot_genai/drmcp/tools/clients/tavily.py +199 -0
- datarobot_genai/drmcp/tools/confluence/tools.py +43 -94
- datarobot_genai/drmcp/tools/gdrive/tools.py +44 -133
- datarobot_genai/drmcp/tools/jira/tools.py +19 -41
- datarobot_genai/drmcp/tools/microsoft_graph/tools.py +201 -32
- datarobot_genai/drmcp/tools/perplexity/__init__.py +0 -0
- datarobot_genai/drmcp/tools/perplexity/tools.py +117 -0
- datarobot_genai/drmcp/tools/predictive/data.py +1 -9
- datarobot_genai/drmcp/tools/predictive/deployment.py +0 -8
- datarobot_genai/drmcp/tools/predictive/deployment_info.py +91 -117
- datarobot_genai/drmcp/tools/predictive/model.py +0 -21
- datarobot_genai/drmcp/tools/predictive/predict_realtime.py +3 -0
- datarobot_genai/drmcp/tools/predictive/project.py +3 -19
- datarobot_genai/drmcp/tools/predictive/training.py +1 -19
- datarobot_genai/drmcp/tools/tavily/__init__.py +13 -0
- datarobot_genai/drmcp/tools/tavily/tools.py +141 -0
- datarobot_genai/langgraph/agent.py +10 -2
- datarobot_genai/llama_index/__init__.py +1 -1
- datarobot_genai/llama_index/agent.py +284 -5
- datarobot_genai/nat/agent.py +17 -6
- {datarobot_genai-0.2.37.dist-info → datarobot_genai-0.3.1.dist-info}/METADATA +3 -1
- {datarobot_genai-0.2.37.dist-info → datarobot_genai-0.3.1.dist-info}/RECORD +43 -40
- datarobot_genai/crewai/base.py +0 -159
- datarobot_genai/drmcp/core/tool_filter.py +0 -117
- datarobot_genai/llama_index/base.py +0 -299
- {datarobot_genai-0.2.37.dist-info → datarobot_genai-0.3.1.dist-info}/WHEEL +0 -0
- {datarobot_genai-0.2.37.dist-info → datarobot_genai-0.3.1.dist-info}/entry_points.txt +0 -0
- {datarobot_genai-0.2.37.dist-info → datarobot_genai-0.3.1.dist-info}/licenses/AUTHORS +0 -0
- {datarobot_genai-0.2.37.dist-info → datarobot_genai-0.3.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 DataRobot, Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
# Copyright 2025 DataRobot, Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Tavily MCP tools for web search."""
|
|
16
|
+
|
|
17
|
+
import logging
|
|
18
|
+
from typing import Annotated
|
|
19
|
+
from typing import Literal
|
|
20
|
+
|
|
21
|
+
from fastmcp.tools.tool import ToolResult
|
|
22
|
+
|
|
23
|
+
from datarobot_genai.drmcp.core.mcp_instance import dr_mcp_tool
|
|
24
|
+
from datarobot_genai.drmcp.tools.clients.tavily import CHUNKS_PER_SOURCE_DEFAULT
|
|
25
|
+
from datarobot_genai.drmcp.tools.clients.tavily import MAX_CHUNKS_PER_SOURCE
|
|
26
|
+
from datarobot_genai.drmcp.tools.clients.tavily import MAX_RESULTS
|
|
27
|
+
from datarobot_genai.drmcp.tools.clients.tavily import MAX_RESULTS_DEFAULT
|
|
28
|
+
from datarobot_genai.drmcp.tools.clients.tavily import TavilyClient
|
|
29
|
+
from datarobot_genai.drmcp.tools.clients.tavily import TavilyImage
|
|
30
|
+
from datarobot_genai.drmcp.tools.clients.tavily import TavilySearchResult
|
|
31
|
+
from datarobot_genai.drmcp.tools.clients.tavily import get_tavily_access_token
|
|
32
|
+
|
|
33
|
+
logger = logging.getLogger(__name__)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@dr_mcp_tool(tags={"tavily", "search", "web", "websearch"})
|
|
37
|
+
async def tavily_search(
|
|
38
|
+
*,
|
|
39
|
+
query: Annotated[str, "The search query to execute."],
|
|
40
|
+
topic: Annotated[
|
|
41
|
+
Literal["general", "news", "finance"],
|
|
42
|
+
"The category of search. Use 'general' for broad web search, "
|
|
43
|
+
"'news' for recent news articles, or 'finance' for financial information.",
|
|
44
|
+
] = "general",
|
|
45
|
+
search_depth: Annotated[
|
|
46
|
+
Literal["basic", "advanced"],
|
|
47
|
+
"The depth of search. 'basic' is faster and cheaper, "
|
|
48
|
+
"'advanced' provides more comprehensive results.",
|
|
49
|
+
] = "basic",
|
|
50
|
+
max_results: Annotated[
|
|
51
|
+
int,
|
|
52
|
+
f"Maximum number of search results to return (1-{MAX_RESULTS}).",
|
|
53
|
+
] = MAX_RESULTS_DEFAULT,
|
|
54
|
+
time_range: Annotated[
|
|
55
|
+
Literal["day", "week", "month", "year"] | None,
|
|
56
|
+
"Filter results by time range. Use 'day' for last 24 hours, "
|
|
57
|
+
"'week' for last 7 days, 'month' for last 30 days, or 'year' for last year.",
|
|
58
|
+
] = None,
|
|
59
|
+
include_images: Annotated[
|
|
60
|
+
bool,
|
|
61
|
+
"Whether to include related images in the search results.",
|
|
62
|
+
] = False,
|
|
63
|
+
include_image_descriptions: Annotated[
|
|
64
|
+
bool,
|
|
65
|
+
"Whether to include AI-generated descriptions for images. "
|
|
66
|
+
"Only applicable when include_images is True.",
|
|
67
|
+
] = False,
|
|
68
|
+
chunks_per_source: Annotated[
|
|
69
|
+
int,
|
|
70
|
+
f"Maximum number of content snippets to return per source URL (1-{MAX_CHUNKS_PER_SOURCE}).",
|
|
71
|
+
] = CHUNKS_PER_SOURCE_DEFAULT,
|
|
72
|
+
include_answer: Annotated[
|
|
73
|
+
bool,
|
|
74
|
+
"Whether to include an AI-generated answer summarizing the search results.",
|
|
75
|
+
] = False,
|
|
76
|
+
) -> ToolResult:
|
|
77
|
+
"""
|
|
78
|
+
Perform a real-time web search using Tavily API.
|
|
79
|
+
|
|
80
|
+
Tavily is optimized for AI agents and provides clean, relevant search results
|
|
81
|
+
suitable for LLM consumption. Use this tool to search the web for current
|
|
82
|
+
information, news, financial data, or general knowledge.
|
|
83
|
+
|
|
84
|
+
Usage:
|
|
85
|
+
- Basic search: tavily_search(query="Python best practices 2026")
|
|
86
|
+
- News search: tavily_search(query="AI regulations", topic="news", time_range="week")
|
|
87
|
+
- Financial search: tavily_search(query="AAPL stock analysis", topic="finance")
|
|
88
|
+
- Comprehensive search: tavily_search(
|
|
89
|
+
query="climate change solutions",
|
|
90
|
+
search_depth="advanced",
|
|
91
|
+
max_results=10,
|
|
92
|
+
include_answer=True
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
Note:
|
|
96
|
+
- Advanced search depth consumes more API credits but provides better results
|
|
97
|
+
- Time range filtering is useful for finding recent information
|
|
98
|
+
"""
|
|
99
|
+
api_key = await get_tavily_access_token()
|
|
100
|
+
|
|
101
|
+
async with TavilyClient(api_key) as client:
|
|
102
|
+
response = await client.search(
|
|
103
|
+
query=query,
|
|
104
|
+
topic=topic,
|
|
105
|
+
search_depth=search_depth,
|
|
106
|
+
max_results=max_results,
|
|
107
|
+
time_range=time_range,
|
|
108
|
+
include_images=include_images,
|
|
109
|
+
include_image_descriptions=include_image_descriptions,
|
|
110
|
+
chunks_per_source=chunks_per_source,
|
|
111
|
+
include_answer=include_answer,
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
results = [TavilySearchResult.from_tavily_sdk(r) for r in response.get("results", [])]
|
|
115
|
+
|
|
116
|
+
images: list[TavilyImage] | None = None
|
|
117
|
+
if include_images and response.get("images"):
|
|
118
|
+
images = [TavilyImage.from_tavily_sdk(img) for img in response.get("images", [])]
|
|
119
|
+
|
|
120
|
+
result_count = len(results)
|
|
121
|
+
answer = response.get("answer")
|
|
122
|
+
response_time = response.get("response_time", 0.0)
|
|
123
|
+
|
|
124
|
+
structured_content: dict = {
|
|
125
|
+
"query": response.get("query", query),
|
|
126
|
+
"results": [r.as_flat_dict() for r in results],
|
|
127
|
+
"resultCount": result_count,
|
|
128
|
+
"responseTime": response_time,
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
if answer:
|
|
132
|
+
structured_content["answer"] = answer
|
|
133
|
+
|
|
134
|
+
if images:
|
|
135
|
+
structured_content["images"] = [
|
|
136
|
+
{"url": img.url, "description": img.description} for img in images
|
|
137
|
+
]
|
|
138
|
+
|
|
139
|
+
return ToolResult(
|
|
140
|
+
structured_content=structured_content,
|
|
141
|
+
)
|
|
@@ -11,9 +11,12 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
14
16
|
import abc
|
|
15
17
|
import logging
|
|
16
18
|
from collections.abc import AsyncGenerator
|
|
19
|
+
from typing import TYPE_CHECKING
|
|
17
20
|
from typing import Any
|
|
18
21
|
from typing import cast
|
|
19
22
|
|
|
@@ -34,8 +37,6 @@ from langgraph.graph import MessagesState
|
|
|
34
37
|
from langgraph.graph import StateGraph
|
|
35
38
|
from langgraph.types import Command
|
|
36
39
|
from openai.types.chat import CompletionCreateParams
|
|
37
|
-
from ragas import MultiTurnSample
|
|
38
|
-
from ragas.integrations.langgraph import convert_to_ragas_messages
|
|
39
40
|
|
|
40
41
|
from datarobot_genai.core.agents.base import BaseAgent
|
|
41
42
|
from datarobot_genai.core.agents.base import InvokeReturn
|
|
@@ -44,6 +45,9 @@ from datarobot_genai.core.agents.base import extract_user_prompt_content
|
|
|
44
45
|
from datarobot_genai.core.agents.base import is_streaming
|
|
45
46
|
from datarobot_genai.langgraph.mcp import mcp_tools_context
|
|
46
47
|
|
|
48
|
+
if TYPE_CHECKING:
|
|
49
|
+
from ragas import MultiTurnSample
|
|
50
|
+
|
|
47
51
|
logger = logging.getLogger(__name__)
|
|
48
52
|
|
|
49
53
|
|
|
@@ -337,5 +341,9 @@ class LangGraphAgent(BaseAgent[BaseTool], abc.ABC):
|
|
|
337
341
|
if v is not None:
|
|
338
342
|
messages.extend(v.get("messages", []))
|
|
339
343
|
messages = [m for m in messages if not isinstance(m, ToolMessage)]
|
|
344
|
+
# Lazy import to reduce memory overhead when ragas is not used
|
|
345
|
+
from ragas import MultiTurnSample
|
|
346
|
+
from ragas.integrations.langgraph import convert_to_ragas_messages
|
|
347
|
+
|
|
340
348
|
ragas_trace = convert_to_ragas_messages(messages)
|
|
341
349
|
return MultiTurnSample(user_input=ragas_trace)
|
|
@@ -3,8 +3,8 @@
|
|
|
3
3
|
from datarobot_genai.core.mcp.common import MCPConfig
|
|
4
4
|
|
|
5
5
|
from .agent import DataRobotLiteLLM
|
|
6
|
+
from .agent import LlamaIndexAgent
|
|
6
7
|
from .agent import create_pipeline_interactions_from_events
|
|
7
|
-
from .base import LlamaIndexAgent
|
|
8
8
|
from .mcp import load_mcp_tools
|
|
9
9
|
|
|
10
10
|
__all__ = [
|
|
@@ -11,17 +11,32 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
+
from __future__ import annotations
|
|
14
15
|
|
|
16
|
+
import abc
|
|
17
|
+
import inspect
|
|
18
|
+
from collections.abc import AsyncGenerator
|
|
19
|
+
from typing import TYPE_CHECKING
|
|
20
|
+
from typing import Any
|
|
15
21
|
from typing import cast
|
|
16
22
|
|
|
17
23
|
from llama_index.core.base.llms.types import LLMMetadata
|
|
24
|
+
from llama_index.core.tools import BaseTool
|
|
18
25
|
from llama_index.core.workflow import Event
|
|
19
26
|
from llama_index.llms.litellm import LiteLLM
|
|
20
|
-
from
|
|
21
|
-
|
|
22
|
-
from
|
|
23
|
-
from
|
|
24
|
-
from
|
|
27
|
+
from openai.types.chat import CompletionCreateParams
|
|
28
|
+
|
|
29
|
+
from datarobot_genai.core.agents.base import BaseAgent
|
|
30
|
+
from datarobot_genai.core.agents.base import InvokeReturn
|
|
31
|
+
from datarobot_genai.core.agents.base import UsageMetrics
|
|
32
|
+
from datarobot_genai.core.agents.base import default_usage_metrics
|
|
33
|
+
from datarobot_genai.core.agents.base import extract_user_prompt_content
|
|
34
|
+
from datarobot_genai.core.agents.base import is_streaming
|
|
35
|
+
|
|
36
|
+
from .mcp import load_mcp_tools
|
|
37
|
+
|
|
38
|
+
if TYPE_CHECKING:
|
|
39
|
+
from ragas import MultiTurnSample
|
|
25
40
|
|
|
26
41
|
|
|
27
42
|
class DataRobotLiteLLM(LiteLLM):
|
|
@@ -44,7 +59,271 @@ def create_pipeline_interactions_from_events(
|
|
|
44
59
|
) -> MultiTurnSample | None:
|
|
45
60
|
if not events:
|
|
46
61
|
return None
|
|
62
|
+
# Lazy import to reduce memory overhead when ragas is not used
|
|
63
|
+
from ragas import MultiTurnSample
|
|
64
|
+
from ragas.integrations.llama_index import convert_to_ragas_messages
|
|
65
|
+
from ragas.messages import AIMessage
|
|
66
|
+
from ragas.messages import HumanMessage
|
|
67
|
+
from ragas.messages import ToolMessage
|
|
68
|
+
|
|
47
69
|
# convert_to_ragas_messages expects a list[Event]
|
|
48
70
|
ragas_trace = convert_to_ragas_messages(list(events))
|
|
49
71
|
ragas_messages = cast(list[HumanMessage | AIMessage | ToolMessage], ragas_trace)
|
|
50
72
|
return MultiTurnSample(user_input=ragas_messages)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class LlamaIndexAgent(BaseAgent[BaseTool], abc.ABC):
|
|
76
|
+
"""Abstract base agent for LlamaIndex workflows."""
|
|
77
|
+
|
|
78
|
+
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
|
79
|
+
super().__init__(*args, **kwargs)
|
|
80
|
+
self._mcp_tools: list[Any] = []
|
|
81
|
+
|
|
82
|
+
def set_mcp_tools(self, tools: list[Any]) -> None:
|
|
83
|
+
"""Set MCP tools for this agent."""
|
|
84
|
+
self._mcp_tools = tools
|
|
85
|
+
|
|
86
|
+
@property
|
|
87
|
+
def mcp_tools(self) -> list[Any]:
|
|
88
|
+
"""Return the list of MCP tools available to this agent.
|
|
89
|
+
|
|
90
|
+
Subclasses can use this to wire tools into LlamaIndex agents during
|
|
91
|
+
workflow construction inside ``build_workflow``.
|
|
92
|
+
"""
|
|
93
|
+
return self._mcp_tools
|
|
94
|
+
|
|
95
|
+
@abc.abstractmethod
|
|
96
|
+
def build_workflow(self) -> Any:
|
|
97
|
+
"""Return an AgentWorkflow instance ready to run."""
|
|
98
|
+
raise NotImplementedError
|
|
99
|
+
|
|
100
|
+
@abc.abstractmethod
|
|
101
|
+
def extract_response_text(self, result_state: Any, events: list[Any]) -> str:
|
|
102
|
+
"""Extract final response text from workflow state and/or events."""
|
|
103
|
+
raise NotImplementedError
|
|
104
|
+
|
|
105
|
+
def make_input_message(self, completion_create_params: CompletionCreateParams) -> str:
|
|
106
|
+
"""Create an input string for the workflow from the user prompt."""
|
|
107
|
+
user_prompt_content = extract_user_prompt_content(completion_create_params)
|
|
108
|
+
return str(user_prompt_content)
|
|
109
|
+
|
|
110
|
+
async def invoke(self, completion_create_params: CompletionCreateParams) -> InvokeReturn:
|
|
111
|
+
"""Run the LlamaIndex workflow with the provided completion parameters."""
|
|
112
|
+
input_message = self.make_input_message(completion_create_params)
|
|
113
|
+
|
|
114
|
+
# Load MCP tools (if configured) asynchronously before building workflow
|
|
115
|
+
mcp_tools = await load_mcp_tools(
|
|
116
|
+
authorization_context=self._authorization_context,
|
|
117
|
+
forwarded_headers=self.forwarded_headers,
|
|
118
|
+
)
|
|
119
|
+
self.set_mcp_tools(mcp_tools)
|
|
120
|
+
|
|
121
|
+
# Preserve prior template startup print for CLI parity
|
|
122
|
+
try:
|
|
123
|
+
print(
|
|
124
|
+
"Running agent with user prompt:",
|
|
125
|
+
extract_user_prompt_content(completion_create_params),
|
|
126
|
+
flush=True,
|
|
127
|
+
)
|
|
128
|
+
except Exception:
|
|
129
|
+
# Printing is best-effort; proceed regardless
|
|
130
|
+
pass
|
|
131
|
+
|
|
132
|
+
workflow = self.build_workflow()
|
|
133
|
+
handler = workflow.run(user_msg=input_message)
|
|
134
|
+
|
|
135
|
+
usage_metrics: UsageMetrics = default_usage_metrics()
|
|
136
|
+
|
|
137
|
+
# Streaming parity with LangGraph: yield incremental deltas during event processing
|
|
138
|
+
if is_streaming(completion_create_params):
|
|
139
|
+
|
|
140
|
+
async def _gen() -> AsyncGenerator[tuple[str, MultiTurnSample | None, UsageMetrics]]:
|
|
141
|
+
events: list[Any] = []
|
|
142
|
+
current_agent_name: str | None = None
|
|
143
|
+
async for event in handler.stream_events():
|
|
144
|
+
events.append(event)
|
|
145
|
+
# Best-effort extraction of incremental text from LlamaIndex events
|
|
146
|
+
delta: str | None = None
|
|
147
|
+
# Agent switch banner if available on event
|
|
148
|
+
try:
|
|
149
|
+
if hasattr(event, "current_agent_name"):
|
|
150
|
+
new_agent = getattr(event, "current_agent_name")
|
|
151
|
+
if (
|
|
152
|
+
isinstance(new_agent, str)
|
|
153
|
+
and new_agent
|
|
154
|
+
and new_agent != current_agent_name
|
|
155
|
+
):
|
|
156
|
+
current_agent_name = new_agent
|
|
157
|
+
# Print banner for agent switch (do not emit as streamed content)
|
|
158
|
+
print("\n" + "=" * 50, flush=True)
|
|
159
|
+
print(f"🤖 Agent: {current_agent_name}", flush=True)
|
|
160
|
+
print("=" * 50 + "\n", flush=True)
|
|
161
|
+
except Exception:
|
|
162
|
+
pass
|
|
163
|
+
|
|
164
|
+
try:
|
|
165
|
+
if hasattr(event, "delta") and isinstance(getattr(event, "delta"), str):
|
|
166
|
+
delta = getattr(event, "delta")
|
|
167
|
+
# Some event types may carry incremental text under "text" or similar
|
|
168
|
+
elif hasattr(event, "text") and isinstance(getattr(event, "text"), str):
|
|
169
|
+
delta = getattr(event, "text")
|
|
170
|
+
except Exception:
|
|
171
|
+
# Ignore malformed events and continue
|
|
172
|
+
delta = None
|
|
173
|
+
|
|
174
|
+
if delta:
|
|
175
|
+
# Yield token/content delta with current (accumulated) usage metrics
|
|
176
|
+
yield delta, None, usage_metrics
|
|
177
|
+
|
|
178
|
+
# Best-effort debug/event messages printed to CLI (do not stream as content)
|
|
179
|
+
try:
|
|
180
|
+
event_type = type(event).__name__
|
|
181
|
+
if event_type == "AgentInput" and hasattr(event, "input"):
|
|
182
|
+
print("📥 Input:", getattr(event, "input"), flush=True)
|
|
183
|
+
elif event_type == "AgentOutput":
|
|
184
|
+
# Output content
|
|
185
|
+
resp = getattr(event, "response", None)
|
|
186
|
+
if (
|
|
187
|
+
resp is not None
|
|
188
|
+
and hasattr(resp, "content")
|
|
189
|
+
and getattr(resp, "content")
|
|
190
|
+
):
|
|
191
|
+
print("📤 Output:", getattr(resp, "content"), flush=True)
|
|
192
|
+
# Planned tool calls
|
|
193
|
+
tcalls = getattr(event, "tool_calls", None)
|
|
194
|
+
if isinstance(tcalls, list) and tcalls:
|
|
195
|
+
names = []
|
|
196
|
+
for c in tcalls:
|
|
197
|
+
try:
|
|
198
|
+
nm = getattr(c, "tool_name", None) or (
|
|
199
|
+
c.get("tool_name") if isinstance(c, dict) else None
|
|
200
|
+
)
|
|
201
|
+
if nm:
|
|
202
|
+
names.append(str(nm))
|
|
203
|
+
except Exception:
|
|
204
|
+
pass
|
|
205
|
+
if names:
|
|
206
|
+
print("🛠️ Planning to use tools:", names, flush=True)
|
|
207
|
+
elif event_type == "ToolCallResult":
|
|
208
|
+
tname = getattr(event, "tool_name", None)
|
|
209
|
+
tkwargs = getattr(event, "tool_kwargs", None)
|
|
210
|
+
tout = getattr(event, "tool_output", None)
|
|
211
|
+
print(f"🔧 Tool Result ({tname}):", flush=True)
|
|
212
|
+
print(f" Arguments: {tkwargs}", flush=True)
|
|
213
|
+
print(f" Output: {tout}", flush=True)
|
|
214
|
+
elif event_type == "ToolCall":
|
|
215
|
+
tname = getattr(event, "tool_name", None)
|
|
216
|
+
tkwargs = getattr(event, "tool_kwargs", None)
|
|
217
|
+
print(f"🔨 Calling Tool: {tname}", flush=True)
|
|
218
|
+
print(f" With arguments: {tkwargs}", flush=True)
|
|
219
|
+
except Exception:
|
|
220
|
+
# Ignore best-effort debug rendering errors
|
|
221
|
+
pass
|
|
222
|
+
|
|
223
|
+
# After streaming completes, build final interactions and finish chunk
|
|
224
|
+
# Extract state from workflow context (supports sync/async get or attribute)
|
|
225
|
+
state = None
|
|
226
|
+
ctx = getattr(handler, "ctx", None)
|
|
227
|
+
try:
|
|
228
|
+
if ctx is not None:
|
|
229
|
+
get = getattr(ctx, "get", None)
|
|
230
|
+
if callable(get):
|
|
231
|
+
result = get("state")
|
|
232
|
+
state = await result if inspect.isawaitable(result) else result
|
|
233
|
+
elif hasattr(ctx, "state"):
|
|
234
|
+
state = getattr(ctx, "state")
|
|
235
|
+
except (AttributeError, TypeError):
|
|
236
|
+
state = None
|
|
237
|
+
|
|
238
|
+
# Run subclass-defined response extraction (not streamed) for completeness
|
|
239
|
+
_ = self.extract_response_text(state, events)
|
|
240
|
+
|
|
241
|
+
pipeline_interactions = create_pipeline_interactions_from_events(events)
|
|
242
|
+
# Final empty chunk indicates end of stream, carrying interactions and usage
|
|
243
|
+
yield "", pipeline_interactions, usage_metrics
|
|
244
|
+
|
|
245
|
+
return _gen()
|
|
246
|
+
|
|
247
|
+
# Non-streaming path: run to completion, emit debug prints, then return final response
|
|
248
|
+
events: list[Any] = []
|
|
249
|
+
current_agent_name: str | None = None
|
|
250
|
+
async for event in handler.stream_events():
|
|
251
|
+
events.append(event)
|
|
252
|
+
|
|
253
|
+
# Replicate prior template CLI prints for non-streaming mode
|
|
254
|
+
try:
|
|
255
|
+
if hasattr(event, "current_agent_name"):
|
|
256
|
+
new_agent = getattr(event, "current_agent_name")
|
|
257
|
+
if isinstance(new_agent, str) and new_agent and new_agent != current_agent_name:
|
|
258
|
+
current_agent_name = new_agent
|
|
259
|
+
print(f"\n{'=' * 50}", flush=True)
|
|
260
|
+
print(f"🤖 Agent: {current_agent_name}", flush=True)
|
|
261
|
+
print(f"{'=' * 50}\n", flush=True)
|
|
262
|
+
except Exception:
|
|
263
|
+
pass
|
|
264
|
+
|
|
265
|
+
try:
|
|
266
|
+
if hasattr(event, "delta") and isinstance(getattr(event, "delta"), str):
|
|
267
|
+
print(getattr(event, "delta"), end="", flush=True)
|
|
268
|
+
elif hasattr(event, "text") and isinstance(getattr(event, "text"), str):
|
|
269
|
+
print(getattr(event, "text"), end="", flush=True)
|
|
270
|
+
else:
|
|
271
|
+
event_type = type(event).__name__
|
|
272
|
+
if event_type == "AgentInput" and hasattr(event, "input"):
|
|
273
|
+
print("📥 Input:", getattr(event, "input"), flush=True)
|
|
274
|
+
elif event_type == "AgentOutput":
|
|
275
|
+
resp = getattr(event, "response", None)
|
|
276
|
+
if (
|
|
277
|
+
resp is not None
|
|
278
|
+
and hasattr(resp, "content")
|
|
279
|
+
and getattr(resp, "content")
|
|
280
|
+
):
|
|
281
|
+
print("📤 Output:", getattr(resp, "content"), flush=True)
|
|
282
|
+
tcalls = getattr(event, "tool_calls", None)
|
|
283
|
+
if isinstance(tcalls, list) and tcalls:
|
|
284
|
+
names: list[str] = []
|
|
285
|
+
for c in tcalls:
|
|
286
|
+
try:
|
|
287
|
+
nm = getattr(c, "tool_name", None) or (
|
|
288
|
+
c.get("tool_name") if isinstance(c, dict) else None
|
|
289
|
+
)
|
|
290
|
+
if nm:
|
|
291
|
+
names.append(str(nm))
|
|
292
|
+
except Exception:
|
|
293
|
+
pass
|
|
294
|
+
if names:
|
|
295
|
+
print("🛠️ Planning to use tools:", names, flush=True)
|
|
296
|
+
elif event_type == "ToolCallResult":
|
|
297
|
+
tname = getattr(event, "tool_name", None)
|
|
298
|
+
tkwargs = getattr(event, "tool_kwargs", None)
|
|
299
|
+
tout = getattr(event, "tool_output", None)
|
|
300
|
+
print(f"🔧 Tool Result ({tname}):", flush=True)
|
|
301
|
+
print(f" Arguments: {tkwargs}", flush=True)
|
|
302
|
+
print(f" Output: {tout}", flush=True)
|
|
303
|
+
elif event_type == "ToolCall":
|
|
304
|
+
tname = getattr(event, "tool_name", None)
|
|
305
|
+
tkwargs = getattr(event, "tool_kwargs", None)
|
|
306
|
+
print(f"🔨 Calling Tool: {tname}", flush=True)
|
|
307
|
+
print(f" With arguments: {tkwargs}", flush=True)
|
|
308
|
+
except Exception:
|
|
309
|
+
# Best-effort debug printing; continue on errors
|
|
310
|
+
pass
|
|
311
|
+
|
|
312
|
+
# Extract state from workflow context (supports sync/async get or attribute)
|
|
313
|
+
state = None
|
|
314
|
+
ctx = getattr(handler, "ctx", None)
|
|
315
|
+
try:
|
|
316
|
+
if ctx is not None:
|
|
317
|
+
get = getattr(ctx, "get", None)
|
|
318
|
+
if callable(get):
|
|
319
|
+
result = get("state")
|
|
320
|
+
state = await result if inspect.isawaitable(result) else result
|
|
321
|
+
elif hasattr(ctx, "state"):
|
|
322
|
+
state = getattr(ctx, "state")
|
|
323
|
+
except (AttributeError, TypeError):
|
|
324
|
+
state = None
|
|
325
|
+
response_text = self.extract_response_text(state, events)
|
|
326
|
+
|
|
327
|
+
pipeline_interactions = create_pipeline_interactions_from_events(events)
|
|
328
|
+
|
|
329
|
+
return response_text, pipeline_interactions, usage_metrics
|
datarobot_genai/nat/agent.py
CHANGED
|
@@ -11,9 +11,12 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
14
16
|
import asyncio
|
|
15
17
|
import logging
|
|
16
18
|
from collections.abc import AsyncGenerator
|
|
19
|
+
from typing import TYPE_CHECKING
|
|
17
20
|
from typing import Any
|
|
18
21
|
|
|
19
22
|
from nat.builder.context import Context
|
|
@@ -23,10 +26,6 @@ from nat.data_models.intermediate_step import IntermediateStep
|
|
|
23
26
|
from nat.data_models.intermediate_step import IntermediateStepType
|
|
24
27
|
from nat.utils.type_utils import StrPath
|
|
25
28
|
from openai.types.chat import CompletionCreateParams
|
|
26
|
-
from ragas import MultiTurnSample
|
|
27
|
-
from ragas.messages import AIMessage
|
|
28
|
-
from ragas.messages import HumanMessage
|
|
29
|
-
from ragas.messages import ToolMessage
|
|
30
29
|
|
|
31
30
|
from datarobot_genai.core.agents.base import BaseAgent
|
|
32
31
|
from datarobot_genai.core.agents.base import InvokeReturn
|
|
@@ -36,13 +35,22 @@ from datarobot_genai.core.agents.base import is_streaming
|
|
|
36
35
|
from datarobot_genai.core.mcp.common import MCPConfig
|
|
37
36
|
from datarobot_genai.nat.helpers import load_workflow
|
|
38
37
|
|
|
38
|
+
if TYPE_CHECKING:
|
|
39
|
+
from ragas import MultiTurnSample
|
|
40
|
+
from ragas.messages import AIMessage
|
|
41
|
+
from ragas.messages import HumanMessage
|
|
42
|
+
|
|
39
43
|
logger = logging.getLogger(__name__)
|
|
40
44
|
|
|
41
45
|
|
|
42
46
|
def convert_to_ragas_messages(
|
|
43
47
|
steps: list[IntermediateStep],
|
|
44
|
-
) -> list[HumanMessage | AIMessage
|
|
45
|
-
|
|
48
|
+
) -> list[HumanMessage | AIMessage]:
|
|
49
|
+
# Lazy import to reduce memory overhead when ragas is not used
|
|
50
|
+
from ragas.messages import AIMessage
|
|
51
|
+
from ragas.messages import HumanMessage
|
|
52
|
+
|
|
53
|
+
def _to_ragas(step: IntermediateStep) -> HumanMessage | AIMessage:
|
|
46
54
|
if step.event_type == IntermediateStepType.LLM_START:
|
|
47
55
|
return HumanMessage(content=_parse(step.data.input))
|
|
48
56
|
elif step.event_type == IntermediateStepType.LLM_END:
|
|
@@ -78,6 +86,9 @@ def create_pipeline_interactions_from_steps(
|
|
|
78
86
|
) -> MultiTurnSample | None:
|
|
79
87
|
if not steps:
|
|
80
88
|
return None
|
|
89
|
+
# Lazy import to reduce memory overhead when ragas is not used
|
|
90
|
+
from ragas import MultiTurnSample
|
|
91
|
+
|
|
81
92
|
ragas_trace = convert_to_ragas_messages(steps)
|
|
82
93
|
return MultiTurnSample(user_input=ragas_trace)
|
|
83
94
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: datarobot-genai
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.3.1
|
|
4
4
|
Summary: Generic helpers for GenAI
|
|
5
5
|
Project-URL: Homepage, https://github.com/datarobot-oss/datarobot-genai
|
|
6
6
|
Author: DataRobot, Inc.
|
|
@@ -43,9 +43,11 @@ Requires-Dist: opentelemetry-api<2.0.0,>=1.22.0; extra == 'drmcp'
|
|
|
43
43
|
Requires-Dist: opentelemetry-exporter-otlp-proto-http<2.0.0,>=1.22.0; extra == 'drmcp'
|
|
44
44
|
Requires-Dist: opentelemetry-exporter-otlp<2.0.0,>=1.22.0; extra == 'drmcp'
|
|
45
45
|
Requires-Dist: opentelemetry-sdk<2.0.0,>=1.22.0; extra == 'drmcp'
|
|
46
|
+
Requires-Dist: perplexityai<1.0,>=0.27; extra == 'drmcp'
|
|
46
47
|
Requires-Dist: pydantic-settings<3.0.0,>=2.1.0; extra == 'drmcp'
|
|
47
48
|
Requires-Dist: pydantic<3.0.0,>=2.6.1; extra == 'drmcp'
|
|
48
49
|
Requires-Dist: python-dotenv<2.0.0,>=1.1.0; extra == 'drmcp'
|
|
50
|
+
Requires-Dist: tavily-python<1.0.0,>=0.7.20; extra == 'drmcp'
|
|
49
51
|
Provides-Extra: langgraph
|
|
50
52
|
Requires-Dist: langchain-mcp-adapters<0.2.0,>=0.1.12; extra == 'langgraph'
|
|
51
53
|
Requires-Dist: langgraph-prebuilt<0.7.0,>=0.2.3; extra == 'langgraph'
|