datarobot-genai 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. datarobot_genai/__init__.py +19 -0
  2. datarobot_genai/core/__init__.py +0 -0
  3. datarobot_genai/core/agents/__init__.py +43 -0
  4. datarobot_genai/core/agents/base.py +195 -0
  5. datarobot_genai/core/chat/__init__.py +19 -0
  6. datarobot_genai/core/chat/auth.py +146 -0
  7. datarobot_genai/core/chat/client.py +178 -0
  8. datarobot_genai/core/chat/responses.py +297 -0
  9. datarobot_genai/core/cli/__init__.py +18 -0
  10. datarobot_genai/core/cli/agent_environment.py +47 -0
  11. datarobot_genai/core/cli/agent_kernel.py +211 -0
  12. datarobot_genai/core/custom_model.py +141 -0
  13. datarobot_genai/core/mcp/__init__.py +0 -0
  14. datarobot_genai/core/mcp/common.py +218 -0
  15. datarobot_genai/core/telemetry_agent.py +126 -0
  16. datarobot_genai/core/utils/__init__.py +3 -0
  17. datarobot_genai/core/utils/auth.py +234 -0
  18. datarobot_genai/core/utils/urls.py +64 -0
  19. datarobot_genai/crewai/__init__.py +24 -0
  20. datarobot_genai/crewai/agent.py +42 -0
  21. datarobot_genai/crewai/base.py +159 -0
  22. datarobot_genai/crewai/events.py +117 -0
  23. datarobot_genai/crewai/mcp.py +59 -0
  24. datarobot_genai/drmcp/__init__.py +78 -0
  25. datarobot_genai/drmcp/core/__init__.py +13 -0
  26. datarobot_genai/drmcp/core/auth.py +165 -0
  27. datarobot_genai/drmcp/core/clients.py +180 -0
  28. datarobot_genai/drmcp/core/config.py +250 -0
  29. datarobot_genai/drmcp/core/config_utils.py +174 -0
  30. datarobot_genai/drmcp/core/constants.py +18 -0
  31. datarobot_genai/drmcp/core/credentials.py +190 -0
  32. datarobot_genai/drmcp/core/dr_mcp_server.py +316 -0
  33. datarobot_genai/drmcp/core/dr_mcp_server_logo.py +136 -0
  34. datarobot_genai/drmcp/core/dynamic_prompts/__init__.py +13 -0
  35. datarobot_genai/drmcp/core/dynamic_prompts/controllers.py +130 -0
  36. datarobot_genai/drmcp/core/dynamic_prompts/dr_lib.py +128 -0
  37. datarobot_genai/drmcp/core/dynamic_prompts/register.py +206 -0
  38. datarobot_genai/drmcp/core/dynamic_prompts/utils.py +33 -0
  39. datarobot_genai/drmcp/core/dynamic_tools/__init__.py +14 -0
  40. datarobot_genai/drmcp/core/dynamic_tools/deployment/__init__.py +0 -0
  41. datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/__init__.py +14 -0
  42. datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/base.py +72 -0
  43. datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/default.py +82 -0
  44. datarobot_genai/drmcp/core/dynamic_tools/deployment/adapters/drum.py +238 -0
  45. datarobot_genai/drmcp/core/dynamic_tools/deployment/config.py +228 -0
  46. datarobot_genai/drmcp/core/dynamic_tools/deployment/controllers.py +63 -0
  47. datarobot_genai/drmcp/core/dynamic_tools/deployment/metadata.py +162 -0
  48. datarobot_genai/drmcp/core/dynamic_tools/deployment/register.py +87 -0
  49. datarobot_genai/drmcp/core/dynamic_tools/deployment/schemas/drum_agentic_fallback_schema.json +36 -0
  50. datarobot_genai/drmcp/core/dynamic_tools/deployment/schemas/drum_prediction_fallback_schema.json +10 -0
  51. datarobot_genai/drmcp/core/dynamic_tools/register.py +254 -0
  52. datarobot_genai/drmcp/core/dynamic_tools/schema.py +532 -0
  53. datarobot_genai/drmcp/core/exceptions.py +25 -0
  54. datarobot_genai/drmcp/core/logging.py +98 -0
  55. datarobot_genai/drmcp/core/mcp_instance.py +542 -0
  56. datarobot_genai/drmcp/core/mcp_server_tools.py +129 -0
  57. datarobot_genai/drmcp/core/memory_management/__init__.py +13 -0
  58. datarobot_genai/drmcp/core/memory_management/manager.py +820 -0
  59. datarobot_genai/drmcp/core/memory_management/memory_tools.py +201 -0
  60. datarobot_genai/drmcp/core/routes.py +436 -0
  61. datarobot_genai/drmcp/core/routes_utils.py +30 -0
  62. datarobot_genai/drmcp/core/server_life_cycle.py +107 -0
  63. datarobot_genai/drmcp/core/telemetry.py +424 -0
  64. datarobot_genai/drmcp/core/tool_filter.py +108 -0
  65. datarobot_genai/drmcp/core/utils.py +131 -0
  66. datarobot_genai/drmcp/server.py +19 -0
  67. datarobot_genai/drmcp/test_utils/__init__.py +13 -0
  68. datarobot_genai/drmcp/test_utils/integration_mcp_server.py +102 -0
  69. datarobot_genai/drmcp/test_utils/mcp_utils_ete.py +96 -0
  70. datarobot_genai/drmcp/test_utils/mcp_utils_integration.py +94 -0
  71. datarobot_genai/drmcp/test_utils/openai_llm_mcp_client.py +234 -0
  72. datarobot_genai/drmcp/test_utils/tool_base_ete.py +151 -0
  73. datarobot_genai/drmcp/test_utils/utils.py +91 -0
  74. datarobot_genai/drmcp/tools/__init__.py +14 -0
  75. datarobot_genai/drmcp/tools/predictive/__init__.py +27 -0
  76. datarobot_genai/drmcp/tools/predictive/data.py +97 -0
  77. datarobot_genai/drmcp/tools/predictive/deployment.py +91 -0
  78. datarobot_genai/drmcp/tools/predictive/deployment_info.py +392 -0
  79. datarobot_genai/drmcp/tools/predictive/model.py +148 -0
  80. datarobot_genai/drmcp/tools/predictive/predict.py +254 -0
  81. datarobot_genai/drmcp/tools/predictive/predict_realtime.py +307 -0
  82. datarobot_genai/drmcp/tools/predictive/project.py +72 -0
  83. datarobot_genai/drmcp/tools/predictive/training.py +651 -0
  84. datarobot_genai/langgraph/__init__.py +0 -0
  85. datarobot_genai/langgraph/agent.py +341 -0
  86. datarobot_genai/langgraph/mcp.py +73 -0
  87. datarobot_genai/llama_index/__init__.py +16 -0
  88. datarobot_genai/llama_index/agent.py +50 -0
  89. datarobot_genai/llama_index/base.py +299 -0
  90. datarobot_genai/llama_index/mcp.py +79 -0
  91. datarobot_genai/nat/__init__.py +0 -0
  92. datarobot_genai/nat/agent.py +258 -0
  93. datarobot_genai/nat/datarobot_llm_clients.py +249 -0
  94. datarobot_genai/nat/datarobot_llm_providers.py +130 -0
  95. datarobot_genai/py.typed +0 -0
  96. datarobot_genai-0.2.0.dist-info/METADATA +139 -0
  97. datarobot_genai-0.2.0.dist-info/RECORD +101 -0
  98. datarobot_genai-0.2.0.dist-info/WHEEL +4 -0
  99. datarobot_genai-0.2.0.dist-info/entry_points.txt +3 -0
  100. datarobot_genai-0.2.0.dist-info/licenses/AUTHORS +2 -0
  101. datarobot_genai-0.2.0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,651 @@
1
+ # Copyright 2025 DataRobot, Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Tools for analyzing datasets and suggesting ML use cases."""
16
+
17
+ import json
18
+ import logging
19
+ from dataclasses import asdict
20
+ from dataclasses import dataclass
21
+
22
+ import pandas as pd
23
+
24
+ from datarobot_genai.drmcp.core.clients import get_sdk_client
25
+ from datarobot_genai.drmcp.core.mcp_instance import dr_mcp_tool
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ @dataclass
31
+ class UseCaseSuggestion:
32
+ """Represents a suggested use case based on dataset analysis."""
33
+
34
+ name: str
35
+ description: str
36
+ suggested_target: str
37
+ problem_type: str
38
+ confidence: float
39
+ reasoning: str
40
+
41
+
42
+ @dataclass
43
+ class DatasetInsight:
44
+ """Contains insights about a dataset for use case discovery."""
45
+
46
+ total_columns: int
47
+ total_rows: int
48
+ numerical_columns: list[str]
49
+ categorical_columns: list[str]
50
+ datetime_columns: list[str]
51
+ text_columns: list[str]
52
+ potential_targets: list[str]
53
+ missing_data_summary: dict[str, float]
54
+
55
+
56
+ @dr_mcp_tool(tags={"training", "analysis", "dataset"})
57
+ async def analyze_dataset(dataset_id: str) -> str:
58
+ """
59
+ Analyze a dataset to understand its structure and potential use cases.
60
+
61
+ Args:
62
+ dataset_id: The ID of the DataRobot dataset to analyze
63
+
64
+ Returns
65
+ -------
66
+ JSON string containing dataset insights including:
67
+ - Basic statistics (rows, columns)
68
+ - Column types (numerical, categorical, datetime, text)
69
+ - Potential target columns
70
+ - Missing data summary
71
+ """
72
+ client = get_sdk_client()
73
+ dataset = client.Dataset.get(dataset_id)
74
+ df = dataset.get_as_dataframe()
75
+
76
+ # Analyze dataset structure
77
+ numerical_cols = df.select_dtypes(include=["int64", "float64"]).columns.tolist()
78
+ categorical_cols = df.select_dtypes(include=["object", "category"]).columns.tolist()
79
+ datetime_cols = df.select_dtypes(include=["datetime64"]).columns.tolist()
80
+
81
+ # Identify potential text columns (categorical with high cardinality)
82
+ text_cols = []
83
+ for col in categorical_cols:
84
+ if df[col].str.len().mean() > 20: # Text detection
85
+ text_cols.append(col)
86
+ categorical_cols.remove(col) # Remove from categorical columns
87
+
88
+ # Calculate missing data
89
+ missing_data = {}
90
+ for col in df.columns:
91
+ missing_pct = (df[col].isnull().sum() / len(df)) * 100
92
+ if missing_pct > 0:
93
+ missing_data[col] = missing_pct
94
+
95
+ # Identify potential target columns
96
+ potential_targets = _identify_potential_targets(df, numerical_cols, categorical_cols)
97
+
98
+ insights = DatasetInsight(
99
+ total_columns=len(df.columns),
100
+ total_rows=len(df),
101
+ numerical_columns=numerical_cols,
102
+ categorical_columns=categorical_cols,
103
+ datetime_columns=datetime_cols,
104
+ text_columns=text_cols,
105
+ potential_targets=potential_targets,
106
+ missing_data_summary=missing_data,
107
+ )
108
+
109
+ return json.dumps(asdict(insights), indent=2)
110
+
111
+
112
+ @dr_mcp_tool(tags={"training", "analysis", "usecase"})
113
+ async def suggest_use_cases(dataset_id: str) -> str:
114
+ """
115
+ Analyze a dataset and suggest potential machine learning use cases.
116
+
117
+ Args:
118
+ dataset_id: The ID of the DataRobot dataset to analyze
119
+
120
+ Returns
121
+ -------
122
+ JSON string containing suggested use cases with:
123
+ - Use case name and description
124
+ - Suggested target column
125
+ - Problem type
126
+ - Confidence score
127
+ - Reasoning for the suggestion
128
+ """
129
+ client = get_sdk_client()
130
+ dataset = client.Dataset.get(dataset_id)
131
+ df = dataset.get_as_dataframe()
132
+
133
+ # Get dataset insights first
134
+ insights_json = await analyze_dataset(dataset_id)
135
+ insights = json.loads(insights_json)
136
+
137
+ suggestions = []
138
+ for target_col in insights["potential_targets"]:
139
+ target_suggestions = _analyze_target_for_use_cases(df, target_col)
140
+ suggestions.extend([asdict(s) for s in target_suggestions])
141
+
142
+ # Sort by confidence score
143
+ suggestions.sort(key=lambda x: x["confidence"], reverse=True)
144
+ return json.dumps(suggestions, indent=2)
145
+
146
+
147
+ @dr_mcp_tool(tags={"training", "analysis", "eda"})
148
+ async def get_exploratory_insights(dataset_id: str, target_col: str | None = None) -> str:
149
+ """
150
+ Generate exploratory data insights for a dataset.
151
+
152
+ Args:
153
+ dataset_id: The ID of the DataRobot dataset to analyze
154
+ target_col: Optional target column to focus EDA insights on
155
+
156
+ Returns
157
+ -------
158
+ JSON string containing EDA insights including:
159
+ - Dataset summary statistics
160
+ - Target variable analysis (if specified)
161
+ - Feature correlations with target
162
+ - Missing data analysis
163
+ - Data type distribution
164
+ """
165
+ client = get_sdk_client()
166
+ dataset = client.Dataset.get(dataset_id)
167
+ df = dataset.get_as_dataframe()
168
+
169
+ # Get dataset insights first
170
+ insights_json = await analyze_dataset(dataset_id)
171
+ insights = json.loads(insights_json)
172
+
173
+ eda_insights = {
174
+ "dataset_summary": {
175
+ "total_rows": int(insights["total_rows"]), # Convert to native Python int
176
+ "total_columns": int(insights["total_columns"]), # Convert to native Python int
177
+ "memory_usage": int(df.memory_usage().sum()), # Convert to native Python int
178
+ },
179
+ "target_analysis": {},
180
+ "feature_correlations": {},
181
+ "missing_data": insights["missing_data_summary"],
182
+ "data_types": {
183
+ "numerical": insights["numerical_columns"],
184
+ "categorical": insights["categorical_columns"],
185
+ "datetime": insights["datetime_columns"],
186
+ "text": insights["text_columns"],
187
+ },
188
+ }
189
+
190
+ # Target-specific analysis
191
+ if target_col and target_col in df.columns:
192
+ target_data = df[target_col]
193
+ target_analysis = {
194
+ "column_name": target_col,
195
+ "data_type": str(target_data.dtype),
196
+ "unique_values": int(target_data.nunique()), # Convert to native Python int
197
+ "missing_values": int(target_data.isnull().sum()), # Convert to native Python int
198
+ "missing_percentage": float(
199
+ target_data.isnull().sum() / len(df) * 100
200
+ ), # Already float
201
+ }
202
+
203
+ if pd.api.types.is_numeric_dtype(target_data):
204
+ target_analysis.update(
205
+ {
206
+ "min_value": float(target_data.min()), # Convert to native Python float
207
+ "max_value": float(target_data.max()), # Convert to native Python float
208
+ "mean": float(target_data.mean()), # Convert to native Python float
209
+ "median": float(target_data.median()), # Convert to native Python float
210
+ "std_dev": float(target_data.std()), # Convert to native Python float
211
+ }
212
+ )
213
+ else:
214
+ value_counts = target_data.value_counts()
215
+ target_analysis.update(
216
+ {
217
+ "value_counts": {
218
+ str(k): int(v) for k, v in value_counts.items()
219
+ }, # Convert both key and value
220
+ "most_common": str(value_counts.index[0]) if len(value_counts) > 0 else None,
221
+ }
222
+ )
223
+
224
+ eda_insights["target_analysis"] = target_analysis
225
+
226
+ # Feature correlations with target (for numerical features)
227
+ if pd.api.types.is_numeric_dtype(target_data):
228
+ numerical_features = [col for col in insights["numerical_columns"] if col != target_col]
229
+ if numerical_features:
230
+ correlations = {}
231
+ for feature in numerical_features:
232
+ corr = df[feature].corr(target_data)
233
+ if not pd.isna(corr):
234
+ correlations[feature] = float(corr) # Convert to native Python float
235
+
236
+ # Sort by absolute correlation
237
+ eda_insights["feature_correlations"] = dict(
238
+ sorted(correlations.items(), key=lambda x: abs(x[1]), reverse=True)
239
+ )
240
+
241
+ eda_insights["ui_panel"] = ["eda"]
242
+ return json.dumps(eda_insights, indent=2)
243
+
244
+
245
+ def _identify_potential_targets(
246
+ df: pd.DataFrame, numerical_cols: list[str], categorical_cols: list[str]
247
+ ) -> list[str]:
248
+ """Identify columns that could potentially be targets."""
249
+ potential_targets = []
250
+
251
+ # Look for common target column names
252
+ target_keywords = [
253
+ "target",
254
+ "label",
255
+ "class",
256
+ "outcome",
257
+ "result",
258
+ "prediction",
259
+ "predict",
260
+ "sales",
261
+ "revenue",
262
+ "price",
263
+ "amount",
264
+ "value",
265
+ "score",
266
+ "rating",
267
+ "churn",
268
+ "conversion",
269
+ "fraud",
270
+ "default",
271
+ "failure",
272
+ "success",
273
+ "risk",
274
+ "probability",
275
+ "likelihood",
276
+ "status",
277
+ "category",
278
+ "type",
279
+ ]
280
+
281
+ for col in df.columns:
282
+ col_lower = col.lower()
283
+
284
+ # Check if column name contains target keywords
285
+ if any(keyword in col_lower for keyword in target_keywords):
286
+ potential_targets.append(col)
287
+ continue
288
+
289
+ # For numerical columns, check if they might be targets
290
+ if col in numerical_cols:
291
+ # Skip ID-like columns
292
+ if "id" in col_lower or col_lower.endswith("_id"):
293
+ continue
294
+
295
+ # Check for bounded values (might be scores/ratings)
296
+ if df[col].min() >= 0 and df[col].max() <= 100:
297
+ potential_targets.append(col)
298
+ continue
299
+
300
+ # Check for binary-like numerical values
301
+ unique_vals = df[col].nunique()
302
+ if unique_vals == 2:
303
+ potential_targets.append(col)
304
+ continue
305
+
306
+ # For categorical columns, check cardinality
307
+ if col in categorical_cols:
308
+ unique_vals = df[col].nunique()
309
+ # Good targets have reasonable cardinality (2-20 classes typically)
310
+ if 2 <= unique_vals <= 20:
311
+ potential_targets.append(col)
312
+
313
+ return potential_targets
314
+
315
+
316
+ def _analyze_target_for_use_cases(df: pd.DataFrame, target_col: str) -> list[UseCaseSuggestion]:
317
+ """Analyze a specific target column and suggest use cases."""
318
+ suggestions: list[UseCaseSuggestion] = []
319
+
320
+ target_data = df[target_col].dropna()
321
+ if len(target_data) == 0:
322
+ return suggestions
323
+
324
+ # Determine if it's numerical or categorical
325
+ if pd.api.types.is_numeric_dtype(target_data):
326
+ unique_count = target_data.nunique()
327
+
328
+ if unique_count == 2:
329
+ # Binary classification
330
+ values = sorted(target_data.unique())
331
+ suggestions.append(
332
+ UseCaseSuggestion(
333
+ name="Binary Classification",
334
+ description=f"Predict whether {target_col} will be {values[0]} or {values[1]}",
335
+ suggested_target=target_col,
336
+ problem_type="Binary Classification",
337
+ confidence=0.8,
338
+ reasoning=f"Column {target_col} has exactly 2 unique values, suggesting binary "
339
+ f"classification",
340
+ )
341
+ )
342
+ elif unique_count <= 10:
343
+ # Multiclass classification
344
+ suggestions.append(
345
+ UseCaseSuggestion(
346
+ name="Multiclass Classification",
347
+ description=f"Classify {target_col} into {unique_count} categories",
348
+ suggested_target=target_col,
349
+ problem_type="Multiclass Classification",
350
+ confidence=0.7,
351
+ reasoning=f"Column {target_col} has {unique_count} discrete values, suggesting "
352
+ f"classification",
353
+ )
354
+ )
355
+
356
+ # Always suggest regression for numeric columns with more than 2 unique values
357
+ if unique_count > 2:
358
+ suggestions.append(
359
+ UseCaseSuggestion(
360
+ name="Regression Modeling",
361
+ description=f"Predict the value of {target_col}",
362
+ suggested_target=target_col,
363
+ problem_type="Regression",
364
+ confidence=0.6
365
+ + (
366
+ 0.1 if unique_count > 10 else 0
367
+ ), # higher confidence for columns with more unique values for regression
368
+ reasoning=(
369
+ f"Column {target_col} is numerical with {unique_count} unique values, "
370
+ f"suggesting regression"
371
+ ),
372
+ )
373
+ )
374
+
375
+ else:
376
+ # Categorical target
377
+ unique_count = target_data.nunique()
378
+
379
+ if unique_count == 2:
380
+ suggestions.append(
381
+ UseCaseSuggestion(
382
+ name="Binary Classification",
383
+ description=f"Predict the category of {target_col}",
384
+ suggested_target=target_col,
385
+ problem_type="Binary Classification",
386
+ confidence=0.9,
387
+ reasoning=f"Column {target_col} is categorical with 2 classes",
388
+ )
389
+ )
390
+ elif unique_count <= 20:
391
+ suggestions.append(
392
+ UseCaseSuggestion(
393
+ name="Multiclass Classification",
394
+ description=f"Classify {target_col} into {unique_count} categories",
395
+ suggested_target=target_col,
396
+ problem_type="Multiclass Classification",
397
+ confidence=0.8,
398
+ reasoning=f"Column {target_col} is categorical with {unique_count} manageable "
399
+ f"classes",
400
+ )
401
+ )
402
+
403
+ # Add specific use case suggestions based on column names
404
+ col_lower = target_col.lower()
405
+ if "sales" in col_lower or "revenue" in col_lower:
406
+ suggestions.append(
407
+ UseCaseSuggestion(
408
+ name="Sales Forecasting",
409
+ description=f"Forecast future {target_col} values",
410
+ suggested_target=target_col,
411
+ problem_type="Regression",
412
+ confidence=0.9,
413
+ reasoning="Sales/revenue data is ideal for forecasting models",
414
+ )
415
+ )
416
+ elif "churn" in col_lower:
417
+ suggestions.append(
418
+ UseCaseSuggestion(
419
+ name="Customer Churn Prediction",
420
+ description="Predict which customers are likely to churn",
421
+ suggested_target=target_col,
422
+ problem_type="Binary Classification",
423
+ confidence=0.95,
424
+ reasoning="Churn prediction is a classic binary classification problem",
425
+ )
426
+ )
427
+ elif "fraud" in col_lower:
428
+ suggestions.append(
429
+ UseCaseSuggestion(
430
+ name="Fraud Detection",
431
+ description="Detect fraudulent transactions or activities",
432
+ suggested_target=target_col,
433
+ problem_type="Binary Classification",
434
+ confidence=0.95,
435
+ reasoning="Fraud detection is typically a binary classification problem",
436
+ )
437
+ )
438
+ elif "price" in col_lower or "cost" in col_lower:
439
+ suggestions.append(
440
+ UseCaseSuggestion(
441
+ name="Price Prediction",
442
+ description=f"Predict optimal {target_col}",
443
+ suggested_target=target_col,
444
+ problem_type="Regression",
445
+ confidence=0.85,
446
+ reasoning="Price prediction is a common regression use case",
447
+ )
448
+ )
449
+
450
+ return suggestions
451
+
452
+
453
+ @dr_mcp_tool(tags={"training", "autopilot", "model"})
454
+ async def start_autopilot(
455
+ target: str,
456
+ project_id: str | None = None,
457
+ mode: str | None = "quick",
458
+ dataset_url: str | None = None,
459
+ dataset_id: str | None = None,
460
+ project_name: str | None = "MCP Project",
461
+ use_case_id: str | None = None,
462
+ ) -> str:
463
+ """
464
+ Start automated model training (Autopilot) for a project.
465
+
466
+ Args:
467
+ target: Name of the target column for modeling.
468
+ project_id: Optional, the ID of the DataRobot project or a new project if no id is provided.
469
+ mode: Optional, Autopilot mode ('quick', 'comprehensive', or 'manual')
470
+ dataset_url: Optional, The URL to the dataset to upload (optional if dataset_id is provided)
471
+ for a new project.
472
+ dataset_id: Optional, The ID of an existing dataset in AI Catalog (optional if dataset_url
473
+ is provided) for a new project.
474
+ project_name: Optional, name for the project if no id is provided, creates a new project.
475
+ use_case_id: Optional, ID of the use case to associate this project with (required for
476
+ next-gen platform).
477
+
478
+ Returns
479
+ -------
480
+ JSON string containing:
481
+ - project_id: Project ID
482
+ - target: Target column name
483
+ - mode: Selected Autopilot mode
484
+ - status: Current project status
485
+ - ui_panel: List of recommended UI panels for visualization
486
+ """
487
+ client = get_sdk_client()
488
+
489
+ if not project_id:
490
+ if not dataset_url and not dataset_id:
491
+ return "Error: Either dataset_url or dataset_id must be provided"
492
+ if dataset_url and dataset_id:
493
+ return "Error: Please provide either dataset_url or dataset_id, not both"
494
+
495
+ if dataset_url:
496
+ dataset = client.Dataset.create_from_url(dataset_url)
497
+ else:
498
+ dataset = client.Dataset.get(dataset_id)
499
+
500
+ project = client.Project.create_from_dataset(
501
+ dataset.id, project_name=project_name, use_case=use_case_id
502
+ )
503
+ else:
504
+ project = client.Project.get(project_id)
505
+
506
+ if not target:
507
+ return "Error: Target variable must be specified"
508
+
509
+ try:
510
+ # Start modeling
511
+ project.analyze_and_model(target=target, mode=mode)
512
+
513
+ result = {
514
+ "project_id": project.id,
515
+ "target": target,
516
+ "mode": mode,
517
+ "status": project.get_status(),
518
+ "ui_panel": ["eda", "model-training", "leaderboard"],
519
+ "use_case_id": project.use_case_id,
520
+ }
521
+
522
+ return json.dumps(result, indent=2)
523
+ except Exception as e:
524
+ return json.dumps(
525
+ {
526
+ "error": f"Failed to start Autopilot: {str(e)}",
527
+ "project_id": project.id,
528
+ "target": target,
529
+ "mode": mode,
530
+ },
531
+ indent=2,
532
+ )
533
+
534
+
535
+ @dr_mcp_tool(tags={"training", "model", "evaluation"})
536
+ async def get_model_roc_curve(project_id: str, model_id: str, source: str = "validation") -> str:
537
+ """
538
+ Get detailed ROC curve for a specific model.
539
+
540
+ Args:
541
+ project_id: The ID of the DataRobot project
542
+ model_id: The ID of the model to analyze
543
+ source: The source of the data to use for the ROC curve ('validation' or 'holdout' or
544
+ 'crossValidation')
545
+
546
+ Returns
547
+ -------
548
+ JSON string containing:
549
+ - roc_curve: ROC curve data
550
+ - ui_panel: List of recommended UI panels for visualization
551
+ """
552
+ client = get_sdk_client()
553
+ project = client.Project.get(project_id)
554
+ model = client.Model.get(project=project, model_id=model_id)
555
+
556
+ try:
557
+ roc_curve = model.get_roc_curve(source=source)
558
+ roc_data = {
559
+ "roc_points": [
560
+ {
561
+ "accuracy": point.get("accuracy", 0),
562
+ "f1_score": point.get("f1_score", 0),
563
+ "false_negative_score": point.get("false_negative_score", 0),
564
+ "true_negative_score": point.get("true_negative_score", 0),
565
+ "true_negative_rate": point.get("true_negative_rate", 0),
566
+ "matthews_correlation_coefficient": point.get(
567
+ "matthews_correlation_coefficient", 0
568
+ ),
569
+ "true_positive_score": point.get("true_positive_score", 0),
570
+ "positive_predictive_value": point.get("positive_predictive_value", 0),
571
+ "false_positive_score": point.get("false_positive_score", 0),
572
+ "false_positive_rate": point.get("false_positive_rate", 0),
573
+ "negative_predictive_value": point.get("negative_predictive_value", 0),
574
+ "true_positive_rate": point.get("true_positive_rate", 0),
575
+ "threshold": point.get("threshold", 0),
576
+ }
577
+ for point in roc_curve.roc_points
578
+ ],
579
+ "negative_class_predictions": roc_curve.negative_class_predictions,
580
+ "positive_class_predictions": roc_curve.positive_class_predictions,
581
+ "source": source,
582
+ }
583
+
584
+ return json.dumps({"data": roc_data, "ui_panel": ["roc-curve"]}, indent=2)
585
+ except Exception as e:
586
+ return json.dumps({"error": f"Failed to get ROC curve: {str(e)}"}, indent=2)
587
+
588
+
589
+ @dr_mcp_tool(tags={"training", "model", "evaluation"})
590
+ async def get_model_feature_impact(project_id: str, model_id: str) -> str:
591
+ """
592
+ Get detailed feature impact for a specific model.
593
+
594
+ Args:
595
+ project_id: The ID of the DataRobot project
596
+ model_id: The ID of the model to analyze
597
+
598
+ Returns
599
+ -------
600
+ JSON string containing:
601
+ - feature_impact: Feature importance scores
602
+ - ui_panel: List of recommended UI panels for visualization
603
+ """
604
+ client = get_sdk_client()
605
+ project = client.Project.get(project_id)
606
+ model = client.Model.get(project=project, model_id=model_id)
607
+ # Get feature impact
608
+ model.request_feature_impact()
609
+ feature_impact = model.get_or_request_feature_impact()
610
+
611
+ return json.dumps({"data": feature_impact, "ui_panel": ["feature-impact"]}, indent=2)
612
+
613
+
614
+ @dr_mcp_tool(tags={"training", "model", "evaluation"})
615
+ async def get_model_lift_chart(project_id: str, model_id: str, source: str = "validation") -> str:
616
+ """
617
+ Get detailed lift chart for a specific model.
618
+
619
+ Args:
620
+ project_id: The ID of the DataRobot project
621
+ model_id: The ID of the model to analyze
622
+ source: The source of the data to use for the lift chart ('validation' or 'holdout' or
623
+ 'crossValidation')
624
+
625
+ Returns
626
+ -------
627
+ JSON string containing:
628
+ - lift_chart: Lift chart data
629
+ - ui_panel: List of recommended UI panels for visualization
630
+ """
631
+ client = get_sdk_client()
632
+ project = client.Project.get(project_id)
633
+ model = client.Model.get(project=project, model_id=model_id)
634
+
635
+ # Get lift chart
636
+ lift_chart = model.get_lift_chart(source=source)
637
+
638
+ lift_chart_data = {
639
+ "bins": [
640
+ {
641
+ "actual": bin["actual"],
642
+ "predicted": bin["predicted"],
643
+ "bin_weight": bin["bin_weight"],
644
+ }
645
+ for bin in lift_chart.bins
646
+ ],
647
+ "source_model_id": lift_chart.source_model_id,
648
+ "target_class": lift_chart.target_class,
649
+ }
650
+
651
+ return json.dumps({"data": lift_chart_data, "ui_panel": ["lift-chart"]}, indent=2)
File without changes