datarobot-genai 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- datarobot_genai/__init__.py +19 -0
- datarobot_genai/core/__init__.py +0 -0
- datarobot_genai/core/agents/__init__.py +43 -0
- datarobot_genai/core/agents/base.py +195 -0
- datarobot_genai/core/chat/__init__.py +19 -0
- datarobot_genai/core/chat/auth.py +146 -0
- datarobot_genai/core/chat/client.py +178 -0
- datarobot_genai/core/chat/responses.py +297 -0
- datarobot_genai/core/cli/__init__.py +18 -0
- datarobot_genai/core/cli/agent_environment.py +47 -0
- datarobot_genai/core/cli/agent_kernel.py +211 -0
- datarobot_genai/core/custom_model.py +141 -0
- datarobot_genai/core/mcp/__init__.py +0 -0
- datarobot_genai/core/mcp/common.py +218 -0
- datarobot_genai/core/telemetry_agent.py +126 -0
- datarobot_genai/core/utils/__init__.py +3 -0
- datarobot_genai/core/utils/auth.py +234 -0
- datarobot_genai/core/utils/urls.py +64 -0
- datarobot_genai/crewai/__init__.py +24 -0
- datarobot_genai/crewai/agent.py +42 -0
- datarobot_genai/crewai/base.py +159 -0
- datarobot_genai/crewai/events.py +117 -0
- datarobot_genai/crewai/mcp.py +59 -0
- datarobot_genai/drmcp/__init__.py +78 -0
- datarobot_genai/drmcp/core/__init__.py +13 -0
- datarobot_genai/drmcp/core/auth.py +165 -0
- datarobot_genai/drmcp/core/clients.py +180 -0
- datarobot_genai/drmcp/core/config.py +250 -0
- datarobot_genai/drmcp/core/config_utils.py +174 -0
- datarobot_genai/drmcp/core/constants.py +18 -0
- datarobot_genai/drmcp/core/credentials.py +190 -0
- datarobot_genai/drmcp/core/dr_mcp_server.py +316 -0
- datarobot_genai/drmcp/core/dr_mcp_server_logo.py +136 -0
- datarobot_genai/drmcp/core/dynamic_prompts/__init__.py +13 -0
- datarobot_genai/drmcp/core/dynamic_prompts/controllers.py +130 -0
- datarobot_genai/drmcp/core/dynamic_prompts/dr_lib.py +128 -0
- datarobot_genai/drmcp/core/dynamic_prompts/register.py +206 -0
- datarobot_genai/drmcp/core/dynamic_prompts/utils.py +33 -0
- datarobot_genai/drmcp/core/dynamic_tools/__init__.py +14 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/__init__.py +0 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/__init__.py +14 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/base.py +72 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/default.py +82 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/drum.py +238 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/config.py +228 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/controllers.py +63 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/metadata.py +162 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/register.py +87 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/schemas/drum_agentic_fallback_schema.json +36 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/schemas/drum_prediction_fallback_schema.json +10 -0
- datarobot_genai/drmcp/core/dynamic_tools/register.py +254 -0
- datarobot_genai/drmcp/core/dynamic_tools/schema.py +532 -0
- datarobot_genai/drmcp/core/exceptions.py +25 -0
- datarobot_genai/drmcp/core/logging.py +98 -0
- datarobot_genai/drmcp/core/mcp_instance.py +542 -0
- datarobot_genai/drmcp/core/mcp_server_tools.py +129 -0
- datarobot_genai/drmcp/core/memory_management/__init__.py +13 -0
- datarobot_genai/drmcp/core/memory_management/manager.py +820 -0
- datarobot_genai/drmcp/core/memory_management/memory_tools.py +201 -0
- datarobot_genai/drmcp/core/routes.py +436 -0
- datarobot_genai/drmcp/core/routes_utils.py +30 -0
- datarobot_genai/drmcp/core/server_life_cycle.py +107 -0
- datarobot_genai/drmcp/core/telemetry.py +424 -0
- datarobot_genai/drmcp/core/tool_filter.py +108 -0
- datarobot_genai/drmcp/core/utils.py +131 -0
- datarobot_genai/drmcp/server.py +19 -0
- datarobot_genai/drmcp/test_utils/__init__.py +13 -0
- datarobot_genai/drmcp/test_utils/integration_mcp_server.py +102 -0
- datarobot_genai/drmcp/test_utils/mcp_utils_ete.py +96 -0
- datarobot_genai/drmcp/test_utils/mcp_utils_integration.py +94 -0
- datarobot_genai/drmcp/test_utils/openai_llm_mcp_client.py +234 -0
- datarobot_genai/drmcp/test_utils/tool_base_ete.py +151 -0
- datarobot_genai/drmcp/test_utils/utils.py +91 -0
- datarobot_genai/drmcp/tools/__init__.py +14 -0
- datarobot_genai/drmcp/tools/predictive/__init__.py +27 -0
- datarobot_genai/drmcp/tools/predictive/data.py +97 -0
- datarobot_genai/drmcp/tools/predictive/deployment.py +91 -0
- datarobot_genai/drmcp/tools/predictive/deployment_info.py +392 -0
- datarobot_genai/drmcp/tools/predictive/model.py +148 -0
- datarobot_genai/drmcp/tools/predictive/predict.py +254 -0
- datarobot_genai/drmcp/tools/predictive/predict_realtime.py +307 -0
- datarobot_genai/drmcp/tools/predictive/project.py +72 -0
- datarobot_genai/drmcp/tools/predictive/training.py +651 -0
- datarobot_genai/langgraph/__init__.py +0 -0
- datarobot_genai/langgraph/agent.py +341 -0
- datarobot_genai/langgraph/mcp.py +73 -0
- datarobot_genai/llama_index/__init__.py +16 -0
- datarobot_genai/llama_index/agent.py +50 -0
- datarobot_genai/llama_index/base.py +299 -0
- datarobot_genai/llama_index/mcp.py +79 -0
- datarobot_genai/nat/__init__.py +0 -0
- datarobot_genai/nat/agent.py +258 -0
- datarobot_genai/nat/datarobot_llm_clients.py +249 -0
- datarobot_genai/nat/datarobot_llm_providers.py +130 -0
- datarobot_genai/py.typed +0 -0
- datarobot_genai-0.2.0.dist-info/METADATA +139 -0
- datarobot_genai-0.2.0.dist-info/RECORD +101 -0
- datarobot_genai-0.2.0.dist-info/WHEEL +4 -0
- datarobot_genai-0.2.0.dist-info/entry_points.txt +3 -0
- datarobot_genai-0.2.0.dist-info/licenses/AUTHORS +2 -0
- datarobot_genai-0.2.0.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,651 @@
|
|
|
1
|
+
# Copyright 2025 DataRobot, Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Tools for analyzing datasets and suggesting ML use cases."""
|
|
16
|
+
|
|
17
|
+
import json
|
|
18
|
+
import logging
|
|
19
|
+
from dataclasses import asdict
|
|
20
|
+
from dataclasses import dataclass
|
|
21
|
+
|
|
22
|
+
import pandas as pd
|
|
23
|
+
|
|
24
|
+
from datarobot_genai.drmcp.core.clients import get_sdk_client
|
|
25
|
+
from datarobot_genai.drmcp.core.mcp_instance import dr_mcp_tool
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass
|
|
31
|
+
class UseCaseSuggestion:
|
|
32
|
+
"""Represents a suggested use case based on dataset analysis."""
|
|
33
|
+
|
|
34
|
+
name: str
|
|
35
|
+
description: str
|
|
36
|
+
suggested_target: str
|
|
37
|
+
problem_type: str
|
|
38
|
+
confidence: float
|
|
39
|
+
reasoning: str
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@dataclass
|
|
43
|
+
class DatasetInsight:
|
|
44
|
+
"""Contains insights about a dataset for use case discovery."""
|
|
45
|
+
|
|
46
|
+
total_columns: int
|
|
47
|
+
total_rows: int
|
|
48
|
+
numerical_columns: list[str]
|
|
49
|
+
categorical_columns: list[str]
|
|
50
|
+
datetime_columns: list[str]
|
|
51
|
+
text_columns: list[str]
|
|
52
|
+
potential_targets: list[str]
|
|
53
|
+
missing_data_summary: dict[str, float]
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@dr_mcp_tool(tags={"training", "analysis", "dataset"})
|
|
57
|
+
async def analyze_dataset(dataset_id: str) -> str:
|
|
58
|
+
"""
|
|
59
|
+
Analyze a dataset to understand its structure and potential use cases.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
dataset_id: The ID of the DataRobot dataset to analyze
|
|
63
|
+
|
|
64
|
+
Returns
|
|
65
|
+
-------
|
|
66
|
+
JSON string containing dataset insights including:
|
|
67
|
+
- Basic statistics (rows, columns)
|
|
68
|
+
- Column types (numerical, categorical, datetime, text)
|
|
69
|
+
- Potential target columns
|
|
70
|
+
- Missing data summary
|
|
71
|
+
"""
|
|
72
|
+
client = get_sdk_client()
|
|
73
|
+
dataset = client.Dataset.get(dataset_id)
|
|
74
|
+
df = dataset.get_as_dataframe()
|
|
75
|
+
|
|
76
|
+
# Analyze dataset structure
|
|
77
|
+
numerical_cols = df.select_dtypes(include=["int64", "float64"]).columns.tolist()
|
|
78
|
+
categorical_cols = df.select_dtypes(include=["object", "category"]).columns.tolist()
|
|
79
|
+
datetime_cols = df.select_dtypes(include=["datetime64"]).columns.tolist()
|
|
80
|
+
|
|
81
|
+
# Identify potential text columns (categorical with high cardinality)
|
|
82
|
+
text_cols = []
|
|
83
|
+
for col in categorical_cols:
|
|
84
|
+
if df[col].str.len().mean() > 20: # Text detection
|
|
85
|
+
text_cols.append(col)
|
|
86
|
+
categorical_cols.remove(col) # Remove from categorical columns
|
|
87
|
+
|
|
88
|
+
# Calculate missing data
|
|
89
|
+
missing_data = {}
|
|
90
|
+
for col in df.columns:
|
|
91
|
+
missing_pct = (df[col].isnull().sum() / len(df)) * 100
|
|
92
|
+
if missing_pct > 0:
|
|
93
|
+
missing_data[col] = missing_pct
|
|
94
|
+
|
|
95
|
+
# Identify potential target columns
|
|
96
|
+
potential_targets = _identify_potential_targets(df, numerical_cols, categorical_cols)
|
|
97
|
+
|
|
98
|
+
insights = DatasetInsight(
|
|
99
|
+
total_columns=len(df.columns),
|
|
100
|
+
total_rows=len(df),
|
|
101
|
+
numerical_columns=numerical_cols,
|
|
102
|
+
categorical_columns=categorical_cols,
|
|
103
|
+
datetime_columns=datetime_cols,
|
|
104
|
+
text_columns=text_cols,
|
|
105
|
+
potential_targets=potential_targets,
|
|
106
|
+
missing_data_summary=missing_data,
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
return json.dumps(asdict(insights), indent=2)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
@dr_mcp_tool(tags={"training", "analysis", "usecase"})
|
|
113
|
+
async def suggest_use_cases(dataset_id: str) -> str:
|
|
114
|
+
"""
|
|
115
|
+
Analyze a dataset and suggest potential machine learning use cases.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
dataset_id: The ID of the DataRobot dataset to analyze
|
|
119
|
+
|
|
120
|
+
Returns
|
|
121
|
+
-------
|
|
122
|
+
JSON string containing suggested use cases with:
|
|
123
|
+
- Use case name and description
|
|
124
|
+
- Suggested target column
|
|
125
|
+
- Problem type
|
|
126
|
+
- Confidence score
|
|
127
|
+
- Reasoning for the suggestion
|
|
128
|
+
"""
|
|
129
|
+
client = get_sdk_client()
|
|
130
|
+
dataset = client.Dataset.get(dataset_id)
|
|
131
|
+
df = dataset.get_as_dataframe()
|
|
132
|
+
|
|
133
|
+
# Get dataset insights first
|
|
134
|
+
insights_json = await analyze_dataset(dataset_id)
|
|
135
|
+
insights = json.loads(insights_json)
|
|
136
|
+
|
|
137
|
+
suggestions = []
|
|
138
|
+
for target_col in insights["potential_targets"]:
|
|
139
|
+
target_suggestions = _analyze_target_for_use_cases(df, target_col)
|
|
140
|
+
suggestions.extend([asdict(s) for s in target_suggestions])
|
|
141
|
+
|
|
142
|
+
# Sort by confidence score
|
|
143
|
+
suggestions.sort(key=lambda x: x["confidence"], reverse=True)
|
|
144
|
+
return json.dumps(suggestions, indent=2)
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
@dr_mcp_tool(tags={"training", "analysis", "eda"})
|
|
148
|
+
async def get_exploratory_insights(dataset_id: str, target_col: str | None = None) -> str:
|
|
149
|
+
"""
|
|
150
|
+
Generate exploratory data insights for a dataset.
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
dataset_id: The ID of the DataRobot dataset to analyze
|
|
154
|
+
target_col: Optional target column to focus EDA insights on
|
|
155
|
+
|
|
156
|
+
Returns
|
|
157
|
+
-------
|
|
158
|
+
JSON string containing EDA insights including:
|
|
159
|
+
- Dataset summary statistics
|
|
160
|
+
- Target variable analysis (if specified)
|
|
161
|
+
- Feature correlations with target
|
|
162
|
+
- Missing data analysis
|
|
163
|
+
- Data type distribution
|
|
164
|
+
"""
|
|
165
|
+
client = get_sdk_client()
|
|
166
|
+
dataset = client.Dataset.get(dataset_id)
|
|
167
|
+
df = dataset.get_as_dataframe()
|
|
168
|
+
|
|
169
|
+
# Get dataset insights first
|
|
170
|
+
insights_json = await analyze_dataset(dataset_id)
|
|
171
|
+
insights = json.loads(insights_json)
|
|
172
|
+
|
|
173
|
+
eda_insights = {
|
|
174
|
+
"dataset_summary": {
|
|
175
|
+
"total_rows": int(insights["total_rows"]), # Convert to native Python int
|
|
176
|
+
"total_columns": int(insights["total_columns"]), # Convert to native Python int
|
|
177
|
+
"memory_usage": int(df.memory_usage().sum()), # Convert to native Python int
|
|
178
|
+
},
|
|
179
|
+
"target_analysis": {},
|
|
180
|
+
"feature_correlations": {},
|
|
181
|
+
"missing_data": insights["missing_data_summary"],
|
|
182
|
+
"data_types": {
|
|
183
|
+
"numerical": insights["numerical_columns"],
|
|
184
|
+
"categorical": insights["categorical_columns"],
|
|
185
|
+
"datetime": insights["datetime_columns"],
|
|
186
|
+
"text": insights["text_columns"],
|
|
187
|
+
},
|
|
188
|
+
}
|
|
189
|
+
|
|
190
|
+
# Target-specific analysis
|
|
191
|
+
if target_col and target_col in df.columns:
|
|
192
|
+
target_data = df[target_col]
|
|
193
|
+
target_analysis = {
|
|
194
|
+
"column_name": target_col,
|
|
195
|
+
"data_type": str(target_data.dtype),
|
|
196
|
+
"unique_values": int(target_data.nunique()), # Convert to native Python int
|
|
197
|
+
"missing_values": int(target_data.isnull().sum()), # Convert to native Python int
|
|
198
|
+
"missing_percentage": float(
|
|
199
|
+
target_data.isnull().sum() / len(df) * 100
|
|
200
|
+
), # Already float
|
|
201
|
+
}
|
|
202
|
+
|
|
203
|
+
if pd.api.types.is_numeric_dtype(target_data):
|
|
204
|
+
target_analysis.update(
|
|
205
|
+
{
|
|
206
|
+
"min_value": float(target_data.min()), # Convert to native Python float
|
|
207
|
+
"max_value": float(target_data.max()), # Convert to native Python float
|
|
208
|
+
"mean": float(target_data.mean()), # Convert to native Python float
|
|
209
|
+
"median": float(target_data.median()), # Convert to native Python float
|
|
210
|
+
"std_dev": float(target_data.std()), # Convert to native Python float
|
|
211
|
+
}
|
|
212
|
+
)
|
|
213
|
+
else:
|
|
214
|
+
value_counts = target_data.value_counts()
|
|
215
|
+
target_analysis.update(
|
|
216
|
+
{
|
|
217
|
+
"value_counts": {
|
|
218
|
+
str(k): int(v) for k, v in value_counts.items()
|
|
219
|
+
}, # Convert both key and value
|
|
220
|
+
"most_common": str(value_counts.index[0]) if len(value_counts) > 0 else None,
|
|
221
|
+
}
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
eda_insights["target_analysis"] = target_analysis
|
|
225
|
+
|
|
226
|
+
# Feature correlations with target (for numerical features)
|
|
227
|
+
if pd.api.types.is_numeric_dtype(target_data):
|
|
228
|
+
numerical_features = [col for col in insights["numerical_columns"] if col != target_col]
|
|
229
|
+
if numerical_features:
|
|
230
|
+
correlations = {}
|
|
231
|
+
for feature in numerical_features:
|
|
232
|
+
corr = df[feature].corr(target_data)
|
|
233
|
+
if not pd.isna(corr):
|
|
234
|
+
correlations[feature] = float(corr) # Convert to native Python float
|
|
235
|
+
|
|
236
|
+
# Sort by absolute correlation
|
|
237
|
+
eda_insights["feature_correlations"] = dict(
|
|
238
|
+
sorted(correlations.items(), key=lambda x: abs(x[1]), reverse=True)
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
eda_insights["ui_panel"] = ["eda"]
|
|
242
|
+
return json.dumps(eda_insights, indent=2)
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
def _identify_potential_targets(
|
|
246
|
+
df: pd.DataFrame, numerical_cols: list[str], categorical_cols: list[str]
|
|
247
|
+
) -> list[str]:
|
|
248
|
+
"""Identify columns that could potentially be targets."""
|
|
249
|
+
potential_targets = []
|
|
250
|
+
|
|
251
|
+
# Look for common target column names
|
|
252
|
+
target_keywords = [
|
|
253
|
+
"target",
|
|
254
|
+
"label",
|
|
255
|
+
"class",
|
|
256
|
+
"outcome",
|
|
257
|
+
"result",
|
|
258
|
+
"prediction",
|
|
259
|
+
"predict",
|
|
260
|
+
"sales",
|
|
261
|
+
"revenue",
|
|
262
|
+
"price",
|
|
263
|
+
"amount",
|
|
264
|
+
"value",
|
|
265
|
+
"score",
|
|
266
|
+
"rating",
|
|
267
|
+
"churn",
|
|
268
|
+
"conversion",
|
|
269
|
+
"fraud",
|
|
270
|
+
"default",
|
|
271
|
+
"failure",
|
|
272
|
+
"success",
|
|
273
|
+
"risk",
|
|
274
|
+
"probability",
|
|
275
|
+
"likelihood",
|
|
276
|
+
"status",
|
|
277
|
+
"category",
|
|
278
|
+
"type",
|
|
279
|
+
]
|
|
280
|
+
|
|
281
|
+
for col in df.columns:
|
|
282
|
+
col_lower = col.lower()
|
|
283
|
+
|
|
284
|
+
# Check if column name contains target keywords
|
|
285
|
+
if any(keyword in col_lower for keyword in target_keywords):
|
|
286
|
+
potential_targets.append(col)
|
|
287
|
+
continue
|
|
288
|
+
|
|
289
|
+
# For numerical columns, check if they might be targets
|
|
290
|
+
if col in numerical_cols:
|
|
291
|
+
# Skip ID-like columns
|
|
292
|
+
if "id" in col_lower or col_lower.endswith("_id"):
|
|
293
|
+
continue
|
|
294
|
+
|
|
295
|
+
# Check for bounded values (might be scores/ratings)
|
|
296
|
+
if df[col].min() >= 0 and df[col].max() <= 100:
|
|
297
|
+
potential_targets.append(col)
|
|
298
|
+
continue
|
|
299
|
+
|
|
300
|
+
# Check for binary-like numerical values
|
|
301
|
+
unique_vals = df[col].nunique()
|
|
302
|
+
if unique_vals == 2:
|
|
303
|
+
potential_targets.append(col)
|
|
304
|
+
continue
|
|
305
|
+
|
|
306
|
+
# For categorical columns, check cardinality
|
|
307
|
+
if col in categorical_cols:
|
|
308
|
+
unique_vals = df[col].nunique()
|
|
309
|
+
# Good targets have reasonable cardinality (2-20 classes typically)
|
|
310
|
+
if 2 <= unique_vals <= 20:
|
|
311
|
+
potential_targets.append(col)
|
|
312
|
+
|
|
313
|
+
return potential_targets
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
def _analyze_target_for_use_cases(df: pd.DataFrame, target_col: str) -> list[UseCaseSuggestion]:
|
|
317
|
+
"""Analyze a specific target column and suggest use cases."""
|
|
318
|
+
suggestions: list[UseCaseSuggestion] = []
|
|
319
|
+
|
|
320
|
+
target_data = df[target_col].dropna()
|
|
321
|
+
if len(target_data) == 0:
|
|
322
|
+
return suggestions
|
|
323
|
+
|
|
324
|
+
# Determine if it's numerical or categorical
|
|
325
|
+
if pd.api.types.is_numeric_dtype(target_data):
|
|
326
|
+
unique_count = target_data.nunique()
|
|
327
|
+
|
|
328
|
+
if unique_count == 2:
|
|
329
|
+
# Binary classification
|
|
330
|
+
values = sorted(target_data.unique())
|
|
331
|
+
suggestions.append(
|
|
332
|
+
UseCaseSuggestion(
|
|
333
|
+
name="Binary Classification",
|
|
334
|
+
description=f"Predict whether {target_col} will be {values[0]} or {values[1]}",
|
|
335
|
+
suggested_target=target_col,
|
|
336
|
+
problem_type="Binary Classification",
|
|
337
|
+
confidence=0.8,
|
|
338
|
+
reasoning=f"Column {target_col} has exactly 2 unique values, suggesting binary "
|
|
339
|
+
f"classification",
|
|
340
|
+
)
|
|
341
|
+
)
|
|
342
|
+
elif unique_count <= 10:
|
|
343
|
+
# Multiclass classification
|
|
344
|
+
suggestions.append(
|
|
345
|
+
UseCaseSuggestion(
|
|
346
|
+
name="Multiclass Classification",
|
|
347
|
+
description=f"Classify {target_col} into {unique_count} categories",
|
|
348
|
+
suggested_target=target_col,
|
|
349
|
+
problem_type="Multiclass Classification",
|
|
350
|
+
confidence=0.7,
|
|
351
|
+
reasoning=f"Column {target_col} has {unique_count} discrete values, suggesting "
|
|
352
|
+
f"classification",
|
|
353
|
+
)
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
# Always suggest regression for numeric columns with more than 2 unique values
|
|
357
|
+
if unique_count > 2:
|
|
358
|
+
suggestions.append(
|
|
359
|
+
UseCaseSuggestion(
|
|
360
|
+
name="Regression Modeling",
|
|
361
|
+
description=f"Predict the value of {target_col}",
|
|
362
|
+
suggested_target=target_col,
|
|
363
|
+
problem_type="Regression",
|
|
364
|
+
confidence=0.6
|
|
365
|
+
+ (
|
|
366
|
+
0.1 if unique_count > 10 else 0
|
|
367
|
+
), # higher confidence for columns with more unique values for regression
|
|
368
|
+
reasoning=(
|
|
369
|
+
f"Column {target_col} is numerical with {unique_count} unique values, "
|
|
370
|
+
f"suggesting regression"
|
|
371
|
+
),
|
|
372
|
+
)
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
else:
|
|
376
|
+
# Categorical target
|
|
377
|
+
unique_count = target_data.nunique()
|
|
378
|
+
|
|
379
|
+
if unique_count == 2:
|
|
380
|
+
suggestions.append(
|
|
381
|
+
UseCaseSuggestion(
|
|
382
|
+
name="Binary Classification",
|
|
383
|
+
description=f"Predict the category of {target_col}",
|
|
384
|
+
suggested_target=target_col,
|
|
385
|
+
problem_type="Binary Classification",
|
|
386
|
+
confidence=0.9,
|
|
387
|
+
reasoning=f"Column {target_col} is categorical with 2 classes",
|
|
388
|
+
)
|
|
389
|
+
)
|
|
390
|
+
elif unique_count <= 20:
|
|
391
|
+
suggestions.append(
|
|
392
|
+
UseCaseSuggestion(
|
|
393
|
+
name="Multiclass Classification",
|
|
394
|
+
description=f"Classify {target_col} into {unique_count} categories",
|
|
395
|
+
suggested_target=target_col,
|
|
396
|
+
problem_type="Multiclass Classification",
|
|
397
|
+
confidence=0.8,
|
|
398
|
+
reasoning=f"Column {target_col} is categorical with {unique_count} manageable "
|
|
399
|
+
f"classes",
|
|
400
|
+
)
|
|
401
|
+
)
|
|
402
|
+
|
|
403
|
+
# Add specific use case suggestions based on column names
|
|
404
|
+
col_lower = target_col.lower()
|
|
405
|
+
if "sales" in col_lower or "revenue" in col_lower:
|
|
406
|
+
suggestions.append(
|
|
407
|
+
UseCaseSuggestion(
|
|
408
|
+
name="Sales Forecasting",
|
|
409
|
+
description=f"Forecast future {target_col} values",
|
|
410
|
+
suggested_target=target_col,
|
|
411
|
+
problem_type="Regression",
|
|
412
|
+
confidence=0.9,
|
|
413
|
+
reasoning="Sales/revenue data is ideal for forecasting models",
|
|
414
|
+
)
|
|
415
|
+
)
|
|
416
|
+
elif "churn" in col_lower:
|
|
417
|
+
suggestions.append(
|
|
418
|
+
UseCaseSuggestion(
|
|
419
|
+
name="Customer Churn Prediction",
|
|
420
|
+
description="Predict which customers are likely to churn",
|
|
421
|
+
suggested_target=target_col,
|
|
422
|
+
problem_type="Binary Classification",
|
|
423
|
+
confidence=0.95,
|
|
424
|
+
reasoning="Churn prediction is a classic binary classification problem",
|
|
425
|
+
)
|
|
426
|
+
)
|
|
427
|
+
elif "fraud" in col_lower:
|
|
428
|
+
suggestions.append(
|
|
429
|
+
UseCaseSuggestion(
|
|
430
|
+
name="Fraud Detection",
|
|
431
|
+
description="Detect fraudulent transactions or activities",
|
|
432
|
+
suggested_target=target_col,
|
|
433
|
+
problem_type="Binary Classification",
|
|
434
|
+
confidence=0.95,
|
|
435
|
+
reasoning="Fraud detection is typically a binary classification problem",
|
|
436
|
+
)
|
|
437
|
+
)
|
|
438
|
+
elif "price" in col_lower or "cost" in col_lower:
|
|
439
|
+
suggestions.append(
|
|
440
|
+
UseCaseSuggestion(
|
|
441
|
+
name="Price Prediction",
|
|
442
|
+
description=f"Predict optimal {target_col}",
|
|
443
|
+
suggested_target=target_col,
|
|
444
|
+
problem_type="Regression",
|
|
445
|
+
confidence=0.85,
|
|
446
|
+
reasoning="Price prediction is a common regression use case",
|
|
447
|
+
)
|
|
448
|
+
)
|
|
449
|
+
|
|
450
|
+
return suggestions
|
|
451
|
+
|
|
452
|
+
|
|
453
|
+
@dr_mcp_tool(tags={"training", "autopilot", "model"})
|
|
454
|
+
async def start_autopilot(
|
|
455
|
+
target: str,
|
|
456
|
+
project_id: str | None = None,
|
|
457
|
+
mode: str | None = "quick",
|
|
458
|
+
dataset_url: str | None = None,
|
|
459
|
+
dataset_id: str | None = None,
|
|
460
|
+
project_name: str | None = "MCP Project",
|
|
461
|
+
use_case_id: str | None = None,
|
|
462
|
+
) -> str:
|
|
463
|
+
"""
|
|
464
|
+
Start automated model training (Autopilot) for a project.
|
|
465
|
+
|
|
466
|
+
Args:
|
|
467
|
+
target: Name of the target column for modeling.
|
|
468
|
+
project_id: Optional, the ID of the DataRobot project or a new project if no id is provided.
|
|
469
|
+
mode: Optional, Autopilot mode ('quick', 'comprehensive', or 'manual')
|
|
470
|
+
dataset_url: Optional, The URL to the dataset to upload (optional if dataset_id is provided)
|
|
471
|
+
for a new project.
|
|
472
|
+
dataset_id: Optional, The ID of an existing dataset in AI Catalog (optional if dataset_url
|
|
473
|
+
is provided) for a new project.
|
|
474
|
+
project_name: Optional, name for the project if no id is provided, creates a new project.
|
|
475
|
+
use_case_id: Optional, ID of the use case to associate this project with (required for
|
|
476
|
+
next-gen platform).
|
|
477
|
+
|
|
478
|
+
Returns
|
|
479
|
+
-------
|
|
480
|
+
JSON string containing:
|
|
481
|
+
- project_id: Project ID
|
|
482
|
+
- target: Target column name
|
|
483
|
+
- mode: Selected Autopilot mode
|
|
484
|
+
- status: Current project status
|
|
485
|
+
- ui_panel: List of recommended UI panels for visualization
|
|
486
|
+
"""
|
|
487
|
+
client = get_sdk_client()
|
|
488
|
+
|
|
489
|
+
if not project_id:
|
|
490
|
+
if not dataset_url and not dataset_id:
|
|
491
|
+
return "Error: Either dataset_url or dataset_id must be provided"
|
|
492
|
+
if dataset_url and dataset_id:
|
|
493
|
+
return "Error: Please provide either dataset_url or dataset_id, not both"
|
|
494
|
+
|
|
495
|
+
if dataset_url:
|
|
496
|
+
dataset = client.Dataset.create_from_url(dataset_url)
|
|
497
|
+
else:
|
|
498
|
+
dataset = client.Dataset.get(dataset_id)
|
|
499
|
+
|
|
500
|
+
project = client.Project.create_from_dataset(
|
|
501
|
+
dataset.id, project_name=project_name, use_case=use_case_id
|
|
502
|
+
)
|
|
503
|
+
else:
|
|
504
|
+
project = client.Project.get(project_id)
|
|
505
|
+
|
|
506
|
+
if not target:
|
|
507
|
+
return "Error: Target variable must be specified"
|
|
508
|
+
|
|
509
|
+
try:
|
|
510
|
+
# Start modeling
|
|
511
|
+
project.analyze_and_model(target=target, mode=mode)
|
|
512
|
+
|
|
513
|
+
result = {
|
|
514
|
+
"project_id": project.id,
|
|
515
|
+
"target": target,
|
|
516
|
+
"mode": mode,
|
|
517
|
+
"status": project.get_status(),
|
|
518
|
+
"ui_panel": ["eda", "model-training", "leaderboard"],
|
|
519
|
+
"use_case_id": project.use_case_id,
|
|
520
|
+
}
|
|
521
|
+
|
|
522
|
+
return json.dumps(result, indent=2)
|
|
523
|
+
except Exception as e:
|
|
524
|
+
return json.dumps(
|
|
525
|
+
{
|
|
526
|
+
"error": f"Failed to start Autopilot: {str(e)}",
|
|
527
|
+
"project_id": project.id,
|
|
528
|
+
"target": target,
|
|
529
|
+
"mode": mode,
|
|
530
|
+
},
|
|
531
|
+
indent=2,
|
|
532
|
+
)
|
|
533
|
+
|
|
534
|
+
|
|
535
|
+
@dr_mcp_tool(tags={"training", "model", "evaluation"})
|
|
536
|
+
async def get_model_roc_curve(project_id: str, model_id: str, source: str = "validation") -> str:
|
|
537
|
+
"""
|
|
538
|
+
Get detailed ROC curve for a specific model.
|
|
539
|
+
|
|
540
|
+
Args:
|
|
541
|
+
project_id: The ID of the DataRobot project
|
|
542
|
+
model_id: The ID of the model to analyze
|
|
543
|
+
source: The source of the data to use for the ROC curve ('validation' or 'holdout' or
|
|
544
|
+
'crossValidation')
|
|
545
|
+
|
|
546
|
+
Returns
|
|
547
|
+
-------
|
|
548
|
+
JSON string containing:
|
|
549
|
+
- roc_curve: ROC curve data
|
|
550
|
+
- ui_panel: List of recommended UI panels for visualization
|
|
551
|
+
"""
|
|
552
|
+
client = get_sdk_client()
|
|
553
|
+
project = client.Project.get(project_id)
|
|
554
|
+
model = client.Model.get(project=project, model_id=model_id)
|
|
555
|
+
|
|
556
|
+
try:
|
|
557
|
+
roc_curve = model.get_roc_curve(source=source)
|
|
558
|
+
roc_data = {
|
|
559
|
+
"roc_points": [
|
|
560
|
+
{
|
|
561
|
+
"accuracy": point.get("accuracy", 0),
|
|
562
|
+
"f1_score": point.get("f1_score", 0),
|
|
563
|
+
"false_negative_score": point.get("false_negative_score", 0),
|
|
564
|
+
"true_negative_score": point.get("true_negative_score", 0),
|
|
565
|
+
"true_negative_rate": point.get("true_negative_rate", 0),
|
|
566
|
+
"matthews_correlation_coefficient": point.get(
|
|
567
|
+
"matthews_correlation_coefficient", 0
|
|
568
|
+
),
|
|
569
|
+
"true_positive_score": point.get("true_positive_score", 0),
|
|
570
|
+
"positive_predictive_value": point.get("positive_predictive_value", 0),
|
|
571
|
+
"false_positive_score": point.get("false_positive_score", 0),
|
|
572
|
+
"false_positive_rate": point.get("false_positive_rate", 0),
|
|
573
|
+
"negative_predictive_value": point.get("negative_predictive_value", 0),
|
|
574
|
+
"true_positive_rate": point.get("true_positive_rate", 0),
|
|
575
|
+
"threshold": point.get("threshold", 0),
|
|
576
|
+
}
|
|
577
|
+
for point in roc_curve.roc_points
|
|
578
|
+
],
|
|
579
|
+
"negative_class_predictions": roc_curve.negative_class_predictions,
|
|
580
|
+
"positive_class_predictions": roc_curve.positive_class_predictions,
|
|
581
|
+
"source": source,
|
|
582
|
+
}
|
|
583
|
+
|
|
584
|
+
return json.dumps({"data": roc_data, "ui_panel": ["roc-curve"]}, indent=2)
|
|
585
|
+
except Exception as e:
|
|
586
|
+
return json.dumps({"error": f"Failed to get ROC curve: {str(e)}"}, indent=2)
|
|
587
|
+
|
|
588
|
+
|
|
589
|
+
@dr_mcp_tool(tags={"training", "model", "evaluation"})
|
|
590
|
+
async def get_model_feature_impact(project_id: str, model_id: str) -> str:
|
|
591
|
+
"""
|
|
592
|
+
Get detailed feature impact for a specific model.
|
|
593
|
+
|
|
594
|
+
Args:
|
|
595
|
+
project_id: The ID of the DataRobot project
|
|
596
|
+
model_id: The ID of the model to analyze
|
|
597
|
+
|
|
598
|
+
Returns
|
|
599
|
+
-------
|
|
600
|
+
JSON string containing:
|
|
601
|
+
- feature_impact: Feature importance scores
|
|
602
|
+
- ui_panel: List of recommended UI panels for visualization
|
|
603
|
+
"""
|
|
604
|
+
client = get_sdk_client()
|
|
605
|
+
project = client.Project.get(project_id)
|
|
606
|
+
model = client.Model.get(project=project, model_id=model_id)
|
|
607
|
+
# Get feature impact
|
|
608
|
+
model.request_feature_impact()
|
|
609
|
+
feature_impact = model.get_or_request_feature_impact()
|
|
610
|
+
|
|
611
|
+
return json.dumps({"data": feature_impact, "ui_panel": ["feature-impact"]}, indent=2)
|
|
612
|
+
|
|
613
|
+
|
|
614
|
+
@dr_mcp_tool(tags={"training", "model", "evaluation"})
|
|
615
|
+
async def get_model_lift_chart(project_id: str, model_id: str, source: str = "validation") -> str:
|
|
616
|
+
"""
|
|
617
|
+
Get detailed lift chart for a specific model.
|
|
618
|
+
|
|
619
|
+
Args:
|
|
620
|
+
project_id: The ID of the DataRobot project
|
|
621
|
+
model_id: The ID of the model to analyze
|
|
622
|
+
source: The source of the data to use for the lift chart ('validation' or 'holdout' or
|
|
623
|
+
'crossValidation')
|
|
624
|
+
|
|
625
|
+
Returns
|
|
626
|
+
-------
|
|
627
|
+
JSON string containing:
|
|
628
|
+
- lift_chart: Lift chart data
|
|
629
|
+
- ui_panel: List of recommended UI panels for visualization
|
|
630
|
+
"""
|
|
631
|
+
client = get_sdk_client()
|
|
632
|
+
project = client.Project.get(project_id)
|
|
633
|
+
model = client.Model.get(project=project, model_id=model_id)
|
|
634
|
+
|
|
635
|
+
# Get lift chart
|
|
636
|
+
lift_chart = model.get_lift_chart(source=source)
|
|
637
|
+
|
|
638
|
+
lift_chart_data = {
|
|
639
|
+
"bins": [
|
|
640
|
+
{
|
|
641
|
+
"actual": bin["actual"],
|
|
642
|
+
"predicted": bin["predicted"],
|
|
643
|
+
"bin_weight": bin["bin_weight"],
|
|
644
|
+
}
|
|
645
|
+
for bin in lift_chart.bins
|
|
646
|
+
],
|
|
647
|
+
"source_model_id": lift_chart.source_model_id,
|
|
648
|
+
"target_class": lift_chart.target_class,
|
|
649
|
+
}
|
|
650
|
+
|
|
651
|
+
return json.dumps({"data": lift_chart_data, "ui_panel": ["lift-chart"]}, indent=2)
|
|
File without changes
|