balancr 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- balancr/__init__.py +13 -0
- balancr/base.py +14 -0
- balancr/classifier_registry.py +300 -0
- balancr/cli/__init__.py +0 -0
- balancr/cli/commands.py +1838 -0
- balancr/cli/config.py +165 -0
- balancr/cli/main.py +778 -0
- balancr/cli/utils.py +101 -0
- balancr/data/__init__.py +5 -0
- balancr/data/loader.py +59 -0
- balancr/data/preprocessor.py +556 -0
- balancr/evaluation/__init__.py +19 -0
- balancr/evaluation/metrics.py +442 -0
- balancr/evaluation/visualisation.py +660 -0
- balancr/imbalance_analyser.py +677 -0
- balancr/technique_registry.py +284 -0
- balancr/techniques/__init__.py +4 -0
- balancr/techniques/custom/__init__.py +0 -0
- balancr/techniques/custom/example_custom_technique.py +27 -0
- balancr-0.1.0.dist-info/LICENSE +21 -0
- balancr-0.1.0.dist-info/METADATA +536 -0
- balancr-0.1.0.dist-info/RECORD +25 -0
- balancr-0.1.0.dist-info/WHEEL +5 -0
- balancr-0.1.0.dist-info/entry_points.txt +2 -0
- balancr-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,677 @@
|
|
1
|
+
import time
|
2
|
+
from typing import Dict, List, Optional, Union, Any
|
3
|
+
import os
|
4
|
+
import logging
|
5
|
+
import numpy as np
|
6
|
+
import pandas as pd
|
7
|
+
from pathlib import Path
|
8
|
+
from sklearn.model_selection import train_test_split
|
9
|
+
|
10
|
+
from .technique_registry import TechniqueRegistry
|
11
|
+
from .data import DataLoader
|
12
|
+
from .classifier_registry import ClassifierRegistry
|
13
|
+
from .data import DataPreprocessor
|
14
|
+
from .evaluation import (
|
15
|
+
get_metrics,
|
16
|
+
get_cv_scores,
|
17
|
+
get_learning_curve_data_multiple_techniques,
|
18
|
+
)
|
19
|
+
from .evaluation import (
|
20
|
+
plot_class_distribution,
|
21
|
+
plot_class_distributions_comparison,
|
22
|
+
plot_comparison_results,
|
23
|
+
plot_learning_curves,
|
24
|
+
)
|
25
|
+
|
26
|
+
|
27
|
+
def format_time(seconds):
|
28
|
+
"""Format time in seconds to minutes and seconds"""
|
29
|
+
minutes = int(seconds // 60)
|
30
|
+
remaining_seconds = seconds % 60
|
31
|
+
return f"{minutes}mins, {remaining_seconds:.2f}secs"
|
32
|
+
|
33
|
+
|
34
|
+
class BalancingFramework:
|
35
|
+
"""
|
36
|
+
A unified framework for analysing and comparing different techniques
|
37
|
+
for handling imbalanced data.
|
38
|
+
"""
|
39
|
+
|
40
|
+
def __init__(self):
|
41
|
+
"""Initialise the framework with core components."""
|
42
|
+
self.technique_registry = TechniqueRegistry()
|
43
|
+
self.preprocessor = DataPreprocessor()
|
44
|
+
self.classifier_registry = ClassifierRegistry()
|
45
|
+
self.X = None
|
46
|
+
self.y = None
|
47
|
+
self.X_test = None
|
48
|
+
self.y_test = None
|
49
|
+
self.results = {}
|
50
|
+
self.current_data_info = {}
|
51
|
+
self.current_balanced_datasets = {}
|
52
|
+
self.quality_report = {}
|
53
|
+
|
54
|
+
def load_data(
|
55
|
+
self,
|
56
|
+
file_path: Union[str, Path],
|
57
|
+
target_column: str,
|
58
|
+
feature_columns: Optional[List[str]] = None,
|
59
|
+
auto_preprocess: bool = False,
|
60
|
+
correlation_threshold: float = 0.95,
|
61
|
+
) -> None:
|
62
|
+
"""
|
63
|
+
Load data from a file and optionally preprocess it.
|
64
|
+
|
65
|
+
Args:
|
66
|
+
file_path: Path to the data file
|
67
|
+
target_column: Name of the target column
|
68
|
+
feature_columns: List of feature columns to use (optional)
|
69
|
+
auto_preprocess: Whether to automatically preprocess the data
|
70
|
+
"""
|
71
|
+
# Load data
|
72
|
+
self.X, self.y = DataLoader.load_data(file_path, target_column, feature_columns)
|
73
|
+
|
74
|
+
if feature_columns is None:
|
75
|
+
# Need to re-determine what columns were actually used
|
76
|
+
import pandas as pd
|
77
|
+
|
78
|
+
data = pd.read_csv(file_path) # Re-read the data
|
79
|
+
feature_columns = [col for col in data.columns if col != target_column]
|
80
|
+
|
81
|
+
# Store data info
|
82
|
+
self.current_data_info = {
|
83
|
+
"file_path": file_path,
|
84
|
+
"target_column": target_column,
|
85
|
+
"feature_columns": feature_columns,
|
86
|
+
"original_shape": self.X.shape,
|
87
|
+
"class_distribution": self._get_class_distribution(),
|
88
|
+
}
|
89
|
+
|
90
|
+
# Check data quality
|
91
|
+
quality_report = self.preprocessor.check_data_quality(
|
92
|
+
self.X, feature_columns, correlation_threshold=correlation_threshold
|
93
|
+
)
|
94
|
+
|
95
|
+
self.quality_report = quality_report
|
96
|
+
self._handle_quality_issues(
|
97
|
+
quality_report, correlation_threshold=correlation_threshold
|
98
|
+
)
|
99
|
+
|
100
|
+
if auto_preprocess:
|
101
|
+
self.preprocess_data()
|
102
|
+
|
103
|
+
def preprocess_data(
|
104
|
+
self,
|
105
|
+
handle_missing: str = "mean",
|
106
|
+
scale: str = "standard",
|
107
|
+
handle_constant_features: str = "none",
|
108
|
+
handle_correlations: str = "none",
|
109
|
+
categorical_features: Optional[List[str]] = None,
|
110
|
+
hash_components_dict: Optional[Dict[str, int]] = None,
|
111
|
+
) -> None:
|
112
|
+
"""
|
113
|
+
Preprocess the loaded data with enhanced options.
|
114
|
+
|
115
|
+
Args:
|
116
|
+
handle_missing: Strategy to handle missing values
|
117
|
+
("drop", "mean", "median", "mode", "none")
|
118
|
+
scale: Scaling method
|
119
|
+
("standard", "minmax", "robust", "none")
|
120
|
+
handle_constant_features: Strategy to handle constant features
|
121
|
+
("drop", "none")
|
122
|
+
handle_correlations: Strategy to handle highly correlated features
|
123
|
+
("drop_first", "drop_lowest", "pca", "none")
|
124
|
+
categorical_features: List of column names for categorical features
|
125
|
+
hash_components_dict: Dictionary mapping feature names to number of hash components
|
126
|
+
"""
|
127
|
+
if self.X is None or self.y is None:
|
128
|
+
raise ValueError("No data loaded. Call load_data() first.")
|
129
|
+
|
130
|
+
# Extract constant features and correlated features from quality report
|
131
|
+
constant_features = []
|
132
|
+
if self.quality_report and "constant_features" in self.quality_report:
|
133
|
+
# Extract feature names from the constant_features list of tuples
|
134
|
+
constant_features = [
|
135
|
+
feature[0]
|
136
|
+
for feature in self.quality_report.get("constant_features", [])
|
137
|
+
]
|
138
|
+
|
139
|
+
correlated_features = []
|
140
|
+
if self.quality_report and "feature_correlations" in self.quality_report:
|
141
|
+
# Extract correlation pairs from the feature_correlations list of tuples
|
142
|
+
correlated_features = self.quality_report.get("feature_correlations", [])
|
143
|
+
|
144
|
+
# Process the data using the preprocessor with all options
|
145
|
+
self.X, self.y = self.preprocessor.preprocess(
|
146
|
+
self.X,
|
147
|
+
self.y,
|
148
|
+
handle_missing=handle_missing,
|
149
|
+
scale=scale,
|
150
|
+
handle_constant_features=handle_constant_features,
|
151
|
+
handle_correlations=handle_correlations,
|
152
|
+
constant_features=constant_features,
|
153
|
+
correlated_features=correlated_features,
|
154
|
+
all_features=self.current_data_info.get("feature_columns"),
|
155
|
+
categorical_features=categorical_features,
|
156
|
+
hash_components_dict=hash_components_dict,
|
157
|
+
)
|
158
|
+
|
159
|
+
# Update feature columns in current_data_info with the new feature names
|
160
|
+
if hasattr(self.preprocessor, "feature_names"):
|
161
|
+
self.current_data_info["feature_columns"] = self.preprocessor.feature_names
|
162
|
+
|
163
|
+
def inspect_class_distribution(
|
164
|
+
self, save_path: Optional[str] = None, display: bool = False
|
165
|
+
) -> Dict[Any, int]:
|
166
|
+
"""
|
167
|
+
Inspect the distribution of classes in the target variable.
|
168
|
+
|
169
|
+
Args:
|
170
|
+
plot: Whether to create a visualisation
|
171
|
+
|
172
|
+
Returns:
|
173
|
+
Dictionary mapping class labels to their counts
|
174
|
+
"""
|
175
|
+
if self.y is None:
|
176
|
+
raise ValueError("No data loaded. Call load_data() first.")
|
177
|
+
|
178
|
+
distribution = self._get_class_distribution()
|
179
|
+
|
180
|
+
plot_class_distribution(
|
181
|
+
distribution,
|
182
|
+
title="Imbalanced Dataset Class Comparison",
|
183
|
+
save_path=save_path,
|
184
|
+
display=display,
|
185
|
+
)
|
186
|
+
|
187
|
+
return distribution
|
188
|
+
|
189
|
+
def list_available_techniques(self) -> Dict[str, List[str]]:
|
190
|
+
"""List all available balancing techniques."""
|
191
|
+
return self.technique_registry.list_available_techniques()
|
192
|
+
|
193
|
+
def apply_balancing_techniques(
|
194
|
+
self,
|
195
|
+
technique_names: List[str],
|
196
|
+
test_size: float = 0.2,
|
197
|
+
random_state: int = 42,
|
198
|
+
technique_params: Optional[Dict[str, Dict[str, Any]]] = None,
|
199
|
+
include_original: bool = False,
|
200
|
+
) -> Dict[str, Dict[str, Any]]:
|
201
|
+
"""
|
202
|
+
Apply multiple balancing techniques to the dataset.
|
203
|
+
|
204
|
+
Args:
|
205
|
+
technique_names: List of technique names to apply
|
206
|
+
test_size: Proportion of dataset to use for testing
|
207
|
+
random_state: Random seed for reproducibility
|
208
|
+
technique_params: Dictionary mapping technique names to their parameters
|
209
|
+
|
210
|
+
Returns:
|
211
|
+
Dictionary containing balanced datasets for each technique
|
212
|
+
"""
|
213
|
+
if self.X is None or self.y is None:
|
214
|
+
raise ValueError("No data loaded. Call load_data() first.")
|
215
|
+
|
216
|
+
# Split data
|
217
|
+
X_train, X_test, y_train, y_test = train_test_split(
|
218
|
+
self.X, self.y, test_size=test_size, random_state=random_state
|
219
|
+
)
|
220
|
+
|
221
|
+
# Store test data for later evaluation
|
222
|
+
self.X_test = X_test
|
223
|
+
self.y_test = y_test
|
224
|
+
|
225
|
+
balanced_datasets = {}
|
226
|
+
|
227
|
+
if include_original:
|
228
|
+
# Store imbalanced dataset to compare with balanced later
|
229
|
+
balanced_datasets["Original"] = {
|
230
|
+
"X_balanced": X_train,
|
231
|
+
"y_balanced": y_train,
|
232
|
+
}
|
233
|
+
|
234
|
+
for technique_name in technique_names:
|
235
|
+
# Get technique
|
236
|
+
technique_class = self.technique_registry.get_technique_class(
|
237
|
+
technique_name
|
238
|
+
)
|
239
|
+
if technique_class is None:
|
240
|
+
raise ValueError(
|
241
|
+
f"Technique '{technique_name}' not found. "
|
242
|
+
f"Available techniques: {self.list_available_techniques()}"
|
243
|
+
)
|
244
|
+
|
245
|
+
# Get parameters for this technique
|
246
|
+
params = {}
|
247
|
+
if technique_params and technique_name in technique_params:
|
248
|
+
params = technique_params[technique_name]
|
249
|
+
|
250
|
+
# Apply technique with parameters
|
251
|
+
technique = technique_class(**params)
|
252
|
+
X_balanced, y_balanced = technique.balance(X_train, y_train)
|
253
|
+
|
254
|
+
# Store balanced data
|
255
|
+
balanced_datasets[technique_name] = {
|
256
|
+
"X_balanced": X_balanced,
|
257
|
+
"y_balanced": y_balanced,
|
258
|
+
}
|
259
|
+
|
260
|
+
# Update current balanced datasets
|
261
|
+
self.current_balanced_datasets = balanced_datasets
|
262
|
+
|
263
|
+
return balanced_datasets
|
264
|
+
|
265
|
+
def train_classifiers(
|
266
|
+
self,
|
267
|
+
classifier_configs: Dict[str, Dict[str, Any]] = None,
|
268
|
+
enable_cv: bool = False,
|
269
|
+
cv_folds: int = 5,
|
270
|
+
) -> Dict[str, Dict[str, Dict[str, float]]]:
|
271
|
+
"""
|
272
|
+
Train classifiers on balanced datasets and evaluate their performance.
|
273
|
+
|
274
|
+
Args:
|
275
|
+
classifier_configs: Dictionary mapping classifier names to their parameters
|
276
|
+
If None, uses default RandomForestClassifier
|
277
|
+
plot_results: Whether to visualise the comparison results
|
278
|
+
enable_cv: Whether to perform cross-validation evaluation
|
279
|
+
cv_folds: Number of cross-validation folds (if enabled)
|
280
|
+
|
281
|
+
Returns:
|
282
|
+
Dictionary mapping classifier names to technique results
|
283
|
+
"""
|
284
|
+
if not self.current_balanced_datasets:
|
285
|
+
raise ValueError(
|
286
|
+
"No balanced datasets available. Run apply_balancing_techniques first."
|
287
|
+
)
|
288
|
+
|
289
|
+
if self.X_test is None or self.y_test is None:
|
290
|
+
raise ValueError(
|
291
|
+
"Test data not found. Run apply_balancing_techniques first."
|
292
|
+
)
|
293
|
+
|
294
|
+
# Default classifier if none provided
|
295
|
+
if classifier_configs is None:
|
296
|
+
classifier_configs = {"RandomForestClassifier": {"random_state": 42}}
|
297
|
+
|
298
|
+
# Initialise results dictionary
|
299
|
+
results = {}
|
300
|
+
|
301
|
+
# For each classifier
|
302
|
+
for clf_name, clf_params in classifier_configs.items():
|
303
|
+
# Get classifier class from registry
|
304
|
+
clf_class = self.classifier_registry.get_classifier_class(clf_name)
|
305
|
+
|
306
|
+
if clf_class is None:
|
307
|
+
logging.warning(
|
308
|
+
f"Classifier '{clf_name}' not found in registry. Skipping."
|
309
|
+
)
|
310
|
+
continue
|
311
|
+
|
312
|
+
classifier_results = {}
|
313
|
+
|
314
|
+
# For each balancing technique
|
315
|
+
for technique_name, balanced_data in self.current_balanced_datasets.items():
|
316
|
+
X_balanced = balanced_data["X_balanced"]
|
317
|
+
y_balanced = balanced_data["y_balanced"]
|
318
|
+
|
319
|
+
try:
|
320
|
+
# Create classifier instance with parameters
|
321
|
+
clf_instance = clf_class(**clf_params)
|
322
|
+
|
323
|
+
# Train the classifier
|
324
|
+
start_time = time.time()
|
325
|
+
logging.info(f"Training {clf_name} with dataset balanced with {technique_name}...")
|
326
|
+
clf_instance.fit(X_balanced, y_balanced)
|
327
|
+
train_time = time.time() - start_time
|
328
|
+
logging.info(f"Training {clf_name} with dataset balanced with {technique_name} complete"
|
329
|
+
f"(Time Taken: {format_time(train_time)})")
|
330
|
+
|
331
|
+
# Initialise metrics for this technique
|
332
|
+
start_time = time.time()
|
333
|
+
logging.info(f"Getting standard metrics of {technique_name} after training {clf_name}...")
|
334
|
+
technique_metrics = {
|
335
|
+
"standard_metrics": get_metrics(
|
336
|
+
clf_instance,
|
337
|
+
self.X_test,
|
338
|
+
self.y_test,
|
339
|
+
)
|
340
|
+
}
|
341
|
+
std_metrics_time = time.time() - start_time
|
342
|
+
logging.info(f"Getting standard metrics of {technique_name} after training {clf_name} complete"
|
343
|
+
f"(Time Taken: {format_time(std_metrics_time)})")
|
344
|
+
|
345
|
+
# Add cross-validation metrics if enabled
|
346
|
+
if enable_cv:
|
347
|
+
start_time = time.time()
|
348
|
+
logging.info(f"Getting cv metrics of {technique_name} after training {clf_name}...")
|
349
|
+
technique_metrics["cv_metrics"] = get_cv_scores(
|
350
|
+
clf_class(**clf_params),
|
351
|
+
X_balanced,
|
352
|
+
y_balanced,
|
353
|
+
n_folds=cv_folds,
|
354
|
+
)
|
355
|
+
cv_metrics_time = time.time() - start_time
|
356
|
+
logging.info(f"Getting cv metrics of {technique_name} after training {clf_name} complete"
|
357
|
+
f"(Time Taken: {format_time(cv_metrics_time)})")
|
358
|
+
|
359
|
+
classifier_results[technique_name] = technique_metrics
|
360
|
+
|
361
|
+
except Exception as e:
|
362
|
+
logging.error(
|
363
|
+
f"Error training classifier '{clf_name}' with technique '{technique_name}': {str(e)}"
|
364
|
+
)
|
365
|
+
continue
|
366
|
+
|
367
|
+
# Only add classifier results if at least one technique was successful
|
368
|
+
if classifier_results:
|
369
|
+
results[clf_name] = classifier_results
|
370
|
+
|
371
|
+
# Update the overall results
|
372
|
+
self.results = results
|
373
|
+
|
374
|
+
return results
|
375
|
+
|
376
|
+
def save_results(
|
377
|
+
self,
|
378
|
+
file_path: Union[str, Path],
|
379
|
+
file_type: str = "csv",
|
380
|
+
include_plots: bool = True,
|
381
|
+
) -> None:
|
382
|
+
"""
|
383
|
+
Save comparison results to a file.
|
384
|
+
|
385
|
+
Args:
|
386
|
+
file_path: Path to save the results
|
387
|
+
file_type: Type of file ('csv' or 'json')
|
388
|
+
include_plots: Whether to save visualisation plots
|
389
|
+
"""
|
390
|
+
if not self.results:
|
391
|
+
raise ValueError("No results to save. Run compare_techniques() first.")
|
392
|
+
|
393
|
+
file_path = Path(file_path)
|
394
|
+
|
395
|
+
# Save results
|
396
|
+
if file_type == "csv":
|
397
|
+
pd.DataFrame(self.results).to_csv(file_path)
|
398
|
+
elif file_type == "json":
|
399
|
+
pd.DataFrame(self.results).to_json(file_path)
|
400
|
+
else:
|
401
|
+
raise ValueError(f"Unsupported file type: {file_type}")
|
402
|
+
|
403
|
+
# Save plots if requested
|
404
|
+
if include_plots:
|
405
|
+
plot_path = file_path.parent / f"{file_path.stem}_plots.png"
|
406
|
+
plot_comparison_results(self.results, save_path=plot_path)
|
407
|
+
|
408
|
+
def save_classifier_results(
|
409
|
+
self,
|
410
|
+
file_path: Union[str, Path],
|
411
|
+
classifier_name: str,
|
412
|
+
metric_type: str = "standard_metrics",
|
413
|
+
file_type: str = "csv",
|
414
|
+
) -> None:
|
415
|
+
"""
|
416
|
+
Save results for a specific classifier and metric type to a file.
|
417
|
+
|
418
|
+
Args:
|
419
|
+
file_path: Path to save the results
|
420
|
+
classifier_name: Name of the classifier to extract results for
|
421
|
+
metric_type: Type of metrics to save ('standard_metrics' or 'cv_metrics')
|
422
|
+
file_type: Type of file ('csv' or 'json')
|
423
|
+
"""
|
424
|
+
if not self.results:
|
425
|
+
raise ValueError("No results to save. Run train_classifiers() first.")
|
426
|
+
|
427
|
+
if classifier_name not in self.results:
|
428
|
+
raise ValueError(f"Classifier '{classifier_name}' not found in results.")
|
429
|
+
|
430
|
+
file_path = Path(file_path)
|
431
|
+
|
432
|
+
# Extract results for this classifier
|
433
|
+
classifier_results = self.results[classifier_name]
|
434
|
+
|
435
|
+
# Create a dictionary where keys are techniques and values are the metrics
|
436
|
+
extracted_results = {}
|
437
|
+
for technique_name, technique_data in classifier_results.items():
|
438
|
+
if metric_type in technique_data:
|
439
|
+
extracted_results[technique_name] = technique_data[metric_type]
|
440
|
+
|
441
|
+
# Convert to DataFrame for easier saving
|
442
|
+
results_df = pd.DataFrame(extracted_results)
|
443
|
+
|
444
|
+
# Save results in requested format
|
445
|
+
if file_type == "csv":
|
446
|
+
results_df.to_csv(file_path)
|
447
|
+
elif file_type == "json":
|
448
|
+
results_df.to_json(file_path)
|
449
|
+
else:
|
450
|
+
raise ValueError(f"Unsupported file type: {file_type}")
|
451
|
+
|
452
|
+
logging.info(f"Saved {classifier_name} {metric_type} results to {file_path}")
|
453
|
+
|
454
|
+
def generate_balanced_data(
|
455
|
+
self,
|
456
|
+
folder_path: str,
|
457
|
+
techniques: Optional[List[str]] = None,
|
458
|
+
file_format: str = "csv",
|
459
|
+
) -> None:
|
460
|
+
"""
|
461
|
+
Save balanced datasets to files for specified techniques.
|
462
|
+
|
463
|
+
Args:
|
464
|
+
folder_path: Directory to save the datasets.
|
465
|
+
techniques: List of techniques to save. Saves all if None.
|
466
|
+
file_format: Format for saving the data ('csv' or 'json').
|
467
|
+
|
468
|
+
Raises:
|
469
|
+
ValueError if no balanced datasets are available or specified techniques are invalid.
|
470
|
+
"""
|
471
|
+
|
472
|
+
# Validate datasets exist
|
473
|
+
if not self.current_balanced_datasets:
|
474
|
+
raise ValueError(
|
475
|
+
"No balanced datasets available. Run compare_techniques first."
|
476
|
+
)
|
477
|
+
|
478
|
+
# Validate output format
|
479
|
+
if file_format not in ["csv", "json"]:
|
480
|
+
raise ValueError("Invalid file format. Supported formats: 'csv', 'json'.")
|
481
|
+
|
482
|
+
# Create output folder
|
483
|
+
os.makedirs(folder_path, exist_ok=True)
|
484
|
+
|
485
|
+
# Determine techniques to save
|
486
|
+
if techniques is None:
|
487
|
+
techniques = list(self.current_balanced_datasets.keys())
|
488
|
+
|
489
|
+
# Retrieve input data column names
|
490
|
+
feature_columns = self.current_data_info.get("feature_columns")
|
491
|
+
target_column = self.current_data_info.get("target_column")
|
492
|
+
if feature_columns is None or target_column is None:
|
493
|
+
raise ValueError(
|
494
|
+
"Original column names are missing in 'current_data_info'."
|
495
|
+
)
|
496
|
+
|
497
|
+
# Export datasets
|
498
|
+
for technique in techniques:
|
499
|
+
if technique not in self.current_balanced_datasets:
|
500
|
+
raise ValueError(
|
501
|
+
f"Technique '{technique}' not found in current datasets."
|
502
|
+
)
|
503
|
+
|
504
|
+
# Retrieve data
|
505
|
+
dataset = self.current_balanced_datasets[technique]
|
506
|
+
X_balanced = dataset["X_balanced"]
|
507
|
+
y_balanced = dataset["y_balanced"]
|
508
|
+
|
509
|
+
# Combine into a single DataFrame
|
510
|
+
balanced_df = pd.DataFrame(X_balanced, columns=feature_columns)
|
511
|
+
balanced_df[target_column] = y_balanced
|
512
|
+
|
513
|
+
# Construct file path
|
514
|
+
file_path = os.path.join(folder_path, f"balanced_{technique}.{file_format}")
|
515
|
+
|
516
|
+
# Save in the chosen format
|
517
|
+
if file_format == "csv":
|
518
|
+
balanced_df.to_csv(file_path, index=False)
|
519
|
+
elif file_format == "json":
|
520
|
+
balanced_df.to_json(file_path, index=False)
|
521
|
+
|
522
|
+
logging.info(f"Saved balanced dataset for '{technique}' to {file_path}")
|
523
|
+
|
524
|
+
def compare_balanced_class_distributions(
|
525
|
+
self, save_path: Optional[str] = None, display: bool = False
|
526
|
+
) -> None:
|
527
|
+
"""
|
528
|
+
Compare class distributions of balanced datasets for all techniques.
|
529
|
+
|
530
|
+
Args:
|
531
|
+
save_path: Path to save the visualisation (optional).
|
532
|
+
|
533
|
+
Raises:
|
534
|
+
ValueError: If no balanced datasets are available.
|
535
|
+
"""
|
536
|
+
if not self.current_balanced_datasets:
|
537
|
+
raise ValueError(
|
538
|
+
"No balanced datasets available. Run compare_techniques first."
|
539
|
+
)
|
540
|
+
|
541
|
+
# Generate class distributions for each balanced dataset
|
542
|
+
distributions = {}
|
543
|
+
for technique, dataset in self.current_balanced_datasets.items():
|
544
|
+
y_balanced = dataset["y_balanced"]
|
545
|
+
|
546
|
+
# Generate class distribution
|
547
|
+
distribution = self.preprocessor.inspect_class_distribution(y_balanced)
|
548
|
+
distributions[technique] = distribution
|
549
|
+
|
550
|
+
# Call the visualisation function
|
551
|
+
plot_class_distributions_comparison(
|
552
|
+
distributions,
|
553
|
+
title="Class Distribution Comparison After Balancing",
|
554
|
+
save_path=save_path,
|
555
|
+
display=display,
|
556
|
+
)
|
557
|
+
|
558
|
+
def generate_learning_curves(
|
559
|
+
self,
|
560
|
+
classifier_name: str,
|
561
|
+
train_sizes: np.ndarray = np.linspace(0.1, 1.0, 10),
|
562
|
+
n_folds: int = 5,
|
563
|
+
save_path: Optional[str] = None,
|
564
|
+
display: bool = False,
|
565
|
+
) -> None:
|
566
|
+
"""
|
567
|
+
Generate and plot learning curves for multiple balancing techniques.
|
568
|
+
|
569
|
+
Args:
|
570
|
+
classifier_name: Name of the classifier to generate curves for
|
571
|
+
train_sizes: Training set sizes to evaluate
|
572
|
+
n_folds: Number of cross-validation folds
|
573
|
+
save_path: Path to save the plot (optional)
|
574
|
+
display: Whether to display the plot
|
575
|
+
"""
|
576
|
+
if not self.current_balanced_datasets:
|
577
|
+
raise ValueError(
|
578
|
+
"No balanced datasets available. Run apply_balancing_techniques first."
|
579
|
+
)
|
580
|
+
|
581
|
+
try:
|
582
|
+
# Get the classifier class
|
583
|
+
clf_class = self.classifier_registry.get_classifier_class(classifier_name)
|
584
|
+
if clf_class is None:
|
585
|
+
logging.warning(
|
586
|
+
f"Classifier '{classifier_name}' not found. Skipping learning curves."
|
587
|
+
)
|
588
|
+
return
|
589
|
+
|
590
|
+
# Get classifier parameters from configuration
|
591
|
+
clf_params = {}
|
592
|
+
if (
|
593
|
+
hasattr(self, "classifier_configs")
|
594
|
+
and classifier_name in self.classifier_configs
|
595
|
+
):
|
596
|
+
clf_params = self.classifier_configs[classifier_name]
|
597
|
+
|
598
|
+
# Create classifier instance with the same parameters used in training
|
599
|
+
classifier = clf_class(**clf_params)
|
600
|
+
|
601
|
+
learning_curve_data = get_learning_curve_data_multiple_techniques(
|
602
|
+
classifier_name=classifier_name,
|
603
|
+
classifier=classifier,
|
604
|
+
techniques_data=self.current_balanced_datasets,
|
605
|
+
train_sizes=train_sizes,
|
606
|
+
n_folds=n_folds,
|
607
|
+
)
|
608
|
+
|
609
|
+
title = f"{classifier_name} - Learning Curves"
|
610
|
+
|
611
|
+
plot_learning_curves(
|
612
|
+
learning_curve_data, title=title, save_path=save_path, display=display
|
613
|
+
)
|
614
|
+
|
615
|
+
except Exception as e:
|
616
|
+
logging.warning(
|
617
|
+
f"Failed to generate learning curves for classifier '{classifier_name}': {str(e)}"
|
618
|
+
)
|
619
|
+
logging.warning("Continuing with other visualisations...")
|
620
|
+
|
621
|
+
def _get_class_distribution(self) -> Dict[Any, int]:
|
622
|
+
"""Get the distribution of classes in the target variable."""
|
623
|
+
return self.preprocessor.inspect_class_distribution(self.y)
|
624
|
+
|
625
|
+
def _handle_quality_issues(
|
626
|
+
self, quality_report: Dict[str, Any], correlation_threshold: float = 0.95
|
627
|
+
) -> None:
|
628
|
+
"""Handle any data quality issues found."""
|
629
|
+
warnings = []
|
630
|
+
|
631
|
+
# Check if there are any missing values (now a list of tuples)
|
632
|
+
if quality_report["missing_values"]:
|
633
|
+
missing_value_info = ", ".join(
|
634
|
+
[f"{name}: {count}" for name, count in quality_report["missing_values"]]
|
635
|
+
)
|
636
|
+
if len(quality_report["missing_values"]) == 1:
|
637
|
+
warnings.append(
|
638
|
+
f"Data contains missing values in feature: {missing_value_info}"
|
639
|
+
)
|
640
|
+
else:
|
641
|
+
warnings.append(
|
642
|
+
f"Data contains missing values in features: {missing_value_info}"
|
643
|
+
)
|
644
|
+
|
645
|
+
# Check if there are any constant features (now a list of tuples)
|
646
|
+
if quality_report["constant_features"]:
|
647
|
+
# Extract feature names from the tuples
|
648
|
+
constant_feature_names = [
|
649
|
+
name for name, _ in quality_report["constant_features"]
|
650
|
+
]
|
651
|
+
if len(constant_feature_names) == 1:
|
652
|
+
warnings.append(
|
653
|
+
f"Constant Features: {constant_feature_names} has constant values"
|
654
|
+
)
|
655
|
+
else:
|
656
|
+
warnings.append(
|
657
|
+
f"Constant Features: {constant_feature_names} have constant values"
|
658
|
+
)
|
659
|
+
|
660
|
+
# Check if there are any highly correlated features (already a list)
|
661
|
+
if quality_report["feature_correlations"]:
|
662
|
+
# Format the correlation information more readably
|
663
|
+
correlation_info = ", ".join(
|
664
|
+
[
|
665
|
+
f"{col1} & {col2} ({corr:.2f})"
|
666
|
+
for col1, col2, corr in quality_report["feature_correlations"]
|
667
|
+
]
|
668
|
+
)
|
669
|
+
warnings.append(
|
670
|
+
f"Found highly correlated features (Threshold={correlation_threshold}): {correlation_info}"
|
671
|
+
)
|
672
|
+
|
673
|
+
# Print all warnings
|
674
|
+
if warnings:
|
675
|
+
print("Data Quality Warnings:")
|
676
|
+
for warning in warnings:
|
677
|
+
print(f"- {warning}")
|