datarobot-genai 0.2.31__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- datarobot_genai/__init__.py +19 -0
- datarobot_genai/core/__init__.py +0 -0
- datarobot_genai/core/agents/__init__.py +43 -0
- datarobot_genai/core/agents/base.py +195 -0
- datarobot_genai/core/chat/__init__.py +19 -0
- datarobot_genai/core/chat/auth.py +146 -0
- datarobot_genai/core/chat/client.py +178 -0
- datarobot_genai/core/chat/responses.py +297 -0
- datarobot_genai/core/cli/__init__.py +18 -0
- datarobot_genai/core/cli/agent_environment.py +47 -0
- datarobot_genai/core/cli/agent_kernel.py +211 -0
- datarobot_genai/core/custom_model.py +141 -0
- datarobot_genai/core/mcp/__init__.py +0 -0
- datarobot_genai/core/mcp/common.py +218 -0
- datarobot_genai/core/telemetry_agent.py +126 -0
- datarobot_genai/core/utils/__init__.py +3 -0
- datarobot_genai/core/utils/auth.py +234 -0
- datarobot_genai/core/utils/urls.py +64 -0
- datarobot_genai/crewai/__init__.py +24 -0
- datarobot_genai/crewai/agent.py +42 -0
- datarobot_genai/crewai/base.py +159 -0
- datarobot_genai/crewai/events.py +117 -0
- datarobot_genai/crewai/mcp.py +59 -0
- datarobot_genai/drmcp/__init__.py +78 -0
- datarobot_genai/drmcp/core/__init__.py +13 -0
- datarobot_genai/drmcp/core/auth.py +165 -0
- datarobot_genai/drmcp/core/clients.py +180 -0
- datarobot_genai/drmcp/core/config.py +364 -0
- datarobot_genai/drmcp/core/config_utils.py +174 -0
- datarobot_genai/drmcp/core/constants.py +18 -0
- datarobot_genai/drmcp/core/credentials.py +190 -0
- datarobot_genai/drmcp/core/dr_mcp_server.py +350 -0
- datarobot_genai/drmcp/core/dr_mcp_server_logo.py +136 -0
- datarobot_genai/drmcp/core/dynamic_prompts/__init__.py +13 -0
- datarobot_genai/drmcp/core/dynamic_prompts/controllers.py +130 -0
- datarobot_genai/drmcp/core/dynamic_prompts/dr_lib.py +70 -0
- datarobot_genai/drmcp/core/dynamic_prompts/register.py +205 -0
- datarobot_genai/drmcp/core/dynamic_prompts/utils.py +33 -0
- datarobot_genai/drmcp/core/dynamic_tools/__init__.py +14 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/__init__.py +0 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/__init__.py +14 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/base.py +72 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/default.py +82 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/drum.py +238 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/config.py +228 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/controllers.py +63 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/metadata.py +162 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/register.py +87 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/schemas/drum_agentic_fallback_schema.json +36 -0
- datarobot_genai/drmcp/core/dynamic_tools/deployment/schemas/drum_prediction_fallback_schema.json +10 -0
- datarobot_genai/drmcp/core/dynamic_tools/register.py +254 -0
- datarobot_genai/drmcp/core/dynamic_tools/schema.py +532 -0
- datarobot_genai/drmcp/core/exceptions.py +25 -0
- datarobot_genai/drmcp/core/logging.py +98 -0
- datarobot_genai/drmcp/core/mcp_instance.py +515 -0
- datarobot_genai/drmcp/core/memory_management/__init__.py +13 -0
- datarobot_genai/drmcp/core/memory_management/manager.py +820 -0
- datarobot_genai/drmcp/core/memory_management/memory_tools.py +201 -0
- datarobot_genai/drmcp/core/routes.py +439 -0
- datarobot_genai/drmcp/core/routes_utils.py +30 -0
- datarobot_genai/drmcp/core/server_life_cycle.py +107 -0
- datarobot_genai/drmcp/core/telemetry.py +424 -0
- datarobot_genai/drmcp/core/tool_config.py +111 -0
- datarobot_genai/drmcp/core/tool_filter.py +117 -0
- datarobot_genai/drmcp/core/utils.py +138 -0
- datarobot_genai/drmcp/server.py +19 -0
- datarobot_genai/drmcp/test_utils/__init__.py +13 -0
- datarobot_genai/drmcp/test_utils/clients/__init__.py +0 -0
- datarobot_genai/drmcp/test_utils/clients/anthropic.py +68 -0
- datarobot_genai/drmcp/test_utils/clients/base.py +300 -0
- datarobot_genai/drmcp/test_utils/clients/dr_gateway.py +58 -0
- datarobot_genai/drmcp/test_utils/clients/openai.py +68 -0
- datarobot_genai/drmcp/test_utils/elicitation_test_tool.py +89 -0
- datarobot_genai/drmcp/test_utils/integration_mcp_server.py +109 -0
- datarobot_genai/drmcp/test_utils/mcp_utils_ete.py +133 -0
- datarobot_genai/drmcp/test_utils/mcp_utils_integration.py +107 -0
- datarobot_genai/drmcp/test_utils/test_interactive.py +205 -0
- datarobot_genai/drmcp/test_utils/tool_base_ete.py +220 -0
- datarobot_genai/drmcp/test_utils/utils.py +91 -0
- datarobot_genai/drmcp/tools/__init__.py +14 -0
- datarobot_genai/drmcp/tools/clients/__init__.py +14 -0
- datarobot_genai/drmcp/tools/clients/atlassian.py +188 -0
- datarobot_genai/drmcp/tools/clients/confluence.py +584 -0
- datarobot_genai/drmcp/tools/clients/gdrive.py +832 -0
- datarobot_genai/drmcp/tools/clients/jira.py +334 -0
- datarobot_genai/drmcp/tools/clients/microsoft_graph.py +479 -0
- datarobot_genai/drmcp/tools/clients/s3.py +28 -0
- datarobot_genai/drmcp/tools/confluence/__init__.py +14 -0
- datarobot_genai/drmcp/tools/confluence/tools.py +321 -0
- datarobot_genai/drmcp/tools/gdrive/__init__.py +0 -0
- datarobot_genai/drmcp/tools/gdrive/tools.py +347 -0
- datarobot_genai/drmcp/tools/jira/__init__.py +14 -0
- datarobot_genai/drmcp/tools/jira/tools.py +243 -0
- datarobot_genai/drmcp/tools/microsoft_graph/__init__.py +13 -0
- datarobot_genai/drmcp/tools/microsoft_graph/tools.py +198 -0
- datarobot_genai/drmcp/tools/predictive/__init__.py +27 -0
- datarobot_genai/drmcp/tools/predictive/data.py +133 -0
- datarobot_genai/drmcp/tools/predictive/deployment.py +91 -0
- datarobot_genai/drmcp/tools/predictive/deployment_info.py +392 -0
- datarobot_genai/drmcp/tools/predictive/model.py +148 -0
- datarobot_genai/drmcp/tools/predictive/predict.py +254 -0
- datarobot_genai/drmcp/tools/predictive/predict_realtime.py +307 -0
- datarobot_genai/drmcp/tools/predictive/project.py +90 -0
- datarobot_genai/drmcp/tools/predictive/training.py +661 -0
- datarobot_genai/langgraph/__init__.py +0 -0
- datarobot_genai/langgraph/agent.py +341 -0
- datarobot_genai/langgraph/mcp.py +73 -0
- datarobot_genai/llama_index/__init__.py +16 -0
- datarobot_genai/llama_index/agent.py +50 -0
- datarobot_genai/llama_index/base.py +299 -0
- datarobot_genai/llama_index/mcp.py +79 -0
- datarobot_genai/nat/__init__.py +0 -0
- datarobot_genai/nat/agent.py +275 -0
- datarobot_genai/nat/datarobot_auth_provider.py +110 -0
- datarobot_genai/nat/datarobot_llm_clients.py +318 -0
- datarobot_genai/nat/datarobot_llm_providers.py +130 -0
- datarobot_genai/nat/datarobot_mcp_client.py +266 -0
- datarobot_genai/nat/helpers.py +87 -0
- datarobot_genai/py.typed +0 -0
- datarobot_genai-0.2.31.dist-info/METADATA +145 -0
- datarobot_genai-0.2.31.dist-info/RECORD +125 -0
- datarobot_genai-0.2.31.dist-info/WHEEL +4 -0
- datarobot_genai-0.2.31.dist-info/entry_points.txt +5 -0
- datarobot_genai-0.2.31.dist-info/licenses/AUTHORS +2 -0
- datarobot_genai-0.2.31.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,661 @@
|
|
|
1
|
+
# Copyright 2025 DataRobot, Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Tools for analyzing datasets and suggesting ML use cases."""
|
|
16
|
+
|
|
17
|
+
import json
|
|
18
|
+
import logging
|
|
19
|
+
from dataclasses import asdict
|
|
20
|
+
from dataclasses import dataclass
|
|
21
|
+
from typing import Annotated
|
|
22
|
+
|
|
23
|
+
import pandas as pd
|
|
24
|
+
from fastmcp.exceptions import ToolError
|
|
25
|
+
from fastmcp.tools.tool import ToolResult
|
|
26
|
+
|
|
27
|
+
from datarobot_genai.drmcp.core.clients import get_sdk_client
|
|
28
|
+
from datarobot_genai.drmcp.core.mcp_instance import dr_mcp_tool
|
|
29
|
+
|
|
30
|
+
logger = logging.getLogger(__name__)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@dataclass
|
|
34
|
+
class UseCaseSuggestion:
|
|
35
|
+
"""Represents a suggested use case based on dataset analysis."""
|
|
36
|
+
|
|
37
|
+
name: str
|
|
38
|
+
description: str
|
|
39
|
+
suggested_target: str
|
|
40
|
+
problem_type: str
|
|
41
|
+
confidence: float
|
|
42
|
+
reasoning: str
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@dataclass
|
|
46
|
+
class DatasetInsight:
|
|
47
|
+
"""Contains insights about a dataset for use case discovery."""
|
|
48
|
+
|
|
49
|
+
total_columns: int
|
|
50
|
+
total_rows: int
|
|
51
|
+
numerical_columns: list[str]
|
|
52
|
+
categorical_columns: list[str]
|
|
53
|
+
datetime_columns: list[str]
|
|
54
|
+
text_columns: list[str]
|
|
55
|
+
potential_targets: list[str]
|
|
56
|
+
missing_data_summary: dict[str, float]
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@dr_mcp_tool(tags={"predictive", "training", "read", "analysis", "dataset"})
|
|
60
|
+
async def analyze_dataset(
|
|
61
|
+
*,
|
|
62
|
+
dataset_id: Annotated[str, "The ID of the DataRobot dataset to analyze"] | None = None,
|
|
63
|
+
) -> ToolError | ToolResult:
|
|
64
|
+
"""Analyze a dataset to understand its structure and potential use cases."""
|
|
65
|
+
if not dataset_id:
|
|
66
|
+
return ToolError("Dataset ID must be provided")
|
|
67
|
+
|
|
68
|
+
client = get_sdk_client()
|
|
69
|
+
dataset = client.Dataset.get(dataset_id)
|
|
70
|
+
df = dataset.get_as_dataframe()
|
|
71
|
+
|
|
72
|
+
# Analyze dataset structure
|
|
73
|
+
numerical_cols = df.select_dtypes(include=["int64", "float64"]).columns.tolist()
|
|
74
|
+
categorical_cols = df.select_dtypes(include=["object", "category"]).columns.tolist()
|
|
75
|
+
datetime_cols = df.select_dtypes(include=["datetime64"]).columns.tolist()
|
|
76
|
+
|
|
77
|
+
# Identify potential text columns (categorical with high cardinality)
|
|
78
|
+
text_cols = []
|
|
79
|
+
for col in categorical_cols:
|
|
80
|
+
if df[col].str.len().mean() > 20: # Text detection
|
|
81
|
+
text_cols.append(col)
|
|
82
|
+
categorical_cols.remove(col) # Remove from categorical columns
|
|
83
|
+
|
|
84
|
+
# Calculate missing data
|
|
85
|
+
missing_data = {}
|
|
86
|
+
for col in df.columns:
|
|
87
|
+
missing_pct = (df[col].isnull().sum() / len(df)) * 100
|
|
88
|
+
if missing_pct > 0:
|
|
89
|
+
missing_data[col] = missing_pct
|
|
90
|
+
|
|
91
|
+
# Identify potential target columns
|
|
92
|
+
potential_targets = _identify_potential_targets(df, numerical_cols, categorical_cols)
|
|
93
|
+
|
|
94
|
+
insights = DatasetInsight(
|
|
95
|
+
total_columns=len(df.columns),
|
|
96
|
+
total_rows=len(df),
|
|
97
|
+
numerical_columns=numerical_cols,
|
|
98
|
+
categorical_columns=categorical_cols,
|
|
99
|
+
datetime_columns=datetime_cols,
|
|
100
|
+
text_columns=text_cols,
|
|
101
|
+
potential_targets=potential_targets,
|
|
102
|
+
missing_data_summary=missing_data,
|
|
103
|
+
)
|
|
104
|
+
insights_dict = asdict(insights)
|
|
105
|
+
|
|
106
|
+
return ToolResult(
|
|
107
|
+
content=json.dumps(insights_dict, indent=2),
|
|
108
|
+
structured_content=insights_dict,
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
@dr_mcp_tool(tags={"predictive", "training", "read", "analysis", "usecase"})
|
|
113
|
+
async def suggest_use_cases(
|
|
114
|
+
*,
|
|
115
|
+
dataset_id: Annotated[str, "The ID of the DataRobot dataset to analyze"] | None = None,
|
|
116
|
+
) -> ToolError | ToolResult:
|
|
117
|
+
"""Analyze a dataset and suggest potential machine learning use cases."""
|
|
118
|
+
if not dataset_id:
|
|
119
|
+
return ToolError("Dataset ID must be provided")
|
|
120
|
+
|
|
121
|
+
client = get_sdk_client()
|
|
122
|
+
dataset = client.Dataset.get(dataset_id)
|
|
123
|
+
df = dataset.get_as_dataframe()
|
|
124
|
+
|
|
125
|
+
# Get dataset insights first
|
|
126
|
+
insights_json = await analyze_dataset(dataset_id)
|
|
127
|
+
insights = json.loads(insights_json)
|
|
128
|
+
|
|
129
|
+
suggestions = []
|
|
130
|
+
for target_col in insights["potential_targets"]:
|
|
131
|
+
target_suggestions = _analyze_target_for_use_cases(df, target_col)
|
|
132
|
+
suggestions.extend([asdict(s) for s in target_suggestions])
|
|
133
|
+
|
|
134
|
+
# Sort by confidence score
|
|
135
|
+
suggestions.sort(key=lambda x: x["confidence"], reverse=True)
|
|
136
|
+
|
|
137
|
+
return ToolResult(
|
|
138
|
+
content=json.dumps(suggestions, indent=2),
|
|
139
|
+
structured_content={"use_case_suggestions": suggestions},
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
@dr_mcp_tool(tags={"predictive", "training", "read", "analysis", "eda"})
|
|
144
|
+
async def get_exploratory_insights(
|
|
145
|
+
*,
|
|
146
|
+
dataset_id: Annotated[str, "The ID of the DataRobot dataset to analyze"] | None = None,
|
|
147
|
+
target_col: Annotated[str, "Optional target column to focus EDA insights on"] | None = None,
|
|
148
|
+
) -> ToolError | ToolResult:
|
|
149
|
+
"""Generate exploratory data insights for a dataset."""
|
|
150
|
+
if not dataset_id:
|
|
151
|
+
return ToolError("Dataset ID must be provided")
|
|
152
|
+
|
|
153
|
+
client = get_sdk_client()
|
|
154
|
+
dataset = client.Dataset.get(dataset_id)
|
|
155
|
+
df = dataset.get_as_dataframe()
|
|
156
|
+
|
|
157
|
+
# Get dataset insights first
|
|
158
|
+
insights_json = await analyze_dataset(dataset_id)
|
|
159
|
+
insights = json.loads(insights_json)
|
|
160
|
+
|
|
161
|
+
eda_insights = {
|
|
162
|
+
"dataset_summary": {
|
|
163
|
+
"total_rows": int(insights["total_rows"]), # Convert to native Python int
|
|
164
|
+
"total_columns": int(insights["total_columns"]), # Convert to native Python int
|
|
165
|
+
"memory_usage": int(df.memory_usage().sum()), # Convert to native Python int
|
|
166
|
+
},
|
|
167
|
+
"target_analysis": {},
|
|
168
|
+
"feature_correlations": {},
|
|
169
|
+
"missing_data": insights["missing_data_summary"],
|
|
170
|
+
"data_types": {
|
|
171
|
+
"numerical": insights["numerical_columns"],
|
|
172
|
+
"categorical": insights["categorical_columns"],
|
|
173
|
+
"datetime": insights["datetime_columns"],
|
|
174
|
+
"text": insights["text_columns"],
|
|
175
|
+
},
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
# Target-specific analysis
|
|
179
|
+
if target_col and target_col in df.columns:
|
|
180
|
+
target_data = df[target_col]
|
|
181
|
+
target_analysis = {
|
|
182
|
+
"column_name": target_col,
|
|
183
|
+
"data_type": str(target_data.dtype),
|
|
184
|
+
"unique_values": int(target_data.nunique()), # Convert to native Python int
|
|
185
|
+
"missing_values": int(target_data.isnull().sum()), # Convert to native Python int
|
|
186
|
+
"missing_percentage": float(
|
|
187
|
+
target_data.isnull().sum() / len(df) * 100
|
|
188
|
+
), # Already float
|
|
189
|
+
}
|
|
190
|
+
|
|
191
|
+
if pd.api.types.is_numeric_dtype(target_data):
|
|
192
|
+
target_analysis.update(
|
|
193
|
+
{
|
|
194
|
+
"min_value": float(target_data.min()), # Convert to native Python float
|
|
195
|
+
"max_value": float(target_data.max()), # Convert to native Python float
|
|
196
|
+
"mean": float(target_data.mean()), # Convert to native Python float
|
|
197
|
+
"median": float(target_data.median()), # Convert to native Python float
|
|
198
|
+
"std_dev": float(target_data.std()), # Convert to native Python float
|
|
199
|
+
}
|
|
200
|
+
)
|
|
201
|
+
else:
|
|
202
|
+
value_counts = target_data.value_counts()
|
|
203
|
+
target_analysis.update(
|
|
204
|
+
{
|
|
205
|
+
"value_counts": {
|
|
206
|
+
str(k): int(v) for k, v in value_counts.items()
|
|
207
|
+
}, # Convert both key and value
|
|
208
|
+
"most_common": str(value_counts.index[0]) if len(value_counts) > 0 else None,
|
|
209
|
+
}
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
eda_insights["target_analysis"] = target_analysis
|
|
213
|
+
|
|
214
|
+
# Feature correlations with target (for numerical features)
|
|
215
|
+
if pd.api.types.is_numeric_dtype(target_data):
|
|
216
|
+
numerical_features = [col for col in insights["numerical_columns"] if col != target_col]
|
|
217
|
+
if numerical_features:
|
|
218
|
+
correlations = {}
|
|
219
|
+
for feature in numerical_features:
|
|
220
|
+
corr = df[feature].corr(target_data)
|
|
221
|
+
if not pd.isna(corr):
|
|
222
|
+
correlations[feature] = float(corr) # Convert to native Python float
|
|
223
|
+
|
|
224
|
+
# Sort by absolute correlation
|
|
225
|
+
eda_insights["feature_correlations"] = dict(
|
|
226
|
+
sorted(correlations.items(), key=lambda x: abs(x[1]), reverse=True)
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
return ToolResult(
|
|
230
|
+
content=json.dumps(eda_insights, indent=2),
|
|
231
|
+
structured_content=eda_insights,
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
def _identify_potential_targets(
|
|
236
|
+
df: pd.DataFrame, numerical_cols: list[str], categorical_cols: list[str]
|
|
237
|
+
) -> list[str]:
|
|
238
|
+
"""Identify columns that could potentially be targets."""
|
|
239
|
+
potential_targets = []
|
|
240
|
+
|
|
241
|
+
# Look for common target column names
|
|
242
|
+
target_keywords = [
|
|
243
|
+
"target",
|
|
244
|
+
"label",
|
|
245
|
+
"class",
|
|
246
|
+
"outcome",
|
|
247
|
+
"result",
|
|
248
|
+
"prediction",
|
|
249
|
+
"predict",
|
|
250
|
+
"sales",
|
|
251
|
+
"revenue",
|
|
252
|
+
"price",
|
|
253
|
+
"amount",
|
|
254
|
+
"value",
|
|
255
|
+
"score",
|
|
256
|
+
"rating",
|
|
257
|
+
"churn",
|
|
258
|
+
"conversion",
|
|
259
|
+
"fraud",
|
|
260
|
+
"default",
|
|
261
|
+
"failure",
|
|
262
|
+
"success",
|
|
263
|
+
"risk",
|
|
264
|
+
"probability",
|
|
265
|
+
"likelihood",
|
|
266
|
+
"status",
|
|
267
|
+
"category",
|
|
268
|
+
"type",
|
|
269
|
+
]
|
|
270
|
+
|
|
271
|
+
for col in df.columns:
|
|
272
|
+
col_lower = col.lower()
|
|
273
|
+
|
|
274
|
+
# Check if column name contains target keywords
|
|
275
|
+
if any(keyword in col_lower for keyword in target_keywords):
|
|
276
|
+
potential_targets.append(col)
|
|
277
|
+
continue
|
|
278
|
+
|
|
279
|
+
# For numerical columns, check if they might be targets
|
|
280
|
+
if col in numerical_cols:
|
|
281
|
+
# Skip ID-like columns
|
|
282
|
+
if "id" in col_lower or col_lower.endswith("_id"):
|
|
283
|
+
continue
|
|
284
|
+
|
|
285
|
+
# Check for bounded values (might be scores/ratings)
|
|
286
|
+
if df[col].min() >= 0 and df[col].max() <= 100:
|
|
287
|
+
potential_targets.append(col)
|
|
288
|
+
continue
|
|
289
|
+
|
|
290
|
+
# Check for binary-like numerical values
|
|
291
|
+
unique_vals = df[col].nunique()
|
|
292
|
+
if unique_vals == 2:
|
|
293
|
+
potential_targets.append(col)
|
|
294
|
+
continue
|
|
295
|
+
|
|
296
|
+
# For categorical columns, check cardinality
|
|
297
|
+
if col in categorical_cols:
|
|
298
|
+
unique_vals = df[col].nunique()
|
|
299
|
+
# Good targets have reasonable cardinality (2-20 classes typically)
|
|
300
|
+
if 2 <= unique_vals <= 20:
|
|
301
|
+
potential_targets.append(col)
|
|
302
|
+
|
|
303
|
+
return potential_targets
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
def _analyze_target_for_use_cases(df: pd.DataFrame, target_col: str) -> list[UseCaseSuggestion]:
|
|
307
|
+
"""Analyze a specific target column and suggest use cases."""
|
|
308
|
+
suggestions: list[UseCaseSuggestion] = []
|
|
309
|
+
|
|
310
|
+
target_data = df[target_col].dropna()
|
|
311
|
+
if len(target_data) == 0:
|
|
312
|
+
return suggestions
|
|
313
|
+
|
|
314
|
+
# Determine if it's numerical or categorical
|
|
315
|
+
if pd.api.types.is_numeric_dtype(target_data):
|
|
316
|
+
unique_count = target_data.nunique()
|
|
317
|
+
|
|
318
|
+
if unique_count == 2:
|
|
319
|
+
# Binary classification
|
|
320
|
+
values = sorted(target_data.unique())
|
|
321
|
+
suggestions.append(
|
|
322
|
+
UseCaseSuggestion(
|
|
323
|
+
name="Binary Classification",
|
|
324
|
+
description=f"Predict whether {target_col} will be {values[0]} or {values[1]}",
|
|
325
|
+
suggested_target=target_col,
|
|
326
|
+
problem_type="Binary Classification",
|
|
327
|
+
confidence=0.8,
|
|
328
|
+
reasoning=f"Column {target_col} has exactly 2 unique values, suggesting binary "
|
|
329
|
+
f"classification",
|
|
330
|
+
)
|
|
331
|
+
)
|
|
332
|
+
elif unique_count <= 10:
|
|
333
|
+
# Multiclass classification
|
|
334
|
+
suggestions.append(
|
|
335
|
+
UseCaseSuggestion(
|
|
336
|
+
name="Multiclass Classification",
|
|
337
|
+
description=f"Classify {target_col} into {unique_count} categories",
|
|
338
|
+
suggested_target=target_col,
|
|
339
|
+
problem_type="Multiclass Classification",
|
|
340
|
+
confidence=0.7,
|
|
341
|
+
reasoning=f"Column {target_col} has {unique_count} discrete values, suggesting "
|
|
342
|
+
f"classification",
|
|
343
|
+
)
|
|
344
|
+
)
|
|
345
|
+
|
|
346
|
+
# Always suggest regression for numeric columns with more than 2 unique values
|
|
347
|
+
if unique_count > 2:
|
|
348
|
+
suggestions.append(
|
|
349
|
+
UseCaseSuggestion(
|
|
350
|
+
name="Regression Modeling",
|
|
351
|
+
description=f"Predict the value of {target_col}",
|
|
352
|
+
suggested_target=target_col,
|
|
353
|
+
problem_type="Regression",
|
|
354
|
+
confidence=0.6
|
|
355
|
+
+ (
|
|
356
|
+
0.1 if unique_count > 10 else 0
|
|
357
|
+
), # higher confidence for columns with more unique values for regression
|
|
358
|
+
reasoning=(
|
|
359
|
+
f"Column {target_col} is numerical with {unique_count} unique values, "
|
|
360
|
+
f"suggesting regression"
|
|
361
|
+
),
|
|
362
|
+
)
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
else:
|
|
366
|
+
# Categorical target
|
|
367
|
+
unique_count = target_data.nunique()
|
|
368
|
+
|
|
369
|
+
if unique_count == 2:
|
|
370
|
+
suggestions.append(
|
|
371
|
+
UseCaseSuggestion(
|
|
372
|
+
name="Binary Classification",
|
|
373
|
+
description=f"Predict the category of {target_col}",
|
|
374
|
+
suggested_target=target_col,
|
|
375
|
+
problem_type="Binary Classification",
|
|
376
|
+
confidence=0.9,
|
|
377
|
+
reasoning=f"Column {target_col} is categorical with 2 classes",
|
|
378
|
+
)
|
|
379
|
+
)
|
|
380
|
+
elif unique_count <= 20:
|
|
381
|
+
suggestions.append(
|
|
382
|
+
UseCaseSuggestion(
|
|
383
|
+
name="Multiclass Classification",
|
|
384
|
+
description=f"Classify {target_col} into {unique_count} categories",
|
|
385
|
+
suggested_target=target_col,
|
|
386
|
+
problem_type="Multiclass Classification",
|
|
387
|
+
confidence=0.8,
|
|
388
|
+
reasoning=f"Column {target_col} is categorical with {unique_count} manageable "
|
|
389
|
+
f"classes",
|
|
390
|
+
)
|
|
391
|
+
)
|
|
392
|
+
|
|
393
|
+
# Add specific use case suggestions based on column names
|
|
394
|
+
col_lower = target_col.lower()
|
|
395
|
+
if "sales" in col_lower or "revenue" in col_lower:
|
|
396
|
+
suggestions.append(
|
|
397
|
+
UseCaseSuggestion(
|
|
398
|
+
name="Sales Forecasting",
|
|
399
|
+
description=f"Forecast future {target_col} values",
|
|
400
|
+
suggested_target=target_col,
|
|
401
|
+
problem_type="Regression",
|
|
402
|
+
confidence=0.9,
|
|
403
|
+
reasoning="Sales/revenue data is ideal for forecasting models",
|
|
404
|
+
)
|
|
405
|
+
)
|
|
406
|
+
elif "churn" in col_lower:
|
|
407
|
+
suggestions.append(
|
|
408
|
+
UseCaseSuggestion(
|
|
409
|
+
name="Customer Churn Prediction",
|
|
410
|
+
description="Predict which customers are likely to churn",
|
|
411
|
+
suggested_target=target_col,
|
|
412
|
+
problem_type="Binary Classification",
|
|
413
|
+
confidence=0.95,
|
|
414
|
+
reasoning="Churn prediction is a classic binary classification problem",
|
|
415
|
+
)
|
|
416
|
+
)
|
|
417
|
+
elif "fraud" in col_lower:
|
|
418
|
+
suggestions.append(
|
|
419
|
+
UseCaseSuggestion(
|
|
420
|
+
name="Fraud Detection",
|
|
421
|
+
description="Detect fraudulent transactions or activities",
|
|
422
|
+
suggested_target=target_col,
|
|
423
|
+
problem_type="Binary Classification",
|
|
424
|
+
confidence=0.95,
|
|
425
|
+
reasoning="Fraud detection is typically a binary classification problem",
|
|
426
|
+
)
|
|
427
|
+
)
|
|
428
|
+
elif "price" in col_lower or "cost" in col_lower:
|
|
429
|
+
suggestions.append(
|
|
430
|
+
UseCaseSuggestion(
|
|
431
|
+
name="Price Prediction",
|
|
432
|
+
description=f"Predict optimal {target_col}",
|
|
433
|
+
suggested_target=target_col,
|
|
434
|
+
problem_type="Regression",
|
|
435
|
+
confidence=0.85,
|
|
436
|
+
reasoning="Price prediction is a common regression use case",
|
|
437
|
+
)
|
|
438
|
+
)
|
|
439
|
+
|
|
440
|
+
return suggestions
|
|
441
|
+
|
|
442
|
+
|
|
443
|
+
@dr_mcp_tool(tags={"predictive", "training", "write", "autopilot", "model"})
|
|
444
|
+
async def start_autopilot(
|
|
445
|
+
*,
|
|
446
|
+
target: Annotated[str, "Name of the target column for modeling"] | None = None,
|
|
447
|
+
project_id: Annotated[
|
|
448
|
+
str, "Optional, the ID of the DataRobot project or a new project if no id is provided"
|
|
449
|
+
]
|
|
450
|
+
| None = None,
|
|
451
|
+
mode: Annotated[str, "Optional, Autopilot mode ('quick', 'comprehensive', or 'manual')"]
|
|
452
|
+
| None = "quick",
|
|
453
|
+
dataset_url: Annotated[
|
|
454
|
+
str,
|
|
455
|
+
"""
|
|
456
|
+
Optional, The URL to the dataset to upload
|
|
457
|
+
(optional if dataset_id is provided) for a new project.
|
|
458
|
+
""",
|
|
459
|
+
]
|
|
460
|
+
| None = None,
|
|
461
|
+
dataset_id: Annotated[
|
|
462
|
+
str,
|
|
463
|
+
"""
|
|
464
|
+
Optional, The ID of an existing dataset in AI Catalog
|
|
465
|
+
(optional if dataset_url is provided) for a new project.
|
|
466
|
+
""",
|
|
467
|
+
]
|
|
468
|
+
| None = None,
|
|
469
|
+
project_name: Annotated[
|
|
470
|
+
str, "Optional, name for the project if no id is provided, creates a new project"
|
|
471
|
+
]
|
|
472
|
+
| None = "MCP Project",
|
|
473
|
+
use_case_id: Annotated[
|
|
474
|
+
str,
|
|
475
|
+
"Optional, ID of the use case to associate this project (required for next-gen platform)",
|
|
476
|
+
]
|
|
477
|
+
| None = None,
|
|
478
|
+
) -> ToolError | ToolResult:
|
|
479
|
+
"""Start automated model training (Autopilot) for a project."""
|
|
480
|
+
client = get_sdk_client()
|
|
481
|
+
|
|
482
|
+
if not project_id:
|
|
483
|
+
if not dataset_url and not dataset_id:
|
|
484
|
+
return ToolError("Either dataset_url or dataset_id must be provided")
|
|
485
|
+
if dataset_url and dataset_id:
|
|
486
|
+
return ToolError("Please provide either dataset_url or dataset_id, not both")
|
|
487
|
+
|
|
488
|
+
if dataset_url:
|
|
489
|
+
dataset = client.Dataset.create_from_url(dataset_url)
|
|
490
|
+
else:
|
|
491
|
+
dataset = client.Dataset.get(dataset_id)
|
|
492
|
+
|
|
493
|
+
project = client.Project.create_from_dataset(
|
|
494
|
+
dataset.id, project_name=project_name, use_case=use_case_id
|
|
495
|
+
)
|
|
496
|
+
else:
|
|
497
|
+
project = client.Project.get(project_id)
|
|
498
|
+
|
|
499
|
+
if not target:
|
|
500
|
+
return ToolError("Target variable must be specified")
|
|
501
|
+
|
|
502
|
+
try:
|
|
503
|
+
# Start modeling
|
|
504
|
+
project.analyze_and_model(target=target, mode=mode)
|
|
505
|
+
|
|
506
|
+
result = {
|
|
507
|
+
"project_id": project.id,
|
|
508
|
+
"target": target,
|
|
509
|
+
"mode": mode,
|
|
510
|
+
"status": project.get_status(),
|
|
511
|
+
"use_case_id": project.use_case_id,
|
|
512
|
+
}
|
|
513
|
+
|
|
514
|
+
return ToolResult(
|
|
515
|
+
content=json.dumps(result, indent=2),
|
|
516
|
+
structured_content=result,
|
|
517
|
+
)
|
|
518
|
+
|
|
519
|
+
except Exception as e:
|
|
520
|
+
return ToolError(
|
|
521
|
+
content=json.dumps(
|
|
522
|
+
{
|
|
523
|
+
"error": f"Failed to start Autopilot: {str(e)}",
|
|
524
|
+
"project_id": project.id if project else None,
|
|
525
|
+
"target": target,
|
|
526
|
+
"mode": mode,
|
|
527
|
+
},
|
|
528
|
+
indent=2,
|
|
529
|
+
)
|
|
530
|
+
)
|
|
531
|
+
|
|
532
|
+
|
|
533
|
+
@dr_mcp_tool(tags={"prediction", "training", "read", "model", "evaluation"})
|
|
534
|
+
async def get_model_roc_curve(
|
|
535
|
+
*,
|
|
536
|
+
project_id: Annotated[str, "The ID of the DataRobot project"] | None = None,
|
|
537
|
+
model_id: Annotated[str, "The ID of the model to analyze"] | None = None,
|
|
538
|
+
source: Annotated[
|
|
539
|
+
str,
|
|
540
|
+
"""
|
|
541
|
+
The source of the data to use for the ROC curve
|
|
542
|
+
('validation' or 'holdout' or 'crossValidation')
|
|
543
|
+
""",
|
|
544
|
+
]
|
|
545
|
+
| str = "validation",
|
|
546
|
+
) -> ToolError | ToolResult:
|
|
547
|
+
"""Get detailed ROC curve for a specific model."""
|
|
548
|
+
if not project_id:
|
|
549
|
+
return ToolError("Project ID must be provided")
|
|
550
|
+
if not model_id:
|
|
551
|
+
return ToolError("Model ID must be provided")
|
|
552
|
+
|
|
553
|
+
client = get_sdk_client()
|
|
554
|
+
project = client.Project.get(project_id)
|
|
555
|
+
model = client.Model.get(project=project, model_id=model_id)
|
|
556
|
+
|
|
557
|
+
try:
|
|
558
|
+
roc_curve = model.get_roc_curve(source=source)
|
|
559
|
+
roc_data = {
|
|
560
|
+
"roc_points": [
|
|
561
|
+
{
|
|
562
|
+
"accuracy": point.get("accuracy", 0),
|
|
563
|
+
"f1_score": point.get("f1_score", 0),
|
|
564
|
+
"false_negative_score": point.get("false_negative_score", 0),
|
|
565
|
+
"true_negative_score": point.get("true_negative_score", 0),
|
|
566
|
+
"true_negative_rate": point.get("true_negative_rate", 0),
|
|
567
|
+
"matthews_correlation_coefficient": point.get(
|
|
568
|
+
"matthews_correlation_coefficient", 0
|
|
569
|
+
),
|
|
570
|
+
"true_positive_score": point.get("true_positive_score", 0),
|
|
571
|
+
"positive_predictive_value": point.get("positive_predictive_value", 0),
|
|
572
|
+
"false_positive_score": point.get("false_positive_score", 0),
|
|
573
|
+
"false_positive_rate": point.get("false_positive_rate", 0),
|
|
574
|
+
"negative_predictive_value": point.get("negative_predictive_value", 0),
|
|
575
|
+
"true_positive_rate": point.get("true_positive_rate", 0),
|
|
576
|
+
"threshold": point.get("threshold", 0),
|
|
577
|
+
}
|
|
578
|
+
for point in roc_curve.roc_points
|
|
579
|
+
],
|
|
580
|
+
"negative_class_predictions": roc_curve.negative_class_predictions,
|
|
581
|
+
"positive_class_predictions": roc_curve.positive_class_predictions,
|
|
582
|
+
"source": source,
|
|
583
|
+
}
|
|
584
|
+
|
|
585
|
+
return ToolResult(
|
|
586
|
+
content=json.dumps({"data": roc_data}, indent=2),
|
|
587
|
+
structured_content={"data": roc_data},
|
|
588
|
+
)
|
|
589
|
+
except Exception as e:
|
|
590
|
+
return ToolError(f"Failed to get ROC curve: {str(e)}")
|
|
591
|
+
|
|
592
|
+
|
|
593
|
+
@dr_mcp_tool(tags={"predictive", "training", "read", "model", "evaluation"})
|
|
594
|
+
async def get_model_feature_impact(
|
|
595
|
+
*,
|
|
596
|
+
project_id: Annotated[str, "The ID of the DataRobot project"] | None = None,
|
|
597
|
+
model_id: Annotated[str, "The ID of the model to analyze"] | None = None,
|
|
598
|
+
) -> ToolError | ToolResult:
|
|
599
|
+
"""Get detailed feature impact for a specific model."""
|
|
600
|
+
if not project_id:
|
|
601
|
+
return ToolError("Project ID must be provided")
|
|
602
|
+
if not model_id:
|
|
603
|
+
return ToolError("Model ID must be provided")
|
|
604
|
+
|
|
605
|
+
client = get_sdk_client()
|
|
606
|
+
project = client.Project.get(project_id)
|
|
607
|
+
model = client.Model.get(project=project, model_id=model_id)
|
|
608
|
+
# Get feature impact
|
|
609
|
+
model.request_feature_impact()
|
|
610
|
+
feature_impact = model.get_or_request_feature_impact()
|
|
611
|
+
|
|
612
|
+
return ToolResult(
|
|
613
|
+
content=json.dumps({"data": feature_impact}, indent=2),
|
|
614
|
+
structured_content={"data": feature_impact},
|
|
615
|
+
)
|
|
616
|
+
|
|
617
|
+
|
|
618
|
+
@dr_mcp_tool(tags={"predictive", "training", "read", "model", "evaluation"})
|
|
619
|
+
async def get_model_lift_chart(
|
|
620
|
+
*,
|
|
621
|
+
project_id: Annotated[str, "The ID of the DataRobot project"] | None = None,
|
|
622
|
+
model_id: Annotated[str, "The ID of the model to analyze"] | None = None,
|
|
623
|
+
source: Annotated[
|
|
624
|
+
str,
|
|
625
|
+
"""
|
|
626
|
+
The source of the data to use for the lift chart
|
|
627
|
+
('validation' or 'holdout' or 'crossValidation')
|
|
628
|
+
""",
|
|
629
|
+
]
|
|
630
|
+
| str = "validation",
|
|
631
|
+
) -> ToolError | ToolResult:
|
|
632
|
+
"""Get detailed lift chart for a specific model."""
|
|
633
|
+
if not project_id:
|
|
634
|
+
return ToolError("Project ID must be provided")
|
|
635
|
+
if not model_id:
|
|
636
|
+
return ToolError("Model ID must be provided")
|
|
637
|
+
|
|
638
|
+
client = get_sdk_client()
|
|
639
|
+
project = client.Project.get(project_id)
|
|
640
|
+
model = client.Model.get(project=project, model_id=model_id)
|
|
641
|
+
|
|
642
|
+
# Get lift chart
|
|
643
|
+
lift_chart = model.get_lift_chart(source=source)
|
|
644
|
+
|
|
645
|
+
lift_chart_data = {
|
|
646
|
+
"bins": [
|
|
647
|
+
{
|
|
648
|
+
"actual": bin["actual"],
|
|
649
|
+
"predicted": bin["predicted"],
|
|
650
|
+
"bin_weight": bin["bin_weight"],
|
|
651
|
+
}
|
|
652
|
+
for bin in lift_chart.bins
|
|
653
|
+
],
|
|
654
|
+
"source_model_id": lift_chart.source_model_id,
|
|
655
|
+
"target_class": lift_chart.target_class,
|
|
656
|
+
}
|
|
657
|
+
|
|
658
|
+
return ToolResult(
|
|
659
|
+
content=json.dumps({"data": lift_chart_data}, indent=2),
|
|
660
|
+
structured_content={"data": lift_chart_data},
|
|
661
|
+
)
|
|
File without changes
|