datarobot-genai 0.2.31__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (125) hide show
  1. datarobot_genai/__init__.py +19 -0
  2. datarobot_genai/core/__init__.py +0 -0
  3. datarobot_genai/core/agents/__init__.py +43 -0
  4. datarobot_genai/core/agents/base.py +195 -0
  5. datarobot_genai/core/chat/__init__.py +19 -0
  6. datarobot_genai/core/chat/auth.py +146 -0
  7. datarobot_genai/core/chat/client.py +178 -0
  8. datarobot_genai/core/chat/responses.py +297 -0
  9. datarobot_genai/core/cli/__init__.py +18 -0
  10. datarobot_genai/core/cli/agent_environment.py +47 -0
  11. datarobot_genai/core/cli/agent_kernel.py +211 -0
  12. datarobot_genai/core/custom_model.py +141 -0
  13. datarobot_genai/core/mcp/__init__.py +0 -0
  14. datarobot_genai/core/mcp/common.py +218 -0
  15. datarobot_genai/core/telemetry_agent.py +126 -0
  16. datarobot_genai/core/utils/__init__.py +3 -0
  17. datarobot_genai/core/utils/auth.py +234 -0
  18. datarobot_genai/core/utils/urls.py +64 -0
  19. datarobot_genai/crewai/__init__.py +24 -0
  20. datarobot_genai/crewai/agent.py +42 -0
  21. datarobot_genai/crewai/base.py +159 -0
  22. datarobot_genai/crewai/events.py +117 -0
  23. datarobot_genai/crewai/mcp.py +59 -0
  24. datarobot_genai/drmcp/__init__.py +78 -0
  25. datarobot_genai/drmcp/core/__init__.py +13 -0
  26. datarobot_genai/drmcp/core/auth.py +165 -0
  27. datarobot_genai/drmcp/core/clients.py +180 -0
  28. datarobot_genai/drmcp/core/config.py +364 -0
  29. datarobot_genai/drmcp/core/config_utils.py +174 -0
  30. datarobot_genai/drmcp/core/constants.py +18 -0
  31. datarobot_genai/drmcp/core/credentials.py +190 -0
  32. datarobot_genai/drmcp/core/dr_mcp_server.py +350 -0
  33. datarobot_genai/drmcp/core/dr_mcp_server_logo.py +136 -0
  34. datarobot_genai/drmcp/core/dynamic_prompts/__init__.py +13 -0
  35. datarobot_genai/drmcp/core/dynamic_prompts/controllers.py +130 -0
  36. datarobot_genai/drmcp/core/dynamic_prompts/dr_lib.py +70 -0
  37. datarobot_genai/drmcp/core/dynamic_prompts/register.py +205 -0
  38. datarobot_genai/drmcp/core/dynamic_prompts/utils.py +33 -0
  39. datarobot_genai/drmcp/core/dynamic_tools/__init__.py +14 -0
  40. datarobot_genai/drmcp/core/dynamic_tools/deployment/__init__.py +0 -0
  41. datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/__init__.py +14 -0
  42. datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/base.py +72 -0
  43. datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/default.py +82 -0
  44. datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/drum.py +238 -0
  45. datarobot_genai/drmcp/core/dynamic_tools/deployment/config.py +228 -0
  46. datarobot_genai/drmcp/core/dynamic_tools/deployment/controllers.py +63 -0
  47. datarobot_genai/drmcp/core/dynamic_tools/deployment/metadata.py +162 -0
  48. datarobot_genai/drmcp/core/dynamic_tools/deployment/register.py +87 -0
  49. datarobot_genai/drmcp/core/dynamic_tools/deployment/schemas/drum_agentic_fallback_schema.json +36 -0
  50. datarobot_genai/drmcp/core/dynamic_tools/deployment/schemas/drum_prediction_fallback_schema.json +10 -0
  51. datarobot_genai/drmcp/core/dynamic_tools/register.py +254 -0
  52. datarobot_genai/drmcp/core/dynamic_tools/schema.py +532 -0
  53. datarobot_genai/drmcp/core/exceptions.py +25 -0
  54. datarobot_genai/drmcp/core/logging.py +98 -0
  55. datarobot_genai/drmcp/core/mcp_instance.py +515 -0
  56. datarobot_genai/drmcp/core/memory_management/__init__.py +13 -0
  57. datarobot_genai/drmcp/core/memory_management/manager.py +820 -0
  58. datarobot_genai/drmcp/core/memory_management/memory_tools.py +201 -0
  59. datarobot_genai/drmcp/core/routes.py +439 -0
  60. datarobot_genai/drmcp/core/routes_utils.py +30 -0
  61. datarobot_genai/drmcp/core/server_life_cycle.py +107 -0
  62. datarobot_genai/drmcp/core/telemetry.py +424 -0
  63. datarobot_genai/drmcp/core/tool_config.py +111 -0
  64. datarobot_genai/drmcp/core/tool_filter.py +117 -0
  65. datarobot_genai/drmcp/core/utils.py +138 -0
  66. datarobot_genai/drmcp/server.py +19 -0
  67. datarobot_genai/drmcp/test_utils/__init__.py +13 -0
  68. datarobot_genai/drmcp/test_utils/clients/__init__.py +0 -0
  69. datarobot_genai/drmcp/test_utils/clients/anthropic.py +68 -0
  70. datarobot_genai/drmcp/test_utils/clients/base.py +300 -0
  71. datarobot_genai/drmcp/test_utils/clients/dr_gateway.py +58 -0
  72. datarobot_genai/drmcp/test_utils/clients/openai.py +68 -0
  73. datarobot_genai/drmcp/test_utils/elicitation_test_tool.py +89 -0
  74. datarobot_genai/drmcp/test_utils/integration_mcp_server.py +109 -0
  75. datarobot_genai/drmcp/test_utils/mcp_utils_ete.py +133 -0
  76. datarobot_genai/drmcp/test_utils/mcp_utils_integration.py +107 -0
  77. datarobot_genai/drmcp/test_utils/test_interactive.py +205 -0
  78. datarobot_genai/drmcp/test_utils/tool_base_ete.py +220 -0
  79. datarobot_genai/drmcp/test_utils/utils.py +91 -0
  80. datarobot_genai/drmcp/tools/__init__.py +14 -0
  81. datarobot_genai/drmcp/tools/clients/__init__.py +14 -0
  82. datarobot_genai/drmcp/tools/clients/atlassian.py +188 -0
  83. datarobot_genai/drmcp/tools/clients/confluence.py +584 -0
  84. datarobot_genai/drmcp/tools/clients/gdrive.py +832 -0
  85. datarobot_genai/drmcp/tools/clients/jira.py +334 -0
  86. datarobot_genai/drmcp/tools/clients/microsoft_graph.py +479 -0
  87. datarobot_genai/drmcp/tools/clients/s3.py +28 -0
  88. datarobot_genai/drmcp/tools/confluence/__init__.py +14 -0
  89. datarobot_genai/drmcp/tools/confluence/tools.py +321 -0
  90. datarobot_genai/drmcp/tools/gdrive/__init__.py +0 -0
  91. datarobot_genai/drmcp/tools/gdrive/tools.py +347 -0
  92. datarobot_genai/drmcp/tools/jira/__init__.py +14 -0
  93. datarobot_genai/drmcp/tools/jira/tools.py +243 -0
  94. datarobot_genai/drmcp/tools/microsoft_graph/__init__.py +13 -0
  95. datarobot_genai/drmcp/tools/microsoft_graph/tools.py +198 -0
  96. datarobot_genai/drmcp/tools/predictive/__init__.py +27 -0
  97. datarobot_genai/drmcp/tools/predictive/data.py +133 -0
  98. datarobot_genai/drmcp/tools/predictive/deployment.py +91 -0
  99. datarobot_genai/drmcp/tools/predictive/deployment_info.py +392 -0
  100. datarobot_genai/drmcp/tools/predictive/model.py +148 -0
  101. datarobot_genai/drmcp/tools/predictive/predict.py +254 -0
  102. datarobot_genai/drmcp/tools/predictive/predict_realtime.py +307 -0
  103. datarobot_genai/drmcp/tools/predictive/project.py +90 -0
  104. datarobot_genai/drmcp/tools/predictive/training.py +661 -0
  105. datarobot_genai/langgraph/__init__.py +0 -0
  106. datarobot_genai/langgraph/agent.py +341 -0
  107. datarobot_genai/langgraph/mcp.py +73 -0
  108. datarobot_genai/llama_index/__init__.py +16 -0
  109. datarobot_genai/llama_index/agent.py +50 -0
  110. datarobot_genai/llama_index/base.py +299 -0
  111. datarobot_genai/llama_index/mcp.py +79 -0
  112. datarobot_genai/nat/__init__.py +0 -0
  113. datarobot_genai/nat/agent.py +275 -0
  114. datarobot_genai/nat/datarobot_auth_provider.py +110 -0
  115. datarobot_genai/nat/datarobot_llm_clients.py +318 -0
  116. datarobot_genai/nat/datarobot_llm_providers.py +130 -0
  117. datarobot_genai/nat/datarobot_mcp_client.py +266 -0
  118. datarobot_genai/nat/helpers.py +87 -0
  119. datarobot_genai/py.typed +0 -0
  120. datarobot_genai-0.2.31.dist-info/METADATA +145 -0
  121. datarobot_genai-0.2.31.dist-info/RECORD +125 -0
  122. datarobot_genai-0.2.31.dist-info/WHEEL +4 -0
  123. datarobot_genai-0.2.31.dist-info/entry_points.txt +5 -0
  124. datarobot_genai-0.2.31.dist-info/licenses/AUTHORS +2 -0
  125. datarobot_genai-0.2.31.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,661 @@
1
+ # Copyright 2025 DataRobot, Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Tools for analyzing datasets and suggesting ML use cases."""
16
+
17
+ import json
18
+ import logging
19
+ from dataclasses import asdict
20
+ from dataclasses import dataclass
21
+ from typing import Annotated
22
+
23
+ import pandas as pd
24
+ from fastmcp.exceptions import ToolError
25
+ from fastmcp.tools.tool import ToolResult
26
+
27
+ from datarobot_genai.drmcp.core.clients import get_sdk_client
28
+ from datarobot_genai.drmcp.core.mcp_instance import dr_mcp_tool
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+
33
+ @dataclass
34
+ class UseCaseSuggestion:
35
+ """Represents a suggested use case based on dataset analysis."""
36
+
37
+ name: str
38
+ description: str
39
+ suggested_target: str
40
+ problem_type: str
41
+ confidence: float
42
+ reasoning: str
43
+
44
+
45
+ @dataclass
46
+ class DatasetInsight:
47
+ """Contains insights about a dataset for use case discovery."""
48
+
49
+ total_columns: int
50
+ total_rows: int
51
+ numerical_columns: list[str]
52
+ categorical_columns: list[str]
53
+ datetime_columns: list[str]
54
+ text_columns: list[str]
55
+ potential_targets: list[str]
56
+ missing_data_summary: dict[str, float]
57
+
58
+
59
+ @dr_mcp_tool(tags={"predictive", "training", "read", "analysis", "dataset"})
60
+ async def analyze_dataset(
61
+ *,
62
+ dataset_id: Annotated[str, "The ID of the DataRobot dataset to analyze"] | None = None,
63
+ ) -> ToolError | ToolResult:
64
+ """Analyze a dataset to understand its structure and potential use cases."""
65
+ if not dataset_id:
66
+ return ToolError("Dataset ID must be provided")
67
+
68
+ client = get_sdk_client()
69
+ dataset = client.Dataset.get(dataset_id)
70
+ df = dataset.get_as_dataframe()
71
+
72
+ # Analyze dataset structure
73
+ numerical_cols = df.select_dtypes(include=["int64", "float64"]).columns.tolist()
74
+ categorical_cols = df.select_dtypes(include=["object", "category"]).columns.tolist()
75
+ datetime_cols = df.select_dtypes(include=["datetime64"]).columns.tolist()
76
+
77
+ # Identify potential text columns (categorical with high cardinality)
78
+ text_cols = []
79
+ for col in categorical_cols:
80
+ if df[col].str.len().mean() > 20: # Text detection
81
+ text_cols.append(col)
82
+ categorical_cols.remove(col) # Remove from categorical columns
83
+
84
+ # Calculate missing data
85
+ missing_data = {}
86
+ for col in df.columns:
87
+ missing_pct = (df[col].isnull().sum() / len(df)) * 100
88
+ if missing_pct > 0:
89
+ missing_data[col] = missing_pct
90
+
91
+ # Identify potential target columns
92
+ potential_targets = _identify_potential_targets(df, numerical_cols, categorical_cols)
93
+
94
+ insights = DatasetInsight(
95
+ total_columns=len(df.columns),
96
+ total_rows=len(df),
97
+ numerical_columns=numerical_cols,
98
+ categorical_columns=categorical_cols,
99
+ datetime_columns=datetime_cols,
100
+ text_columns=text_cols,
101
+ potential_targets=potential_targets,
102
+ missing_data_summary=missing_data,
103
+ )
104
+ insights_dict = asdict(insights)
105
+
106
+ return ToolResult(
107
+ content=json.dumps(insights_dict, indent=2),
108
+ structured_content=insights_dict,
109
+ )
110
+
111
+
112
+ @dr_mcp_tool(tags={"predictive", "training", "read", "analysis", "usecase"})
113
+ async def suggest_use_cases(
114
+ *,
115
+ dataset_id: Annotated[str, "The ID of the DataRobot dataset to analyze"] | None = None,
116
+ ) -> ToolError | ToolResult:
117
+ """Analyze a dataset and suggest potential machine learning use cases."""
118
+ if not dataset_id:
119
+ return ToolError("Dataset ID must be provided")
120
+
121
+ client = get_sdk_client()
122
+ dataset = client.Dataset.get(dataset_id)
123
+ df = dataset.get_as_dataframe()
124
+
125
+ # Get dataset insights first
126
+ insights_json = await analyze_dataset(dataset_id)
127
+ insights = json.loads(insights_json)
128
+
129
+ suggestions = []
130
+ for target_col in insights["potential_targets"]:
131
+ target_suggestions = _analyze_target_for_use_cases(df, target_col)
132
+ suggestions.extend([asdict(s) for s in target_suggestions])
133
+
134
+ # Sort by confidence score
135
+ suggestions.sort(key=lambda x: x["confidence"], reverse=True)
136
+
137
+ return ToolResult(
138
+ content=json.dumps(suggestions, indent=2),
139
+ structured_content={"use_case_suggestions": suggestions},
140
+ )
141
+
142
+
143
+ @dr_mcp_tool(tags={"predictive", "training", "read", "analysis", "eda"})
144
+ async def get_exploratory_insights(
145
+ *,
146
+ dataset_id: Annotated[str, "The ID of the DataRobot dataset to analyze"] | None = None,
147
+ target_col: Annotated[str, "Optional target column to focus EDA insights on"] | None = None,
148
+ ) -> ToolError | ToolResult:
149
+ """Generate exploratory data insights for a dataset."""
150
+ if not dataset_id:
151
+ return ToolError("Dataset ID must be provided")
152
+
153
+ client = get_sdk_client()
154
+ dataset = client.Dataset.get(dataset_id)
155
+ df = dataset.get_as_dataframe()
156
+
157
+ # Get dataset insights first
158
+ insights_json = await analyze_dataset(dataset_id)
159
+ insights = json.loads(insights_json)
160
+
161
+ eda_insights = {
162
+ "dataset_summary": {
163
+ "total_rows": int(insights["total_rows"]), # Convert to native Python int
164
+ "total_columns": int(insights["total_columns"]), # Convert to native Python int
165
+ "memory_usage": int(df.memory_usage().sum()), # Convert to native Python int
166
+ },
167
+ "target_analysis": {},
168
+ "feature_correlations": {},
169
+ "missing_data": insights["missing_data_summary"],
170
+ "data_types": {
171
+ "numerical": insights["numerical_columns"],
172
+ "categorical": insights["categorical_columns"],
173
+ "datetime": insights["datetime_columns"],
174
+ "text": insights["text_columns"],
175
+ },
176
+ }
177
+
178
+ # Target-specific analysis
179
+ if target_col and target_col in df.columns:
180
+ target_data = df[target_col]
181
+ target_analysis = {
182
+ "column_name": target_col,
183
+ "data_type": str(target_data.dtype),
184
+ "unique_values": int(target_data.nunique()), # Convert to native Python int
185
+ "missing_values": int(target_data.isnull().sum()), # Convert to native Python int
186
+ "missing_percentage": float(
187
+ target_data.isnull().sum() / len(df) * 100
188
+ ), # Already float
189
+ }
190
+
191
+ if pd.api.types.is_numeric_dtype(target_data):
192
+ target_analysis.update(
193
+ {
194
+ "min_value": float(target_data.min()), # Convert to native Python float
195
+ "max_value": float(target_data.max()), # Convert to native Python float
196
+ "mean": float(target_data.mean()), # Convert to native Python float
197
+ "median": float(target_data.median()), # Convert to native Python float
198
+ "std_dev": float(target_data.std()), # Convert to native Python float
199
+ }
200
+ )
201
+ else:
202
+ value_counts = target_data.value_counts()
203
+ target_analysis.update(
204
+ {
205
+ "value_counts": {
206
+ str(k): int(v) for k, v in value_counts.items()
207
+ }, # Convert both key and value
208
+ "most_common": str(value_counts.index[0]) if len(value_counts) > 0 else None,
209
+ }
210
+ )
211
+
212
+ eda_insights["target_analysis"] = target_analysis
213
+
214
+ # Feature correlations with target (for numerical features)
215
+ if pd.api.types.is_numeric_dtype(target_data):
216
+ numerical_features = [col for col in insights["numerical_columns"] if col != target_col]
217
+ if numerical_features:
218
+ correlations = {}
219
+ for feature in numerical_features:
220
+ corr = df[feature].corr(target_data)
221
+ if not pd.isna(corr):
222
+ correlations[feature] = float(corr) # Convert to native Python float
223
+
224
+ # Sort by absolute correlation
225
+ eda_insights["feature_correlations"] = dict(
226
+ sorted(correlations.items(), key=lambda x: abs(x[1]), reverse=True)
227
+ )
228
+
229
+ return ToolResult(
230
+ content=json.dumps(eda_insights, indent=2),
231
+ structured_content=eda_insights,
232
+ )
233
+
234
+
235
+ def _identify_potential_targets(
236
+ df: pd.DataFrame, numerical_cols: list[str], categorical_cols: list[str]
237
+ ) -> list[str]:
238
+ """Identify columns that could potentially be targets."""
239
+ potential_targets = []
240
+
241
+ # Look for common target column names
242
+ target_keywords = [
243
+ "target",
244
+ "label",
245
+ "class",
246
+ "outcome",
247
+ "result",
248
+ "prediction",
249
+ "predict",
250
+ "sales",
251
+ "revenue",
252
+ "price",
253
+ "amount",
254
+ "value",
255
+ "score",
256
+ "rating",
257
+ "churn",
258
+ "conversion",
259
+ "fraud",
260
+ "default",
261
+ "failure",
262
+ "success",
263
+ "risk",
264
+ "probability",
265
+ "likelihood",
266
+ "status",
267
+ "category",
268
+ "type",
269
+ ]
270
+
271
+ for col in df.columns:
272
+ col_lower = col.lower()
273
+
274
+ # Check if column name contains target keywords
275
+ if any(keyword in col_lower for keyword in target_keywords):
276
+ potential_targets.append(col)
277
+ continue
278
+
279
+ # For numerical columns, check if they might be targets
280
+ if col in numerical_cols:
281
+ # Skip ID-like columns
282
+ if "id" in col_lower or col_lower.endswith("_id"):
283
+ continue
284
+
285
+ # Check for bounded values (might be scores/ratings)
286
+ if df[col].min() >= 0 and df[col].max() <= 100:
287
+ potential_targets.append(col)
288
+ continue
289
+
290
+ # Check for binary-like numerical values
291
+ unique_vals = df[col].nunique()
292
+ if unique_vals == 2:
293
+ potential_targets.append(col)
294
+ continue
295
+
296
+ # For categorical columns, check cardinality
297
+ if col in categorical_cols:
298
+ unique_vals = df[col].nunique()
299
+ # Good targets have reasonable cardinality (2-20 classes typically)
300
+ if 2 <= unique_vals <= 20:
301
+ potential_targets.append(col)
302
+
303
+ return potential_targets
304
+
305
+
306
+ def _analyze_target_for_use_cases(df: pd.DataFrame, target_col: str) -> list[UseCaseSuggestion]:
307
+ """Analyze a specific target column and suggest use cases."""
308
+ suggestions: list[UseCaseSuggestion] = []
309
+
310
+ target_data = df[target_col].dropna()
311
+ if len(target_data) == 0:
312
+ return suggestions
313
+
314
+ # Determine if it's numerical or categorical
315
+ if pd.api.types.is_numeric_dtype(target_data):
316
+ unique_count = target_data.nunique()
317
+
318
+ if unique_count == 2:
319
+ # Binary classification
320
+ values = sorted(target_data.unique())
321
+ suggestions.append(
322
+ UseCaseSuggestion(
323
+ name="Binary Classification",
324
+ description=f"Predict whether {target_col} will be {values[0]} or {values[1]}",
325
+ suggested_target=target_col,
326
+ problem_type="Binary Classification",
327
+ confidence=0.8,
328
+ reasoning=f"Column {target_col} has exactly 2 unique values, suggesting binary "
329
+ f"classification",
330
+ )
331
+ )
332
+ elif unique_count <= 10:
333
+ # Multiclass classification
334
+ suggestions.append(
335
+ UseCaseSuggestion(
336
+ name="Multiclass Classification",
337
+ description=f"Classify {target_col} into {unique_count} categories",
338
+ suggested_target=target_col,
339
+ problem_type="Multiclass Classification",
340
+ confidence=0.7,
341
+ reasoning=f"Column {target_col} has {unique_count} discrete values, suggesting "
342
+ f"classification",
343
+ )
344
+ )
345
+
346
+ # Always suggest regression for numeric columns with more than 2 unique values
347
+ if unique_count > 2:
348
+ suggestions.append(
349
+ UseCaseSuggestion(
350
+ name="Regression Modeling",
351
+ description=f"Predict the value of {target_col}",
352
+ suggested_target=target_col,
353
+ problem_type="Regression",
354
+ confidence=0.6
355
+ + (
356
+ 0.1 if unique_count > 10 else 0
357
+ ), # higher confidence for columns with more unique values for regression
358
+ reasoning=(
359
+ f"Column {target_col} is numerical with {unique_count} unique values, "
360
+ f"suggesting regression"
361
+ ),
362
+ )
363
+ )
364
+
365
+ else:
366
+ # Categorical target
367
+ unique_count = target_data.nunique()
368
+
369
+ if unique_count == 2:
370
+ suggestions.append(
371
+ UseCaseSuggestion(
372
+ name="Binary Classification",
373
+ description=f"Predict the category of {target_col}",
374
+ suggested_target=target_col,
375
+ problem_type="Binary Classification",
376
+ confidence=0.9,
377
+ reasoning=f"Column {target_col} is categorical with 2 classes",
378
+ )
379
+ )
380
+ elif unique_count <= 20:
381
+ suggestions.append(
382
+ UseCaseSuggestion(
383
+ name="Multiclass Classification",
384
+ description=f"Classify {target_col} into {unique_count} categories",
385
+ suggested_target=target_col,
386
+ problem_type="Multiclass Classification",
387
+ confidence=0.8,
388
+ reasoning=f"Column {target_col} is categorical with {unique_count} manageable "
389
+ f"classes",
390
+ )
391
+ )
392
+
393
+ # Add specific use case suggestions based on column names
394
+ col_lower = target_col.lower()
395
+ if "sales" in col_lower or "revenue" in col_lower:
396
+ suggestions.append(
397
+ UseCaseSuggestion(
398
+ name="Sales Forecasting",
399
+ description=f"Forecast future {target_col} values",
400
+ suggested_target=target_col,
401
+ problem_type="Regression",
402
+ confidence=0.9,
403
+ reasoning="Sales/revenue data is ideal for forecasting models",
404
+ )
405
+ )
406
+ elif "churn" in col_lower:
407
+ suggestions.append(
408
+ UseCaseSuggestion(
409
+ name="Customer Churn Prediction",
410
+ description="Predict which customers are likely to churn",
411
+ suggested_target=target_col,
412
+ problem_type="Binary Classification",
413
+ confidence=0.95,
414
+ reasoning="Churn prediction is a classic binary classification problem",
415
+ )
416
+ )
417
+ elif "fraud" in col_lower:
418
+ suggestions.append(
419
+ UseCaseSuggestion(
420
+ name="Fraud Detection",
421
+ description="Detect fraudulent transactions or activities",
422
+ suggested_target=target_col,
423
+ problem_type="Binary Classification",
424
+ confidence=0.95,
425
+ reasoning="Fraud detection is typically a binary classification problem",
426
+ )
427
+ )
428
+ elif "price" in col_lower or "cost" in col_lower:
429
+ suggestions.append(
430
+ UseCaseSuggestion(
431
+ name="Price Prediction",
432
+ description=f"Predict optimal {target_col}",
433
+ suggested_target=target_col,
434
+ problem_type="Regression",
435
+ confidence=0.85,
436
+ reasoning="Price prediction is a common regression use case",
437
+ )
438
+ )
439
+
440
+ return suggestions
441
+
442
+
443
+ @dr_mcp_tool(tags={"predictive", "training", "write", "autopilot", "model"})
444
+ async def start_autopilot(
445
+ *,
446
+ target: Annotated[str, "Name of the target column for modeling"] | None = None,
447
+ project_id: Annotated[
448
+ str, "Optional, the ID of the DataRobot project or a new project if no id is provided"
449
+ ]
450
+ | None = None,
451
+ mode: Annotated[str, "Optional, Autopilot mode ('quick', 'comprehensive', or 'manual')"]
452
+ | None = "quick",
453
+ dataset_url: Annotated[
454
+ str,
455
+ """
456
+ Optional, The URL to the dataset to upload
457
+ (optional if dataset_id is provided) for a new project.
458
+ """,
459
+ ]
460
+ | None = None,
461
+ dataset_id: Annotated[
462
+ str,
463
+ """
464
+ Optional, The ID of an existing dataset in AI Catalog
465
+ (optional if dataset_url is provided) for a new project.
466
+ """,
467
+ ]
468
+ | None = None,
469
+ project_name: Annotated[
470
+ str, "Optional, name for the project if no id is provided, creates a new project"
471
+ ]
472
+ | None = "MCP Project",
473
+ use_case_id: Annotated[
474
+ str,
475
+ "Optional, ID of the use case to associate this project (required for next-gen platform)",
476
+ ]
477
+ | None = None,
478
+ ) -> ToolError | ToolResult:
479
+ """Start automated model training (Autopilot) for a project."""
480
+ client = get_sdk_client()
481
+
482
+ if not project_id:
483
+ if not dataset_url and not dataset_id:
484
+ return ToolError("Either dataset_url or dataset_id must be provided")
485
+ if dataset_url and dataset_id:
486
+ return ToolError("Please provide either dataset_url or dataset_id, not both")
487
+
488
+ if dataset_url:
489
+ dataset = client.Dataset.create_from_url(dataset_url)
490
+ else:
491
+ dataset = client.Dataset.get(dataset_id)
492
+
493
+ project = client.Project.create_from_dataset(
494
+ dataset.id, project_name=project_name, use_case=use_case_id
495
+ )
496
+ else:
497
+ project = client.Project.get(project_id)
498
+
499
+ if not target:
500
+ return ToolError("Target variable must be specified")
501
+
502
+ try:
503
+ # Start modeling
504
+ project.analyze_and_model(target=target, mode=mode)
505
+
506
+ result = {
507
+ "project_id": project.id,
508
+ "target": target,
509
+ "mode": mode,
510
+ "status": project.get_status(),
511
+ "use_case_id": project.use_case_id,
512
+ }
513
+
514
+ return ToolResult(
515
+ content=json.dumps(result, indent=2),
516
+ structured_content=result,
517
+ )
518
+
519
+ except Exception as e:
520
+ return ToolError(
521
+ content=json.dumps(
522
+ {
523
+ "error": f"Failed to start Autopilot: {str(e)}",
524
+ "project_id": project.id if project else None,
525
+ "target": target,
526
+ "mode": mode,
527
+ },
528
+ indent=2,
529
+ )
530
+ )
531
+
532
+
533
+ @dr_mcp_tool(tags={"prediction", "training", "read", "model", "evaluation"})
534
+ async def get_model_roc_curve(
535
+ *,
536
+ project_id: Annotated[str, "The ID of the DataRobot project"] | None = None,
537
+ model_id: Annotated[str, "The ID of the model to analyze"] | None = None,
538
+ source: Annotated[
539
+ str,
540
+ """
541
+ The source of the data to use for the ROC curve
542
+ ('validation' or 'holdout' or 'crossValidation')
543
+ """,
544
+ ]
545
+ | str = "validation",
546
+ ) -> ToolError | ToolResult:
547
+ """Get detailed ROC curve for a specific model."""
548
+ if not project_id:
549
+ return ToolError("Project ID must be provided")
550
+ if not model_id:
551
+ return ToolError("Model ID must be provided")
552
+
553
+ client = get_sdk_client()
554
+ project = client.Project.get(project_id)
555
+ model = client.Model.get(project=project, model_id=model_id)
556
+
557
+ try:
558
+ roc_curve = model.get_roc_curve(source=source)
559
+ roc_data = {
560
+ "roc_points": [
561
+ {
562
+ "accuracy": point.get("accuracy", 0),
563
+ "f1_score": point.get("f1_score", 0),
564
+ "false_negative_score": point.get("false_negative_score", 0),
565
+ "true_negative_score": point.get("true_negative_score", 0),
566
+ "true_negative_rate": point.get("true_negative_rate", 0),
567
+ "matthews_correlation_coefficient": point.get(
568
+ "matthews_correlation_coefficient", 0
569
+ ),
570
+ "true_positive_score": point.get("true_positive_score", 0),
571
+ "positive_predictive_value": point.get("positive_predictive_value", 0),
572
+ "false_positive_score": point.get("false_positive_score", 0),
573
+ "false_positive_rate": point.get("false_positive_rate", 0),
574
+ "negative_predictive_value": point.get("negative_predictive_value", 0),
575
+ "true_positive_rate": point.get("true_positive_rate", 0),
576
+ "threshold": point.get("threshold", 0),
577
+ }
578
+ for point in roc_curve.roc_points
579
+ ],
580
+ "negative_class_predictions": roc_curve.negative_class_predictions,
581
+ "positive_class_predictions": roc_curve.positive_class_predictions,
582
+ "source": source,
583
+ }
584
+
585
+ return ToolResult(
586
+ content=json.dumps({"data": roc_data}, indent=2),
587
+ structured_content={"data": roc_data},
588
+ )
589
+ except Exception as e:
590
+ return ToolError(f"Failed to get ROC curve: {str(e)}")
591
+
592
+
593
+ @dr_mcp_tool(tags={"predictive", "training", "read", "model", "evaluation"})
594
+ async def get_model_feature_impact(
595
+ *,
596
+ project_id: Annotated[str, "The ID of the DataRobot project"] | None = None,
597
+ model_id: Annotated[str, "The ID of the model to analyze"] | None = None,
598
+ ) -> ToolError | ToolResult:
599
+ """Get detailed feature impact for a specific model."""
600
+ if not project_id:
601
+ return ToolError("Project ID must be provided")
602
+ if not model_id:
603
+ return ToolError("Model ID must be provided")
604
+
605
+ client = get_sdk_client()
606
+ project = client.Project.get(project_id)
607
+ model = client.Model.get(project=project, model_id=model_id)
608
+ # Get feature impact
609
+ model.request_feature_impact()
610
+ feature_impact = model.get_or_request_feature_impact()
611
+
612
+ return ToolResult(
613
+ content=json.dumps({"data": feature_impact}, indent=2),
614
+ structured_content={"data": feature_impact},
615
+ )
616
+
617
+
618
+ @dr_mcp_tool(tags={"predictive", "training", "read", "model", "evaluation"})
619
+ async def get_model_lift_chart(
620
+ *,
621
+ project_id: Annotated[str, "The ID of the DataRobot project"] | None = None,
622
+ model_id: Annotated[str, "The ID of the model to analyze"] | None = None,
623
+ source: Annotated[
624
+ str,
625
+ """
626
+ The source of the data to use for the lift chart
627
+ ('validation' or 'holdout' or 'crossValidation')
628
+ """,
629
+ ]
630
+ | str = "validation",
631
+ ) -> ToolError | ToolResult:
632
+ """Get detailed lift chart for a specific model."""
633
+ if not project_id:
634
+ return ToolError("Project ID must be provided")
635
+ if not model_id:
636
+ return ToolError("Model ID must be provided")
637
+
638
+ client = get_sdk_client()
639
+ project = client.Project.get(project_id)
640
+ model = client.Model.get(project=project, model_id=model_id)
641
+
642
+ # Get lift chart
643
+ lift_chart = model.get_lift_chart(source=source)
644
+
645
+ lift_chart_data = {
646
+ "bins": [
647
+ {
648
+ "actual": bin["actual"],
649
+ "predicted": bin["predicted"],
650
+ "bin_weight": bin["bin_weight"],
651
+ }
652
+ for bin in lift_chart.bins
653
+ ],
654
+ "source_model_id": lift_chart.source_model_id,
655
+ "target_class": lift_chart.target_class,
656
+ }
657
+
658
+ return ToolResult(
659
+ content=json.dumps({"data": lift_chart_data}, indent=2),
660
+ structured_content={"data": lift_chart_data},
661
+ )
File without changes