ds-agent-cli 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/bin/ds-agent.js +451 -0
- package/ds_agent/__init__.py +8 -0
- package/package.json +28 -0
- package/requirements.txt +126 -0
- package/setup.py +35 -0
- package/src/__init__.py +7 -0
- package/src/_compress_tool_result.py +118 -0
- package/src/api/__init__.py +4 -0
- package/src/api/app.py +1626 -0
- package/src/cache/__init__.py +5 -0
- package/src/cache/cache_manager.py +561 -0
- package/src/cli.py +2886 -0
- package/src/dynamic_prompts.py +281 -0
- package/src/orchestrator.py +4799 -0
- package/src/progress_manager.py +139 -0
- package/src/reasoning/__init__.py +332 -0
- package/src/reasoning/business_summary.py +431 -0
- package/src/reasoning/data_understanding.py +356 -0
- package/src/reasoning/model_explanation.py +383 -0
- package/src/reasoning/reasoning_trace.py +239 -0
- package/src/registry/__init__.py +3 -0
- package/src/registry/tools_registry.py +3 -0
- package/src/session_memory.py +448 -0
- package/src/session_store.py +370 -0
- package/src/storage/__init__.py +19 -0
- package/src/storage/artifact_store.py +620 -0
- package/src/storage/helpers.py +116 -0
- package/src/storage/huggingface_storage.py +694 -0
- package/src/storage/r2_storage.py +0 -0
- package/src/storage/user_files_service.py +288 -0
- package/src/tools/__init__.py +335 -0
- package/src/tools/advanced_analysis.py +823 -0
- package/src/tools/advanced_feature_engineering.py +708 -0
- package/src/tools/advanced_insights.py +578 -0
- package/src/tools/advanced_preprocessing.py +549 -0
- package/src/tools/advanced_training.py +906 -0
- package/src/tools/agent_tool_mapping.py +326 -0
- package/src/tools/auto_pipeline.py +420 -0
- package/src/tools/autogluon_training.py +1480 -0
- package/src/tools/business_intelligence.py +860 -0
- package/src/tools/cloud_data_sources.py +581 -0
- package/src/tools/code_interpreter.py +390 -0
- package/src/tools/computer_vision.py +614 -0
- package/src/tools/data_cleaning.py +614 -0
- package/src/tools/data_profiling.py +593 -0
- package/src/tools/data_type_conversion.py +268 -0
- package/src/tools/data_wrangling.py +433 -0
- package/src/tools/eda_reports.py +284 -0
- package/src/tools/enhanced_feature_engineering.py +241 -0
- package/src/tools/feature_engineering.py +302 -0
- package/src/tools/matplotlib_visualizations.py +1327 -0
- package/src/tools/model_training.py +520 -0
- package/src/tools/nlp_text_analytics.py +761 -0
- package/src/tools/plotly_visualizations.py +497 -0
- package/src/tools/production_mlops.py +852 -0
- package/src/tools/time_series.py +507 -0
- package/src/tools/tools_registry.py +2133 -0
- package/src/tools/visualization_engine.py +559 -0
- package/src/utils/__init__.py +42 -0
- package/src/utils/error_recovery.py +313 -0
- package/src/utils/parallel_executor.py +402 -0
- package/src/utils/polars_helpers.py +248 -0
- package/src/utils/schema_extraction.py +132 -0
- package/src/utils/semantic_layer.py +392 -0
- package/src/utils/token_budget.py +411 -0
- package/src/utils/validation.py +377 -0
- package/src/workflow_state.py +154 -0
|
@@ -0,0 +1,549 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Advanced Preprocessing Tools
|
|
3
|
+
Tools for handling imbalanced data, feature scaling, and strategic data splitting.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import polars as pl
|
|
7
|
+
import numpy as np
|
|
8
|
+
from typing import Dict, Any, List, Optional, Tuple
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
import sys
|
|
11
|
+
import os
|
|
12
|
+
import joblib
|
|
13
|
+
import warnings
|
|
14
|
+
|
|
15
|
+
warnings.filterwarnings('ignore')
|
|
16
|
+
|
|
17
|
+
# Add parent directory to path for imports
|
|
18
|
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
19
|
+
|
|
20
|
+
from sklearn.model_selection import train_test_split
|
|
21
|
+
from sklearn.preprocessing import StandardScaler, MinMaxScaler, RobustScaler, LabelEncoder
|
|
22
|
+
from imblearn.over_sampling import SMOTE, ADASYN, BorderlineSMOTE
|
|
23
|
+
from imblearn.under_sampling import RandomUnderSampler, TomekLinks, EditedNearestNeighbours
|
|
24
|
+
from imblearn.combine import SMOTETomek, SMOTEENN
|
|
25
|
+
from collections import Counter
|
|
26
|
+
|
|
27
|
+
from ds_agent.utils.polars_helpers import (
|
|
28
|
+
load_dataframe, save_dataframe, get_numeric_columns,
|
|
29
|
+
get_categorical_columns, split_features_target
|
|
30
|
+
)
|
|
31
|
+
from ds_agent.utils.validation import (
|
|
32
|
+
validate_file_exists, validate_file_format, validate_dataframe,
|
|
33
|
+
validate_column_exists
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def handle_imbalanced_data(
|
|
38
|
+
file_path: str,
|
|
39
|
+
target_col: str,
|
|
40
|
+
strategy: str = "smote",
|
|
41
|
+
sampling_ratio: float = 1.0,
|
|
42
|
+
output_path: str = None,
|
|
43
|
+
random_state: int = 42
|
|
44
|
+
) -> Dict[str, Any]:
|
|
45
|
+
"""
|
|
46
|
+
Handle imbalanced datasets using various resampling techniques.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
file_path: Path to dataset
|
|
50
|
+
target_col: Target column name
|
|
51
|
+
strategy: Resampling strategy:
|
|
52
|
+
- 'smote': Synthetic Minority Over-sampling (SMOTE)
|
|
53
|
+
- 'adasyn': Adaptive Synthetic Sampling
|
|
54
|
+
- 'borderline_smote': Borderline SMOTE variant
|
|
55
|
+
- 'random_undersample': Random undersampling
|
|
56
|
+
- 'tomek': Tomek Links undersampling
|
|
57
|
+
- 'smote_tomek': Combined SMOTE + Tomek Links
|
|
58
|
+
- 'smote_enn': Combined SMOTE + Edited Nearest Neighbours
|
|
59
|
+
- 'class_weights': Return class weights (no resampling)
|
|
60
|
+
sampling_ratio: Ratio of minority to majority class (0.5 = 50%, 1.0 = 100%)
|
|
61
|
+
output_path: Path to save balanced dataset
|
|
62
|
+
random_state: Random seed
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
Dictionary with balancing results and class distributions
|
|
66
|
+
"""
|
|
67
|
+
# Validation
|
|
68
|
+
validate_file_exists(file_path)
|
|
69
|
+
validate_file_format(file_path)
|
|
70
|
+
|
|
71
|
+
# Load data
|
|
72
|
+
df = load_dataframe(file_path)
|
|
73
|
+
validate_dataframe(df)
|
|
74
|
+
validate_column_exists(df, target_col)
|
|
75
|
+
|
|
76
|
+
# Get original class distribution
|
|
77
|
+
original_dist = df[target_col].value_counts().to_dict()
|
|
78
|
+
original_counts = dict(sorted(original_dist.items()))
|
|
79
|
+
|
|
80
|
+
print(f"📊 Original class distribution: {original_counts}")
|
|
81
|
+
|
|
82
|
+
# Calculate imbalance ratio
|
|
83
|
+
class_counts = list(original_counts.values())
|
|
84
|
+
imbalance_ratio = max(class_counts) / min(class_counts)
|
|
85
|
+
|
|
86
|
+
if imbalance_ratio < 1.5:
|
|
87
|
+
return {
|
|
88
|
+
'status': 'skipped',
|
|
89
|
+
'message': 'Dataset is already balanced (ratio < 1.5)',
|
|
90
|
+
'original_distribution': original_counts,
|
|
91
|
+
'imbalance_ratio': float(imbalance_ratio)
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
# Prepare data
|
|
95
|
+
X, y = split_features_target(df, target_col)
|
|
96
|
+
|
|
97
|
+
# Handle class weights strategy (no resampling)
|
|
98
|
+
if strategy == "class_weights":
|
|
99
|
+
from sklearn.utils.class_weight import compute_class_weight
|
|
100
|
+
classes = np.unique(y)
|
|
101
|
+
weights = compute_class_weight('balanced', classes=classes, y=y)
|
|
102
|
+
class_weights = dict(zip(classes, weights))
|
|
103
|
+
|
|
104
|
+
return {
|
|
105
|
+
'status': 'success',
|
|
106
|
+
'strategy': 'class_weights',
|
|
107
|
+
'class_weights': {str(k): float(v) for k, v in class_weights.items()},
|
|
108
|
+
'original_distribution': original_counts,
|
|
109
|
+
'imbalance_ratio': float(imbalance_ratio),
|
|
110
|
+
'recommendation': 'Use class_weight parameter in your model training'
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
# Create resampler based on strategy
|
|
114
|
+
sampling_strategy = sampling_ratio if sampling_ratio < 1.0 else 'auto'
|
|
115
|
+
|
|
116
|
+
if strategy == "smote":
|
|
117
|
+
resampler = SMOTE(sampling_strategy=sampling_strategy, random_state=random_state)
|
|
118
|
+
elif strategy == "adasyn":
|
|
119
|
+
resampler = ADASYN(sampling_strategy=sampling_strategy, random_state=random_state)
|
|
120
|
+
elif strategy == "borderline_smote":
|
|
121
|
+
resampler = BorderlineSMOTE(sampling_strategy=sampling_strategy, random_state=random_state)
|
|
122
|
+
elif strategy == "random_undersample":
|
|
123
|
+
resampler = RandomUnderSampler(sampling_strategy=sampling_strategy, random_state=random_state)
|
|
124
|
+
elif strategy == "tomek":
|
|
125
|
+
resampler = TomekLinks(sampling_strategy='auto')
|
|
126
|
+
elif strategy == "smote_tomek":
|
|
127
|
+
resampler = SMOTETomek(sampling_strategy=sampling_strategy, random_state=random_state)
|
|
128
|
+
elif strategy == "smote_enn":
|
|
129
|
+
resampler = SMOTEENN(sampling_strategy=sampling_strategy, random_state=random_state)
|
|
130
|
+
else:
|
|
131
|
+
raise ValueError(f"Unsupported strategy: {strategy}")
|
|
132
|
+
|
|
133
|
+
# Perform resampling
|
|
134
|
+
print(f"⚖️ Applying {strategy} resampling...")
|
|
135
|
+
X_resampled, y_resampled = resampler.fit_resample(X, y)
|
|
136
|
+
|
|
137
|
+
# Get new class distribution
|
|
138
|
+
new_counts = dict(Counter(y_resampled))
|
|
139
|
+
new_counts = dict(sorted(new_counts.items()))
|
|
140
|
+
|
|
141
|
+
print(f"✅ New class distribution: {new_counts}")
|
|
142
|
+
|
|
143
|
+
# Calculate changes
|
|
144
|
+
total_original = sum(original_counts.values())
|
|
145
|
+
total_new = sum(new_counts.values())
|
|
146
|
+
|
|
147
|
+
changes = {
|
|
148
|
+
str(cls): {
|
|
149
|
+
'original': original_counts.get(cls, 0),
|
|
150
|
+
'new': new_counts.get(cls, 0),
|
|
151
|
+
'change': new_counts.get(cls, 0) - original_counts.get(cls, 0)
|
|
152
|
+
}
|
|
153
|
+
for cls in set(list(original_counts.keys()) + list(new_counts.keys()))
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
# Create balanced dataframe
|
|
157
|
+
feature_cols = [col for col in df.columns if col != target_col]
|
|
158
|
+
balanced_data = {col: X_resampled[:, i] for i, col in enumerate(feature_cols)}
|
|
159
|
+
balanced_data[target_col] = y_resampled
|
|
160
|
+
|
|
161
|
+
balanced_df = pl.DataFrame(balanced_data)
|
|
162
|
+
|
|
163
|
+
# Save if output path provided
|
|
164
|
+
if output_path:
|
|
165
|
+
save_dataframe(balanced_df, output_path)
|
|
166
|
+
print(f"💾 Balanced dataset saved to: {output_path}")
|
|
167
|
+
|
|
168
|
+
return {
|
|
169
|
+
'status': 'success',
|
|
170
|
+
'strategy': strategy,
|
|
171
|
+
'original_distribution': original_counts,
|
|
172
|
+
'new_distribution': new_counts,
|
|
173
|
+
'changes_by_class': changes,
|
|
174
|
+
'total_samples_before': total_original,
|
|
175
|
+
'total_samples_after': total_new,
|
|
176
|
+
'sample_change': f"{'+' if total_new > total_original else ''}{total_new - total_original}",
|
|
177
|
+
'new_imbalance_ratio': float(max(new_counts.values()) / min(new_counts.values())),
|
|
178
|
+
'output_path': output_path
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def perform_feature_scaling(
|
|
183
|
+
file_path: str,
|
|
184
|
+
scaler_type: str = "standard",
|
|
185
|
+
columns: Optional[List[str]] = None,
|
|
186
|
+
output_path: Optional[str] = None,
|
|
187
|
+
scaler_save_path: Optional[str] = None
|
|
188
|
+
) -> Dict[str, Any]:
|
|
189
|
+
"""
|
|
190
|
+
Scale features using various normalization techniques.
|
|
191
|
+
|
|
192
|
+
Args:
|
|
193
|
+
file_path: Path to dataset
|
|
194
|
+
scaler_type: Scaling method:
|
|
195
|
+
- 'standard': StandardScaler (mean=0, std=1)
|
|
196
|
+
- 'minmax': MinMaxScaler (range 0-1)
|
|
197
|
+
- 'robust': RobustScaler (median, IQR - robust to outliers)
|
|
198
|
+
- 'power': PowerTransformer (Yeo-Johnson, makes data more Gaussian)
|
|
199
|
+
- 'quantile': QuantileTransformer (uniform or normal output distribution)
|
|
200
|
+
columns: List of columns to scale (None = all numeric columns)
|
|
201
|
+
output_path: Path to save scaled dataset
|
|
202
|
+
scaler_save_path: Path to save fitted scaler for future use
|
|
203
|
+
|
|
204
|
+
Returns:
|
|
205
|
+
Dictionary with scaling statistics
|
|
206
|
+
"""
|
|
207
|
+
# Validation
|
|
208
|
+
validate_file_exists(file_path)
|
|
209
|
+
validate_file_format(file_path)
|
|
210
|
+
|
|
211
|
+
# Load data
|
|
212
|
+
df = load_dataframe(file_path)
|
|
213
|
+
validate_dataframe(df)
|
|
214
|
+
|
|
215
|
+
# Get numeric columns if not specified
|
|
216
|
+
if columns is None:
|
|
217
|
+
columns = get_numeric_columns(df)
|
|
218
|
+
print(f"🔢 Auto-detected {len(columns)} numeric columns for scaling")
|
|
219
|
+
else:
|
|
220
|
+
for col in columns:
|
|
221
|
+
validate_column_exists(df, col)
|
|
222
|
+
|
|
223
|
+
if not columns:
|
|
224
|
+
return {
|
|
225
|
+
'status': 'skipped',
|
|
226
|
+
'message': 'No numeric columns found to scale'
|
|
227
|
+
}
|
|
228
|
+
|
|
229
|
+
# Create scaler
|
|
230
|
+
if scaler_type == "standard":
|
|
231
|
+
scaler = StandardScaler()
|
|
232
|
+
elif scaler_type == "minmax":
|
|
233
|
+
scaler = MinMaxScaler()
|
|
234
|
+
elif scaler_type == "robust":
|
|
235
|
+
scaler = RobustScaler()
|
|
236
|
+
elif scaler_type == "power":
|
|
237
|
+
from sklearn.preprocessing import PowerTransformer
|
|
238
|
+
scaler = PowerTransformer(method='yeo-johnson', standardize=True)
|
|
239
|
+
print(" 📐 Using Yeo-Johnson PowerTransformer (makes data more Gaussian)")
|
|
240
|
+
elif scaler_type == "quantile":
|
|
241
|
+
from sklearn.preprocessing import QuantileTransformer
|
|
242
|
+
scaler = QuantileTransformer(output_distribution='normal', random_state=42, n_quantiles=min(1000, len(df)))
|
|
243
|
+
print(" 📐 Using QuantileTransformer (maps to normal distribution)")
|
|
244
|
+
else:
|
|
245
|
+
raise ValueError(f"Unsupported scaler_type: {scaler_type}. Use 'standard', 'minmax', 'robust', 'power', or 'quantile'.")
|
|
246
|
+
|
|
247
|
+
# Get original statistics
|
|
248
|
+
original_stats = {}
|
|
249
|
+
for col in columns:
|
|
250
|
+
col_data = df[col].to_numpy()
|
|
251
|
+
original_stats[col] = {
|
|
252
|
+
'mean': float(np.mean(col_data)),
|
|
253
|
+
'std': float(np.std(col_data)),
|
|
254
|
+
'min': float(np.min(col_data)),
|
|
255
|
+
'max': float(np.max(col_data)),
|
|
256
|
+
'median': float(np.median(col_data))
|
|
257
|
+
}
|
|
258
|
+
|
|
259
|
+
# Fit and transform
|
|
260
|
+
print(f"📏 Applying {scaler_type} scaling to {len(columns)} columns...")
|
|
261
|
+
scaled_data = scaler.fit_transform(df[columns].to_numpy())
|
|
262
|
+
|
|
263
|
+
# Create scaled dataframe
|
|
264
|
+
df_scaled = df.clone()
|
|
265
|
+
for i, col in enumerate(columns):
|
|
266
|
+
df_scaled = df_scaled.with_columns(
|
|
267
|
+
pl.Series(col, scaled_data[:, i])
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
# Get new statistics
|
|
271
|
+
new_stats = {}
|
|
272
|
+
for i, col in enumerate(columns):
|
|
273
|
+
new_stats[col] = {
|
|
274
|
+
'mean': float(np.mean(scaled_data[:, i])),
|
|
275
|
+
'std': float(np.std(scaled_data[:, i])),
|
|
276
|
+
'min': float(np.min(scaled_data[:, i])),
|
|
277
|
+
'max': float(np.max(scaled_data[:, i])),
|
|
278
|
+
'median': float(np.median(scaled_data[:, i]))
|
|
279
|
+
}
|
|
280
|
+
|
|
281
|
+
# Save scaled data
|
|
282
|
+
if output_path:
|
|
283
|
+
save_dataframe(df_scaled, output_path)
|
|
284
|
+
print(f"💾 Scaled dataset saved to: {output_path}")
|
|
285
|
+
|
|
286
|
+
# Save scaler
|
|
287
|
+
if scaler_save_path:
|
|
288
|
+
os.makedirs(os.path.dirname(scaler_save_path), exist_ok=True)
|
|
289
|
+
joblib.dump(scaler, scaler_save_path)
|
|
290
|
+
print(f"💾 Scaler saved to: {scaler_save_path}")
|
|
291
|
+
|
|
292
|
+
return {
|
|
293
|
+
'status': 'success',
|
|
294
|
+
'scaler_type': scaler_type,
|
|
295
|
+
'columns_scaled': columns,
|
|
296
|
+
'n_columns': len(columns),
|
|
297
|
+
'original_stats': original_stats,
|
|
298
|
+
'scaled_stats': new_stats,
|
|
299
|
+
'output_path': output_path,
|
|
300
|
+
'scaler_path': scaler_save_path
|
|
301
|
+
}
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
def split_data_strategically(
|
|
305
|
+
file_path: str,
|
|
306
|
+
target_col: Optional[str] = None,
|
|
307
|
+
split_type: str = "train_test",
|
|
308
|
+
test_size: float = 0.2,
|
|
309
|
+
val_size: float = 0.1,
|
|
310
|
+
stratify: bool = True,
|
|
311
|
+
time_col: Optional[str] = None,
|
|
312
|
+
group_col: Optional[str] = None,
|
|
313
|
+
random_state: int = 42,
|
|
314
|
+
output_dir: Optional[str] = None
|
|
315
|
+
) -> Dict[str, Any]:
|
|
316
|
+
"""
|
|
317
|
+
Perform strategic data splitting with multiple options.
|
|
318
|
+
|
|
319
|
+
Args:
|
|
320
|
+
file_path: Path to dataset
|
|
321
|
+
target_col: Target column (for stratification)
|
|
322
|
+
split_type: Split strategy:
|
|
323
|
+
- 'train_test': Train/test split
|
|
324
|
+
- 'train_val_test': Train/validation/test split
|
|
325
|
+
- 'time_based': Time-based split (requires time_col)
|
|
326
|
+
- 'group_based': Group-based split (requires group_col, prevents leakage)
|
|
327
|
+
test_size: Test set proportion
|
|
328
|
+
val_size: Validation set proportion (for train_val_test)
|
|
329
|
+
stratify: Whether to stratify by target
|
|
330
|
+
time_col: Column to use for time-based splitting
|
|
331
|
+
group_col: Column to use for group-based splitting
|
|
332
|
+
random_state: Random seed
|
|
333
|
+
output_dir: Directory to save split datasets
|
|
334
|
+
|
|
335
|
+
Returns:
|
|
336
|
+
Dictionary with split information and file paths
|
|
337
|
+
"""
|
|
338
|
+
# Validation
|
|
339
|
+
validate_file_exists(file_path)
|
|
340
|
+
validate_file_format(file_path)
|
|
341
|
+
|
|
342
|
+
# Load data
|
|
343
|
+
df = load_dataframe(file_path)
|
|
344
|
+
validate_dataframe(df)
|
|
345
|
+
|
|
346
|
+
if target_col:
|
|
347
|
+
validate_column_exists(df, target_col)
|
|
348
|
+
|
|
349
|
+
n_samples = len(df)
|
|
350
|
+
|
|
351
|
+
# Time-based split
|
|
352
|
+
if split_type == "time_based":
|
|
353
|
+
if not time_col:
|
|
354
|
+
raise ValueError("time_col is required for time_based split")
|
|
355
|
+
validate_column_exists(df, time_col)
|
|
356
|
+
|
|
357
|
+
# Sort by time
|
|
358
|
+
df = df.sort(time_col)
|
|
359
|
+
|
|
360
|
+
# Calculate split points
|
|
361
|
+
test_idx = int(n_samples * (1 - test_size))
|
|
362
|
+
|
|
363
|
+
if output_dir:
|
|
364
|
+
train_df = df[:test_idx]
|
|
365
|
+
test_df = df[test_idx:]
|
|
366
|
+
|
|
367
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
368
|
+
train_path = os.path.join(output_dir, "train.csv")
|
|
369
|
+
test_path = os.path.join(output_dir, "test.csv")
|
|
370
|
+
|
|
371
|
+
save_dataframe(train_df, train_path)
|
|
372
|
+
save_dataframe(test_df, test_path)
|
|
373
|
+
|
|
374
|
+
print(f"✅ Time-based split: train={len(train_df)}, test={len(test_df)}")
|
|
375
|
+
|
|
376
|
+
return {
|
|
377
|
+
'status': 'success',
|
|
378
|
+
'split_type': 'time_based',
|
|
379
|
+
'train_size': len(train_df),
|
|
380
|
+
'test_size': len(test_df),
|
|
381
|
+
'train_path': train_path,
|
|
382
|
+
'test_path': test_path,
|
|
383
|
+
'time_column': time_col
|
|
384
|
+
}
|
|
385
|
+
|
|
386
|
+
# Group-based split
|
|
387
|
+
elif split_type == "group_based":
|
|
388
|
+
if not group_col:
|
|
389
|
+
raise ValueError("group_col is required for group_based split")
|
|
390
|
+
validate_column_exists(df, group_col)
|
|
391
|
+
|
|
392
|
+
# Get unique groups
|
|
393
|
+
unique_groups = df[group_col].unique().to_list()
|
|
394
|
+
n_groups = len(unique_groups)
|
|
395
|
+
|
|
396
|
+
# Split groups
|
|
397
|
+
np.random.seed(random_state)
|
|
398
|
+
np.random.shuffle(unique_groups)
|
|
399
|
+
|
|
400
|
+
test_n_groups = max(1, int(n_groups * test_size))
|
|
401
|
+
test_groups = unique_groups[:test_n_groups]
|
|
402
|
+
train_groups = unique_groups[test_n_groups:]
|
|
403
|
+
|
|
404
|
+
train_df = df.filter(pl.col(group_col).is_in(train_groups))
|
|
405
|
+
test_df = df.filter(pl.col(group_col).is_in(test_groups))
|
|
406
|
+
|
|
407
|
+
if output_dir:
|
|
408
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
409
|
+
train_path = os.path.join(output_dir, "train.csv")
|
|
410
|
+
test_path = os.path.join(output_dir, "test.csv")
|
|
411
|
+
|
|
412
|
+
save_dataframe(train_df, train_path)
|
|
413
|
+
save_dataframe(test_df, test_path)
|
|
414
|
+
|
|
415
|
+
print(f"✅ Group-based split: train={len(train_df)}, test={len(test_df)}")
|
|
416
|
+
|
|
417
|
+
return {
|
|
418
|
+
'status': 'success',
|
|
419
|
+
'split_type': 'group_based',
|
|
420
|
+
'train_size': len(train_df),
|
|
421
|
+
'test_size': len(test_df),
|
|
422
|
+
'train_groups': len(train_groups),
|
|
423
|
+
'test_groups': len(test_groups),
|
|
424
|
+
'train_path': train_path,
|
|
425
|
+
'test_path': test_path,
|
|
426
|
+
'group_column': group_col
|
|
427
|
+
}
|
|
428
|
+
|
|
429
|
+
# Standard train/test split
|
|
430
|
+
elif split_type == "train_test":
|
|
431
|
+
X, y = split_features_target(df, target_col) if target_col else (df.to_numpy(), None)
|
|
432
|
+
|
|
433
|
+
stratify_y = y if (stratify and target_col and len(np.unique(y)) < 20) else None
|
|
434
|
+
|
|
435
|
+
if target_col:
|
|
436
|
+
X_train, X_test, y_train, y_test = train_test_split(
|
|
437
|
+
X, y, test_size=test_size, random_state=random_state, stratify=stratify_y
|
|
438
|
+
)
|
|
439
|
+
|
|
440
|
+
# Reconstruct dataframes
|
|
441
|
+
feature_cols = [col for col in df.columns if col != target_col]
|
|
442
|
+
train_data = {col: X_train[:, i] for i, col in enumerate(feature_cols)}
|
|
443
|
+
train_data[target_col] = y_train
|
|
444
|
+
train_df = pl.DataFrame(train_data)
|
|
445
|
+
|
|
446
|
+
test_data = {col: X_test[:, i] for i, col in enumerate(feature_cols)}
|
|
447
|
+
test_data[target_col] = y_test
|
|
448
|
+
test_df = pl.DataFrame(test_data)
|
|
449
|
+
else:
|
|
450
|
+
indices = np.arange(len(df))
|
|
451
|
+
train_idx, test_idx = train_test_split(
|
|
452
|
+
indices, test_size=test_size, random_state=random_state
|
|
453
|
+
)
|
|
454
|
+
train_df = df[train_idx]
|
|
455
|
+
test_df = df[test_idx]
|
|
456
|
+
|
|
457
|
+
if output_dir:
|
|
458
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
459
|
+
train_path = os.path.join(output_dir, "train.csv")
|
|
460
|
+
test_path = os.path.join(output_dir, "test.csv")
|
|
461
|
+
|
|
462
|
+
save_dataframe(train_df, train_path)
|
|
463
|
+
save_dataframe(test_df, test_path)
|
|
464
|
+
|
|
465
|
+
print(f"✅ Train/test split: train={len(train_df)}, test={len(test_df)}")
|
|
466
|
+
|
|
467
|
+
return {
|
|
468
|
+
'status': 'success',
|
|
469
|
+
'split_type': 'train_test',
|
|
470
|
+
'train_size': len(train_df),
|
|
471
|
+
'test_size': len(test_df),
|
|
472
|
+
'stratified': bool(stratify_y is not None),
|
|
473
|
+
'train_path': train_path,
|
|
474
|
+
'test_path': test_path
|
|
475
|
+
}
|
|
476
|
+
|
|
477
|
+
# Train/val/test split
|
|
478
|
+
elif split_type == "train_val_test":
|
|
479
|
+
X, y = split_features_target(df, target_col) if target_col else (df.to_numpy(), None)
|
|
480
|
+
|
|
481
|
+
stratify_y = y if (stratify and target_col and len(np.unique(y)) < 20) else None
|
|
482
|
+
|
|
483
|
+
# First split: train+val vs test
|
|
484
|
+
if target_col:
|
|
485
|
+
X_temp, X_test, y_temp, y_test = train_test_split(
|
|
486
|
+
X, y, test_size=test_size, random_state=random_state, stratify=stratify_y
|
|
487
|
+
)
|
|
488
|
+
|
|
489
|
+
# Second split: train vs val
|
|
490
|
+
val_ratio = val_size / (1 - test_size)
|
|
491
|
+
stratify_temp = y_temp if stratify_y is not None else None
|
|
492
|
+
X_train, X_val, y_train, y_val = train_test_split(
|
|
493
|
+
X_temp, y_temp, test_size=val_ratio, random_state=random_state, stratify=stratify_temp
|
|
494
|
+
)
|
|
495
|
+
|
|
496
|
+
# Reconstruct dataframes
|
|
497
|
+
feature_cols = [col for col in df.columns if col != target_col]
|
|
498
|
+
|
|
499
|
+
train_data = {col: X_train[:, i] for i, col in enumerate(feature_cols)}
|
|
500
|
+
train_data[target_col] = y_train
|
|
501
|
+
train_df = pl.DataFrame(train_data)
|
|
502
|
+
|
|
503
|
+
val_data = {col: X_val[:, i] for i, col in enumerate(feature_cols)}
|
|
504
|
+
val_data[target_col] = y_val
|
|
505
|
+
val_df = pl.DataFrame(val_data)
|
|
506
|
+
|
|
507
|
+
test_data = {col: X_test[:, i] for i, col in enumerate(feature_cols)}
|
|
508
|
+
test_data[target_col] = y_test
|
|
509
|
+
test_df = pl.DataFrame(test_data)
|
|
510
|
+
else:
|
|
511
|
+
indices = np.arange(len(df))
|
|
512
|
+
temp_idx, test_idx = train_test_split(
|
|
513
|
+
indices, test_size=test_size, random_state=random_state
|
|
514
|
+
)
|
|
515
|
+
val_ratio = val_size / (1 - test_size)
|
|
516
|
+
train_idx, val_idx = train_test_split(
|
|
517
|
+
temp_idx, test_size=val_ratio, random_state=random_state
|
|
518
|
+
)
|
|
519
|
+
|
|
520
|
+
train_df = df[train_idx]
|
|
521
|
+
val_df = df[val_idx]
|
|
522
|
+
test_df = df[test_idx]
|
|
523
|
+
|
|
524
|
+
if output_dir:
|
|
525
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
526
|
+
train_path = os.path.join(output_dir, "train.csv")
|
|
527
|
+
val_path = os.path.join(output_dir, "val.csv")
|
|
528
|
+
test_path = os.path.join(output_dir, "test.csv")
|
|
529
|
+
|
|
530
|
+
save_dataframe(train_df, train_path)
|
|
531
|
+
save_dataframe(val_df, val_path)
|
|
532
|
+
save_dataframe(test_df, test_path)
|
|
533
|
+
|
|
534
|
+
print(f"✅ Train/val/test split: train={len(train_df)}, val={len(val_df)}, test={len(test_df)}")
|
|
535
|
+
|
|
536
|
+
return {
|
|
537
|
+
'status': 'success',
|
|
538
|
+
'split_type': 'train_val_test',
|
|
539
|
+
'train_size': len(train_df),
|
|
540
|
+
'val_size': len(val_df),
|
|
541
|
+
'test_size': len(test_df),
|
|
542
|
+
'stratified': bool(stratify_y is not None),
|
|
543
|
+
'train_path': train_path,
|
|
544
|
+
'val_path': val_path,
|
|
545
|
+
'test_path': test_path
|
|
546
|
+
}
|
|
547
|
+
|
|
548
|
+
else:
|
|
549
|
+
raise ValueError(f"Unsupported split_type: {split_type}")
|