ds-agent-cli 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. package/bin/ds-agent.js +451 -0
  2. package/ds_agent/__init__.py +8 -0
  3. package/package.json +28 -0
  4. package/requirements.txt +126 -0
  5. package/setup.py +35 -0
  6. package/src/__init__.py +7 -0
  7. package/src/_compress_tool_result.py +118 -0
  8. package/src/api/__init__.py +4 -0
  9. package/src/api/app.py +1626 -0
  10. package/src/cache/__init__.py +5 -0
  11. package/src/cache/cache_manager.py +561 -0
  12. package/src/cli.py +2886 -0
  13. package/src/dynamic_prompts.py +281 -0
  14. package/src/orchestrator.py +4799 -0
  15. package/src/progress_manager.py +139 -0
  16. package/src/reasoning/__init__.py +332 -0
  17. package/src/reasoning/business_summary.py +431 -0
  18. package/src/reasoning/data_understanding.py +356 -0
  19. package/src/reasoning/model_explanation.py +383 -0
  20. package/src/reasoning/reasoning_trace.py +239 -0
  21. package/src/registry/__init__.py +3 -0
  22. package/src/registry/tools_registry.py +3 -0
  23. package/src/session_memory.py +448 -0
  24. package/src/session_store.py +370 -0
  25. package/src/storage/__init__.py +19 -0
  26. package/src/storage/artifact_store.py +620 -0
  27. package/src/storage/helpers.py +116 -0
  28. package/src/storage/huggingface_storage.py +694 -0
  29. package/src/storage/r2_storage.py +0 -0
  30. package/src/storage/user_files_service.py +288 -0
  31. package/src/tools/__init__.py +335 -0
  32. package/src/tools/advanced_analysis.py +823 -0
  33. package/src/tools/advanced_feature_engineering.py +708 -0
  34. package/src/tools/advanced_insights.py +578 -0
  35. package/src/tools/advanced_preprocessing.py +549 -0
  36. package/src/tools/advanced_training.py +906 -0
  37. package/src/tools/agent_tool_mapping.py +326 -0
  38. package/src/tools/auto_pipeline.py +420 -0
  39. package/src/tools/autogluon_training.py +1480 -0
  40. package/src/tools/business_intelligence.py +860 -0
  41. package/src/tools/cloud_data_sources.py +581 -0
  42. package/src/tools/code_interpreter.py +390 -0
  43. package/src/tools/computer_vision.py +614 -0
  44. package/src/tools/data_cleaning.py +614 -0
  45. package/src/tools/data_profiling.py +593 -0
  46. package/src/tools/data_type_conversion.py +268 -0
  47. package/src/tools/data_wrangling.py +433 -0
  48. package/src/tools/eda_reports.py +284 -0
  49. package/src/tools/enhanced_feature_engineering.py +241 -0
  50. package/src/tools/feature_engineering.py +302 -0
  51. package/src/tools/matplotlib_visualizations.py +1327 -0
  52. package/src/tools/model_training.py +520 -0
  53. package/src/tools/nlp_text_analytics.py +761 -0
  54. package/src/tools/plotly_visualizations.py +497 -0
  55. package/src/tools/production_mlops.py +852 -0
  56. package/src/tools/time_series.py +507 -0
  57. package/src/tools/tools_registry.py +2133 -0
  58. package/src/tools/visualization_engine.py +559 -0
  59. package/src/utils/__init__.py +42 -0
  60. package/src/utils/error_recovery.py +313 -0
  61. package/src/utils/parallel_executor.py +402 -0
  62. package/src/utils/polars_helpers.py +248 -0
  63. package/src/utils/schema_extraction.py +132 -0
  64. package/src/utils/semantic_layer.py +392 -0
  65. package/src/utils/token_budget.py +411 -0
  66. package/src/utils/validation.py +377 -0
  67. package/src/workflow_state.py +154 -0
@@ -0,0 +1,549 @@
1
+ """
2
+ Advanced Preprocessing Tools
3
+ Tools for handling imbalanced data, feature scaling, and strategic data splitting.
4
+ """
5
+
6
+ import polars as pl
7
+ import numpy as np
8
+ from typing import Dict, Any, List, Optional, Tuple
9
+ from pathlib import Path
10
+ import sys
11
+ import os
12
+ import joblib
13
+ import warnings
14
+
15
+ warnings.filterwarnings('ignore')
16
+
17
+ # Add parent directory to path for imports
18
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
19
+
20
+ from sklearn.model_selection import train_test_split
21
+ from sklearn.preprocessing import StandardScaler, MinMaxScaler, RobustScaler, LabelEncoder
22
+ from imblearn.over_sampling import SMOTE, ADASYN, BorderlineSMOTE
23
+ from imblearn.under_sampling import RandomUnderSampler, TomekLinks, EditedNearestNeighbours
24
+ from imblearn.combine import SMOTETomek, SMOTEENN
25
+ from collections import Counter
26
+
27
+ from ds_agent.utils.polars_helpers import (
28
+ load_dataframe, save_dataframe, get_numeric_columns,
29
+ get_categorical_columns, split_features_target
30
+ )
31
+ from ds_agent.utils.validation import (
32
+ validate_file_exists, validate_file_format, validate_dataframe,
33
+ validate_column_exists
34
+ )
35
+
36
+
37
+ def handle_imbalanced_data(
38
+ file_path: str,
39
+ target_col: str,
40
+ strategy: str = "smote",
41
+ sampling_ratio: float = 1.0,
42
+ output_path: str = None,
43
+ random_state: int = 42
44
+ ) -> Dict[str, Any]:
45
+ """
46
+ Handle imbalanced datasets using various resampling techniques.
47
+
48
+ Args:
49
+ file_path: Path to dataset
50
+ target_col: Target column name
51
+ strategy: Resampling strategy:
52
+ - 'smote': Synthetic Minority Over-sampling (SMOTE)
53
+ - 'adasyn': Adaptive Synthetic Sampling
54
+ - 'borderline_smote': Borderline SMOTE variant
55
+ - 'random_undersample': Random undersampling
56
+ - 'tomek': Tomek Links undersampling
57
+ - 'smote_tomek': Combined SMOTE + Tomek Links
58
+ - 'smote_enn': Combined SMOTE + Edited Nearest Neighbours
59
+ - 'class_weights': Return class weights (no resampling)
60
+ sampling_ratio: Ratio of minority to majority class (0.5 = 50%, 1.0 = 100%)
61
+ output_path: Path to save balanced dataset
62
+ random_state: Random seed
63
+
64
+ Returns:
65
+ Dictionary with balancing results and class distributions
66
+ """
67
+ # Validation
68
+ validate_file_exists(file_path)
69
+ validate_file_format(file_path)
70
+
71
+ # Load data
72
+ df = load_dataframe(file_path)
73
+ validate_dataframe(df)
74
+ validate_column_exists(df, target_col)
75
+
76
+ # Get original class distribution
77
+ original_dist = df[target_col].value_counts().to_dict()
78
+ original_counts = dict(sorted(original_dist.items()))
79
+
80
+ print(f"📊 Original class distribution: {original_counts}")
81
+
82
+ # Calculate imbalance ratio
83
+ class_counts = list(original_counts.values())
84
+ imbalance_ratio = max(class_counts) / min(class_counts)
85
+
86
+ if imbalance_ratio < 1.5:
87
+ return {
88
+ 'status': 'skipped',
89
+ 'message': 'Dataset is already balanced (ratio < 1.5)',
90
+ 'original_distribution': original_counts,
91
+ 'imbalance_ratio': float(imbalance_ratio)
92
+ }
93
+
94
+ # Prepare data
95
+ X, y = split_features_target(df, target_col)
96
+
97
+ # Handle class weights strategy (no resampling)
98
+ if strategy == "class_weights":
99
+ from sklearn.utils.class_weight import compute_class_weight
100
+ classes = np.unique(y)
101
+ weights = compute_class_weight('balanced', classes=classes, y=y)
102
+ class_weights = dict(zip(classes, weights))
103
+
104
+ return {
105
+ 'status': 'success',
106
+ 'strategy': 'class_weights',
107
+ 'class_weights': {str(k): float(v) for k, v in class_weights.items()},
108
+ 'original_distribution': original_counts,
109
+ 'imbalance_ratio': float(imbalance_ratio),
110
+ 'recommendation': 'Use class_weight parameter in your model training'
111
+ }
112
+
113
+ # Create resampler based on strategy
114
+ sampling_strategy = sampling_ratio if sampling_ratio < 1.0 else 'auto'
115
+
116
+ if strategy == "smote":
117
+ resampler = SMOTE(sampling_strategy=sampling_strategy, random_state=random_state)
118
+ elif strategy == "adasyn":
119
+ resampler = ADASYN(sampling_strategy=sampling_strategy, random_state=random_state)
120
+ elif strategy == "borderline_smote":
121
+ resampler = BorderlineSMOTE(sampling_strategy=sampling_strategy, random_state=random_state)
122
+ elif strategy == "random_undersample":
123
+ resampler = RandomUnderSampler(sampling_strategy=sampling_strategy, random_state=random_state)
124
+ elif strategy == "tomek":
125
+ resampler = TomekLinks(sampling_strategy='auto')
126
+ elif strategy == "smote_tomek":
127
+ resampler = SMOTETomek(sampling_strategy=sampling_strategy, random_state=random_state)
128
+ elif strategy == "smote_enn":
129
+ resampler = SMOTEENN(sampling_strategy=sampling_strategy, random_state=random_state)
130
+ else:
131
+ raise ValueError(f"Unsupported strategy: {strategy}")
132
+
133
+ # Perform resampling
134
+ print(f"⚖️ Applying {strategy} resampling...")
135
+ X_resampled, y_resampled = resampler.fit_resample(X, y)
136
+
137
+ # Get new class distribution
138
+ new_counts = dict(Counter(y_resampled))
139
+ new_counts = dict(sorted(new_counts.items()))
140
+
141
+ print(f"✅ New class distribution: {new_counts}")
142
+
143
+ # Calculate changes
144
+ total_original = sum(original_counts.values())
145
+ total_new = sum(new_counts.values())
146
+
147
+ changes = {
148
+ str(cls): {
149
+ 'original': original_counts.get(cls, 0),
150
+ 'new': new_counts.get(cls, 0),
151
+ 'change': new_counts.get(cls, 0) - original_counts.get(cls, 0)
152
+ }
153
+ for cls in set(list(original_counts.keys()) + list(new_counts.keys()))
154
+ }
155
+
156
+ # Create balanced dataframe
157
+ feature_cols = [col for col in df.columns if col != target_col]
158
+ balanced_data = {col: X_resampled[:, i] for i, col in enumerate(feature_cols)}
159
+ balanced_data[target_col] = y_resampled
160
+
161
+ balanced_df = pl.DataFrame(balanced_data)
162
+
163
+ # Save if output path provided
164
+ if output_path:
165
+ save_dataframe(balanced_df, output_path)
166
+ print(f"💾 Balanced dataset saved to: {output_path}")
167
+
168
+ return {
169
+ 'status': 'success',
170
+ 'strategy': strategy,
171
+ 'original_distribution': original_counts,
172
+ 'new_distribution': new_counts,
173
+ 'changes_by_class': changes,
174
+ 'total_samples_before': total_original,
175
+ 'total_samples_after': total_new,
176
+ 'sample_change': f"{'+' if total_new > total_original else ''}{total_new - total_original}",
177
+ 'new_imbalance_ratio': float(max(new_counts.values()) / min(new_counts.values())),
178
+ 'output_path': output_path
179
+ }
180
+
181
+
182
+ def perform_feature_scaling(
183
+ file_path: str,
184
+ scaler_type: str = "standard",
185
+ columns: Optional[List[str]] = None,
186
+ output_path: Optional[str] = None,
187
+ scaler_save_path: Optional[str] = None
188
+ ) -> Dict[str, Any]:
189
+ """
190
+ Scale features using various normalization techniques.
191
+
192
+ Args:
193
+ file_path: Path to dataset
194
+ scaler_type: Scaling method:
195
+ - 'standard': StandardScaler (mean=0, std=1)
196
+ - 'minmax': MinMaxScaler (range 0-1)
197
+ - 'robust': RobustScaler (median, IQR - robust to outliers)
198
+ - 'power': PowerTransformer (Yeo-Johnson, makes data more Gaussian)
199
+ - 'quantile': QuantileTransformer (uniform or normal output distribution)
200
+ columns: List of columns to scale (None = all numeric columns)
201
+ output_path: Path to save scaled dataset
202
+ scaler_save_path: Path to save fitted scaler for future use
203
+
204
+ Returns:
205
+ Dictionary with scaling statistics
206
+ """
207
+ # Validation
208
+ validate_file_exists(file_path)
209
+ validate_file_format(file_path)
210
+
211
+ # Load data
212
+ df = load_dataframe(file_path)
213
+ validate_dataframe(df)
214
+
215
+ # Get numeric columns if not specified
216
+ if columns is None:
217
+ columns = get_numeric_columns(df)
218
+ print(f"🔢 Auto-detected {len(columns)} numeric columns for scaling")
219
+ else:
220
+ for col in columns:
221
+ validate_column_exists(df, col)
222
+
223
+ if not columns:
224
+ return {
225
+ 'status': 'skipped',
226
+ 'message': 'No numeric columns found to scale'
227
+ }
228
+
229
+ # Create scaler
230
+ if scaler_type == "standard":
231
+ scaler = StandardScaler()
232
+ elif scaler_type == "minmax":
233
+ scaler = MinMaxScaler()
234
+ elif scaler_type == "robust":
235
+ scaler = RobustScaler()
236
+ elif scaler_type == "power":
237
+ from sklearn.preprocessing import PowerTransformer
238
+ scaler = PowerTransformer(method='yeo-johnson', standardize=True)
239
+ print(" 📐 Using Yeo-Johnson PowerTransformer (makes data more Gaussian)")
240
+ elif scaler_type == "quantile":
241
+ from sklearn.preprocessing import QuantileTransformer
242
+ scaler = QuantileTransformer(output_distribution='normal', random_state=42, n_quantiles=min(1000, len(df)))
243
+ print(" 📐 Using QuantileTransformer (maps to normal distribution)")
244
+ else:
245
+ raise ValueError(f"Unsupported scaler_type: {scaler_type}. Use 'standard', 'minmax', 'robust', 'power', or 'quantile'.")
246
+
247
+ # Get original statistics
248
+ original_stats = {}
249
+ for col in columns:
250
+ col_data = df[col].to_numpy()
251
+ original_stats[col] = {
252
+ 'mean': float(np.mean(col_data)),
253
+ 'std': float(np.std(col_data)),
254
+ 'min': float(np.min(col_data)),
255
+ 'max': float(np.max(col_data)),
256
+ 'median': float(np.median(col_data))
257
+ }
258
+
259
+ # Fit and transform
260
+ print(f"📏 Applying {scaler_type} scaling to {len(columns)} columns...")
261
+ scaled_data = scaler.fit_transform(df[columns].to_numpy())
262
+
263
+ # Create scaled dataframe
264
+ df_scaled = df.clone()
265
+ for i, col in enumerate(columns):
266
+ df_scaled = df_scaled.with_columns(
267
+ pl.Series(col, scaled_data[:, i])
268
+ )
269
+
270
+ # Get new statistics
271
+ new_stats = {}
272
+ for i, col in enumerate(columns):
273
+ new_stats[col] = {
274
+ 'mean': float(np.mean(scaled_data[:, i])),
275
+ 'std': float(np.std(scaled_data[:, i])),
276
+ 'min': float(np.min(scaled_data[:, i])),
277
+ 'max': float(np.max(scaled_data[:, i])),
278
+ 'median': float(np.median(scaled_data[:, i]))
279
+ }
280
+
281
+ # Save scaled data
282
+ if output_path:
283
+ save_dataframe(df_scaled, output_path)
284
+ print(f"💾 Scaled dataset saved to: {output_path}")
285
+
286
+ # Save scaler
287
+ if scaler_save_path:
288
+ os.makedirs(os.path.dirname(scaler_save_path), exist_ok=True)
289
+ joblib.dump(scaler, scaler_save_path)
290
+ print(f"💾 Scaler saved to: {scaler_save_path}")
291
+
292
+ return {
293
+ 'status': 'success',
294
+ 'scaler_type': scaler_type,
295
+ 'columns_scaled': columns,
296
+ 'n_columns': len(columns),
297
+ 'original_stats': original_stats,
298
+ 'scaled_stats': new_stats,
299
+ 'output_path': output_path,
300
+ 'scaler_path': scaler_save_path
301
+ }
302
+
303
+
304
+ def split_data_strategically(
305
+ file_path: str,
306
+ target_col: Optional[str] = None,
307
+ split_type: str = "train_test",
308
+ test_size: float = 0.2,
309
+ val_size: float = 0.1,
310
+ stratify: bool = True,
311
+ time_col: Optional[str] = None,
312
+ group_col: Optional[str] = None,
313
+ random_state: int = 42,
314
+ output_dir: Optional[str] = None
315
+ ) -> Dict[str, Any]:
316
+ """
317
+ Perform strategic data splitting with multiple options.
318
+
319
+ Args:
320
+ file_path: Path to dataset
321
+ target_col: Target column (for stratification)
322
+ split_type: Split strategy:
323
+ - 'train_test': Train/test split
324
+ - 'train_val_test': Train/validation/test split
325
+ - 'time_based': Time-based split (requires time_col)
326
+ - 'group_based': Group-based split (requires group_col, prevents leakage)
327
+ test_size: Test set proportion
328
+ val_size: Validation set proportion (for train_val_test)
329
+ stratify: Whether to stratify by target
330
+ time_col: Column to use for time-based splitting
331
+ group_col: Column to use for group-based splitting
332
+ random_state: Random seed
333
+ output_dir: Directory to save split datasets
334
+
335
+ Returns:
336
+ Dictionary with split information and file paths
337
+ """
338
+ # Validation
339
+ validate_file_exists(file_path)
340
+ validate_file_format(file_path)
341
+
342
+ # Load data
343
+ df = load_dataframe(file_path)
344
+ validate_dataframe(df)
345
+
346
+ if target_col:
347
+ validate_column_exists(df, target_col)
348
+
349
+ n_samples = len(df)
350
+
351
+ # Time-based split
352
+ if split_type == "time_based":
353
+ if not time_col:
354
+ raise ValueError("time_col is required for time_based split")
355
+ validate_column_exists(df, time_col)
356
+
357
+ # Sort by time
358
+ df = df.sort(time_col)
359
+
360
+ # Calculate split points
361
+ test_idx = int(n_samples * (1 - test_size))
362
+
363
+ if output_dir:
364
+ train_df = df[:test_idx]
365
+ test_df = df[test_idx:]
366
+
367
+ os.makedirs(output_dir, exist_ok=True)
368
+ train_path = os.path.join(output_dir, "train.csv")
369
+ test_path = os.path.join(output_dir, "test.csv")
370
+
371
+ save_dataframe(train_df, train_path)
372
+ save_dataframe(test_df, test_path)
373
+
374
+ print(f"✅ Time-based split: train={len(train_df)}, test={len(test_df)}")
375
+
376
+ return {
377
+ 'status': 'success',
378
+ 'split_type': 'time_based',
379
+ 'train_size': len(train_df),
380
+ 'test_size': len(test_df),
381
+ 'train_path': train_path,
382
+ 'test_path': test_path,
383
+ 'time_column': time_col
384
+ }
385
+
386
+ # Group-based split
387
+ elif split_type == "group_based":
388
+ if not group_col:
389
+ raise ValueError("group_col is required for group_based split")
390
+ validate_column_exists(df, group_col)
391
+
392
+ # Get unique groups
393
+ unique_groups = df[group_col].unique().to_list()
394
+ n_groups = len(unique_groups)
395
+
396
+ # Split groups
397
+ np.random.seed(random_state)
398
+ np.random.shuffle(unique_groups)
399
+
400
+ test_n_groups = max(1, int(n_groups * test_size))
401
+ test_groups = unique_groups[:test_n_groups]
402
+ train_groups = unique_groups[test_n_groups:]
403
+
404
+ train_df = df.filter(pl.col(group_col).is_in(train_groups))
405
+ test_df = df.filter(pl.col(group_col).is_in(test_groups))
406
+
407
+ if output_dir:
408
+ os.makedirs(output_dir, exist_ok=True)
409
+ train_path = os.path.join(output_dir, "train.csv")
410
+ test_path = os.path.join(output_dir, "test.csv")
411
+
412
+ save_dataframe(train_df, train_path)
413
+ save_dataframe(test_df, test_path)
414
+
415
+ print(f"✅ Group-based split: train={len(train_df)}, test={len(test_df)}")
416
+
417
+ return {
418
+ 'status': 'success',
419
+ 'split_type': 'group_based',
420
+ 'train_size': len(train_df),
421
+ 'test_size': len(test_df),
422
+ 'train_groups': len(train_groups),
423
+ 'test_groups': len(test_groups),
424
+ 'train_path': train_path,
425
+ 'test_path': test_path,
426
+ 'group_column': group_col
427
+ }
428
+
429
+ # Standard train/test split
430
+ elif split_type == "train_test":
431
+ X, y = split_features_target(df, target_col) if target_col else (df.to_numpy(), None)
432
+
433
+ stratify_y = y if (stratify and target_col and len(np.unique(y)) < 20) else None
434
+
435
+ if target_col:
436
+ X_train, X_test, y_train, y_test = train_test_split(
437
+ X, y, test_size=test_size, random_state=random_state, stratify=stratify_y
438
+ )
439
+
440
+ # Reconstruct dataframes
441
+ feature_cols = [col for col in df.columns if col != target_col]
442
+ train_data = {col: X_train[:, i] for i, col in enumerate(feature_cols)}
443
+ train_data[target_col] = y_train
444
+ train_df = pl.DataFrame(train_data)
445
+
446
+ test_data = {col: X_test[:, i] for i, col in enumerate(feature_cols)}
447
+ test_data[target_col] = y_test
448
+ test_df = pl.DataFrame(test_data)
449
+ else:
450
+ indices = np.arange(len(df))
451
+ train_idx, test_idx = train_test_split(
452
+ indices, test_size=test_size, random_state=random_state
453
+ )
454
+ train_df = df[train_idx]
455
+ test_df = df[test_idx]
456
+
457
+ if output_dir:
458
+ os.makedirs(output_dir, exist_ok=True)
459
+ train_path = os.path.join(output_dir, "train.csv")
460
+ test_path = os.path.join(output_dir, "test.csv")
461
+
462
+ save_dataframe(train_df, train_path)
463
+ save_dataframe(test_df, test_path)
464
+
465
+ print(f"✅ Train/test split: train={len(train_df)}, test={len(test_df)}")
466
+
467
+ return {
468
+ 'status': 'success',
469
+ 'split_type': 'train_test',
470
+ 'train_size': len(train_df),
471
+ 'test_size': len(test_df),
472
+ 'stratified': bool(stratify_y is not None),
473
+ 'train_path': train_path,
474
+ 'test_path': test_path
475
+ }
476
+
477
+ # Train/val/test split
478
+ elif split_type == "train_val_test":
479
+ X, y = split_features_target(df, target_col) if target_col else (df.to_numpy(), None)
480
+
481
+ stratify_y = y if (stratify and target_col and len(np.unique(y)) < 20) else None
482
+
483
+ # First split: train+val vs test
484
+ if target_col:
485
+ X_temp, X_test, y_temp, y_test = train_test_split(
486
+ X, y, test_size=test_size, random_state=random_state, stratify=stratify_y
487
+ )
488
+
489
+ # Second split: train vs val
490
+ val_ratio = val_size / (1 - test_size)
491
+ stratify_temp = y_temp if stratify_y is not None else None
492
+ X_train, X_val, y_train, y_val = train_test_split(
493
+ X_temp, y_temp, test_size=val_ratio, random_state=random_state, stratify=stratify_temp
494
+ )
495
+
496
+ # Reconstruct dataframes
497
+ feature_cols = [col for col in df.columns if col != target_col]
498
+
499
+ train_data = {col: X_train[:, i] for i, col in enumerate(feature_cols)}
500
+ train_data[target_col] = y_train
501
+ train_df = pl.DataFrame(train_data)
502
+
503
+ val_data = {col: X_val[:, i] for i, col in enumerate(feature_cols)}
504
+ val_data[target_col] = y_val
505
+ val_df = pl.DataFrame(val_data)
506
+
507
+ test_data = {col: X_test[:, i] for i, col in enumerate(feature_cols)}
508
+ test_data[target_col] = y_test
509
+ test_df = pl.DataFrame(test_data)
510
+ else:
511
+ indices = np.arange(len(df))
512
+ temp_idx, test_idx = train_test_split(
513
+ indices, test_size=test_size, random_state=random_state
514
+ )
515
+ val_ratio = val_size / (1 - test_size)
516
+ train_idx, val_idx = train_test_split(
517
+ temp_idx, test_size=val_ratio, random_state=random_state
518
+ )
519
+
520
+ train_df = df[train_idx]
521
+ val_df = df[val_idx]
522
+ test_df = df[test_idx]
523
+
524
+ if output_dir:
525
+ os.makedirs(output_dir, exist_ok=True)
526
+ train_path = os.path.join(output_dir, "train.csv")
527
+ val_path = os.path.join(output_dir, "val.csv")
528
+ test_path = os.path.join(output_dir, "test.csv")
529
+
530
+ save_dataframe(train_df, train_path)
531
+ save_dataframe(val_df, val_path)
532
+ save_dataframe(test_df, test_path)
533
+
534
+ print(f"✅ Train/val/test split: train={len(train_df)}, val={len(val_df)}, test={len(test_df)}")
535
+
536
+ return {
537
+ 'status': 'success',
538
+ 'split_type': 'train_val_test',
539
+ 'train_size': len(train_df),
540
+ 'val_size': len(val_df),
541
+ 'test_size': len(test_df),
542
+ 'stratified': bool(stratify_y is not None),
543
+ 'train_path': train_path,
544
+ 'val_path': val_path,
545
+ 'test_path': test_path
546
+ }
547
+
548
+ else:
549
+ raise ValueError(f"Unsupported split_type: {split_type}")