datarobot-genai 0.1.68__py3-none-any.whl → 0.1.71__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- datarobot_genai/core/mcp/common.py +82 -48
- datarobot_genai/crewai/base.py +33 -53
- datarobot_genai/drmcp/core/dr_mcp_server_logo.py +1 -1
- datarobot_genai/langgraph/agent.py +32 -15
- datarobot_genai/nat/datarobot_llm_clients.py +66 -7
- datarobot_genai/nat/datarobot_llm_providers.py +32 -0
- {datarobot_genai-0.1.68.dist-info → datarobot_genai-0.1.71.dist-info}/METADATA +1 -1
- {datarobot_genai-0.1.68.dist-info → datarobot_genai-0.1.71.dist-info}/RECORD +12 -12
- {datarobot_genai-0.1.68.dist-info → datarobot_genai-0.1.71.dist-info}/WHEEL +0 -0
- {datarobot_genai-0.1.68.dist-info → datarobot_genai-0.1.71.dist-info}/entry_points.txt +0 -0
- {datarobot_genai-0.1.68.dist-info → datarobot_genai-0.1.71.dist-info}/licenses/AUTHORS +0 -0
- {datarobot_genai-0.1.68.dist-info → datarobot_genai-0.1.71.dist-info}/licenses/LICENSE +0 -0
|
@@ -13,16 +13,20 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
import json
|
|
16
|
+
import logging
|
|
16
17
|
import re
|
|
18
|
+
from http import HTTPStatus
|
|
17
19
|
from typing import Any
|
|
18
20
|
from typing import Literal
|
|
19
|
-
from urllib.parse import urlparse
|
|
20
21
|
|
|
22
|
+
import requests
|
|
21
23
|
from datarobot.core.config import DataRobotAppFrameworkBaseSettings
|
|
22
24
|
from pydantic import field_validator
|
|
23
25
|
|
|
24
26
|
from datarobot_genai.core.utils.auth import AuthContextHeaderHandler
|
|
25
27
|
|
|
28
|
+
logger = logging.getLogger(__name__)
|
|
29
|
+
|
|
26
30
|
|
|
27
31
|
class MCPConfig(DataRobotAppFrameworkBaseSettings):
|
|
28
32
|
"""Configuration for MCP server connection.
|
|
@@ -39,6 +43,7 @@ class MCPConfig(DataRobotAppFrameworkBaseSettings):
|
|
|
39
43
|
datarobot_api_token: str | None = None
|
|
40
44
|
authorization_context: dict[str, Any] | None = None
|
|
41
45
|
forwarded_headers: dict[str, str] | None = None
|
|
46
|
+
mcp_server_port: int | None = None
|
|
42
47
|
|
|
43
48
|
_auth_context_handler: AuthContextHeaderHandler | None = None
|
|
44
49
|
_server_config: dict[str, Any] | None = None
|
|
@@ -49,17 +54,14 @@ class MCPConfig(DataRobotAppFrameworkBaseSettings):
|
|
|
49
54
|
if value is None:
|
|
50
55
|
return None
|
|
51
56
|
|
|
52
|
-
if not isinstance(value, str):
|
|
53
|
-
msg = "external_mcp_headers must be a JSON string"
|
|
54
|
-
raise TypeError(msg)
|
|
55
|
-
|
|
56
57
|
candidate = value.strip()
|
|
57
58
|
|
|
58
59
|
try:
|
|
59
60
|
json.loads(candidate)
|
|
60
|
-
except json.JSONDecodeError
|
|
61
|
+
except json.JSONDecodeError:
|
|
61
62
|
msg = "external_mcp_headers must be valid JSON"
|
|
62
|
-
|
|
63
|
+
logger.warning(msg)
|
|
64
|
+
return None
|
|
63
65
|
|
|
64
66
|
return candidate
|
|
65
67
|
|
|
@@ -69,15 +71,12 @@ class MCPConfig(DataRobotAppFrameworkBaseSettings):
|
|
|
69
71
|
if value is None:
|
|
70
72
|
return None
|
|
71
73
|
|
|
72
|
-
if not isinstance(value, str):
|
|
73
|
-
msg = "mcp_deployment_id must be a string"
|
|
74
|
-
raise TypeError(msg)
|
|
75
|
-
|
|
76
74
|
candidate = value.strip()
|
|
77
75
|
|
|
78
76
|
if not re.fullmatch(r"[0-9a-fA-F]{24}", candidate):
|
|
79
77
|
msg = "mcp_deployment_id must be a valid 24-character hex ID"
|
|
80
|
-
|
|
78
|
+
logger.warning(msg)
|
|
79
|
+
return None
|
|
81
80
|
|
|
82
81
|
return candidate
|
|
83
82
|
|
|
@@ -112,6 +111,45 @@ class MCPConfig(DataRobotAppFrameworkBaseSettings):
|
|
|
112
111
|
# Authorization context not available (e.g., in tests)
|
|
113
112
|
return {}
|
|
114
113
|
|
|
114
|
+
def _build_authenticated_headers(self) -> dict[str, str]:
|
|
115
|
+
"""Build headers for authenticated requests.
|
|
116
|
+
|
|
117
|
+
Returns
|
|
118
|
+
-------
|
|
119
|
+
Dictionary containing forwarded headers (if available) and authentication headers.
|
|
120
|
+
"""
|
|
121
|
+
headers: dict[str, str] = {}
|
|
122
|
+
if self.forwarded_headers:
|
|
123
|
+
headers.update(self.forwarded_headers)
|
|
124
|
+
headers.update(self._authorization_bearer_header())
|
|
125
|
+
headers.update(self._authorization_context_header())
|
|
126
|
+
return headers
|
|
127
|
+
|
|
128
|
+
def _check_localhost_server(self, url: str, timeout: float = 2.0) -> bool:
|
|
129
|
+
"""Check if MCP server is running on localhost.
|
|
130
|
+
|
|
131
|
+
Parameters
|
|
132
|
+
----------
|
|
133
|
+
url : str
|
|
134
|
+
The URL to check.
|
|
135
|
+
timeout : float, optional
|
|
136
|
+
Request timeout in seconds (default: 2.0).
|
|
137
|
+
|
|
138
|
+
Returns
|
|
139
|
+
-------
|
|
140
|
+
bool
|
|
141
|
+
True if server is running and responding with OK status, False otherwise.
|
|
142
|
+
"""
|
|
143
|
+
try:
|
|
144
|
+
response = requests.get(url, timeout=timeout)
|
|
145
|
+
return (
|
|
146
|
+
response.status_code == HTTPStatus.OK
|
|
147
|
+
and response.json().get("message") == "DataRobot MCP Server is running"
|
|
148
|
+
)
|
|
149
|
+
except requests.RequestException as e:
|
|
150
|
+
logger.debug(f"Failed to connect to MCP server at {url}: {e}")
|
|
151
|
+
return False
|
|
152
|
+
|
|
115
153
|
def _build_server_config(self) -> dict[str, Any] | None:
|
|
116
154
|
"""
|
|
117
155
|
Get MCP server configuration.
|
|
@@ -121,34 +159,6 @@ class MCPConfig(DataRobotAppFrameworkBaseSettings):
|
|
|
121
159
|
Server configuration dict with url, transport, and optional headers,
|
|
122
160
|
or None if not configured.
|
|
123
161
|
"""
|
|
124
|
-
if self.external_mcp_url:
|
|
125
|
-
# External MCP URL - no authentication needed
|
|
126
|
-
headers: dict[str, str] = {}
|
|
127
|
-
|
|
128
|
-
# Forward headers for localhost connections
|
|
129
|
-
if self.forwarded_headers:
|
|
130
|
-
try:
|
|
131
|
-
parsed_url = urlparse(self.external_mcp_url)
|
|
132
|
-
hostname = parsed_url.hostname or ""
|
|
133
|
-
# Check if hostname is localhost or 127.0.0.1
|
|
134
|
-
if hostname in ("localhost", "127.0.0.1", "::1"):
|
|
135
|
-
headers.update(self.forwarded_headers)
|
|
136
|
-
except Exception:
|
|
137
|
-
# If URL parsing fails, fall back to simple string check
|
|
138
|
-
if "localhost" in self.external_mcp_url or "127.0.0.1" in self.external_mcp_url:
|
|
139
|
-
headers.update(self.forwarded_headers)
|
|
140
|
-
|
|
141
|
-
# Merge external headers if provided
|
|
142
|
-
if self.external_mcp_headers:
|
|
143
|
-
external_headers = json.loads(self.external_mcp_headers)
|
|
144
|
-
headers.update(external_headers)
|
|
145
|
-
|
|
146
|
-
return {
|
|
147
|
-
"url": self.external_mcp_url.rstrip("/"),
|
|
148
|
-
"transport": self.external_mcp_transport,
|
|
149
|
-
"headers": headers,
|
|
150
|
-
}
|
|
151
|
-
|
|
152
162
|
if self.mcp_deployment_id:
|
|
153
163
|
# DataRobot deployment ID - requires authentication
|
|
154
164
|
if self.datarobot_endpoint is None:
|
|
@@ -165,15 +175,9 @@ class MCPConfig(DataRobotAppFrameworkBaseSettings):
|
|
|
165
175
|
base_url = f"{base_url}/api/v2"
|
|
166
176
|
|
|
167
177
|
url = f"{base_url}/deployments/{self.mcp_deployment_id}/directAccess/mcp"
|
|
178
|
+
headers = self._build_authenticated_headers()
|
|
168
179
|
|
|
169
|
-
|
|
170
|
-
headers = {}
|
|
171
|
-
if self.forwarded_headers:
|
|
172
|
-
headers.update(self.forwarded_headers)
|
|
173
|
-
|
|
174
|
-
# Add authentication headers
|
|
175
|
-
headers.update(self._authorization_bearer_header())
|
|
176
|
-
headers.update(self._authorization_context_header())
|
|
180
|
+
logger.info(f"Using DataRobot hosted MCP deployment: {url}")
|
|
177
181
|
|
|
178
182
|
return {
|
|
179
183
|
"url": url,
|
|
@@ -181,4 +185,34 @@ class MCPConfig(DataRobotAppFrameworkBaseSettings):
|
|
|
181
185
|
"headers": headers,
|
|
182
186
|
}
|
|
183
187
|
|
|
188
|
+
if self.external_mcp_url:
|
|
189
|
+
# External MCP URL - no authentication needed
|
|
190
|
+
headers = {}
|
|
191
|
+
|
|
192
|
+
# Merge external headers if provided
|
|
193
|
+
if self.external_mcp_headers:
|
|
194
|
+
external_headers = json.loads(self.external_mcp_headers)
|
|
195
|
+
headers.update(external_headers)
|
|
196
|
+
|
|
197
|
+
logger.info(f"Using external MCP URL: {self.external_mcp_url}")
|
|
198
|
+
|
|
199
|
+
return {
|
|
200
|
+
"url": self.external_mcp_url.rstrip("/"),
|
|
201
|
+
"transport": self.external_mcp_transport,
|
|
202
|
+
"headers": headers,
|
|
203
|
+
}
|
|
204
|
+
|
|
205
|
+
# No MCP configuration found, setup localhost if running locally
|
|
206
|
+
if self.mcp_server_port:
|
|
207
|
+
url = f"http://localhost:{self.mcp_server_port}"
|
|
208
|
+
if self._check_localhost_server(url):
|
|
209
|
+
headers = self._build_authenticated_headers()
|
|
210
|
+
logger.info(f"Using localhost MCP server: {url}")
|
|
211
|
+
return {
|
|
212
|
+
"url": f"{url}/mcp",
|
|
213
|
+
"transport": "streamable-http",
|
|
214
|
+
"headers": headers,
|
|
215
|
+
}
|
|
216
|
+
logger.warning(f"MCP server is not running or not responding at {url}")
|
|
217
|
+
|
|
184
218
|
return None
|
datarobot_genai/crewai/base.py
CHANGED
|
@@ -80,6 +80,37 @@ class CrewAIAgent(BaseAgent[BaseTool], abc.ABC):
|
|
|
80
80
|
"""
|
|
81
81
|
raise NotImplementedError
|
|
82
82
|
|
|
83
|
+
def _extract_pipeline_interactions(self) -> MultiTurnSample | None:
|
|
84
|
+
"""Extract pipeline interactions from event listener if available."""
|
|
85
|
+
if not hasattr(self, "event_listener"):
|
|
86
|
+
return None
|
|
87
|
+
try:
|
|
88
|
+
listener = getattr(self, "event_listener", None)
|
|
89
|
+
messages = getattr(listener, "messages", None) if listener is not None else None
|
|
90
|
+
return create_pipeline_interactions_from_messages(messages)
|
|
91
|
+
except Exception:
|
|
92
|
+
return None
|
|
93
|
+
|
|
94
|
+
def _extract_usage_metrics(self, crew_output: Any) -> UsageMetrics:
|
|
95
|
+
"""Extract usage metrics from crew output."""
|
|
96
|
+
token_usage = getattr(crew_output, "token_usage", None)
|
|
97
|
+
if token_usage is not None:
|
|
98
|
+
return {
|
|
99
|
+
"completion_tokens": int(getattr(token_usage, "completion_tokens", 0)),
|
|
100
|
+
"prompt_tokens": int(getattr(token_usage, "prompt_tokens", 0)),
|
|
101
|
+
"total_tokens": int(getattr(token_usage, "total_tokens", 0)),
|
|
102
|
+
}
|
|
103
|
+
return default_usage_metrics()
|
|
104
|
+
|
|
105
|
+
def _process_crew_output(
|
|
106
|
+
self, crew_output: Any
|
|
107
|
+
) -> tuple[str, MultiTurnSample | None, UsageMetrics]:
|
|
108
|
+
"""Process crew output into response tuple."""
|
|
109
|
+
response_text = str(crew_output.raw)
|
|
110
|
+
pipeline_interactions = self._extract_pipeline_interactions()
|
|
111
|
+
usage_metrics = self._extract_usage_metrics(crew_output)
|
|
112
|
+
return response_text, pipeline_interactions, usage_metrics
|
|
113
|
+
|
|
83
114
|
async def invoke(self, completion_create_params: CompletionCreateParams) -> InvokeReturn:
|
|
84
115
|
"""Run the CrewAI workflow with the provided completion parameters."""
|
|
85
116
|
user_prompt_content = extract_user_prompt_content(completion_create_params)
|
|
@@ -116,64 +147,13 @@ class CrewAIAgent(BaseAgent[BaseTool], abc.ABC):
|
|
|
116
147
|
async def _gen() -> AsyncGenerator[
|
|
117
148
|
tuple[str, MultiTurnSample | None, UsageMetrics]
|
|
118
149
|
]:
|
|
119
|
-
# Run kickoff in a worker thread.
|
|
120
150
|
crew_output = await asyncio.to_thread(
|
|
121
151
|
crew.kickoff,
|
|
122
152
|
inputs=self.make_kickoff_inputs(user_prompt_content),
|
|
123
153
|
)
|
|
124
|
-
|
|
125
|
-
pipeline_interactions = None
|
|
126
|
-
if hasattr(self, "event_listener"):
|
|
127
|
-
try:
|
|
128
|
-
listener = getattr(self, "event_listener", None)
|
|
129
|
-
messages = (
|
|
130
|
-
getattr(listener, "messages", None)
|
|
131
|
-
if listener is not None
|
|
132
|
-
else None
|
|
133
|
-
)
|
|
134
|
-
pipeline_interactions = create_pipeline_interactions_from_messages(
|
|
135
|
-
messages
|
|
136
|
-
)
|
|
137
|
-
except Exception:
|
|
138
|
-
pipeline_interactions = None
|
|
139
|
-
|
|
140
|
-
token_usage = getattr(crew_output, "token_usage", None)
|
|
141
|
-
if token_usage is not None:
|
|
142
|
-
usage_metrics: UsageMetrics = {
|
|
143
|
-
"completion_tokens": int(getattr(token_usage, "completion_tokens", 0)),
|
|
144
|
-
"prompt_tokens": int(getattr(token_usage, "prompt_tokens", 0)),
|
|
145
|
-
"total_tokens": int(getattr(token_usage, "total_tokens", 0)),
|
|
146
|
-
}
|
|
147
|
-
else:
|
|
148
|
-
usage_metrics = default_usage_metrics()
|
|
149
|
-
|
|
150
|
-
# Finalize stream with empty chunk carrying interactions and usage
|
|
151
|
-
yield "", pipeline_interactions, usage_metrics
|
|
154
|
+
yield self._process_crew_output(crew_output)
|
|
152
155
|
|
|
153
156
|
return _gen()
|
|
154
157
|
|
|
155
|
-
# Non-streaming: run to completion and return final result
|
|
156
158
|
crew_output = crew.kickoff(inputs=self.make_kickoff_inputs(user_prompt_content))
|
|
157
|
-
|
|
158
|
-
response_text = str(crew_output.raw)
|
|
159
|
-
|
|
160
|
-
pipeline_interactions = None
|
|
161
|
-
if hasattr(self, "event_listener"):
|
|
162
|
-
try:
|
|
163
|
-
listener = getattr(self, "event_listener", None)
|
|
164
|
-
messages = getattr(listener, "messages", None) if listener is not None else None
|
|
165
|
-
pipeline_interactions = create_pipeline_interactions_from_messages(messages)
|
|
166
|
-
except Exception:
|
|
167
|
-
pipeline_interactions = None
|
|
168
|
-
|
|
169
|
-
token_usage = getattr(crew_output, "token_usage", None)
|
|
170
|
-
if token_usage is not None:
|
|
171
|
-
usage_metrics: UsageMetrics = {
|
|
172
|
-
"completion_tokens": int(getattr(token_usage, "completion_tokens", 0)),
|
|
173
|
-
"prompt_tokens": int(getattr(token_usage, "prompt_tokens", 0)),
|
|
174
|
-
"total_tokens": int(getattr(token_usage, "total_tokens", 0)),
|
|
175
|
-
}
|
|
176
|
-
else:
|
|
177
|
-
usage_metrics = default_usage_metrics()
|
|
178
|
-
|
|
179
|
-
return response_text, pipeline_interactions, usage_metrics
|
|
159
|
+
return self._process_crew_output(crew_output)
|
|
@@ -38,7 +38,7 @@ def _apply_green(text: str) -> str:
|
|
|
38
38
|
return "\n".join(colored_lines)
|
|
39
39
|
|
|
40
40
|
|
|
41
|
-
DR_LOGO_ASCII = _apply_green("""
|
|
41
|
+
DR_LOGO_ASCII = _apply_green(r"""
|
|
42
42
|
____ _ ____ _ _
|
|
43
43
|
| _ \ __ _| |_ __ _| _ \ ___ | |__ ___ | |_
|
|
44
44
|
| | | |/ _` | __/ _` | |_) / _ \| '_ \ / _ \| __|
|
|
@@ -84,21 +84,38 @@ class LangGraphAgent(BaseAgent[BaseTool], abc.ABC):
|
|
|
84
84
|
async def wrapped_generator() -> AsyncGenerator[
|
|
85
85
|
tuple[str, Any | None, UsageMetrics], None
|
|
86
86
|
]:
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
87
|
+
try:
|
|
88
|
+
async with mcp_tools_context(
|
|
89
|
+
authorization_context=self._authorization_context,
|
|
90
|
+
forwarded_headers=self.forwarded_headers,
|
|
91
|
+
) as mcp_tools:
|
|
92
|
+
self.set_mcp_tools(mcp_tools)
|
|
93
|
+
result = await self._invoke(completion_create_params)
|
|
94
|
+
|
|
95
|
+
# Yield all items from the result generator
|
|
96
|
+
# The context will be closed when this generator is exhausted
|
|
97
|
+
# Cast to async generator since we know stream=True means it's a generator
|
|
98
|
+
result_generator = cast(
|
|
99
|
+
AsyncGenerator[tuple[str, Any | None, UsageMetrics], None], result
|
|
100
|
+
)
|
|
101
|
+
async for item in result_generator:
|
|
102
|
+
yield item
|
|
103
|
+
except RuntimeError as e:
|
|
104
|
+
error_message = str(e).lower()
|
|
105
|
+
if "different task" in error_message and "cancel scope" in error_message:
|
|
106
|
+
# Due to anyio task group constraints when consuming async generators
|
|
107
|
+
# across task boundaries, we cannot always clean up properly.
|
|
108
|
+
# The underlying HTTP client/connection pool should handle resource cleanup
|
|
109
|
+
# via timeouts and connection pooling, but this
|
|
110
|
+
# may lead to delayed resource release.
|
|
111
|
+
logger.debug(
|
|
112
|
+
"MCP context cleanup attempted in different task. "
|
|
113
|
+
"This is a limitation when consuming async generators "
|
|
114
|
+
"across task boundaries."
|
|
115
|
+
)
|
|
116
|
+
else:
|
|
117
|
+
# Re-raise if it's a different RuntimeError
|
|
118
|
+
raise
|
|
102
119
|
|
|
103
120
|
return wrapped_generator()
|
|
104
121
|
else:
|
|
@@ -23,6 +23,7 @@ from nat.builder.builder import Builder
|
|
|
23
23
|
from nat.builder.framework_enum import LLMFrameworkEnum
|
|
24
24
|
from nat.cli.register_workflow import register_llm_client
|
|
25
25
|
|
|
26
|
+
from ..nat.datarobot_llm_providers import DataRobotLLMComponentModelConfig
|
|
26
27
|
from ..nat.datarobot_llm_providers import DataRobotLLMDeploymentModelConfig
|
|
27
28
|
from ..nat.datarobot_llm_providers import DataRobotLLMGatewayModelConfig
|
|
28
29
|
from ..nat.datarobot_llm_providers import DataRobotNIMModelConfig
|
|
@@ -75,6 +76,7 @@ async def datarobot_llm_gateway_langchain(
|
|
|
75
76
|
config = llm_config.model_dump(exclude={"type", "thinking"}, by_alias=True, exclude_none=True)
|
|
76
77
|
config["base_url"] = config["base_url"] + "/genai/llmgw"
|
|
77
78
|
config["stream_options"] = {"include_usage": True}
|
|
79
|
+
config["model"] = config["model"].removeprefix("datarobot/")
|
|
78
80
|
yield DataRobotChatOpenAI(**config)
|
|
79
81
|
|
|
80
82
|
|
|
@@ -85,7 +87,8 @@ async def datarobot_llm_gateway_crewai(
|
|
|
85
87
|
llm_config: DataRobotLLMGatewayModelConfig, builder: Builder
|
|
86
88
|
) -> AsyncGenerator[LLM]:
|
|
87
89
|
config = llm_config.model_dump(exclude={"type", "thinking"}, by_alias=True, exclude_none=True)
|
|
88
|
-
config["model"]
|
|
90
|
+
if not config["model"].startswith("datarobot/"):
|
|
91
|
+
config["model"] = "datarobot/" + config["model"]
|
|
89
92
|
config["base_url"] = config["base_url"].removesuffix("/api/v2")
|
|
90
93
|
yield LLM(**config)
|
|
91
94
|
|
|
@@ -97,7 +100,8 @@ async def datarobot_llm_gateway_llamaindex(
|
|
|
97
100
|
llm_config: DataRobotLLMGatewayModelConfig, builder: Builder
|
|
98
101
|
) -> AsyncGenerator[LLM]:
|
|
99
102
|
config = llm_config.model_dump(exclude={"type", "thinking"}, by_alias=True, exclude_none=True)
|
|
100
|
-
config["model"]
|
|
103
|
+
if not config["model"].startswith("datarobot/"):
|
|
104
|
+
config["model"] = "datarobot/" + config["model"]
|
|
101
105
|
config["api_base"] = config.pop("base_url").removesuffix("/api/v2")
|
|
102
106
|
yield DataRobotLiteLLM(**config)
|
|
103
107
|
|
|
@@ -109,11 +113,12 @@ async def datarobot_llm_deployment_langchain(
|
|
|
109
113
|
llm_config: DataRobotLLMDeploymentModelConfig, builder: Builder
|
|
110
114
|
) -> AsyncGenerator[ChatOpenAI]:
|
|
111
115
|
config = llm_config.model_dump(
|
|
112
|
-
exclude={"type", "thinking"
|
|
116
|
+
exclude={"type", "thinking"},
|
|
113
117
|
by_alias=True,
|
|
114
118
|
exclude_none=True,
|
|
115
119
|
)
|
|
116
120
|
config["stream_options"] = {"include_usage": True}
|
|
121
|
+
config["model"] = config["model"].removeprefix("datarobot/")
|
|
117
122
|
yield DataRobotChatOpenAI(**config)
|
|
118
123
|
|
|
119
124
|
|
|
@@ -128,7 +133,8 @@ async def datarobot_llm_deployment_crewai(
|
|
|
128
133
|
by_alias=True,
|
|
129
134
|
exclude_none=True,
|
|
130
135
|
)
|
|
131
|
-
config["model"]
|
|
136
|
+
if not config["model"].startswith("datarobot/"):
|
|
137
|
+
config["model"] = "datarobot/" + config["model"]
|
|
132
138
|
config["api_base"] = config.pop("base_url") + "/chat/completions"
|
|
133
139
|
yield LLM(**config)
|
|
134
140
|
|
|
@@ -144,7 +150,8 @@ async def datarobot_llm_deployment_llamaindex(
|
|
|
144
150
|
by_alias=True,
|
|
145
151
|
exclude_none=True,
|
|
146
152
|
)
|
|
147
|
-
config["model"]
|
|
153
|
+
if not config["model"].startswith("datarobot/"):
|
|
154
|
+
config["model"] = "datarobot/" + config["model"]
|
|
148
155
|
config["api_base"] = config.pop("base_url") + "/chat/completions"
|
|
149
156
|
yield DataRobotLiteLLM(**config)
|
|
150
157
|
|
|
@@ -159,6 +166,7 @@ async def datarobot_nim_langchain(
|
|
|
159
166
|
exclude_none=True,
|
|
160
167
|
)
|
|
161
168
|
config["stream_options"] = {"include_usage": True}
|
|
169
|
+
config["model"] = config["model"].removeprefix("datarobot/")
|
|
162
170
|
yield DataRobotChatOpenAI(**config)
|
|
163
171
|
|
|
164
172
|
|
|
@@ -171,7 +179,8 @@ async def datarobot_nim_crewai(
|
|
|
171
179
|
by_alias=True,
|
|
172
180
|
exclude_none=True,
|
|
173
181
|
)
|
|
174
|
-
config["model"]
|
|
182
|
+
if not config["model"].startswith("datarobot/"):
|
|
183
|
+
config["model"] = "datarobot/" + config["model"]
|
|
175
184
|
config["api_base"] = config.pop("base_url") + "/chat/completions"
|
|
176
185
|
yield LLM(**config)
|
|
177
186
|
|
|
@@ -185,6 +194,56 @@ async def datarobot_nim_llamaindex(
|
|
|
185
194
|
by_alias=True,
|
|
186
195
|
exclude_none=True,
|
|
187
196
|
)
|
|
188
|
-
config["model"]
|
|
197
|
+
if not config["model"].startswith("datarobot/"):
|
|
198
|
+
config["model"] = "datarobot/" + config["model"]
|
|
189
199
|
config["api_base"] = config.pop("base_url") + "/chat/completions"
|
|
190
200
|
yield DataRobotLiteLLM(**config)
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
@register_llm_client(
|
|
204
|
+
config_type=DataRobotLLMComponentModelConfig, wrapper_type=LLMFrameworkEnum.LANGCHAIN
|
|
205
|
+
)
|
|
206
|
+
async def datarobot_llm_component_langchain(
|
|
207
|
+
llm_config: DataRobotLLMComponentModelConfig, builder: Builder
|
|
208
|
+
) -> AsyncGenerator[ChatOpenAI]:
|
|
209
|
+
config = llm_config.model_dump(exclude={"type", "thinking"}, by_alias=True, exclude_none=True)
|
|
210
|
+
if config["use_datarobot_llm_gateway"]:
|
|
211
|
+
config["base_url"] = config["base_url"] + "/genai/llmgw"
|
|
212
|
+
config["stream_options"] = {"include_usage": True}
|
|
213
|
+
config["model"] = config["model"].removeprefix("datarobot/")
|
|
214
|
+
config.pop("use_datarobot_llm_gateway")
|
|
215
|
+
yield DataRobotChatOpenAI(**config)
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
@register_llm_client(
|
|
219
|
+
config_type=DataRobotLLMComponentModelConfig, wrapper_type=LLMFrameworkEnum.CREWAI
|
|
220
|
+
)
|
|
221
|
+
async def datarobot_llm_component_crewai(
|
|
222
|
+
llm_config: DataRobotLLMComponentModelConfig, builder: Builder
|
|
223
|
+
) -> AsyncGenerator[LLM]:
|
|
224
|
+
config = llm_config.model_dump(exclude={"type", "thinking"}, by_alias=True, exclude_none=True)
|
|
225
|
+
if not config["model"].startswith("datarobot/"):
|
|
226
|
+
config["model"] = "datarobot/" + config["model"]
|
|
227
|
+
if config["use_datarobot_llm_gateway"]:
|
|
228
|
+
config["base_url"] = config["base_url"].removesuffix("/api/v2")
|
|
229
|
+
else:
|
|
230
|
+
config["api_base"] = config.pop("base_url") + "/chat/completions"
|
|
231
|
+
config.pop("use_datarobot_llm_gateway")
|
|
232
|
+
yield LLM(**config)
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
@register_llm_client(
|
|
236
|
+
config_type=DataRobotLLMComponentModelConfig, wrapper_type=LLMFrameworkEnum.LLAMA_INDEX
|
|
237
|
+
)
|
|
238
|
+
async def datarobot_llm_component_llamaindex(
|
|
239
|
+
llm_config: DataRobotLLMComponentModelConfig, builder: Builder
|
|
240
|
+
) -> AsyncGenerator[LLM]:
|
|
241
|
+
config = llm_config.model_dump(exclude={"type", "thinking"}, by_alias=True, exclude_none=True)
|
|
242
|
+
if not config["model"].startswith("datarobot/"):
|
|
243
|
+
config["model"] = "datarobot/" + config["model"]
|
|
244
|
+
if config["use_datarobot_llm_gateway"]:
|
|
245
|
+
config["api_base"] = config.pop("base_url").removesuffix("/api/v2")
|
|
246
|
+
else:
|
|
247
|
+
config["api_base"] = config.pop("base_url") + "/chat/completions"
|
|
248
|
+
config.pop("use_datarobot_llm_gateway")
|
|
249
|
+
yield DataRobotLiteLLM(**config)
|
|
@@ -32,11 +32,43 @@ class Config(DataRobotAppFrameworkBaseSettings):
|
|
|
32
32
|
datarobot_api_token: str | None = None
|
|
33
33
|
llm_deployment_id: str | None = None
|
|
34
34
|
nim_deployment_id: str | None = None
|
|
35
|
+
use_datarobot_llm_gateway: bool = False
|
|
36
|
+
llm_default_model: str | None = None
|
|
35
37
|
|
|
36
38
|
|
|
37
39
|
config = Config()
|
|
38
40
|
|
|
39
41
|
|
|
42
|
+
class DataRobotLLMComponentModelConfig(OpenAIModelConfig, name="datarobot-llm-component"): # type: ignore[call-arg]
|
|
43
|
+
"""A DataRobot LLM provider to be used with an LLM client."""
|
|
44
|
+
|
|
45
|
+
api_key: str | None = Field(
|
|
46
|
+
default=config.datarobot_api_token, description="DataRobot API key."
|
|
47
|
+
)
|
|
48
|
+
base_url: str | None = Field(
|
|
49
|
+
default=config.datarobot_endpoint.rstrip("/")
|
|
50
|
+
if config.use_datarobot_llm_gateway
|
|
51
|
+
else config.datarobot_endpoint + f"/deployments/{config.llm_deployment_id}",
|
|
52
|
+
description="DataRobot LLM URL.",
|
|
53
|
+
)
|
|
54
|
+
model_name: str = Field(
|
|
55
|
+
validation_alias=AliasChoices("model_name", "model"),
|
|
56
|
+
serialization_alias="model",
|
|
57
|
+
description="The model name.",
|
|
58
|
+
default=config.llm_default_model or "datarobot-deployed-llm",
|
|
59
|
+
)
|
|
60
|
+
use_datarobot_llm_gateway: bool = config.use_datarobot_llm_gateway
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@register_llm_provider(config_type=DataRobotLLMComponentModelConfig)
|
|
64
|
+
async def datarobot_llm_component(
|
|
65
|
+
config: DataRobotLLMComponentModelConfig, _builder: Builder
|
|
66
|
+
) -> LLMProviderInfo:
|
|
67
|
+
yield LLMProviderInfo(
|
|
68
|
+
config=config, description="DataRobot LLM Component for use with an LLM client."
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
|
|
40
72
|
class DataRobotLLMGatewayModelConfig(OpenAIModelConfig, name="datarobot-llm-gateway"): # type: ignore[call-arg]
|
|
41
73
|
"""A DataRobot LLM provider to be used with an LLM client."""
|
|
42
74
|
|
|
@@ -13,13 +13,13 @@ datarobot_genai/core/cli/__init__.py,sha256=B93Yb6VavoZpatrh8ltCL6YglIfR5FHgytXb
|
|
|
13
13
|
datarobot_genai/core/cli/agent_environment.py,sha256=BJzQoiDvZF5gW4mFE71U0yeg-l72C--kxiE-fv6W194,1662
|
|
14
14
|
datarobot_genai/core/cli/agent_kernel.py,sha256=3XX58DQ6XPpWB_tn5m3iGb3XTfhZf5X3W9tc6ADieU4,7790
|
|
15
15
|
datarobot_genai/core/mcp/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
16
|
-
datarobot_genai/core/mcp/common.py,sha256=
|
|
16
|
+
datarobot_genai/core/mcp/common.py,sha256=Y8SjuquUODKEfI7T9X-QuTMKdIlpCWFI1b3xs6tmHFA,7812
|
|
17
17
|
datarobot_genai/core/utils/__init__.py,sha256=VxtRUz6iwb04eFQQy0zqTNXLAkYpPXcJxVoKV0nOdXk,59
|
|
18
18
|
datarobot_genai/core/utils/auth.py,sha256=Xo1PxVr6oMgtMHkmHdS02klDKK1cyDpjGvIMF4Tx0Lo,7874
|
|
19
19
|
datarobot_genai/core/utils/urls.py,sha256=tk0t13duDEPcmwz2OnS4vwEdatruiuX8lnxMMhSaJik,2289
|
|
20
20
|
datarobot_genai/crewai/__init__.py,sha256=MtFnHA3EtmgiK_GjwUGPgQQ6G1MCEzz1SDBwQi9lE8M,706
|
|
21
21
|
datarobot_genai/crewai/agent.py,sha256=vp8_2LExpeLls7Fpzo0R6ud5I6Ryfu3n3oVTN4Yyi6A,1417
|
|
22
|
-
datarobot_genai/crewai/base.py,sha256=
|
|
22
|
+
datarobot_genai/crewai/base.py,sha256=JLljEN7sj8zaH8OamYoevFBZzza5BjZ4f0CGHRp2jUU,6447
|
|
23
23
|
datarobot_genai/crewai/events.py,sha256=K67bO1zwPrxmppz2wh8dFGNbVebyWGXAMD7oodFE2sQ,5462
|
|
24
24
|
datarobot_genai/crewai/mcp.py,sha256=AJTrs-8KdiRSjRECfBT1lJOsszWMoFoN9NIa1p5_wsM,2115
|
|
25
25
|
datarobot_genai/drmcp/__init__.py,sha256=JE83bfpGU7v77VzrDdlb0l8seM5OwUsUbaQErJ2eisc,2983
|
|
@@ -32,7 +32,7 @@ datarobot_genai/drmcp/core/config_utils.py,sha256=U-aieWw7MyP03cGDFIp97JH99ZUfr3
|
|
|
32
32
|
datarobot_genai/drmcp/core/constants.py,sha256=lUwoW_PTrbaBGqRJifKqCn3EoFacoEgdO-CpoFVrUoU,739
|
|
33
33
|
datarobot_genai/drmcp/core/credentials.py,sha256=PYEUDNMVw1BoMzZKLkPVTypNkVevEPtmk3scKnE-zYg,6706
|
|
34
34
|
datarobot_genai/drmcp/core/dr_mcp_server.py,sha256=7mu5UXHQmKNbIpNoQE0lPJaUI7AZa03avfHZRRtpjNI,12841
|
|
35
|
-
datarobot_genai/drmcp/core/dr_mcp_server_logo.py,sha256=
|
|
35
|
+
datarobot_genai/drmcp/core/dr_mcp_server_logo.py,sha256=hib-nfR1SNTW6CnpFsFCkL9H_OMwa4YYyinV7VNOuLk,4708
|
|
36
36
|
datarobot_genai/drmcp/core/exceptions.py,sha256=eqsGI-lxybgvWL5w4BFhbm3XzH1eU5tetwjnhJxelpc,905
|
|
37
37
|
datarobot_genai/drmcp/core/logging.py,sha256=Y_hig4eBWiXGaVV7B_3wBcaYVRNH4ydptbEQhrP9-mY,3414
|
|
38
38
|
datarobot_genai/drmcp/core/mcp_instance.py,sha256=wMsP39xqTmNBYqd49olEQb5UHTSsxj6BOIoIElorRB0,19235
|
|
@@ -83,7 +83,7 @@ datarobot_genai/drmcp/tools/predictive/predict_realtime.py,sha256=t7f28y_ealZoA6
|
|
|
83
83
|
datarobot_genai/drmcp/tools/predictive/project.py,sha256=KaMDAvJY4s12j_4ybA7-KcCS1yMOj-KPIKNBgCSE2iM,2536
|
|
84
84
|
datarobot_genai/drmcp/tools/predictive/training.py,sha256=kxeDVLqUh9ajDk8wK7CZRRydDK8UNuTVZCB3huUihF8,23660
|
|
85
85
|
datarobot_genai/langgraph/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
86
|
-
datarobot_genai/langgraph/agent.py,sha256=
|
|
86
|
+
datarobot_genai/langgraph/agent.py,sha256=P_vqNO-gFOLREHh8NdUXwc9cRlz6dWUY3B_z0r2vpQQ,10945
|
|
87
87
|
datarobot_genai/langgraph/mcp.py,sha256=iA2_j46mZAaNaL7ntXT-LW6C-NMJkzr3VfKDDfe7mh8,2851
|
|
88
88
|
datarobot_genai/llama_index/__init__.py,sha256=JEMkLQLuP8n14kNE3bZ2j08NdajnkJMfYjDQYqj7C0c,407
|
|
89
89
|
datarobot_genai/llama_index/agent.py,sha256=V6ZsD9GcBDJS-RJo1tJtIHhyW69_78gM6_fOHFV-Piw,1829
|
|
@@ -91,11 +91,11 @@ datarobot_genai/llama_index/base.py,sha256=ovcQQtC-djD_hcLrWdn93jg23AmD6NBEj7xtw
|
|
|
91
91
|
datarobot_genai/llama_index/mcp.py,sha256=leXqF1C4zhuYEKFwNEfZHY4dsUuGZk3W7KArY-zxVL8,2645
|
|
92
92
|
datarobot_genai/nat/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
93
93
|
datarobot_genai/nat/agent.py,sha256=siBLDWAff2-JwZ8Q3iNpM_e4_IoSwG9IvY0hyEjNenw,10292
|
|
94
|
-
datarobot_genai/nat/datarobot_llm_clients.py,sha256=
|
|
95
|
-
datarobot_genai/nat/datarobot_llm_providers.py,sha256=
|
|
96
|
-
datarobot_genai-0.1.
|
|
97
|
-
datarobot_genai-0.1.
|
|
98
|
-
datarobot_genai-0.1.
|
|
99
|
-
datarobot_genai-0.1.
|
|
100
|
-
datarobot_genai-0.1.
|
|
101
|
-
datarobot_genai-0.1.
|
|
94
|
+
datarobot_genai/nat/datarobot_llm_clients.py,sha256=STzAZ4OF8U-Y_cUTywxmKBGVotwsnbGP6vTojnu6q0g,9921
|
|
95
|
+
datarobot_genai/nat/datarobot_llm_providers.py,sha256=aDoQcTeGI-odqydPXEX9OGGNFbzAtpqzTvHHEkmJuEQ,4963
|
|
96
|
+
datarobot_genai-0.1.71.dist-info/METADATA,sha256=Mt1gNWIU1Jp3Vu6NkXO78vKjjXBnNToA0Vrjj0hLH-I,5918
|
|
97
|
+
datarobot_genai-0.1.71.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
98
|
+
datarobot_genai-0.1.71.dist-info/entry_points.txt,sha256=CZhmZcSyt_RBltgLN_b9xasJD6J5SaDc_z7K0wuOY9Y,150
|
|
99
|
+
datarobot_genai-0.1.71.dist-info/licenses/AUTHORS,sha256=isJGUXdjq1U7XZ_B_9AH8Qf0u4eX0XyQifJZ_Sxm4sA,80
|
|
100
|
+
datarobot_genai-0.1.71.dist-info/licenses/LICENSE,sha256=U2_VkLIktQoa60Nf6Tbt7E4RMlfhFSjWjcJJfVC-YCE,11341
|
|
101
|
+
datarobot_genai-0.1.71.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|