balancr 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,677 @@
1
+ import time
2
+ from typing import Dict, List, Optional, Union, Any
3
+ import os
4
+ import logging
5
+ import numpy as np
6
+ import pandas as pd
7
+ from pathlib import Path
8
+ from sklearn.model_selection import train_test_split
9
+
10
+ from .technique_registry import TechniqueRegistry
11
+ from .data import DataLoader
12
+ from .classifier_registry import ClassifierRegistry
13
+ from .data import DataPreprocessor
14
+ from .evaluation import (
15
+ get_metrics,
16
+ get_cv_scores,
17
+ get_learning_curve_data_multiple_techniques,
18
+ )
19
+ from .evaluation import (
20
+ plot_class_distribution,
21
+ plot_class_distributions_comparison,
22
+ plot_comparison_results,
23
+ plot_learning_curves,
24
+ )
25
+
26
+
27
+ def format_time(seconds):
28
+ """Format time in seconds to minutes and seconds"""
29
+ minutes = int(seconds // 60)
30
+ remaining_seconds = seconds % 60
31
+ return f"{minutes}mins, {remaining_seconds:.2f}secs"
32
+
33
+
34
+ class BalancingFramework:
35
+ """
36
+ A unified framework for analysing and comparing different techniques
37
+ for handling imbalanced data.
38
+ """
39
+
40
+ def __init__(self):
41
+ """Initialise the framework with core components."""
42
+ self.technique_registry = TechniqueRegistry()
43
+ self.preprocessor = DataPreprocessor()
44
+ self.classifier_registry = ClassifierRegistry()
45
+ self.X = None
46
+ self.y = None
47
+ self.X_test = None
48
+ self.y_test = None
49
+ self.results = {}
50
+ self.current_data_info = {}
51
+ self.current_balanced_datasets = {}
52
+ self.quality_report = {}
53
+
54
+ def load_data(
55
+ self,
56
+ file_path: Union[str, Path],
57
+ target_column: str,
58
+ feature_columns: Optional[List[str]] = None,
59
+ auto_preprocess: bool = False,
60
+ correlation_threshold: float = 0.95,
61
+ ) -> None:
62
+ """
63
+ Load data from a file and optionally preprocess it.
64
+
65
+ Args:
66
+ file_path: Path to the data file
67
+ target_column: Name of the target column
68
+ feature_columns: List of feature columns to use (optional)
69
+ auto_preprocess: Whether to automatically preprocess the data
70
+ """
71
+ # Load data
72
+ self.X, self.y = DataLoader.load_data(file_path, target_column, feature_columns)
73
+
74
+ if feature_columns is None:
75
+ # Need to re-determine what columns were actually used
76
+ import pandas as pd
77
+
78
+ data = pd.read_csv(file_path) # Re-read the data
79
+ feature_columns = [col for col in data.columns if col != target_column]
80
+
81
+ # Store data info
82
+ self.current_data_info = {
83
+ "file_path": file_path,
84
+ "target_column": target_column,
85
+ "feature_columns": feature_columns,
86
+ "original_shape": self.X.shape,
87
+ "class_distribution": self._get_class_distribution(),
88
+ }
89
+
90
+ # Check data quality
91
+ quality_report = self.preprocessor.check_data_quality(
92
+ self.X, feature_columns, correlation_threshold=correlation_threshold
93
+ )
94
+
95
+ self.quality_report = quality_report
96
+ self._handle_quality_issues(
97
+ quality_report, correlation_threshold=correlation_threshold
98
+ )
99
+
100
+ if auto_preprocess:
101
+ self.preprocess_data()
102
+
103
+ def preprocess_data(
104
+ self,
105
+ handle_missing: str = "mean",
106
+ scale: str = "standard",
107
+ handle_constant_features: str = "none",
108
+ handle_correlations: str = "none",
109
+ categorical_features: Optional[List[str]] = None,
110
+ hash_components_dict: Optional[Dict[str, int]] = None,
111
+ ) -> None:
112
+ """
113
+ Preprocess the loaded data with enhanced options.
114
+
115
+ Args:
116
+ handle_missing: Strategy to handle missing values
117
+ ("drop", "mean", "median", "mode", "none")
118
+ scale: Scaling method
119
+ ("standard", "minmax", "robust", "none")
120
+ handle_constant_features: Strategy to handle constant features
121
+ ("drop", "none")
122
+ handle_correlations: Strategy to handle highly correlated features
123
+ ("drop_first", "drop_lowest", "pca", "none")
124
+ categorical_features: List of column names for categorical features
125
+ hash_components_dict: Dictionary mapping feature names to number of hash components
126
+ """
127
+ if self.X is None or self.y is None:
128
+ raise ValueError("No data loaded. Call load_data() first.")
129
+
130
+ # Extract constant features and correlated features from quality report
131
+ constant_features = []
132
+ if self.quality_report and "constant_features" in self.quality_report:
133
+ # Extract feature names from the constant_features list of tuples
134
+ constant_features = [
135
+ feature[0]
136
+ for feature in self.quality_report.get("constant_features", [])
137
+ ]
138
+
139
+ correlated_features = []
140
+ if self.quality_report and "feature_correlations" in self.quality_report:
141
+ # Extract correlation pairs from the feature_correlations list of tuples
142
+ correlated_features = self.quality_report.get("feature_correlations", [])
143
+
144
+ # Process the data using the preprocessor with all options
145
+ self.X, self.y = self.preprocessor.preprocess(
146
+ self.X,
147
+ self.y,
148
+ handle_missing=handle_missing,
149
+ scale=scale,
150
+ handle_constant_features=handle_constant_features,
151
+ handle_correlations=handle_correlations,
152
+ constant_features=constant_features,
153
+ correlated_features=correlated_features,
154
+ all_features=self.current_data_info.get("feature_columns"),
155
+ categorical_features=categorical_features,
156
+ hash_components_dict=hash_components_dict,
157
+ )
158
+
159
+ # Update feature columns in current_data_info with the new feature names
160
+ if hasattr(self.preprocessor, "feature_names"):
161
+ self.current_data_info["feature_columns"] = self.preprocessor.feature_names
162
+
163
+ def inspect_class_distribution(
164
+ self, save_path: Optional[str] = None, display: bool = False
165
+ ) -> Dict[Any, int]:
166
+ """
167
+ Inspect the distribution of classes in the target variable.
168
+
169
+ Args:
170
+ plot: Whether to create a visualisation
171
+
172
+ Returns:
173
+ Dictionary mapping class labels to their counts
174
+ """
175
+ if self.y is None:
176
+ raise ValueError("No data loaded. Call load_data() first.")
177
+
178
+ distribution = self._get_class_distribution()
179
+
180
+ plot_class_distribution(
181
+ distribution,
182
+ title="Imbalanced Dataset Class Comparison",
183
+ save_path=save_path,
184
+ display=display,
185
+ )
186
+
187
+ return distribution
188
+
189
+ def list_available_techniques(self) -> Dict[str, List[str]]:
190
+ """List all available balancing techniques."""
191
+ return self.technique_registry.list_available_techniques()
192
+
193
+ def apply_balancing_techniques(
194
+ self,
195
+ technique_names: List[str],
196
+ test_size: float = 0.2,
197
+ random_state: int = 42,
198
+ technique_params: Optional[Dict[str, Dict[str, Any]]] = None,
199
+ include_original: bool = False,
200
+ ) -> Dict[str, Dict[str, Any]]:
201
+ """
202
+ Apply multiple balancing techniques to the dataset.
203
+
204
+ Args:
205
+ technique_names: List of technique names to apply
206
+ test_size: Proportion of dataset to use for testing
207
+ random_state: Random seed for reproducibility
208
+ technique_params: Dictionary mapping technique names to their parameters
209
+
210
+ Returns:
211
+ Dictionary containing balanced datasets for each technique
212
+ """
213
+ if self.X is None or self.y is None:
214
+ raise ValueError("No data loaded. Call load_data() first.")
215
+
216
+ # Split data
217
+ X_train, X_test, y_train, y_test = train_test_split(
218
+ self.X, self.y, test_size=test_size, random_state=random_state
219
+ )
220
+
221
+ # Store test data for later evaluation
222
+ self.X_test = X_test
223
+ self.y_test = y_test
224
+
225
+ balanced_datasets = {}
226
+
227
+ if include_original:
228
+ # Store imbalanced dataset to compare with balanced later
229
+ balanced_datasets["Original"] = {
230
+ "X_balanced": X_train,
231
+ "y_balanced": y_train,
232
+ }
233
+
234
+ for technique_name in technique_names:
235
+ # Get technique
236
+ technique_class = self.technique_registry.get_technique_class(
237
+ technique_name
238
+ )
239
+ if technique_class is None:
240
+ raise ValueError(
241
+ f"Technique '{technique_name}' not found. "
242
+ f"Available techniques: {self.list_available_techniques()}"
243
+ )
244
+
245
+ # Get parameters for this technique
246
+ params = {}
247
+ if technique_params and technique_name in technique_params:
248
+ params = technique_params[technique_name]
249
+
250
+ # Apply technique with parameters
251
+ technique = technique_class(**params)
252
+ X_balanced, y_balanced = technique.balance(X_train, y_train)
253
+
254
+ # Store balanced data
255
+ balanced_datasets[technique_name] = {
256
+ "X_balanced": X_balanced,
257
+ "y_balanced": y_balanced,
258
+ }
259
+
260
+ # Update current balanced datasets
261
+ self.current_balanced_datasets = balanced_datasets
262
+
263
+ return balanced_datasets
264
+
265
+ def train_classifiers(
266
+ self,
267
+ classifier_configs: Dict[str, Dict[str, Any]] = None,
268
+ enable_cv: bool = False,
269
+ cv_folds: int = 5,
270
+ ) -> Dict[str, Dict[str, Dict[str, float]]]:
271
+ """
272
+ Train classifiers on balanced datasets and evaluate their performance.
273
+
274
+ Args:
275
+ classifier_configs: Dictionary mapping classifier names to their parameters
276
+ If None, uses default RandomForestClassifier
277
+ plot_results: Whether to visualise the comparison results
278
+ enable_cv: Whether to perform cross-validation evaluation
279
+ cv_folds: Number of cross-validation folds (if enabled)
280
+
281
+ Returns:
282
+ Dictionary mapping classifier names to technique results
283
+ """
284
+ if not self.current_balanced_datasets:
285
+ raise ValueError(
286
+ "No balanced datasets available. Run apply_balancing_techniques first."
287
+ )
288
+
289
+ if self.X_test is None or self.y_test is None:
290
+ raise ValueError(
291
+ "Test data not found. Run apply_balancing_techniques first."
292
+ )
293
+
294
+ # Default classifier if none provided
295
+ if classifier_configs is None:
296
+ classifier_configs = {"RandomForestClassifier": {"random_state": 42}}
297
+
298
+ # Initialise results dictionary
299
+ results = {}
300
+
301
+ # For each classifier
302
+ for clf_name, clf_params in classifier_configs.items():
303
+ # Get classifier class from registry
304
+ clf_class = self.classifier_registry.get_classifier_class(clf_name)
305
+
306
+ if clf_class is None:
307
+ logging.warning(
308
+ f"Classifier '{clf_name}' not found in registry. Skipping."
309
+ )
310
+ continue
311
+
312
+ classifier_results = {}
313
+
314
+ # For each balancing technique
315
+ for technique_name, balanced_data in self.current_balanced_datasets.items():
316
+ X_balanced = balanced_data["X_balanced"]
317
+ y_balanced = balanced_data["y_balanced"]
318
+
319
+ try:
320
+ # Create classifier instance with parameters
321
+ clf_instance = clf_class(**clf_params)
322
+
323
+ # Train the classifier
324
+ start_time = time.time()
325
+ logging.info(f"Training {clf_name} with dataset balanced with {technique_name}...")
326
+ clf_instance.fit(X_balanced, y_balanced)
327
+ train_time = time.time() - start_time
328
+ logging.info(f"Training {clf_name} with dataset balanced with {technique_name} complete"
329
+ f"(Time Taken: {format_time(train_time)})")
330
+
331
+ # Initialise metrics for this technique
332
+ start_time = time.time()
333
+ logging.info(f"Getting standard metrics of {technique_name} after training {clf_name}...")
334
+ technique_metrics = {
335
+ "standard_metrics": get_metrics(
336
+ clf_instance,
337
+ self.X_test,
338
+ self.y_test,
339
+ )
340
+ }
341
+ std_metrics_time = time.time() - start_time
342
+ logging.info(f"Getting standard metrics of {technique_name} after training {clf_name} complete"
343
+ f"(Time Taken: {format_time(std_metrics_time)})")
344
+
345
+ # Add cross-validation metrics if enabled
346
+ if enable_cv:
347
+ start_time = time.time()
348
+ logging.info(f"Getting cv metrics of {technique_name} after training {clf_name}...")
349
+ technique_metrics["cv_metrics"] = get_cv_scores(
350
+ clf_class(**clf_params),
351
+ X_balanced,
352
+ y_balanced,
353
+ n_folds=cv_folds,
354
+ )
355
+ cv_metrics_time = time.time() - start_time
356
+ logging.info(f"Getting cv metrics of {technique_name} after training {clf_name} complete"
357
+ f"(Time Taken: {format_time(cv_metrics_time)})")
358
+
359
+ classifier_results[technique_name] = technique_metrics
360
+
361
+ except Exception as e:
362
+ logging.error(
363
+ f"Error training classifier '{clf_name}' with technique '{technique_name}': {str(e)}"
364
+ )
365
+ continue
366
+
367
+ # Only add classifier results if at least one technique was successful
368
+ if classifier_results:
369
+ results[clf_name] = classifier_results
370
+
371
+ # Update the overall results
372
+ self.results = results
373
+
374
+ return results
375
+
376
+ def save_results(
377
+ self,
378
+ file_path: Union[str, Path],
379
+ file_type: str = "csv",
380
+ include_plots: bool = True,
381
+ ) -> None:
382
+ """
383
+ Save comparison results to a file.
384
+
385
+ Args:
386
+ file_path: Path to save the results
387
+ file_type: Type of file ('csv' or 'json')
388
+ include_plots: Whether to save visualisation plots
389
+ """
390
+ if not self.results:
391
+ raise ValueError("No results to save. Run compare_techniques() first.")
392
+
393
+ file_path = Path(file_path)
394
+
395
+ # Save results
396
+ if file_type == "csv":
397
+ pd.DataFrame(self.results).to_csv(file_path)
398
+ elif file_type == "json":
399
+ pd.DataFrame(self.results).to_json(file_path)
400
+ else:
401
+ raise ValueError(f"Unsupported file type: {file_type}")
402
+
403
+ # Save plots if requested
404
+ if include_plots:
405
+ plot_path = file_path.parent / f"{file_path.stem}_plots.png"
406
+ plot_comparison_results(self.results, save_path=plot_path)
407
+
408
+ def save_classifier_results(
409
+ self,
410
+ file_path: Union[str, Path],
411
+ classifier_name: str,
412
+ metric_type: str = "standard_metrics",
413
+ file_type: str = "csv",
414
+ ) -> None:
415
+ """
416
+ Save results for a specific classifier and metric type to a file.
417
+
418
+ Args:
419
+ file_path: Path to save the results
420
+ classifier_name: Name of the classifier to extract results for
421
+ metric_type: Type of metrics to save ('standard_metrics' or 'cv_metrics')
422
+ file_type: Type of file ('csv' or 'json')
423
+ """
424
+ if not self.results:
425
+ raise ValueError("No results to save. Run train_classifiers() first.")
426
+
427
+ if classifier_name not in self.results:
428
+ raise ValueError(f"Classifier '{classifier_name}' not found in results.")
429
+
430
+ file_path = Path(file_path)
431
+
432
+ # Extract results for this classifier
433
+ classifier_results = self.results[classifier_name]
434
+
435
+ # Create a dictionary where keys are techniques and values are the metrics
436
+ extracted_results = {}
437
+ for technique_name, technique_data in classifier_results.items():
438
+ if metric_type in technique_data:
439
+ extracted_results[technique_name] = technique_data[metric_type]
440
+
441
+ # Convert to DataFrame for easier saving
442
+ results_df = pd.DataFrame(extracted_results)
443
+
444
+ # Save results in requested format
445
+ if file_type == "csv":
446
+ results_df.to_csv(file_path)
447
+ elif file_type == "json":
448
+ results_df.to_json(file_path)
449
+ else:
450
+ raise ValueError(f"Unsupported file type: {file_type}")
451
+
452
+ logging.info(f"Saved {classifier_name} {metric_type} results to {file_path}")
453
+
454
+ def generate_balanced_data(
455
+ self,
456
+ folder_path: str,
457
+ techniques: Optional[List[str]] = None,
458
+ file_format: str = "csv",
459
+ ) -> None:
460
+ """
461
+ Save balanced datasets to files for specified techniques.
462
+
463
+ Args:
464
+ folder_path: Directory to save the datasets.
465
+ techniques: List of techniques to save. Saves all if None.
466
+ file_format: Format for saving the data ('csv' or 'json').
467
+
468
+ Raises:
469
+ ValueError if no balanced datasets are available or specified techniques are invalid.
470
+ """
471
+
472
+ # Validate datasets exist
473
+ if not self.current_balanced_datasets:
474
+ raise ValueError(
475
+ "No balanced datasets available. Run compare_techniques first."
476
+ )
477
+
478
+ # Validate output format
479
+ if file_format not in ["csv", "json"]:
480
+ raise ValueError("Invalid file format. Supported formats: 'csv', 'json'.")
481
+
482
+ # Create output folder
483
+ os.makedirs(folder_path, exist_ok=True)
484
+
485
+ # Determine techniques to save
486
+ if techniques is None:
487
+ techniques = list(self.current_balanced_datasets.keys())
488
+
489
+ # Retrieve input data column names
490
+ feature_columns = self.current_data_info.get("feature_columns")
491
+ target_column = self.current_data_info.get("target_column")
492
+ if feature_columns is None or target_column is None:
493
+ raise ValueError(
494
+ "Original column names are missing in 'current_data_info'."
495
+ )
496
+
497
+ # Export datasets
498
+ for technique in techniques:
499
+ if technique not in self.current_balanced_datasets:
500
+ raise ValueError(
501
+ f"Technique '{technique}' not found in current datasets."
502
+ )
503
+
504
+ # Retrieve data
505
+ dataset = self.current_balanced_datasets[technique]
506
+ X_balanced = dataset["X_balanced"]
507
+ y_balanced = dataset["y_balanced"]
508
+
509
+ # Combine into a single DataFrame
510
+ balanced_df = pd.DataFrame(X_balanced, columns=feature_columns)
511
+ balanced_df[target_column] = y_balanced
512
+
513
+ # Construct file path
514
+ file_path = os.path.join(folder_path, f"balanced_{technique}.{file_format}")
515
+
516
+ # Save in the chosen format
517
+ if file_format == "csv":
518
+ balanced_df.to_csv(file_path, index=False)
519
+ elif file_format == "json":
520
+ balanced_df.to_json(file_path, index=False)
521
+
522
+ logging.info(f"Saved balanced dataset for '{technique}' to {file_path}")
523
+
524
+ def compare_balanced_class_distributions(
525
+ self, save_path: Optional[str] = None, display: bool = False
526
+ ) -> None:
527
+ """
528
+ Compare class distributions of balanced datasets for all techniques.
529
+
530
+ Args:
531
+ save_path: Path to save the visualisation (optional).
532
+
533
+ Raises:
534
+ ValueError: If no balanced datasets are available.
535
+ """
536
+ if not self.current_balanced_datasets:
537
+ raise ValueError(
538
+ "No balanced datasets available. Run compare_techniques first."
539
+ )
540
+
541
+ # Generate class distributions for each balanced dataset
542
+ distributions = {}
543
+ for technique, dataset in self.current_balanced_datasets.items():
544
+ y_balanced = dataset["y_balanced"]
545
+
546
+ # Generate class distribution
547
+ distribution = self.preprocessor.inspect_class_distribution(y_balanced)
548
+ distributions[technique] = distribution
549
+
550
+ # Call the visualisation function
551
+ plot_class_distributions_comparison(
552
+ distributions,
553
+ title="Class Distribution Comparison After Balancing",
554
+ save_path=save_path,
555
+ display=display,
556
+ )
557
+
558
+ def generate_learning_curves(
559
+ self,
560
+ classifier_name: str,
561
+ train_sizes: np.ndarray = np.linspace(0.1, 1.0, 10),
562
+ n_folds: int = 5,
563
+ save_path: Optional[str] = None,
564
+ display: bool = False,
565
+ ) -> None:
566
+ """
567
+ Generate and plot learning curves for multiple balancing techniques.
568
+
569
+ Args:
570
+ classifier_name: Name of the classifier to generate curves for
571
+ train_sizes: Training set sizes to evaluate
572
+ n_folds: Number of cross-validation folds
573
+ save_path: Path to save the plot (optional)
574
+ display: Whether to display the plot
575
+ """
576
+ if not self.current_balanced_datasets:
577
+ raise ValueError(
578
+ "No balanced datasets available. Run apply_balancing_techniques first."
579
+ )
580
+
581
+ try:
582
+ # Get the classifier class
583
+ clf_class = self.classifier_registry.get_classifier_class(classifier_name)
584
+ if clf_class is None:
585
+ logging.warning(
586
+ f"Classifier '{classifier_name}' not found. Skipping learning curves."
587
+ )
588
+ return
589
+
590
+ # Get classifier parameters from configuration
591
+ clf_params = {}
592
+ if (
593
+ hasattr(self, "classifier_configs")
594
+ and classifier_name in self.classifier_configs
595
+ ):
596
+ clf_params = self.classifier_configs[classifier_name]
597
+
598
+ # Create classifier instance with the same parameters used in training
599
+ classifier = clf_class(**clf_params)
600
+
601
+ learning_curve_data = get_learning_curve_data_multiple_techniques(
602
+ classifier_name=classifier_name,
603
+ classifier=classifier,
604
+ techniques_data=self.current_balanced_datasets,
605
+ train_sizes=train_sizes,
606
+ n_folds=n_folds,
607
+ )
608
+
609
+ title = f"{classifier_name} - Learning Curves"
610
+
611
+ plot_learning_curves(
612
+ learning_curve_data, title=title, save_path=save_path, display=display
613
+ )
614
+
615
+ except Exception as e:
616
+ logging.warning(
617
+ f"Failed to generate learning curves for classifier '{classifier_name}': {str(e)}"
618
+ )
619
+ logging.warning("Continuing with other visualisations...")
620
+
621
+ def _get_class_distribution(self) -> Dict[Any, int]:
622
+ """Get the distribution of classes in the target variable."""
623
+ return self.preprocessor.inspect_class_distribution(self.y)
624
+
625
+ def _handle_quality_issues(
626
+ self, quality_report: Dict[str, Any], correlation_threshold: float = 0.95
627
+ ) -> None:
628
+ """Handle any data quality issues found."""
629
+ warnings = []
630
+
631
+ # Check if there are any missing values (now a list of tuples)
632
+ if quality_report["missing_values"]:
633
+ missing_value_info = ", ".join(
634
+ [f"{name}: {count}" for name, count in quality_report["missing_values"]]
635
+ )
636
+ if len(quality_report["missing_values"]) == 1:
637
+ warnings.append(
638
+ f"Data contains missing values in feature: {missing_value_info}"
639
+ )
640
+ else:
641
+ warnings.append(
642
+ f"Data contains missing values in features: {missing_value_info}"
643
+ )
644
+
645
+ # Check if there are any constant features (now a list of tuples)
646
+ if quality_report["constant_features"]:
647
+ # Extract feature names from the tuples
648
+ constant_feature_names = [
649
+ name for name, _ in quality_report["constant_features"]
650
+ ]
651
+ if len(constant_feature_names) == 1:
652
+ warnings.append(
653
+ f"Constant Features: {constant_feature_names} has constant values"
654
+ )
655
+ else:
656
+ warnings.append(
657
+ f"Constant Features: {constant_feature_names} have constant values"
658
+ )
659
+
660
+ # Check if there are any highly correlated features (already a list)
661
+ if quality_report["feature_correlations"]:
662
+ # Format the correlation information more readably
663
+ correlation_info = ", ".join(
664
+ [
665
+ f"{col1} & {col2} ({corr:.2f})"
666
+ for col1, col2, corr in quality_report["feature_correlations"]
667
+ ]
668
+ )
669
+ warnings.append(
670
+ f"Found highly correlated features (Threshold={correlation_threshold}): {correlation_info}"
671
+ )
672
+
673
+ # Print all warnings
674
+ if warnings:
675
+ print("Data Quality Warnings:")
676
+ for warning in warnings:
677
+ print(f"- {warning}")