datarobot-genai 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- datarobot_genai/__init__.py +19 -0
- datarobot_genai/core/__init__.py +0 -0
- datarobot_genai/core/agents/__init__.py +43 -0
- datarobot_genai/core/agents/base.py +195 -0
- datarobot_genai/core/chat/__init__.py +19 -0
- datarobot_genai/core/chat/auth.py +146 -0
- datarobot_genai/core/chat/client.py +178 -0
- datarobot_genai/core/chat/responses.py +297 -0
- datarobot_genai/core/cli/__init__.py +18 -0
- datarobot_genai/core/cli/agent_environment.py +47 -0
- datarobot_genai/core/cli/agent_kernel.py +211 -0
- datarobot_genai/core/custom_model.py +141 -0
- datarobot_genai/core/mcp/__init__.py +0 -0
- datarobot_genai/core/mcp/common.py +218 -0
- datarobot_genai/core/telemetry_agent.py +126 -0
- datarobot_genai/core/utils/__init__.py +3 -0
- datarobot_genai/core/utils/auth.py +234 -0
- datarobot_genai/core/utils/urls.py +64 -0
- datarobot_genai/crewai/__init__.py +24 -0
- datarobot_genai/crewai/agent.py +42 -0
- datarobot_genai/crewai/base.py +159 -0
- datarobot_genai/crewai/events.py +117 -0
- datarobot_genai/crewai/mcp.py +59 -0
- datarobot_genai/drmcp/__init__.py +78 -0
- datarobot_genai/drmcp/core/__init__.py +13 -0
- datarobot_genai/drmcp/core/auth.py +165 -0
- datarobot_genai/drmcp/core/clients.py +180 -0
- datarobot_genai/drmcp/core/config.py +250 -0
- datarobot_genai/drmcp/core/config_utils.py +174 -0
- datarobot_genai/drmcp/core/constants.py +18 -0
- datarobot_genai/drmcp/core/credentials.py +190 -0
- datarobot_genai/drmcp/core/dr_mcp_server.py +316 -0
- datarobot_genai/drmcp/core/dr_mcp_server_logo.py +136 -0
- datarobot_genai/drmcp/core/dynamic_prompts/__init__.py +13 -0
- datarobot_genai/drmcp/core/dynamic_prompts/controllers.py +130 -0
- datarobot_genai/drmcp/core/dynamic_prompts/dr_lib.py +128 -0
- datarobot_genai/drmcp/core/dynamic_prompts/register.py +206 -0
- datarobot_genai/drmcp/core/dynamic_prompts/utils.py +33 -0
- datarobot_genai/drmcp/core/dynamic_tools/__init__.py +14 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/__init__.py +0 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/__init__.py +14 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/base.py +72 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/default.py +82 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/drum.py +238 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/config.py +228 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/controllers.py +63 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/metadata.py +162 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/register.py +87 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/schemas/drum_agentic_fallback_schema.json +36 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/schemas/drum_prediction_fallback_schema.json +10 -0
- datarobot_genai/drmcp/core/dynamic_tools/register.py +254 -0
- datarobot_genai/drmcp/core/dynamic_tools/schema.py +532 -0
- datarobot_genai/drmcp/core/exceptions.py +25 -0
- datarobot_genai/drmcp/core/logging.py +98 -0
- datarobot_genai/drmcp/core/mcp_instance.py +542 -0
- datarobot_genai/drmcp/core/mcp_server_tools.py +129 -0
- datarobot_genai/drmcp/core/memory_management/__init__.py +13 -0
- datarobot_genai/drmcp/core/memory_management/manager.py +820 -0
- datarobot_genai/drmcp/core/memory_management/memory_tools.py +201 -0
- datarobot_genai/drmcp/core/routes.py +436 -0
- datarobot_genai/drmcp/core/routes_utils.py +30 -0
- datarobot_genai/drmcp/core/server_life_cycle.py +107 -0
- datarobot_genai/drmcp/core/telemetry.py +424 -0
- datarobot_genai/drmcp/core/tool_filter.py +108 -0
- datarobot_genai/drmcp/core/utils.py +131 -0
- datarobot_genai/drmcp/server.py +19 -0
- datarobot_genai/drmcp/test_utils/__init__.py +13 -0
- datarobot_genai/drmcp/test_utils/integration_mcp_server.py +102 -0
- datarobot_genai/drmcp/test_utils/mcp_utils_ete.py +96 -0
- datarobot_genai/drmcp/test_utils/mcp_utils_integration.py +94 -0
- datarobot_genai/drmcp/test_utils/openai_llm_mcp_client.py +234 -0
- datarobot_genai/drmcp/test_utils/tool_base_ete.py +151 -0
- datarobot_genai/drmcp/test_utils/utils.py +91 -0
- datarobot_genai/drmcp/tools/__init__.py +14 -0
- datarobot_genai/drmcp/tools/predictive/__init__.py +27 -0
- datarobot_genai/drmcp/tools/predictive/data.py +97 -0
- datarobot_genai/drmcp/tools/predictive/deployment.py +91 -0
- datarobot_genai/drmcp/tools/predictive/deployment_info.py +392 -0
- datarobot_genai/drmcp/tools/predictive/model.py +148 -0
- datarobot_genai/drmcp/tools/predictive/predict.py +254 -0
- datarobot_genai/drmcp/tools/predictive/predict_realtime.py +307 -0
- datarobot_genai/drmcp/tools/predictive/project.py +72 -0
- datarobot_genai/drmcp/tools/predictive/training.py +651 -0
- datarobot_genai/langgraph/__init__.py +0 -0
- datarobot_genai/langgraph/agent.py +341 -0
- datarobot_genai/langgraph/mcp.py +73 -0
- datarobot_genai/llama_index/__init__.py +16 -0
- datarobot_genai/llama_index/agent.py +50 -0
- datarobot_genai/llama_index/base.py +299 -0
- datarobot_genai/llama_index/mcp.py +79 -0
- datarobot_genai/nat/__init__.py +0 -0
- datarobot_genai/nat/agent.py +258 -0
- datarobot_genai/nat/datarobot_llm_clients.py +249 -0
- datarobot_genai/nat/datarobot_llm_providers.py +130 -0
- datarobot_genai/py.typed +0 -0
- datarobot_genai-0.2.0.dist-info/METADATA +139 -0
- datarobot_genai-0.2.0.dist-info/RECORD +101 -0
- datarobot_genai-0.2.0.dist-info/WHEEL +4 -0
- datarobot_genai-0.2.0.dist-info/entry_points.txt +3 -0
- datarobot_genai-0.2.0.dist-info/licenses/AUTHORS +2 -0
- datarobot_genai-0.2.0.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,254 @@
|
|
|
1
|
+
# Copyright 2025 DataRobot, Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import json
|
|
16
|
+
import logging
|
|
17
|
+
import uuid
|
|
18
|
+
from typing import Any
|
|
19
|
+
|
|
20
|
+
import datarobot as dr
|
|
21
|
+
from fastmcp.resources import HttpResource
|
|
22
|
+
from fastmcp.resources import ResourceManager
|
|
23
|
+
|
|
24
|
+
from datarobot_genai.drmcp.core.clients import get_credentials
|
|
25
|
+
from datarobot_genai.drmcp.core.clients import get_s3_bucket_info
|
|
26
|
+
from datarobot_genai.drmcp.core.clients import get_sdk_client
|
|
27
|
+
from datarobot_genai.drmcp.core.mcp_instance import dr_mcp_tool
|
|
28
|
+
from datarobot_genai.drmcp.core.utils import generate_presigned_url
|
|
29
|
+
|
|
30
|
+
logger = logging.getLogger(__name__)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _handle_prediction_resource(
|
|
34
|
+
job: Any, bucket: str, key: str, deployment_id: str, input_desc: str
|
|
35
|
+
) -> str:
|
|
36
|
+
s3_url = generate_presigned_url(bucket, key)
|
|
37
|
+
resource_manager = ResourceManager()
|
|
38
|
+
resource = HttpResource(
|
|
39
|
+
uri=s3_url, # type: ignore[arg-type]
|
|
40
|
+
url=s3_url,
|
|
41
|
+
name=f"Predictions for {deployment_id}",
|
|
42
|
+
mime_type="text/csv",
|
|
43
|
+
)
|
|
44
|
+
resource_manager.add_resource(resource)
|
|
45
|
+
return (
|
|
46
|
+
f"Finished Batch Prediction job ID {job.id} for deployment ID {deployment_id}. "
|
|
47
|
+
f"{input_desc} Results uploaded to {s3_url}. "
|
|
48
|
+
f"Job status: {job.status} and you can find the job on the DataRobot UI at "
|
|
49
|
+
f"/deployments/batch-jobs. "
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def get_or_create_s3_credential() -> Any:
|
|
54
|
+
existing_creds = dr.Credential.list()
|
|
55
|
+
for cred in existing_creds:
|
|
56
|
+
if cred.name == "dr_mcp_server_temp_storage_s3_cred":
|
|
57
|
+
return cred
|
|
58
|
+
|
|
59
|
+
if get_credentials().has_aws_credentials():
|
|
60
|
+
aws_access_key_id, aws_secret_access_key, aws_session_token = (
|
|
61
|
+
get_credentials().get_aws_credentials()
|
|
62
|
+
)
|
|
63
|
+
cred = dr.Credential.create_s3(
|
|
64
|
+
name="dr_mcp_server_temp_storage_s3_cred",
|
|
65
|
+
aws_access_key_id=aws_access_key_id,
|
|
66
|
+
aws_secret_access_key=aws_secret_access_key,
|
|
67
|
+
aws_session_token=aws_session_token,
|
|
68
|
+
)
|
|
69
|
+
return cred
|
|
70
|
+
|
|
71
|
+
raise Exception("No AWS credentials found in your MCP deployment.")
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def make_output_settings(cred: Any) -> tuple[dict[str, Any], str, str]:
|
|
75
|
+
bucket_info = get_s3_bucket_info()
|
|
76
|
+
s3_bucket = bucket_info["bucket"]
|
|
77
|
+
s3_prefix = bucket_info["prefix"]
|
|
78
|
+
s3_key = f"{s3_prefix}{uuid.uuid4()}.csv"
|
|
79
|
+
s3_url = f"s3://{s3_bucket}/{s3_key}"
|
|
80
|
+
|
|
81
|
+
return (
|
|
82
|
+
{
|
|
83
|
+
"type": "s3",
|
|
84
|
+
"url": s3_url,
|
|
85
|
+
"credential_id": cred.credential_id,
|
|
86
|
+
},
|
|
87
|
+
s3_bucket,
|
|
88
|
+
s3_key,
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def wait_for_preds_and_cache_results(
|
|
93
|
+
job: Any, bucket: str, key: str, deployment_id: str, input_desc: str, timeout: int
|
|
94
|
+
) -> str:
|
|
95
|
+
job.wait_for_completion(timeout)
|
|
96
|
+
if job.status in ["ERROR", "FAILED", "ABORTED"]:
|
|
97
|
+
logger.error(f"Job failed with status {job.status}")
|
|
98
|
+
return f"Job failed with status {job.status}"
|
|
99
|
+
return _handle_prediction_resource(job, bucket, key, deployment_id, input_desc)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
@dr_mcp_tool(tags={"prediction", "scoring", "batch"})
|
|
103
|
+
async def predict_by_file_path(
|
|
104
|
+
deployment_id: str,
|
|
105
|
+
file_path: str,
|
|
106
|
+
timeout: int = 600,
|
|
107
|
+
) -> str:
|
|
108
|
+
"""
|
|
109
|
+
Make predictions using a DataRobot deployment and a local CSV file using the DataRobot Python
|
|
110
|
+
SDK. Use this tool to score large amounts of data, for small amounts of data use the
|
|
111
|
+
predict_realtime tool.
|
|
112
|
+
Args:
|
|
113
|
+
deployment_id: The ID of the DataRobot deployment to use for prediction.
|
|
114
|
+
file_path: Path to a CSV file to use as input data.
|
|
115
|
+
timeout: Timeout in seconds for the batch prediction job (default 300).
|
|
116
|
+
|
|
117
|
+
Returns
|
|
118
|
+
-------
|
|
119
|
+
A string summary of the batch prediction job and download link if available.
|
|
120
|
+
"""
|
|
121
|
+
output_settings, bucket, key = make_output_settings(get_or_create_s3_credential())
|
|
122
|
+
job = dr.BatchPredictionJob.score(
|
|
123
|
+
deployment=deployment_id,
|
|
124
|
+
intake_settings={ # type: ignore[arg-type]
|
|
125
|
+
"type": "localFile",
|
|
126
|
+
"file": file_path,
|
|
127
|
+
},
|
|
128
|
+
output_settings=output_settings, # type: ignore[arg-type]
|
|
129
|
+
)
|
|
130
|
+
return wait_for_preds_and_cache_results(
|
|
131
|
+
job, bucket, key, deployment_id, f"Scoring file {file_path}.", timeout
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
@dr_mcp_tool(tags={"prediction", "scoring", "batch"})
|
|
136
|
+
async def predict_by_ai_catalog(
|
|
137
|
+
deployment_id: str,
|
|
138
|
+
dataset_id: str,
|
|
139
|
+
timeout: int = 600,
|
|
140
|
+
) -> str:
|
|
141
|
+
"""
|
|
142
|
+
Make predictions using a DataRobot deployment and an AI Catalog dataset using the DataRobot
|
|
143
|
+
Python SDK.
|
|
144
|
+
|
|
145
|
+
Use this tool when asked to score data stored in AI Catalog by dataset id.
|
|
146
|
+
Args:
|
|
147
|
+
deployment_id: The ID of the DataRobot deployment to use for prediction.
|
|
148
|
+
dataset_id: ID of an AI Catalog item to use as input data.
|
|
149
|
+
timeout: Timeout in seconds for the batch prediction job (default 300).
|
|
150
|
+
|
|
151
|
+
Returns
|
|
152
|
+
-------
|
|
153
|
+
A string summary of the batch prediction job and download link if available.
|
|
154
|
+
"""
|
|
155
|
+
output_settings, bucket, key = make_output_settings(get_or_create_s3_credential())
|
|
156
|
+
client = get_sdk_client()
|
|
157
|
+
dataset = client.Dataset.get(dataset_id)
|
|
158
|
+
job = dr.BatchPredictionJob.score(
|
|
159
|
+
deployment=deployment_id,
|
|
160
|
+
intake_settings={ # type: ignore[arg-type]
|
|
161
|
+
"type": "dataset",
|
|
162
|
+
"dataset": dataset,
|
|
163
|
+
},
|
|
164
|
+
output_settings=output_settings, # type: ignore[arg-type]
|
|
165
|
+
)
|
|
166
|
+
return wait_for_preds_and_cache_results(
|
|
167
|
+
job, bucket, key, deployment_id, f"Scoring dataset {dataset_id}.", timeout
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
@dr_mcp_tool(tags={"prediction", "scoring", "batch"})
|
|
172
|
+
async def predict_from_project_data(
|
|
173
|
+
deployment_id: str,
|
|
174
|
+
project_id: str,
|
|
175
|
+
dataset_id: str | None = None,
|
|
176
|
+
partition: str | None = None,
|
|
177
|
+
timeout: int = 600,
|
|
178
|
+
) -> str:
|
|
179
|
+
"""
|
|
180
|
+
Make predictions using a DataRobot deployment using the training data associated with the
|
|
181
|
+
project that created the deployment.
|
|
182
|
+
Use this tool to score holdout, validation, or allBacktest partitions of the training data.
|
|
183
|
+
Can request a specific partition of the data, or use an external dataset (with dataset_id)
|
|
184
|
+
stored in AI Catalog.
|
|
185
|
+
Args:
|
|
186
|
+
deployment_id: (Required)The ID of the DataRobot deployment to use for prediction.
|
|
187
|
+
project_id: (Required) The ID of the DataRobot project to use for prediction. Can be found
|
|
188
|
+
by using the get_model_info_from_deployment tool.
|
|
189
|
+
dataset_id: (Optional) The ID of the external dataset, ususally stored in AI Catalog, to
|
|
190
|
+
use for prediction.
|
|
191
|
+
partition: (Optional)The partition of the DataRobot dataset to use for prediction, could be
|
|
192
|
+
'holdout', 'validation', or 'allBacktest'.
|
|
193
|
+
timeout: (Optional) Timeout in seconds for the batch prediction job (default 600).
|
|
194
|
+
|
|
195
|
+
Returns
|
|
196
|
+
-------
|
|
197
|
+
A string summary of the batch prediction job and download link if available.
|
|
198
|
+
"""
|
|
199
|
+
output_settings, bucket, key = make_output_settings(get_or_create_s3_credential())
|
|
200
|
+
intake_settings: dict[str, Any] = {
|
|
201
|
+
"type": "dss",
|
|
202
|
+
"project_id": project_id,
|
|
203
|
+
}
|
|
204
|
+
if partition:
|
|
205
|
+
intake_settings["partition"] = partition
|
|
206
|
+
if dataset_id:
|
|
207
|
+
intake_settings["dataset_id"] = dataset_id
|
|
208
|
+
job = dr.BatchPredictionJob.score(
|
|
209
|
+
deployment=deployment_id,
|
|
210
|
+
intake_settings=intake_settings, # type: ignore[arg-type]
|
|
211
|
+
output_settings=output_settings, # type: ignore[arg-type]
|
|
212
|
+
)
|
|
213
|
+
return wait_for_preds_and_cache_results(
|
|
214
|
+
job, bucket, key, deployment_id, f"Scoring project {project_id}.", timeout
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
# FIXME
|
|
219
|
+
# @dr_mcp_tool(tags={"prediction", "explanations", "shap"})
|
|
220
|
+
async def get_prediction_explanations(
|
|
221
|
+
project_id: str,
|
|
222
|
+
model_id: str,
|
|
223
|
+
dataset_id: str,
|
|
224
|
+
max_explanations: int = 100,
|
|
225
|
+
) -> str:
|
|
226
|
+
"""
|
|
227
|
+
Calculate prediction explanations (SHAP values) for a given model and dataset.
|
|
228
|
+
|
|
229
|
+
Args:
|
|
230
|
+
project_id: The ID of the DataRobot project.
|
|
231
|
+
model_id: The ID of the model to use for explanations.
|
|
232
|
+
dataset_id: The ID of the dataset to explain predictions for.
|
|
233
|
+
max_explanations: Maximum number of explanations per row (default 100).
|
|
234
|
+
|
|
235
|
+
Returns
|
|
236
|
+
-------
|
|
237
|
+
JSON string containing the prediction explanations for each row.
|
|
238
|
+
"""
|
|
239
|
+
client = get_sdk_client()
|
|
240
|
+
project = client.Project.get(project_id)
|
|
241
|
+
model = client.Model.get(project=project, model_id=model_id)
|
|
242
|
+
try:
|
|
243
|
+
explanations = model.get_or_request_prediction_explanations(
|
|
244
|
+
dataset_id=dataset_id, max_explanations=max_explanations
|
|
245
|
+
)
|
|
246
|
+
return json.dumps(
|
|
247
|
+
{"explanations": explanations, "ui_panel": ["prediction-distribution"]},
|
|
248
|
+
indent=2,
|
|
249
|
+
)
|
|
250
|
+
except Exception as e:
|
|
251
|
+
logger.error(f"Error in get_prediction_explanations: {type(e).__name__}: {e}")
|
|
252
|
+
return json.dumps(
|
|
253
|
+
{"error": f"Error in get_prediction_explanations: {type(e).__name__}: {e}"}
|
|
254
|
+
)
|
|
@@ -0,0 +1,307 @@
|
|
|
1
|
+
# Copyright 2025 DataRobot, Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import io
|
|
16
|
+
import json
|
|
17
|
+
import logging
|
|
18
|
+
import uuid
|
|
19
|
+
from datetime import datetime
|
|
20
|
+
|
|
21
|
+
import pandas as pd
|
|
22
|
+
from datarobot_predict import TimeSeriesType
|
|
23
|
+
from datarobot_predict.deployment import predict as dr_predict
|
|
24
|
+
from pydantic import BaseModel
|
|
25
|
+
|
|
26
|
+
from datarobot_genai.drmcp.core.clients import get_s3_bucket_info
|
|
27
|
+
from datarobot_genai.drmcp.core.clients import get_sdk_client
|
|
28
|
+
from datarobot_genai.drmcp.core.mcp_instance import dr_mcp_tool
|
|
29
|
+
from datarobot_genai.drmcp.core.utils import PredictionResponse
|
|
30
|
+
from datarobot_genai.drmcp.core.utils import predictions_result_response
|
|
31
|
+
|
|
32
|
+
logger = logging.getLogger(__name__)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class BucketInfo(BaseModel):
|
|
36
|
+
bucket: str
|
|
37
|
+
key: str
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def make_output_settings() -> BucketInfo:
|
|
41
|
+
bucket_info = get_s3_bucket_info()
|
|
42
|
+
s3_key = f"{bucket_info['prefix']}{uuid.uuid4()}.csv"
|
|
43
|
+
return BucketInfo(bucket=bucket_info["bucket"], key=s3_key)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@dr_mcp_tool(tags={"prediction", "realtime", "scoring"})
|
|
47
|
+
async def predict_by_ai_catalog_rt(
|
|
48
|
+
deployment_id: str,
|
|
49
|
+
dataset_id: str,
|
|
50
|
+
timeout: int = 600,
|
|
51
|
+
) -> PredictionResponse:
|
|
52
|
+
"""
|
|
53
|
+
Make real-time predictions using a DataRobot deployment and an AI Catalog dataset using the
|
|
54
|
+
datarobot-predict library.
|
|
55
|
+
Use this for fast results when your data is not huge (not gigabytes). Results larger than 1MB
|
|
56
|
+
will be returned as a resource id and S3 URL; smaller results will be returned inline as a CSV
|
|
57
|
+
string.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
deployment_id: The ID of the DataRobot deployment to use for prediction.
|
|
61
|
+
dataset_id: ID of an AI Catalog item to use as input data.
|
|
62
|
+
timeout: Timeout in seconds for the prediction job (default 600).
|
|
63
|
+
|
|
64
|
+
Returns
|
|
65
|
+
-------
|
|
66
|
+
dict: {"type": "inline", "data": csv_str} for small results (<1MB), or {"type": "resource",
|
|
67
|
+
"resource_id": ..., "s3_url": ...} for large results (>=1MB).
|
|
68
|
+
"""
|
|
69
|
+
client = get_sdk_client()
|
|
70
|
+
dataset = client.Dataset.get(dataset_id)
|
|
71
|
+
|
|
72
|
+
# 1. Preferred: built-in DataFrame helper (newer SDKs)
|
|
73
|
+
if hasattr(dataset, "get_as_dataframe"):
|
|
74
|
+
df = dataset.get_as_dataframe()
|
|
75
|
+
|
|
76
|
+
# 2. Next: if there is a method returning a local file path
|
|
77
|
+
elif hasattr(dataset, "download"):
|
|
78
|
+
path = dataset.download("dataset.csv")
|
|
79
|
+
df = pd.read_csv(path)
|
|
80
|
+
|
|
81
|
+
# 3. Next: if there is a method returning a local file path
|
|
82
|
+
elif hasattr(dataset, "get_file"):
|
|
83
|
+
path = dataset.get_file()
|
|
84
|
+
df = pd.read_csv(path)
|
|
85
|
+
|
|
86
|
+
# 4. Bytes fallback
|
|
87
|
+
elif hasattr(dataset, "get_bytes"):
|
|
88
|
+
raw = dataset.get_bytes()
|
|
89
|
+
df = pd.read_csv(io.BytesIO(raw))
|
|
90
|
+
|
|
91
|
+
# 5. Last resort: expose URL then fetch manually
|
|
92
|
+
else:
|
|
93
|
+
url = dataset.url
|
|
94
|
+
df = pd.read_csv(url)
|
|
95
|
+
|
|
96
|
+
deployment = client.Deployment.get(deployment_id=deployment_id)
|
|
97
|
+
result = dr_predict(deployment, df, timeout=timeout)
|
|
98
|
+
predictions = result.dataframe
|
|
99
|
+
bucket_info = make_output_settings()
|
|
100
|
+
return predictions_result_response(
|
|
101
|
+
predictions,
|
|
102
|
+
bucket_info.bucket,
|
|
103
|
+
bucket_info.key,
|
|
104
|
+
f"pred_{deployment_id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
|
|
105
|
+
True,
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
@dr_mcp_tool(tags={"prediction", "realtime", "scoring"})
|
|
110
|
+
async def predict_realtime(
|
|
111
|
+
deployment_id: str,
|
|
112
|
+
file_path: str | None = None,
|
|
113
|
+
dataset: str | None = None,
|
|
114
|
+
forecast_point: str | None = None,
|
|
115
|
+
forecast_range_start: str | None = None,
|
|
116
|
+
forecast_range_end: str | None = None,
|
|
117
|
+
series_id_column: str | None = None,
|
|
118
|
+
max_explanations: int | str = 0,
|
|
119
|
+
max_ngram_explanations: int | str | None = None,
|
|
120
|
+
threshold_high: float | None = None,
|
|
121
|
+
threshold_low: float | None = None,
|
|
122
|
+
passthrough_columns: str | None = None,
|
|
123
|
+
explanation_algorithm: str | None = None,
|
|
124
|
+
prediction_endpoint: str | None = None,
|
|
125
|
+
timeout: int = 600,
|
|
126
|
+
) -> PredictionResponse:
|
|
127
|
+
"""
|
|
128
|
+
Make real-time predictions using a DataRobot deployment and a local CSV file or a dataset
|
|
129
|
+
string.
|
|
130
|
+
|
|
131
|
+
This is the unified prediction function that supports:
|
|
132
|
+
- Regular classification/regression predictions
|
|
133
|
+
- Time series forecasting with advanced parameters
|
|
134
|
+
- Prediction explanations (SHAP/XEMP)
|
|
135
|
+
- Text explanations for NLP models
|
|
136
|
+
- Custom thresholds and passthrough columns
|
|
137
|
+
|
|
138
|
+
For regular predictions: Just provide deployment_id and file_path or dataset
|
|
139
|
+
For time series: Add forecast_point OR forecast_range_start/end
|
|
140
|
+
For explanations: Set max_explanations > 0 and optionally explanation_algorithm
|
|
141
|
+
For text models: Use max_ngram_explanations for text feature explanations
|
|
142
|
+
|
|
143
|
+
When using this tool, always consider feature importance. For features with high importance,
|
|
144
|
+
try to infer or ask for a reasonable value, using frequent values or domain knowledge if
|
|
145
|
+
available.
|
|
146
|
+
For less important features, you may leave them blank.
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
deployment_id: The ID of the DataRobot deployment to use for prediction.
|
|
150
|
+
file_path: Path to a CSV file to use as input data. For time series with forecast_point,
|
|
151
|
+
must have at least 4 historical values within the feature derivation window.
|
|
152
|
+
dataset: (Optional) CSV or JSON string representing the input data. If provided, this
|
|
153
|
+
takes precedence over file_path.
|
|
154
|
+
forecast_point: (Time Series) Date to start forecasting from (e.g., "2024-06-01").
|
|
155
|
+
If provided, triggers time series FORECAST mode. Uses most recent date if
|
|
156
|
+
None.
|
|
157
|
+
forecast_range_start: (Time Series) Start date for historical predictions (e.g.,
|
|
158
|
+
"2024-06-01").
|
|
159
|
+
Must be used with forecast_range_end for HISTORICAL mode.
|
|
160
|
+
forecast_range_end: (Time Series) End date for historical predictions (e.g., "2024-06-07").
|
|
161
|
+
Must be used with forecast_range_start for HISTORICAL mode.
|
|
162
|
+
series_id_column: (Multiseries Time Series) Column name identifying different series
|
|
163
|
+
(e.g., "store_id", "region"). Must exist in the input data.
|
|
164
|
+
max_explanations: Number of prediction explanations to return per prediction.
|
|
165
|
+
- 0: No explanations (default)
|
|
166
|
+
- Positive integer: Specific number of explanations
|
|
167
|
+
- "all": All available explanations (SHAP only)
|
|
168
|
+
Note: For SHAP, 0 means all explanations; for XEMP, 0 means none.
|
|
169
|
+
max_ngram_explanations: (Text Models) Maximum number of text explanations per prediction.
|
|
170
|
+
Recommended: "all" for text models. None disables text explanations.
|
|
171
|
+
threshold_high: Only compute explanations for predictions above this threshold (0.0-1.0).
|
|
172
|
+
Useful for focusing explanations on high-confidence predictions.
|
|
173
|
+
threshold_low: Only compute explanations for predictions below this threshold (0.0-1.0).
|
|
174
|
+
Useful for focusing explanations on low-confidence predictions.
|
|
175
|
+
passthrough_columns: Input columns to include in output alongside predictions.
|
|
176
|
+
- "all": Include all input columns
|
|
177
|
+
- "column1,column2": Comma-separated list of specific columns
|
|
178
|
+
- None: No passthrough columns (default)
|
|
179
|
+
explanation_algorithm: Algorithm for computing explanations.
|
|
180
|
+
- "shap": SHAP explanations (default for most models)
|
|
181
|
+
- "xemp": XEMP explanations (faster, less accurate)
|
|
182
|
+
- None: Use deployment default
|
|
183
|
+
prediction_endpoint: Override the prediction server endpoint URL.
|
|
184
|
+
Useful for custom prediction servers or Portable Prediction Server.
|
|
185
|
+
timeout: Request timeout in seconds (default 600).
|
|
186
|
+
|
|
187
|
+
Returns
|
|
188
|
+
-------
|
|
189
|
+
dict: Prediction response with the following structure:
|
|
190
|
+
- {"type": "inline", "data": "csv_string"} for results < 1MB
|
|
191
|
+
- {"type": "resource", "resource_id": "...", "s3_url": "..."} for results >= 1MB
|
|
192
|
+
|
|
193
|
+
The CSV data contains:
|
|
194
|
+
- Prediction columns (e.g., class probabilities, regression values)
|
|
195
|
+
- Explanation columns (if max_explanations > 0)
|
|
196
|
+
- Passthrough columns (if specified)
|
|
197
|
+
- Time series metadata (for forecasting: FORECAST_POINT, FORECAST_DISTANCE, etc.)
|
|
198
|
+
|
|
199
|
+
Examples
|
|
200
|
+
--------
|
|
201
|
+
# Regular binary classification
|
|
202
|
+
predict_realtime(deployment_id="abc123", file_path="data.csv")
|
|
203
|
+
|
|
204
|
+
# With SHAP explanations
|
|
205
|
+
predict_realtime(deployment_id="abc123", file_path="data.csv",
|
|
206
|
+
max_explanations=10, explanation_algorithm="shap")
|
|
207
|
+
|
|
208
|
+
# Time series forecasting
|
|
209
|
+
predict_realtime(deployment_id="abc123", file_path="ts_data.csv",
|
|
210
|
+
forecast_point="2024-06-01")
|
|
211
|
+
|
|
212
|
+
# Multiseries time series
|
|
213
|
+
predict_realtime(deployment_id="abc123", file_path="multiseries.csv",
|
|
214
|
+
forecast_point="2024-06-01", series_id_column="store_id")
|
|
215
|
+
|
|
216
|
+
# Historical time series predictions
|
|
217
|
+
predict_realtime(deployment_id="abc123", file_path="ts_data.csv",
|
|
218
|
+
forecast_range_start="2024-06-01",
|
|
219
|
+
forecast_range_end="2024-06-07")
|
|
220
|
+
|
|
221
|
+
# Text model with explanations and passthrough
|
|
222
|
+
predict_realtime(deployment_id="abc123", file_path="text_data.csv",
|
|
223
|
+
max_explanations="all", max_ngram_explanations="all",
|
|
224
|
+
passthrough_columns="document_id,customer_id")
|
|
225
|
+
"""
|
|
226
|
+
# Load input data from dataset string or file_path
|
|
227
|
+
if dataset is not None:
|
|
228
|
+
# Try CSV first
|
|
229
|
+
try:
|
|
230
|
+
df = pd.read_csv(io.StringIO(dataset))
|
|
231
|
+
except Exception:
|
|
232
|
+
# Try JSON
|
|
233
|
+
try:
|
|
234
|
+
data = json.loads(dataset)
|
|
235
|
+
df = pd.DataFrame(data)
|
|
236
|
+
except Exception as e:
|
|
237
|
+
raise ValueError(f"Could not parse dataset string as CSV or JSON: {e}")
|
|
238
|
+
elif file_path is not None:
|
|
239
|
+
df = pd.read_csv(file_path)
|
|
240
|
+
else:
|
|
241
|
+
raise ValueError("Either file_path or dataset must be provided.")
|
|
242
|
+
|
|
243
|
+
if series_id_column and series_id_column not in df.columns:
|
|
244
|
+
raise ValueError(f"series_id_column '{series_id_column}' not found in input data.")
|
|
245
|
+
|
|
246
|
+
client = get_sdk_client()
|
|
247
|
+
deployment = client.Deployment.get(deployment_id=deployment_id)
|
|
248
|
+
|
|
249
|
+
# Check if this is a time series prediction or regular prediction
|
|
250
|
+
is_time_series = bool(forecast_point or (forecast_range_start and forecast_range_end))
|
|
251
|
+
|
|
252
|
+
# Start with base prediction parameters
|
|
253
|
+
predict_kwargs = {
|
|
254
|
+
"deployment": deployment,
|
|
255
|
+
"data_frame": df,
|
|
256
|
+
"timeout": timeout,
|
|
257
|
+
}
|
|
258
|
+
|
|
259
|
+
# Add time series parameters if applicable
|
|
260
|
+
if is_time_series:
|
|
261
|
+
if forecast_point:
|
|
262
|
+
forecast_point_dt = pd.to_datetime(forecast_point)
|
|
263
|
+
predict_kwargs["time_series_type"] = TimeSeriesType.FORECAST
|
|
264
|
+
predict_kwargs["forecast_point"] = forecast_point_dt
|
|
265
|
+
elif forecast_range_start and forecast_range_end:
|
|
266
|
+
predictions_start_date_dt = pd.to_datetime(forecast_range_start)
|
|
267
|
+
predictions_end_date_dt = pd.to_datetime(forecast_range_end)
|
|
268
|
+
predict_kwargs["time_series_type"] = TimeSeriesType.HISTORICAL
|
|
269
|
+
predict_kwargs["predictions_start_date"] = predictions_start_date_dt
|
|
270
|
+
predict_kwargs["predictions_end_date"] = predictions_end_date_dt
|
|
271
|
+
|
|
272
|
+
# Add explanation parameters
|
|
273
|
+
if max_explanations != 0:
|
|
274
|
+
predict_kwargs["max_explanations"] = max_explanations
|
|
275
|
+
if max_ngram_explanations is not None:
|
|
276
|
+
predict_kwargs["max_ngram_explanations"] = max_ngram_explanations
|
|
277
|
+
if threshold_high is not None:
|
|
278
|
+
predict_kwargs["threshold_high"] = threshold_high
|
|
279
|
+
if threshold_low is not None:
|
|
280
|
+
predict_kwargs["threshold_low"] = threshold_low
|
|
281
|
+
if explanation_algorithm is not None:
|
|
282
|
+
predict_kwargs["explanation_algorithm"] = explanation_algorithm
|
|
283
|
+
|
|
284
|
+
# Add passthrough columns
|
|
285
|
+
if passthrough_columns is not None:
|
|
286
|
+
if passthrough_columns == "all":
|
|
287
|
+
predict_kwargs["passthrough_columns"] = "all"
|
|
288
|
+
else:
|
|
289
|
+
# Convert comma-separated string to set
|
|
290
|
+
columns_set = {col.strip() for col in passthrough_columns.split(",")}
|
|
291
|
+
predict_kwargs["passthrough_columns"] = columns_set
|
|
292
|
+
|
|
293
|
+
# Add custom prediction endpoint
|
|
294
|
+
if prediction_endpoint is not None:
|
|
295
|
+
predict_kwargs["prediction_endpoint"] = prediction_endpoint
|
|
296
|
+
|
|
297
|
+
# Run prediction
|
|
298
|
+
result = dr_predict(**predict_kwargs)
|
|
299
|
+
predictions = result.dataframe
|
|
300
|
+
bucket_info = make_output_settings()
|
|
301
|
+
return predictions_result_response(
|
|
302
|
+
predictions,
|
|
303
|
+
bucket_info.bucket,
|
|
304
|
+
bucket_info.key,
|
|
305
|
+
f"pred_{deployment_id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
|
|
306
|
+
max_explanations not in {0, "0"},
|
|
307
|
+
)
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
# Copyright 2025 DataRobot, Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import json
|
|
16
|
+
import logging
|
|
17
|
+
|
|
18
|
+
from datarobot_genai.drmcp.core.clients import get_sdk_client
|
|
19
|
+
from datarobot_genai.drmcp.core.mcp_instance import dr_mcp_tool
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dr_mcp_tool(tags={"project", "management", "list"})
|
|
25
|
+
async def list_projects() -> str:
|
|
26
|
+
"""
|
|
27
|
+
List all DataRobot projects for the authenticated user.
|
|
28
|
+
|
|
29
|
+
Returns
|
|
30
|
+
-------
|
|
31
|
+
A string summary of the user's DataRobot projects.
|
|
32
|
+
"""
|
|
33
|
+
client = get_sdk_client()
|
|
34
|
+
projects = client.Project.list()
|
|
35
|
+
if not projects:
|
|
36
|
+
return "No projects found."
|
|
37
|
+
return "\n".join(f"{p.id}: {p.project_name}" for p in projects)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@dr_mcp_tool(tags={"project", "data", "info"})
|
|
41
|
+
async def get_project_dataset_by_name(project_id: str, dataset_name: str) -> str:
|
|
42
|
+
"""
|
|
43
|
+
Get a dataset ID by name for a given project.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
project_id: The ID of the DataRobot project.
|
|
47
|
+
dataset_name: The name of the dataset to find (e.g., 'training', 'holdout').
|
|
48
|
+
|
|
49
|
+
Returns
|
|
50
|
+
-------
|
|
51
|
+
The dataset ID and the dataset type (source or prediction) as a string, or an error message.
|
|
52
|
+
"""
|
|
53
|
+
client = get_sdk_client()
|
|
54
|
+
project = client.Project.get(project_id)
|
|
55
|
+
all_datasets = []
|
|
56
|
+
source_dataset = project.get_dataset()
|
|
57
|
+
if source_dataset:
|
|
58
|
+
all_datasets.append({"type": "source", "dataset": source_dataset})
|
|
59
|
+
prediction_datasets = project.get_datasets()
|
|
60
|
+
if prediction_datasets:
|
|
61
|
+
all_datasets.extend([{"type": "prediction", "dataset": ds} for ds in prediction_datasets])
|
|
62
|
+
for ds in all_datasets:
|
|
63
|
+
if dataset_name.lower() in ds["dataset"].name.lower():
|
|
64
|
+
return json.dumps(
|
|
65
|
+
{
|
|
66
|
+
"dataset_id": ds["dataset"].id,
|
|
67
|
+
"dataset_type": ds["type"],
|
|
68
|
+
"ui_panel": ["dataset"],
|
|
69
|
+
},
|
|
70
|
+
indent=2,
|
|
71
|
+
)
|
|
72
|
+
return f"Dataset with name containing '{dataset_name}' not found in project {project_id}."
|