datarobot-genai 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. datarobot_genai/__init__.py +19 -0
  2. datarobot_genai/core/__init__.py +0 -0
  3. datarobot_genai/core/agents/__init__.py +43 -0
  4. datarobot_genai/core/agents/base.py +195 -0
  5. datarobot_genai/core/chat/__init__.py +19 -0
  6. datarobot_genai/core/chat/auth.py +146 -0
  7. datarobot_genai/core/chat/client.py +178 -0
  8. datarobot_genai/core/chat/responses.py +297 -0
  9. datarobot_genai/core/cli/__init__.py +18 -0
  10. datarobot_genai/core/cli/agent_environment.py +47 -0
  11. datarobot_genai/core/cli/agent_kernel.py +211 -0
  12. datarobot_genai/core/custom_model.py +141 -0
  13. datarobot_genai/core/mcp/__init__.py +0 -0
  14. datarobot_genai/core/mcp/common.py +218 -0
  15. datarobot_genai/core/telemetry_agent.py +126 -0
  16. datarobot_genai/core/utils/__init__.py +3 -0
  17. datarobot_genai/core/utils/auth.py +234 -0
  18. datarobot_genai/core/utils/urls.py +64 -0
  19. datarobot_genai/crewai/__init__.py +24 -0
  20. datarobot_genai/crewai/agent.py +42 -0
  21. datarobot_genai/crewai/base.py +159 -0
  22. datarobot_genai/crewai/events.py +117 -0
  23. datarobot_genai/crewai/mcp.py +59 -0
  24. datarobot_genai/drmcp/__init__.py +78 -0
  25. datarobot_genai/drmcp/core/__init__.py +13 -0
  26. datarobot_genai/drmcp/core/auth.py +165 -0
  27. datarobot_genai/drmcp/core/clients.py +180 -0
  28. datarobot_genai/drmcp/core/config.py +250 -0
  29. datarobot_genai/drmcp/core/config_utils.py +174 -0
  30. datarobot_genai/drmcp/core/constants.py +18 -0
  31. datarobot_genai/drmcp/core/credentials.py +190 -0
  32. datarobot_genai/drmcp/core/dr_mcp_server.py +316 -0
  33. datarobot_genai/drmcp/core/dr_mcp_server_logo.py +136 -0
  34. datarobot_genai/drmcp/core/dynamic_prompts/__init__.py +13 -0
  35. datarobot_genai/drmcp/core/dynamic_prompts/controllers.py +130 -0
  36. datarobot_genai/drmcp/core/dynamic_prompts/dr_lib.py +128 -0
  37. datarobot_genai/drmcp/core/dynamic_prompts/register.py +206 -0
  38. datarobot_genai/drmcp/core/dynamic_prompts/utils.py +33 -0
  39. datarobot_genai/drmcp/core/dynamic_tools/__init__.py +14 -0
  40. datarobot_genai/drmcp/core/dynamic_tools/deployment/__init__.py +0 -0
  41. datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/__init__.py +14 -0
  42. datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/base.py +72 -0
  43. datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/default.py +82 -0
  44. datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/drum.py +238 -0
  45. datarobot_genai/drmcp/core/dynamic_tools/deployment/config.py +228 -0
  46. datarobot_genai/drmcp/core/dynamic_tools/deployment/controllers.py +63 -0
  47. datarobot_genai/drmcp/core/dynamic_tools/deployment/metadata.py +162 -0
  48. datarobot_genai/drmcp/core/dynamic_tools/deployment/register.py +87 -0
  49. datarobot_genai/drmcp/core/dynamic_tools/deployment/schemas/drum_agentic_fallback_schema.json +36 -0
  50. datarobot_genai/drmcp/core/dynamic_tools/deployment/schemas/drum_prediction_fallback_schema.json +10 -0
  51. datarobot_genai/drmcp/core/dynamic_tools/register.py +254 -0
  52. datarobot_genai/drmcp/core/dynamic_tools/schema.py +532 -0
  53. datarobot_genai/drmcp/core/exceptions.py +25 -0
  54. datarobot_genai/drmcp/core/logging.py +98 -0
  55. datarobot_genai/drmcp/core/mcp_instance.py +542 -0
  56. datarobot_genai/drmcp/core/mcp_server_tools.py +129 -0
  57. datarobot_genai/drmcp/core/memory_management/__init__.py +13 -0
  58. datarobot_genai/drmcp/core/memory_management/manager.py +820 -0
  59. datarobot_genai/drmcp/core/memory_management/memory_tools.py +201 -0
  60. datarobot_genai/drmcp/core/routes.py +436 -0
  61. datarobot_genai/drmcp/core/routes_utils.py +30 -0
  62. datarobot_genai/drmcp/core/server_life_cycle.py +107 -0
  63. datarobot_genai/drmcp/core/telemetry.py +424 -0
  64. datarobot_genai/drmcp/core/tool_filter.py +108 -0
  65. datarobot_genai/drmcp/core/utils.py +131 -0
  66. datarobot_genai/drmcp/server.py +19 -0
  67. datarobot_genai/drmcp/test_utils/__init__.py +13 -0
  68. datarobot_genai/drmcp/test_utils/integration_mcp_server.py +102 -0
  69. datarobot_genai/drmcp/test_utils/mcp_utils_ete.py +96 -0
  70. datarobot_genai/drmcp/test_utils/mcp_utils_integration.py +94 -0
  71. datarobot_genai/drmcp/test_utils/openai_llm_mcp_client.py +234 -0
  72. datarobot_genai/drmcp/test_utils/tool_base_ete.py +151 -0
  73. datarobot_genai/drmcp/test_utils/utils.py +91 -0
  74. datarobot_genai/drmcp/tools/__init__.py +14 -0
  75. datarobot_genai/drmcp/tools/predictive/__init__.py +27 -0
  76. datarobot_genai/drmcp/tools/predictive/data.py +97 -0
  77. datarobot_genai/drmcp/tools/predictive/deployment.py +91 -0
  78. datarobot_genai/drmcp/tools/predictive/deployment_info.py +392 -0
  79. datarobot_genai/drmcp/tools/predictive/model.py +148 -0
  80. datarobot_genai/drmcp/tools/predictive/predict.py +254 -0
  81. datarobot_genai/drmcp/tools/predictive/predict_realtime.py +307 -0
  82. datarobot_genai/drmcp/tools/predictive/project.py +72 -0
  83. datarobot_genai/drmcp/tools/predictive/training.py +651 -0
  84. datarobot_genai/langgraph/__init__.py +0 -0
  85. datarobot_genai/langgraph/agent.py +341 -0
  86. datarobot_genai/langgraph/mcp.py +73 -0
  87. datarobot_genai/llama_index/__init__.py +16 -0
  88. datarobot_genai/llama_index/agent.py +50 -0
  89. datarobot_genai/llama_index/base.py +299 -0
  90. datarobot_genai/llama_index/mcp.py +79 -0
  91. datarobot_genai/nat/__init__.py +0 -0
  92. datarobot_genai/nat/agent.py +258 -0
  93. datarobot_genai/nat/datarobot_llm_clients.py +249 -0
  94. datarobot_genai/nat/datarobot_llm_providers.py +130 -0
  95. datarobot_genai/py.typed +0 -0
  96. datarobot_genai-0.2.0.dist-info/METADATA +139 -0
  97. datarobot_genai-0.2.0.dist-info/RECORD +101 -0
  98. datarobot_genai-0.2.0.dist-info/WHEEL +4 -0
  99. datarobot_genai-0.2.0.dist-info/entry_points.txt +3 -0
  100. datarobot_genai-0.2.0.dist-info/licenses/AUTHORS +2 -0
  101. datarobot_genai-0.2.0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,254 @@
1
+ # Copyright 2025 DataRobot, Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import json
16
+ import logging
17
+ import uuid
18
+ from typing import Any
19
+
20
+ import datarobot as dr
21
+ from fastmcp.resources import HttpResource
22
+ from fastmcp.resources import ResourceManager
23
+
24
+ from datarobot_genai.drmcp.core.clients import get_credentials
25
+ from datarobot_genai.drmcp.core.clients import get_s3_bucket_info
26
+ from datarobot_genai.drmcp.core.clients import get_sdk_client
27
+ from datarobot_genai.drmcp.core.mcp_instance import dr_mcp_tool
28
+ from datarobot_genai.drmcp.core.utils import generate_presigned_url
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+
33
+ def _handle_prediction_resource(
34
+ job: Any, bucket: str, key: str, deployment_id: str, input_desc: str
35
+ ) -> str:
36
+ s3_url = generate_presigned_url(bucket, key)
37
+ resource_manager = ResourceManager()
38
+ resource = HttpResource(
39
+ uri=s3_url, # type: ignore[arg-type]
40
+ url=s3_url,
41
+ name=f"Predictions for {deployment_id}",
42
+ mime_type="text/csv",
43
+ )
44
+ resource_manager.add_resource(resource)
45
+ return (
46
+ f"Finished Batch Prediction job ID {job.id} for deployment ID {deployment_id}. "
47
+ f"{input_desc} Results uploaded to {s3_url}. "
48
+ f"Job status: {job.status} and you can find the job on the DataRobot UI at "
49
+ f"/deployments/batch-jobs. "
50
+ )
51
+
52
+
53
+ def get_or_create_s3_credential() -> Any:
54
+ existing_creds = dr.Credential.list()
55
+ for cred in existing_creds:
56
+ if cred.name == "dr_mcp_server_temp_storage_s3_cred":
57
+ return cred
58
+
59
+ if get_credentials().has_aws_credentials():
60
+ aws_access_key_id, aws_secret_access_key, aws_session_token = (
61
+ get_credentials().get_aws_credentials()
62
+ )
63
+ cred = dr.Credential.create_s3(
64
+ name="dr_mcp_server_temp_storage_s3_cred",
65
+ aws_access_key_id=aws_access_key_id,
66
+ aws_secret_access_key=aws_secret_access_key,
67
+ aws_session_token=aws_session_token,
68
+ )
69
+ return cred
70
+
71
+ raise Exception("No AWS credentials found in your MCP deployment.")
72
+
73
+
74
+ def make_output_settings(cred: Any) -> tuple[dict[str, Any], str, str]:
75
+ bucket_info = get_s3_bucket_info()
76
+ s3_bucket = bucket_info["bucket"]
77
+ s3_prefix = bucket_info["prefix"]
78
+ s3_key = f"{s3_prefix}{uuid.uuid4()}.csv"
79
+ s3_url = f"s3://{s3_bucket}/{s3_key}"
80
+
81
+ return (
82
+ {
83
+ "type": "s3",
84
+ "url": s3_url,
85
+ "credential_id": cred.credential_id,
86
+ },
87
+ s3_bucket,
88
+ s3_key,
89
+ )
90
+
91
+
92
+ def wait_for_preds_and_cache_results(
93
+ job: Any, bucket: str, key: str, deployment_id: str, input_desc: str, timeout: int
94
+ ) -> str:
95
+ job.wait_for_completion(timeout)
96
+ if job.status in ["ERROR", "FAILED", "ABORTED"]:
97
+ logger.error(f"Job failed with status {job.status}")
98
+ return f"Job failed with status {job.status}"
99
+ return _handle_prediction_resource(job, bucket, key, deployment_id, input_desc)
100
+
101
+
102
+ @dr_mcp_tool(tags={"prediction", "scoring", "batch"})
103
+ async def predict_by_file_path(
104
+ deployment_id: str,
105
+ file_path: str,
106
+ timeout: int = 600,
107
+ ) -> str:
108
+ """
109
+ Make predictions using a DataRobot deployment and a local CSV file using the DataRobot Python
110
+ SDK. Use this tool to score large amounts of data, for small amounts of data use the
111
+ predict_realtime tool.
112
+ Args:
113
+ deployment_id: The ID of the DataRobot deployment to use for prediction.
114
+ file_path: Path to a CSV file to use as input data.
115
+ timeout: Timeout in seconds for the batch prediction job (default 300).
116
+
117
+ Returns
118
+ -------
119
+ A string summary of the batch prediction job and download link if available.
120
+ """
121
+ output_settings, bucket, key = make_output_settings(get_or_create_s3_credential())
122
+ job = dr.BatchPredictionJob.score(
123
+ deployment=deployment_id,
124
+ intake_settings={ # type: ignore[arg-type]
125
+ "type": "localFile",
126
+ "file": file_path,
127
+ },
128
+ output_settings=output_settings, # type: ignore[arg-type]
129
+ )
130
+ return wait_for_preds_and_cache_results(
131
+ job, bucket, key, deployment_id, f"Scoring file {file_path}.", timeout
132
+ )
133
+
134
+
135
+ @dr_mcp_tool(tags={"prediction", "scoring", "batch"})
136
+ async def predict_by_ai_catalog(
137
+ deployment_id: str,
138
+ dataset_id: str,
139
+ timeout: int = 600,
140
+ ) -> str:
141
+ """
142
+ Make predictions using a DataRobot deployment and an AI Catalog dataset using the DataRobot
143
+ Python SDK.
144
+
145
+ Use this tool when asked to score data stored in AI Catalog by dataset id.
146
+ Args:
147
+ deployment_id: The ID of the DataRobot deployment to use for prediction.
148
+ dataset_id: ID of an AI Catalog item to use as input data.
149
+ timeout: Timeout in seconds for the batch prediction job (default 300).
150
+
151
+ Returns
152
+ -------
153
+ A string summary of the batch prediction job and download link if available.
154
+ """
155
+ output_settings, bucket, key = make_output_settings(get_or_create_s3_credential())
156
+ client = get_sdk_client()
157
+ dataset = client.Dataset.get(dataset_id)
158
+ job = dr.BatchPredictionJob.score(
159
+ deployment=deployment_id,
160
+ intake_settings={ # type: ignore[arg-type]
161
+ "type": "dataset",
162
+ "dataset": dataset,
163
+ },
164
+ output_settings=output_settings, # type: ignore[arg-type]
165
+ )
166
+ return wait_for_preds_and_cache_results(
167
+ job, bucket, key, deployment_id, f"Scoring dataset {dataset_id}.", timeout
168
+ )
169
+
170
+
171
+ @dr_mcp_tool(tags={"prediction", "scoring", "batch"})
172
+ async def predict_from_project_data(
173
+ deployment_id: str,
174
+ project_id: str,
175
+ dataset_id: str | None = None,
176
+ partition: str | None = None,
177
+ timeout: int = 600,
178
+ ) -> str:
179
+ """
180
+ Make predictions using a DataRobot deployment using the training data associated with the
181
+ project that created the deployment.
182
+ Use this tool to score holdout, validation, or allBacktest partitions of the training data.
183
+ Can request a specific partition of the data, or use an external dataset (with dataset_id)
184
+ stored in AI Catalog.
185
+ Args:
186
+ deployment_id: (Required)The ID of the DataRobot deployment to use for prediction.
187
+ project_id: (Required) The ID of the DataRobot project to use for prediction. Can be found
188
+ by using the get_model_info_from_deployment tool.
189
+ dataset_id: (Optional) The ID of the external dataset, ususally stored in AI Catalog, to
190
+ use for prediction.
191
+ partition: (Optional)The partition of the DataRobot dataset to use for prediction, could be
192
+ 'holdout', 'validation', or 'allBacktest'.
193
+ timeout: (Optional) Timeout in seconds for the batch prediction job (default 600).
194
+
195
+ Returns
196
+ -------
197
+ A string summary of the batch prediction job and download link if available.
198
+ """
199
+ output_settings, bucket, key = make_output_settings(get_or_create_s3_credential())
200
+ intake_settings: dict[str, Any] = {
201
+ "type": "dss",
202
+ "project_id": project_id,
203
+ }
204
+ if partition:
205
+ intake_settings["partition"] = partition
206
+ if dataset_id:
207
+ intake_settings["dataset_id"] = dataset_id
208
+ job = dr.BatchPredictionJob.score(
209
+ deployment=deployment_id,
210
+ intake_settings=intake_settings, # type: ignore[arg-type]
211
+ output_settings=output_settings, # type: ignore[arg-type]
212
+ )
213
+ return wait_for_preds_and_cache_results(
214
+ job, bucket, key, deployment_id, f"Scoring project {project_id}.", timeout
215
+ )
216
+
217
+
218
+ # FIXME
219
+ # @dr_mcp_tool(tags={"prediction", "explanations", "shap"})
220
+ async def get_prediction_explanations(
221
+ project_id: str,
222
+ model_id: str,
223
+ dataset_id: str,
224
+ max_explanations: int = 100,
225
+ ) -> str:
226
+ """
227
+ Calculate prediction explanations (SHAP values) for a given model and dataset.
228
+
229
+ Args:
230
+ project_id: The ID of the DataRobot project.
231
+ model_id: The ID of the model to use for explanations.
232
+ dataset_id: The ID of the dataset to explain predictions for.
233
+ max_explanations: Maximum number of explanations per row (default 100).
234
+
235
+ Returns
236
+ -------
237
+ JSON string containing the prediction explanations for each row.
238
+ """
239
+ client = get_sdk_client()
240
+ project = client.Project.get(project_id)
241
+ model = client.Model.get(project=project, model_id=model_id)
242
+ try:
243
+ explanations = model.get_or_request_prediction_explanations(
244
+ dataset_id=dataset_id, max_explanations=max_explanations
245
+ )
246
+ return json.dumps(
247
+ {"explanations": explanations, "ui_panel": ["prediction-distribution"]},
248
+ indent=2,
249
+ )
250
+ except Exception as e:
251
+ logger.error(f"Error in get_prediction_explanations: {type(e).__name__}: {e}")
252
+ return json.dumps(
253
+ {"error": f"Error in get_prediction_explanations: {type(e).__name__}: {e}"}
254
+ )
@@ -0,0 +1,307 @@
1
+ # Copyright 2025 DataRobot, Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import io
16
+ import json
17
+ import logging
18
+ import uuid
19
+ from datetime import datetime
20
+
21
+ import pandas as pd
22
+ from datarobot_predict import TimeSeriesType
23
+ from datarobot_predict.deployment import predict as dr_predict
24
+ from pydantic import BaseModel
25
+
26
+ from datarobot_genai.drmcp.core.clients import get_s3_bucket_info
27
+ from datarobot_genai.drmcp.core.clients import get_sdk_client
28
+ from datarobot_genai.drmcp.core.mcp_instance import dr_mcp_tool
29
+ from datarobot_genai.drmcp.core.utils import PredictionResponse
30
+ from datarobot_genai.drmcp.core.utils import predictions_result_response
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ class BucketInfo(BaseModel):
36
+ bucket: str
37
+ key: str
38
+
39
+
40
+ def make_output_settings() -> BucketInfo:
41
+ bucket_info = get_s3_bucket_info()
42
+ s3_key = f"{bucket_info['prefix']}{uuid.uuid4()}.csv"
43
+ return BucketInfo(bucket=bucket_info["bucket"], key=s3_key)
44
+
45
+
46
+ @dr_mcp_tool(tags={"prediction", "realtime", "scoring"})
47
+ async def predict_by_ai_catalog_rt(
48
+ deployment_id: str,
49
+ dataset_id: str,
50
+ timeout: int = 600,
51
+ ) -> PredictionResponse:
52
+ """
53
+ Make real-time predictions using a DataRobot deployment and an AI Catalog dataset using the
54
+ datarobot-predict library.
55
+ Use this for fast results when your data is not huge (not gigabytes). Results larger than 1MB
56
+ will be returned as a resource id and S3 URL; smaller results will be returned inline as a CSV
57
+ string.
58
+
59
+ Args:
60
+ deployment_id: The ID of the DataRobot deployment to use for prediction.
61
+ dataset_id: ID of an AI Catalog item to use as input data.
62
+ timeout: Timeout in seconds for the prediction job (default 600).
63
+
64
+ Returns
65
+ -------
66
+ dict: {"type": "inline", "data": csv_str} for small results (<1MB), or {"type": "resource",
67
+ "resource_id": ..., "s3_url": ...} for large results (>=1MB).
68
+ """
69
+ client = get_sdk_client()
70
+ dataset = client.Dataset.get(dataset_id)
71
+
72
+ # 1. Preferred: built-in DataFrame helper (newer SDKs)
73
+ if hasattr(dataset, "get_as_dataframe"):
74
+ df = dataset.get_as_dataframe()
75
+
76
+ # 2. Next: if there is a method returning a local file path
77
+ elif hasattr(dataset, "download"):
78
+ path = dataset.download("dataset.csv")
79
+ df = pd.read_csv(path)
80
+
81
+ # 3. Next: if there is a method returning a local file path
82
+ elif hasattr(dataset, "get_file"):
83
+ path = dataset.get_file()
84
+ df = pd.read_csv(path)
85
+
86
+ # 4. Bytes fallback
87
+ elif hasattr(dataset, "get_bytes"):
88
+ raw = dataset.get_bytes()
89
+ df = pd.read_csv(io.BytesIO(raw))
90
+
91
+ # 5. Last resort: expose URL then fetch manually
92
+ else:
93
+ url = dataset.url
94
+ df = pd.read_csv(url)
95
+
96
+ deployment = client.Deployment.get(deployment_id=deployment_id)
97
+ result = dr_predict(deployment, df, timeout=timeout)
98
+ predictions = result.dataframe
99
+ bucket_info = make_output_settings()
100
+ return predictions_result_response(
101
+ predictions,
102
+ bucket_info.bucket,
103
+ bucket_info.key,
104
+ f"pred_{deployment_id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
105
+ True,
106
+ )
107
+
108
+
109
+ @dr_mcp_tool(tags={"prediction", "realtime", "scoring"})
110
+ async def predict_realtime(
111
+ deployment_id: str,
112
+ file_path: str | None = None,
113
+ dataset: str | None = None,
114
+ forecast_point: str | None = None,
115
+ forecast_range_start: str | None = None,
116
+ forecast_range_end: str | None = None,
117
+ series_id_column: str | None = None,
118
+ max_explanations: int | str = 0,
119
+ max_ngram_explanations: int | str | None = None,
120
+ threshold_high: float | None = None,
121
+ threshold_low: float | None = None,
122
+ passthrough_columns: str | None = None,
123
+ explanation_algorithm: str | None = None,
124
+ prediction_endpoint: str | None = None,
125
+ timeout: int = 600,
126
+ ) -> PredictionResponse:
127
+ """
128
+ Make real-time predictions using a DataRobot deployment and a local CSV file or a dataset
129
+ string.
130
+
131
+ This is the unified prediction function that supports:
132
+ - Regular classification/regression predictions
133
+ - Time series forecasting with advanced parameters
134
+ - Prediction explanations (SHAP/XEMP)
135
+ - Text explanations for NLP models
136
+ - Custom thresholds and passthrough columns
137
+
138
+ For regular predictions: Just provide deployment_id and file_path or dataset
139
+ For time series: Add forecast_point OR forecast_range_start/end
140
+ For explanations: Set max_explanations > 0 and optionally explanation_algorithm
141
+ For text models: Use max_ngram_explanations for text feature explanations
142
+
143
+ When using this tool, always consider feature importance. For features with high importance,
144
+ try to infer or ask for a reasonable value, using frequent values or domain knowledge if
145
+ available.
146
+ For less important features, you may leave them blank.
147
+
148
+ Args:
149
+ deployment_id: The ID of the DataRobot deployment to use for prediction.
150
+ file_path: Path to a CSV file to use as input data. For time series with forecast_point,
151
+ must have at least 4 historical values within the feature derivation window.
152
+ dataset: (Optional) CSV or JSON string representing the input data. If provided, this
153
+ takes precedence over file_path.
154
+ forecast_point: (Time Series) Date to start forecasting from (e.g., "2024-06-01").
155
+ If provided, triggers time series FORECAST mode. Uses most recent date if
156
+ None.
157
+ forecast_range_start: (Time Series) Start date for historical predictions (e.g.,
158
+ "2024-06-01").
159
+ Must be used with forecast_range_end for HISTORICAL mode.
160
+ forecast_range_end: (Time Series) End date for historical predictions (e.g., "2024-06-07").
161
+ Must be used with forecast_range_start for HISTORICAL mode.
162
+ series_id_column: (Multiseries Time Series) Column name identifying different series
163
+ (e.g., "store_id", "region"). Must exist in the input data.
164
+ max_explanations: Number of prediction explanations to return per prediction.
165
+ - 0: No explanations (default)
166
+ - Positive integer: Specific number of explanations
167
+ - "all": All available explanations (SHAP only)
168
+ Note: For SHAP, 0 means all explanations; for XEMP, 0 means none.
169
+ max_ngram_explanations: (Text Models) Maximum number of text explanations per prediction.
170
+ Recommended: "all" for text models. None disables text explanations.
171
+ threshold_high: Only compute explanations for predictions above this threshold (0.0-1.0).
172
+ Useful for focusing explanations on high-confidence predictions.
173
+ threshold_low: Only compute explanations for predictions below this threshold (0.0-1.0).
174
+ Useful for focusing explanations on low-confidence predictions.
175
+ passthrough_columns: Input columns to include in output alongside predictions.
176
+ - "all": Include all input columns
177
+ - "column1,column2": Comma-separated list of specific columns
178
+ - None: No passthrough columns (default)
179
+ explanation_algorithm: Algorithm for computing explanations.
180
+ - "shap": SHAP explanations (default for most models)
181
+ - "xemp": XEMP explanations (faster, less accurate)
182
+ - None: Use deployment default
183
+ prediction_endpoint: Override the prediction server endpoint URL.
184
+ Useful for custom prediction servers or Portable Prediction Server.
185
+ timeout: Request timeout in seconds (default 600).
186
+
187
+ Returns
188
+ -------
189
+ dict: Prediction response with the following structure:
190
+ - {"type": "inline", "data": "csv_string"} for results < 1MB
191
+ - {"type": "resource", "resource_id": "...", "s3_url": "..."} for results >= 1MB
192
+
193
+ The CSV data contains:
194
+ - Prediction columns (e.g., class probabilities, regression values)
195
+ - Explanation columns (if max_explanations > 0)
196
+ - Passthrough columns (if specified)
197
+ - Time series metadata (for forecasting: FORECAST_POINT, FORECAST_DISTANCE, etc.)
198
+
199
+ Examples
200
+ --------
201
+ # Regular binary classification
202
+ predict_realtime(deployment_id="abc123", file_path="data.csv")
203
+
204
+ # With SHAP explanations
205
+ predict_realtime(deployment_id="abc123", file_path="data.csv",
206
+ max_explanations=10, explanation_algorithm="shap")
207
+
208
+ # Time series forecasting
209
+ predict_realtime(deployment_id="abc123", file_path="ts_data.csv",
210
+ forecast_point="2024-06-01")
211
+
212
+ # Multiseries time series
213
+ predict_realtime(deployment_id="abc123", file_path="multiseries.csv",
214
+ forecast_point="2024-06-01", series_id_column="store_id")
215
+
216
+ # Historical time series predictions
217
+ predict_realtime(deployment_id="abc123", file_path="ts_data.csv",
218
+ forecast_range_start="2024-06-01",
219
+ forecast_range_end="2024-06-07")
220
+
221
+ # Text model with explanations and passthrough
222
+ predict_realtime(deployment_id="abc123", file_path="text_data.csv",
223
+ max_explanations="all", max_ngram_explanations="all",
224
+ passthrough_columns="document_id,customer_id")
225
+ """
226
+ # Load input data from dataset string or file_path
227
+ if dataset is not None:
228
+ # Try CSV first
229
+ try:
230
+ df = pd.read_csv(io.StringIO(dataset))
231
+ except Exception:
232
+ # Try JSON
233
+ try:
234
+ data = json.loads(dataset)
235
+ df = pd.DataFrame(data)
236
+ except Exception as e:
237
+ raise ValueError(f"Could not parse dataset string as CSV or JSON: {e}")
238
+ elif file_path is not None:
239
+ df = pd.read_csv(file_path)
240
+ else:
241
+ raise ValueError("Either file_path or dataset must be provided.")
242
+
243
+ if series_id_column and series_id_column not in df.columns:
244
+ raise ValueError(f"series_id_column '{series_id_column}' not found in input data.")
245
+
246
+ client = get_sdk_client()
247
+ deployment = client.Deployment.get(deployment_id=deployment_id)
248
+
249
+ # Check if this is a time series prediction or regular prediction
250
+ is_time_series = bool(forecast_point or (forecast_range_start and forecast_range_end))
251
+
252
+ # Start with base prediction parameters
253
+ predict_kwargs = {
254
+ "deployment": deployment,
255
+ "data_frame": df,
256
+ "timeout": timeout,
257
+ }
258
+
259
+ # Add time series parameters if applicable
260
+ if is_time_series:
261
+ if forecast_point:
262
+ forecast_point_dt = pd.to_datetime(forecast_point)
263
+ predict_kwargs["time_series_type"] = TimeSeriesType.FORECAST
264
+ predict_kwargs["forecast_point"] = forecast_point_dt
265
+ elif forecast_range_start and forecast_range_end:
266
+ predictions_start_date_dt = pd.to_datetime(forecast_range_start)
267
+ predictions_end_date_dt = pd.to_datetime(forecast_range_end)
268
+ predict_kwargs["time_series_type"] = TimeSeriesType.HISTORICAL
269
+ predict_kwargs["predictions_start_date"] = predictions_start_date_dt
270
+ predict_kwargs["predictions_end_date"] = predictions_end_date_dt
271
+
272
+ # Add explanation parameters
273
+ if max_explanations != 0:
274
+ predict_kwargs["max_explanations"] = max_explanations
275
+ if max_ngram_explanations is not None:
276
+ predict_kwargs["max_ngram_explanations"] = max_ngram_explanations
277
+ if threshold_high is not None:
278
+ predict_kwargs["threshold_high"] = threshold_high
279
+ if threshold_low is not None:
280
+ predict_kwargs["threshold_low"] = threshold_low
281
+ if explanation_algorithm is not None:
282
+ predict_kwargs["explanation_algorithm"] = explanation_algorithm
283
+
284
+ # Add passthrough columns
285
+ if passthrough_columns is not None:
286
+ if passthrough_columns == "all":
287
+ predict_kwargs["passthrough_columns"] = "all"
288
+ else:
289
+ # Convert comma-separated string to set
290
+ columns_set = {col.strip() for col in passthrough_columns.split(",")}
291
+ predict_kwargs["passthrough_columns"] = columns_set
292
+
293
+ # Add custom prediction endpoint
294
+ if prediction_endpoint is not None:
295
+ predict_kwargs["prediction_endpoint"] = prediction_endpoint
296
+
297
+ # Run prediction
298
+ result = dr_predict(**predict_kwargs)
299
+ predictions = result.dataframe
300
+ bucket_info = make_output_settings()
301
+ return predictions_result_response(
302
+ predictions,
303
+ bucket_info.bucket,
304
+ bucket_info.key,
305
+ f"pred_{deployment_id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
306
+ max_explanations not in {0, "0"},
307
+ )
@@ -0,0 +1,72 @@
1
+ # Copyright 2025 DataRobot, Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import json
16
+ import logging
17
+
18
+ from datarobot_genai.drmcp.core.clients import get_sdk_client
19
+ from datarobot_genai.drmcp.core.mcp_instance import dr_mcp_tool
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ @dr_mcp_tool(tags={"project", "management", "list"})
25
+ async def list_projects() -> str:
26
+ """
27
+ List all DataRobot projects for the authenticated user.
28
+
29
+ Returns
30
+ -------
31
+ A string summary of the user's DataRobot projects.
32
+ """
33
+ client = get_sdk_client()
34
+ projects = client.Project.list()
35
+ if not projects:
36
+ return "No projects found."
37
+ return "\n".join(f"{p.id}: {p.project_name}" for p in projects)
38
+
39
+
40
+ @dr_mcp_tool(tags={"project", "data", "info"})
41
+ async def get_project_dataset_by_name(project_id: str, dataset_name: str) -> str:
42
+ """
43
+ Get a dataset ID by name for a given project.
44
+
45
+ Args:
46
+ project_id: The ID of the DataRobot project.
47
+ dataset_name: The name of the dataset to find (e.g., 'training', 'holdout').
48
+
49
+ Returns
50
+ -------
51
+ The dataset ID and the dataset type (source or prediction) as a string, or an error message.
52
+ """
53
+ client = get_sdk_client()
54
+ project = client.Project.get(project_id)
55
+ all_datasets = []
56
+ source_dataset = project.get_dataset()
57
+ if source_dataset:
58
+ all_datasets.append({"type": "source", "dataset": source_dataset})
59
+ prediction_datasets = project.get_datasets()
60
+ if prediction_datasets:
61
+ all_datasets.extend([{"type": "prediction", "dataset": ds} for ds in prediction_datasets])
62
+ for ds in all_datasets:
63
+ if dataset_name.lower() in ds["dataset"].name.lower():
64
+ return json.dumps(
65
+ {
66
+ "dataset_id": ds["dataset"].id,
67
+ "dataset_type": ds["type"],
68
+ "ui_panel": ["dataset"],
69
+ },
70
+ indent=2,
71
+ )
72
+ return f"Dataset with name containing '{dataset_name}' not found in project {project_id}."