balancr 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1838 @@
1
+ """
2
+ commands.py - Command handlers for the balancr CLI.
3
+
4
+ This module contains the implementation of all command functions that are
5
+ registered in main.py
6
+ """
7
+
8
+ import shutil
9
+ import time
10
+ from datetime import datetime
11
+ import importlib
12
+ import logging
13
+ import json
14
+ import inspect
15
+ from balancr.data import DataPreprocessor
16
+ import numpy as np
17
+ from pathlib import Path
18
+
19
+ from balancr.evaluation import (
20
+ plot_comparison_results,
21
+ plot_radar_chart,
22
+ plot_3d_scatter,
23
+ )
24
+ import pandas as pd
25
+
26
+ from . import config
27
+ from balancr import BaseBalancer
28
+
29
+ # Will be used to interact with the core balancing framework
30
+ try:
31
+ from balancr import BalancingFramework
32
+ from balancr import TechniqueRegistry
33
+ from balancr import ClassifierRegistry
34
+ except ImportError as e:
35
+ logging.error(f"Could not import balancing framework: {str(e)}")
36
+ logging.error(
37
+ "Could not import balancing framework. Ensure it's installed correctly."
38
+ )
39
+ BalancingFramework = None
40
+ TechniqueRegistry = None
41
+ ClassifierRegistry = None
42
+
43
+
44
+ def load_data(args):
45
+ """
46
+ Handle the load-data command.
47
+
48
+ Args:
49
+ args: Command line arguments from argparse
50
+
51
+ Returns:
52
+ int: Exit code
53
+ """
54
+ logging.info(f"Loading data from {args.file_path}")
55
+
56
+ # Validate file exists
57
+ if not Path(args.file_path).exists():
58
+ logging.error(f"File not found: {args.file_path}")
59
+ return 1
60
+
61
+ # Update configuration with data file settings
62
+ settings = {
63
+ "data_file": args.file_path,
64
+ "target_column": args.target_column,
65
+ }
66
+
67
+ if args.feature_columns:
68
+ settings["feature_columns"] = args.feature_columns
69
+
70
+ try:
71
+ # Validate that the file can be loaded
72
+ if BalancingFramework is not None:
73
+ # This is just a validation check, not storing the framework instance
74
+ framework = BalancingFramework()
75
+
76
+ # Try to get correlation_threshold from config
77
+ try:
78
+ current_config = config.load_config(args.config_path)
79
+ correlation_threshold = current_config["preprocessing"].get(
80
+ "correlation_threshold", 0.95
81
+ )
82
+ except Exception:
83
+ # Fall back to default if config can't be loaded or doesn't have the value
84
+ correlation_threshold = 0.95
85
+
86
+ print(f"Correlation Threshold: {correlation_threshold}")
87
+ framework.load_data(
88
+ args.file_path,
89
+ args.target_column,
90
+ args.feature_columns,
91
+ correlation_threshold=correlation_threshold,
92
+ )
93
+
94
+ # Get and display class distribution
95
+ distribution = framework.inspect_class_distribution(display=False)
96
+ print("\nClass Distribution:")
97
+ for cls, count in distribution.items():
98
+ print(f" Class {cls}: {count} samples")
99
+
100
+ total = sum(distribution.values())
101
+ for cls, count in distribution.items():
102
+ pct = (count / total) * 100
103
+ print(f" Class {cls}: {pct:.2f}%")
104
+
105
+ # Update config with new settings
106
+ config.update_config(args.config_path, settings)
107
+ logging.info(
108
+ f"Data configuration saved: {args.file_path}, target: {args.target_column}"
109
+ )
110
+ return 0
111
+
112
+ except Exception as e:
113
+ logging.error(f"Failed to load data: {e}")
114
+ return 1
115
+
116
+
117
+ def preprocess(args):
118
+ """
119
+ Handle the preprocess command.
120
+
121
+ Args:
122
+ args: Command line arguments from argparse
123
+
124
+ Returns:
125
+ int: Exit code
126
+ """
127
+ logging.info("Configuring preprocessing options")
128
+
129
+ # Initialise basic preprocessing settings
130
+ settings = {
131
+ "preprocessing": {
132
+ "handle_missing": args.handle_missing,
133
+ "handle_constant_features": args.handle_constant_features,
134
+ "handle_correlations": args.handle_correlations,
135
+ "correlation_threshold": args.correlation_threshold,
136
+ "scale": args.scale,
137
+ "encode": args.encode,
138
+ "save_preprocessed": args.save_preprocessed,
139
+ }
140
+ }
141
+
142
+ try:
143
+ current_config = config.load_config(args.config_path)
144
+
145
+ # Check if categorical features are specified
146
+ if args.categorical_features:
147
+ # Process categorical features if a dataset is available
148
+ if "data_file" in current_config:
149
+ data_file = current_config["data_file"]
150
+ target_column = current_config.get("target_column")
151
+
152
+ logging.info(f"Loading dataset from {data_file} for encoding analysis")
153
+
154
+ try:
155
+ # Initialise framework to load data
156
+ # Read the dataset directly
157
+ df = pd.read_csv(data_file)
158
+
159
+ # Validate target column exists
160
+ if target_column and target_column not in df.columns:
161
+ logging.warning(
162
+ f"Target column '{target_column}' not found in dataset"
163
+ )
164
+
165
+ # Create preprocessor and determine encoding types
166
+ preprocessor = DataPreprocessor()
167
+ categorical_encodings = preprocessor.assign_encoding_types(
168
+ df=df,
169
+ categorical_columns=args.categorical_features,
170
+ encoding_type=args.encode,
171
+ hash_components=args.hash_components,
172
+ ordinal_columns=args.ordinal_features,
173
+ )
174
+
175
+ # Add categorical feature encodings to settings
176
+ settings["preprocessing"][
177
+ "categorical_features"
178
+ ] = categorical_encodings
179
+
180
+ # Display encoding recommendations
181
+ print("\nCategorical feature encoding assignments:")
182
+ for column, encoding in categorical_encodings.items():
183
+ print(f" {column}: {encoding}")
184
+ except Exception as e:
185
+ logging.error(f"Error analysing dataset for encoding: {e}")
186
+ logging.info(
187
+ "Storing categorical features without encoding recommendations"
188
+ )
189
+ # If analysis fails, just store the categorical features with the default encoding
190
+ settings["preprocessing"]["categorical_features"] = {
191
+ col: args.encode for col in args.categorical_features
192
+ }
193
+ else:
194
+ logging.warning(
195
+ "No dataset configured. Cannot analyse categorical features."
196
+ )
197
+ logging.info("Storing categorical features with default encoding")
198
+ # Store categorical features with the specified encoding type
199
+ settings["preprocessing"]["categorical_features"] = {
200
+ col: args.encode for col in args.categorical_features
201
+ }
202
+
203
+ # Update config
204
+ config.update_config(args.config_path, settings)
205
+ logging.info("Preprocessing configuration saved")
206
+
207
+ # Display the preprocessing settings
208
+ print("\nPreprocessing Configuration:")
209
+ print(f" Handle Missing Values: {args.handle_missing}")
210
+ print(f" Handle Constant Features: {args.handle_constant_features}")
211
+ print(f" Handle Feature Correlations: {args.handle_correlations}")
212
+ print(f" Correlation Threshold: {args.correlation_threshold}")
213
+ print(f" Feature Scaling: {args.scale}")
214
+ print(f" Default Categorical Encoding: {args.encode}")
215
+ print(f" Save Preprocessed Data to File: {args.save_preprocessed}")
216
+
217
+ if args.categorical_features:
218
+ print(f" Categorical Features: {', '.join(args.categorical_features)}")
219
+
220
+ if args.ordinal_features:
221
+ print(f" Ordinal Features: {', '.join(args.ordinal_features)}")
222
+
223
+ return 0
224
+
225
+ except Exception as e:
226
+ logging.error(f"Failed to configure preprocessing: {e}")
227
+ return 1
228
+
229
+
230
+ def select_techniques(args):
231
+ """
232
+ Handle the select-techniques command.
233
+
234
+ Args:
235
+ args: Command line arguments from argparse
236
+
237
+ Returns:
238
+ int: Exit code
239
+ """
240
+ # List available techniques if requested
241
+ if args.list_available and BalancingFramework is not None:
242
+ print("Listing available balancing techniques...")
243
+ try:
244
+ framework = BalancingFramework()
245
+ techniques = framework.list_available_techniques()
246
+
247
+ print("\nAvailable Techniques:")
248
+
249
+ # Print custom techniques
250
+ if techniques.get("custom"):
251
+ print("\nCustom Techniques:")
252
+ for technique in techniques["custom"]:
253
+ print(f" - {technique}")
254
+
255
+ # Print imblearn techniques
256
+ if techniques.get("imblearn"):
257
+ print("\nImbalanced-Learn Techniques:")
258
+ for technique in sorted(techniques["imblearn"]):
259
+ print(f" - {technique}")
260
+
261
+ return 0
262
+
263
+ except Exception as e:
264
+ logging.error(f"Failed to list techniques: {e}")
265
+ return 1
266
+
267
+ # When not listing but selecting techniques
268
+ logging.info(f"Selecting balancing techniques: {', '.join(args.techniques)}")
269
+
270
+ try:
271
+ # Validate techniques exist if framework is available
272
+ if BalancingFramework is not None:
273
+ framework = BalancingFramework()
274
+ available = framework.list_available_techniques()
275
+ all_techniques = available.get("custom", []) + available.get("imblearn", [])
276
+
277
+ invalid_techniques = [t for t in args.techniques if t not in all_techniques]
278
+
279
+ if invalid_techniques:
280
+ logging.error(f"Invalid techniques: {', '.join(invalid_techniques)}")
281
+ logging.info(
282
+ "Use 'balancr select-techniques --list-available' to see available techniques"
283
+ )
284
+ return 1
285
+
286
+ # Create technique configurations with default parameters
287
+ balancing_techniques = {}
288
+ for technique_name in args.techniques:
289
+ # Get default parameters for this technique
290
+ params = framework.technique_registry.get_technique_default_params(
291
+ technique_name
292
+ )
293
+ balancing_techniques[technique_name] = params
294
+
295
+ # Read existing config
296
+ current_config = config.load_config(args.config_path)
297
+
298
+ if args.append and "balancing_techniques" in current_config:
299
+ # Append mode: Update existing techniques
300
+ existing_techniques = current_config.get("balancing_techniques", {})
301
+
302
+ # Add new techniques to the existing ones
303
+ if BalancingFramework is not None:
304
+ existing_techniques.update(balancing_techniques)
305
+
306
+ # Update config with merged values
307
+ settings = {
308
+ "balancing_techniques": existing_techniques,
309
+ "include_original_data": args.include_original_data,
310
+ }
311
+
312
+ config.update_config(args.config_path, settings)
313
+
314
+ print(f"\nAdded balancing techniques: {', '.join(args.techniques)}")
315
+ print(f"Total techniques: {', '.join(existing_techniques.keys())}")
316
+ else:
317
+ # Replace mode: Create a completely new config entry
318
+ # Create a copy of the current config and set the techniques
319
+ new_config = dict(current_config) # shallow copy
320
+ new_config["balancing_techniques"] = (
321
+ balancing_techniques if BalancingFramework is not None else {}
322
+ )
323
+ new_config["include_original_data"] = args.include_original_data
324
+
325
+ # Directly write the entire config to replace the file
326
+ config_path = Path(args.config_path)
327
+ with open(config_path, "w") as f:
328
+ json.dump(new_config, f, indent=2)
329
+
330
+ print(f"\nReplaced balancing techniques with: {', '.join(args.techniques)}")
331
+
332
+ # Return early since we've manually written the config
333
+ print("Default parameters have been added to the configuration file.")
334
+ print("You can modify them by editing the configuration file.")
335
+ return 0
336
+
337
+ print("Default parameters have been added to the configuration file.")
338
+ print("You can modify them by editing the configuration file.")
339
+ return 0
340
+
341
+ except Exception as e:
342
+ logging.error(f"Failed to select techniques: {e}")
343
+ return 1
344
+
345
+
346
+ def register_techniques(args):
347
+ """
348
+ Handle the register-techniques command.
349
+
350
+ This command allows users to register custom balancing techniques from
351
+ Python files or folders for use in comparisons.
352
+
353
+ Args:
354
+ args: Command line arguments from argparse
355
+
356
+ Returns:
357
+ int: Exit code
358
+ """
359
+ try:
360
+ # Handle removal operations
361
+ if args.remove or args.remove_all:
362
+ return _remove_techniques(args)
363
+
364
+ # Ensure framework is available
365
+ if TechniqueRegistry is None:
366
+ logging.error(
367
+ "Technique registry not available. Please check installation."
368
+ )
369
+ return 1
370
+
371
+ registry = TechniqueRegistry()
372
+
373
+ # Track successfully registered techniques
374
+ registered_techniques = []
375
+
376
+ # Process file path
377
+ if args.file_path:
378
+ file_path = Path(args.file_path)
379
+
380
+ if not file_path.exists():
381
+ logging.error(f"File not found: {file_path}")
382
+ return 1
383
+
384
+ if not file_path.is_file() or file_path.suffix.lower() != ".py":
385
+ logging.error(f"Not a Python file: {file_path}")
386
+ return 1
387
+
388
+ # Register techniques from the file
389
+ registered = _register_from_file(
390
+ registry, file_path, args.name, args.class_name, args.overwrite
391
+ )
392
+ registered_techniques.extend(registered)
393
+
394
+ # Process folder path
395
+ elif args.folder_path:
396
+ folder_path = Path(args.folder_path)
397
+
398
+ if not folder_path.exists():
399
+ logging.error(f"Folder not found: {folder_path}")
400
+ return 1
401
+
402
+ if not folder_path.is_dir():
403
+ logging.error(f"Not a directory: {folder_path}")
404
+ return 1
405
+
406
+ # Register techniques from all Python files in the folder
407
+ for py_file in folder_path.glob("*.py"):
408
+ registered = _register_from_file(
409
+ registry,
410
+ py_file,
411
+ None, # Don't use custom name for folder scanning
412
+ None, # Don't use class name for folder scanning
413
+ args.overwrite,
414
+ )
415
+ registered_techniques.extend(registered)
416
+
417
+ # Print summary
418
+ if registered_techniques:
419
+ print("\nSuccessfully registered techniques:")
420
+ for technique in registered_techniques:
421
+ print(f" - {technique}")
422
+
423
+ # Suggestion for next steps
424
+ print("\nYou can now use these techniques in comparisons. For example:")
425
+ print(f" balancr select-techniques {registered_techniques[0]}")
426
+ return 0
427
+ else:
428
+ logging.warning("No valid balancing techniques found to register.")
429
+ return 1
430
+
431
+ except Exception as e:
432
+ logging.error(f"Error registering techniques: {e}")
433
+ if args.verbose:
434
+ import traceback
435
+
436
+ traceback.print_exc()
437
+ return 1
438
+
439
+
440
+ def _register_from_file(
441
+ registry, file_path, custom_name=None, class_name=None, overwrite=False
442
+ ):
443
+ """
444
+ Register technique classes from a Python file and copy to custom techniques directory.
445
+
446
+ Args:
447
+ registry: TechniqueRegistry instance
448
+ file_path: Path to the Python file
449
+ custom_name: Custom name to register the technique under
450
+ class_name: Name of specific class to register
451
+ overwrite: Whether to overwrite existing techniques
452
+
453
+ Returns:
454
+ list: Names of successfully registered techniques
455
+ """
456
+ registered_techniques = []
457
+
458
+ try:
459
+ # Create custom techniques directory if it doesn't exist
460
+ custom_dir = Path.home() / ".balancr" / "custom_techniques"
461
+ custom_dir.mkdir(parents=True, exist_ok=True)
462
+
463
+ # Import the module dynamically
464
+ module_name = file_path.stem
465
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
466
+ if spec is None or spec.loader is None:
467
+ logging.error(f"Could not load module from {file_path}")
468
+ return registered_techniques
469
+
470
+ module = importlib.util.module_from_spec(spec)
471
+ spec.loader.exec_module(module)
472
+
473
+ # Find all classes that inherit from BaseBalancer
474
+ technique_classes = []
475
+ for name, obj in inspect.getmembers(module, inspect.isclass):
476
+ if (
477
+ obj.__module__
478
+ == module_name # Only consider classes defined in this module
479
+ and issubclass(obj, BaseBalancer)
480
+ and obj is not BaseBalancer
481
+ ): # Exclude the base class itself
482
+ technique_classes.append((name, obj))
483
+
484
+ # If no valid classes found
485
+ if not technique_classes:
486
+ logging.warning(f"No valid technique classes found in {file_path}")
487
+ logging.info(
488
+ "Classes must inherit from balancr.base.BaseBalancer"
489
+ )
490
+ return registered_techniques
491
+
492
+ # If class_name is specified, filter to just that class
493
+ if class_name:
494
+ technique_classes = [
495
+ (name, cls) for name, cls in technique_classes if name == class_name
496
+ ]
497
+ if not technique_classes:
498
+ logging.error(
499
+ f"Class '{class_name}' not found in {file_path} or doesn't inherit from BaseBalancer"
500
+ )
501
+ return registered_techniques
502
+
503
+ # If requesting a custom name but multiple classes found and no class_name specified
504
+ if custom_name and len(technique_classes) > 1 and not class_name:
505
+ logging.error(
506
+ f"Multiple technique classes found in {file_path}, but custom name provided. "
507
+ "Please specify which class to register with --class-name."
508
+ )
509
+ return registered_techniques
510
+
511
+ # Register techniques
512
+ for name, cls in technique_classes:
513
+ # If this is a specifically requested class with a custom name
514
+ if class_name and name == class_name and custom_name:
515
+ register_name = custom_name
516
+ else:
517
+ register_name = name
518
+
519
+ try:
520
+ # Check if technique already exists
521
+ existing_techniques = registry.list_available_techniques()
522
+ if (
523
+ register_name in existing_techniques.get("custom", [])
524
+ and not overwrite
525
+ ):
526
+ logging.warning(
527
+ f"Technique '{register_name}' already exists. "
528
+ "Use --overwrite to replace it."
529
+ )
530
+ continue
531
+
532
+ # Register the technique
533
+ registry.register_custom_technique(register_name, cls)
534
+ registered_techniques.append(register_name)
535
+ logging.info(f"Successfully registered technique: {register_name}")
536
+
537
+ except Exception as e:
538
+ logging.error(f"Error registering technique '{register_name}': {e}")
539
+
540
+ # For successfully registered techniques, copy the file
541
+ if registered_techniques:
542
+ # Generate a unique filename (in case multiple files have same name)
543
+ dest_file = custom_dir / f"{file_path.stem}_{hash(str(file_path))}.py"
544
+ shutil.copy2(file_path, dest_file)
545
+ logging.debug(f"Copied {file_path} to {dest_file}")
546
+
547
+ # Create a metadata file to map technique names to files
548
+ metadata_file = custom_dir / "techniques_metadata.json"
549
+ metadata = {}
550
+ if metadata_file.exists():
551
+ with open(metadata_file, "r") as f:
552
+ metadata = json.load(f)
553
+
554
+ # Update metadata with new techniques
555
+ class_mapping = {cls_name: cls_name for cls_name, _ in technique_classes}
556
+ if class_name and custom_name:
557
+ class_mapping[custom_name] = class_name
558
+
559
+ for technique_name in registered_techniques:
560
+ original_class = class_mapping.get(technique_name, technique_name)
561
+ metadata[technique_name] = {
562
+ "file": str(dest_file),
563
+ "class_name": original_class,
564
+ "registered_at": datetime.now().isoformat(),
565
+ }
566
+
567
+ with open(metadata_file, "w") as f:
568
+ json.dump(metadata, f, indent=2)
569
+
570
+ except Exception as e:
571
+ logging.error(f"Error processing file {file_path}: {e}")
572
+
573
+ return registered_techniques
574
+
575
+
576
+ def _remove_techniques(args):
577
+ """
578
+ Remove custom techniques as specified in the args.
579
+
580
+ Args:
581
+ args: Command line arguments from argparse
582
+
583
+ Returns:
584
+ int: Exit code
585
+ """
586
+ # Get path to custom techniques directory
587
+ custom_dir = Path.home() / ".balancr" / "custom_techniques"
588
+ metadata_file = custom_dir / "techniques_metadata.json"
589
+
590
+ # Check if metadata file exists
591
+ if not metadata_file.exists():
592
+ logging.error("No custom techniques have been registered.")
593
+ return 1
594
+
595
+ # Load metadata
596
+ with open(metadata_file, "r") as f:
597
+ metadata = json.load(f)
598
+
599
+ # If no custom techniques
600
+ if not metadata:
601
+ logging.error("No custom techniques have been registered.")
602
+ return 1
603
+
604
+ # Remove all custom techniques
605
+ if args.remove_all:
606
+ logging.info("Removing all custom techniques...")
607
+
608
+ # Remove all technique files
609
+ file_paths = set(info["file"] for info in metadata.values())
610
+ for file_path in file_paths:
611
+ try:
612
+ Path(file_path).unlink(missing_ok=True)
613
+ except Exception as e:
614
+ logging.warning(f"Error removing file {file_path}: {e}")
615
+
616
+ # Clear metadata
617
+ metadata = {}
618
+ with open(metadata_file, "w") as f:
619
+ json.dump(metadata, f, indent=2)
620
+
621
+ print("All custom techniques have been removed.")
622
+ return 0
623
+
624
+ # Remove specific techniques
625
+ removed_techniques = []
626
+ for technique_name in args.remove:
627
+ if technique_name in metadata:
628
+ # Note the file path (we'll check if it's used by other techniques)
629
+ file_path = metadata[technique_name]["file"]
630
+
631
+ # Remove from metadata
632
+ del metadata[technique_name]
633
+ removed_techniques.append(technique_name)
634
+
635
+ # Check if the file is still used by other techniques
636
+ file_still_used = any(
637
+ info["file"] == file_path for info in metadata.values()
638
+ )
639
+
640
+ # If not used, remove the file
641
+ if not file_still_used:
642
+ try:
643
+ Path(file_path).unlink(missing_ok=True)
644
+ except Exception as e:
645
+ logging.warning(f"Error removing file {file_path}: {e}")
646
+ else:
647
+ logging.warning(f"Technique '{technique_name}' not found.")
648
+
649
+ # Save updated metadata
650
+ with open(metadata_file, "w") as f:
651
+ json.dump(metadata, f, indent=2)
652
+
653
+ if removed_techniques:
654
+ print("\nRemoved techniques:")
655
+ for technique in removed_techniques:
656
+ print(f" - {technique}")
657
+ return 0
658
+ else:
659
+ logging.error("No matching techniques were found.")
660
+ return 1
661
+
662
+
663
+ def select_classifier(args):
664
+ """
665
+ Handle the select-classifier command.
666
+
667
+ Args:
668
+ args: Command line arguments from argparse
669
+
670
+ Returns:
671
+ int: Exit code
672
+ """
673
+ # Check if we should list available classifiers
674
+ if args.list_available:
675
+ return list_available_classifiers(args)
676
+
677
+ logging.info(f"Selecting classifiers: {', '.join(args.classifiers)}")
678
+
679
+ # Create classifier registry
680
+ if ClassifierRegistry is None:
681
+ logging.error("Classifier registry not available. Please check installation.")
682
+ return 1
683
+
684
+ registry = ClassifierRegistry()
685
+
686
+ # Get classifier configurations
687
+ classifier_configs = {}
688
+
689
+ for classifier_name in args.classifiers:
690
+ # Get the classifier class
691
+ classifier_class = registry.get_classifier_class(classifier_name)
692
+
693
+ if classifier_class is None:
694
+ logging.error(f"Classifier '{classifier_name}' not found.")
695
+ logging.info(
696
+ "Use 'balancr select-classifier --list-available' to see available classifiers."
697
+ )
698
+ continue
699
+
700
+ # Get default parameters
701
+ params = get_classifier_default_params(classifier_class)
702
+ classifier_configs[classifier_name] = params
703
+
704
+ # If no valid classifiers were found
705
+ if not classifier_configs:
706
+ logging.error("No valid classifiers selected.")
707
+ return 1
708
+
709
+ try:
710
+ # Read existing config (we need this regardless of append mode)
711
+ current_config = config.load_config(args.config_path)
712
+
713
+ if args.append:
714
+ # Append mode: Update existing classifiers
715
+ existing_classifiers = current_config.get("classifiers", {})
716
+ existing_classifiers.update(classifier_configs)
717
+ settings = {"classifiers": existing_classifiers}
718
+
719
+ print(f"\nAdded classifiers: {', '.join(classifier_configs.keys())}")
720
+ print(f"Total classifiers: {', '.join(existing_classifiers.keys())}")
721
+ else:
722
+ # Replace mode: Create a completely new config entry
723
+ # We'll create a copy of the current config and explicitly set the classifiers
724
+ new_config = dict(current_config) # shallow copy is sufficient
725
+ new_config["classifiers"] = classifier_configs
726
+
727
+ # Use config.write_config instead of update_config to replace the entire file
728
+ config_path = Path(args.config_path)
729
+ with open(config_path, "w") as f:
730
+ json.dump(new_config, f, indent=2)
731
+
732
+ print(
733
+ f"\nReplaced classifiers with: {', '.join(classifier_configs.keys())}"
734
+ )
735
+
736
+ # Return early since we've manually written the config
737
+ print("Default parameters have been added to the configuration file.")
738
+ print("You can modify them by editing the configuration or using the CLI.")
739
+ return 0
740
+
741
+ # Only reach here in append mode
742
+ config.update_config(args.config_path, settings)
743
+
744
+ print("Default parameters have been added to the configuration file.")
745
+ print("You can modify them by editing the configuration or using the CLI.")
746
+
747
+ return 0
748
+ except Exception as e:
749
+ logging.error(f"Failed to select classifiers: {e}")
750
+ return 1
751
+
752
+
753
+ def list_available_classifiers(args):
754
+ """
755
+ List all available classifiers.
756
+
757
+ Args:
758
+ args: Command line arguments from argparse
759
+
760
+ Returns:
761
+ int: Exit code
762
+ """
763
+ if ClassifierRegistry is None:
764
+ logging.error("Classifier registry not available. Please check installation.")
765
+ return 1
766
+
767
+ registry = ClassifierRegistry()
768
+ classifiers = registry.list_available_classifiers()
769
+
770
+ print("\nAvailable Classifiers:")
771
+
772
+ # Print custom classifiers if any
773
+ if "custom" in classifiers and classifiers["custom"]:
774
+ print("\nCustom Classifiers:")
775
+ for module_name, clf_list in classifiers["custom"].items():
776
+ if clf_list:
777
+ print(f"\n {module_name.capitalize()}:")
778
+ for clf in sorted(clf_list):
779
+ print(f" - {clf}")
780
+
781
+ # Print sklearn classifiers by module
782
+ if "sklearn" in classifiers:
783
+ print("\nScikit-learn Classifiers:")
784
+ for module_name, clf_list in classifiers["sklearn"].items():
785
+ print(f"\n {module_name.capitalize()}:")
786
+ for clf in sorted(clf_list):
787
+ print(f" - {clf}")
788
+
789
+ return 0
790
+
791
+
792
+ def get_classifier_default_params(classifier_class):
793
+ """
794
+ Extract default parameters from a classifier class.
795
+
796
+ Args:
797
+ classifier_class: The classifier class to inspect
798
+
799
+ Returns:
800
+ Dictionary of parameter names and their default values
801
+ """
802
+ params = {}
803
+
804
+ try:
805
+ # Get the signature of the __init__ method
806
+ sig = inspect.signature(classifier_class.__init__)
807
+
808
+ # Process each parameter
809
+ for name, param in sig.parameters.items():
810
+ # Skip 'self' parameter
811
+ if name == "self":
812
+ continue
813
+
814
+ # Get default value if it exists
815
+ if param.default is not inspect.Parameter.empty:
816
+ # Handle special case for None (JSON uses null)
817
+ if param.default is None:
818
+ params[name] = None
819
+ # Handle other types that can be serialised to JSON
820
+ elif isinstance(param.default, (int, float, str, bool, list, dict)):
821
+ params[name] = param.default
822
+ else:
823
+ # Convert non-JSON-serialisable defaults to string representation
824
+ params[name] = str(param.default)
825
+ else:
826
+ # For parameters without defaults, use None
827
+ params[name] = None
828
+
829
+ except Exception as e:
830
+ logging.warning(
831
+ f"Error extracting parameters from {classifier_class.__name__}: {e}"
832
+ )
833
+
834
+ return params
835
+
836
+
837
+ def register_classifiers(args):
838
+ """
839
+ Handle the register-classifiers command.
840
+
841
+ This command allows users to register custom classifiers from
842
+ Python files or folders, or remove existing custom classifiers.
843
+
844
+ Args:
845
+ args: Command line arguments from argparse
846
+
847
+ Returns:
848
+ int: Exit code
849
+ """
850
+ try:
851
+ # Ensure framework is available
852
+ if ClassifierRegistry is None:
853
+ logging.error(
854
+ "Classifier registry not available. Please check installation."
855
+ )
856
+ return 1
857
+
858
+ # Handle removal operations
859
+ if args.remove or args.remove_all:
860
+ return _remove_classifiers(args)
861
+
862
+ registry = ClassifierRegistry()
863
+
864
+ # Track successfully registered classifiers
865
+ registered_classifiers = []
866
+
867
+ # Process file path
868
+ if args.file_path:
869
+ file_path = Path(args.file_path)
870
+
871
+ if not file_path.exists():
872
+ logging.error(f"File not found: {file_path}")
873
+ return 1
874
+
875
+ if not file_path.is_file() or file_path.suffix.lower() != ".py":
876
+ logging.error(f"Not a Python file: {file_path}")
877
+ return 1
878
+
879
+ # Register classifiers from the file
880
+ registered = _register_classifier_from_file(
881
+ registry, file_path, args.name, args.class_name, args.overwrite
882
+ )
883
+ registered_classifiers.extend(registered)
884
+
885
+ # Process folder path
886
+ elif args.folder_path:
887
+ folder_path = Path(args.folder_path)
888
+
889
+ if not folder_path.exists():
890
+ logging.error(f"Folder not found: {folder_path}")
891
+ return 1
892
+
893
+ if not folder_path.is_dir():
894
+ logging.error(f"Not a directory: {folder_path}")
895
+ return 1
896
+
897
+ # Register classifiers from all Python files in the folder
898
+ for py_file in folder_path.glob("*.py"):
899
+ registered = _register_classifier_from_file(
900
+ registry,
901
+ py_file,
902
+ None, # Don't use custom name for folder scanning
903
+ None, # Don't use class name for folder scanning
904
+ args.overwrite,
905
+ )
906
+ registered_classifiers.extend(registered)
907
+
908
+ # Print summary
909
+ if registered_classifiers:
910
+ print("\nSuccessfully registered classifiers:")
911
+ for classifier in registered_classifiers:
912
+ print(f" - {classifier}")
913
+
914
+ print("\nYou can now use these classifiers in comparisons. For example:")
915
+ print(f" balancr select-classifiers {registered_classifiers[0]}")
916
+ return 0
917
+ else:
918
+ logging.warning("No valid classifiers found to register.")
919
+ return 1
920
+
921
+ except Exception as e:
922
+ logging.error(f"Error registering classifiers: {e}")
923
+ if args.verbose:
924
+ import traceback
925
+
926
+ traceback.print_exc()
927
+ return 1
928
+
929
+
930
+ def _register_classifier_from_file(
931
+ registry, file_path, custom_name=None, class_name=None, overwrite=False
932
+ ):
933
+ """
934
+ Register classifier classes from a Python file.
935
+
936
+ Args:
937
+ registry: ClassifierRegistry instance
938
+ file_path: Path to the Python file
939
+ custom_name: Custom name to register the classifier under
940
+ class_name: Name of specific class to register
941
+ overwrite: Whether to overwrite existing classifiers
942
+
943
+ Returns:
944
+ list: Names of successfully registered classifiers
945
+ """
946
+ registered_classifiers = []
947
+
948
+ try:
949
+ # Create custom classifiers directory if it doesn't exist
950
+ custom_dir = Path.home() / ".balancr" / "custom_classifiers"
951
+ custom_dir.mkdir(parents=True, exist_ok=True)
952
+
953
+ # Import the module dynamically
954
+ module_name = file_path.stem
955
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
956
+ if spec is None or spec.loader is None:
957
+ logging.error(f"Could not load module from {file_path}")
958
+ return registered_classifiers
959
+
960
+ module = importlib.util.module_from_spec(spec)
961
+ spec.loader.exec_module(module)
962
+
963
+ # Ensure sklearn BaseEstimator is available
964
+ try:
965
+ from sklearn.base import BaseEstimator
966
+ except ImportError:
967
+ logging.error(
968
+ "scikit-learn is not available. Please install it using: pip install scikit-learn"
969
+ )
970
+ return registered_classifiers
971
+
972
+ # Find all classes that inherit from BaseEstimator and have fit/predict methods
973
+ classifier_classes = []
974
+ for name, obj in inspect.getmembers(module, inspect.isclass):
975
+ if (
976
+ obj.__module__
977
+ == module_name # Only consider classes defined in this module
978
+ and issubclass(obj, BaseEstimator)
979
+ and hasattr(obj, "fit")
980
+ and hasattr(obj, "predict")
981
+ ):
982
+ classifier_classes.append((name, obj))
983
+
984
+ # If no valid classes found
985
+ if not classifier_classes:
986
+ logging.warning(f"No valid classifier classes found in {file_path}")
987
+ logging.info(
988
+ "Classes must inherit from sklearn.base.BaseEstimator and implement fit and predict methods"
989
+ )
990
+ return registered_classifiers
991
+
992
+ # If class_name is specified, filter to just that class
993
+ if class_name:
994
+ classifier_classes = [
995
+ (name, cls) for name, cls in classifier_classes if name == class_name
996
+ ]
997
+ if not classifier_classes:
998
+ logging.error(
999
+ f"Class '{class_name}' not found in {file_path} or doesn't meet classifier requirements"
1000
+ )
1001
+ return registered_classifiers
1002
+
1003
+ # If requesting a custom name but multiple classes found and no class_name specified
1004
+ if custom_name and len(classifier_classes) > 1 and not class_name:
1005
+ logging.error(
1006
+ f"Multiple classifier classes found in {file_path}, but custom name provided. "
1007
+ "Please specify which class to register with --class-name."
1008
+ )
1009
+ return registered_classifiers
1010
+
1011
+ # Register the classifiers
1012
+ for name, cls in classifier_classes:
1013
+ # If this is a specifically requested class with a custom name
1014
+ if class_name and name == class_name and custom_name:
1015
+ register_name = custom_name
1016
+ else:
1017
+ register_name = name
1018
+
1019
+ try:
1020
+ # Check if classifier already exists
1021
+ existing_classifiers = registry.list_available_classifiers()
1022
+ flat_existing = []
1023
+ for module_classifiers in existing_classifiers.get(
1024
+ "custom", {}
1025
+ ).values():
1026
+ flat_existing.extend(module_classifiers)
1027
+
1028
+ if register_name in flat_existing and not overwrite:
1029
+ logging.warning(
1030
+ f"Classifier '{register_name}' already exists. "
1031
+ "Use --overwrite to replace it."
1032
+ )
1033
+ continue
1034
+
1035
+ # Register the classifier
1036
+ registry.register_custom_classifier(register_name, cls)
1037
+ registered_classifiers.append(register_name)
1038
+ logging.info(f"Successfully registered classifier: {register_name}")
1039
+
1040
+ except Exception as e:
1041
+ logging.error(f"Error registering classifier '{register_name}': {e}")
1042
+
1043
+ # For successfully registered classifiers, copy the file
1044
+ if registered_classifiers:
1045
+ # Generate a unique filename (in case multiple files have same name)
1046
+ dest_file = custom_dir / f"{file_path.stem}_{hash(str(file_path))}.py"
1047
+ shutil.copy2(file_path, dest_file)
1048
+ logging.debug(f"Copied {file_path} to {dest_file}")
1049
+
1050
+ # Create a metadata file to map classifier names to files
1051
+ metadata_file = custom_dir / "classifiers_metadata.json"
1052
+ metadata = {}
1053
+ if metadata_file.exists():
1054
+ with open(metadata_file, "r") as f:
1055
+ metadata = json.load(f)
1056
+
1057
+ # Update metadata with new classifiers
1058
+ class_mapping = {cls_name: cls_name for cls_name, _ in classifier_classes}
1059
+ if class_name and custom_name:
1060
+ class_mapping[custom_name] = class_name
1061
+
1062
+ for classifier_name in registered_classifiers:
1063
+ original_class = class_mapping.get(classifier_name, classifier_name)
1064
+ metadata[classifier_name] = {
1065
+ "file": str(dest_file),
1066
+ "class_name": original_class,
1067
+ "registered_at": datetime.now().isoformat(),
1068
+ }
1069
+
1070
+ with open(metadata_file, "w") as f:
1071
+ json.dump(metadata, f, indent=2)
1072
+
1073
+ except Exception as e:
1074
+ logging.error(f"Error processing file {file_path}: {e}")
1075
+
1076
+ return registered_classifiers
1077
+
1078
+
1079
+ def _remove_classifiers(args):
1080
+ """
1081
+ Remove custom classifiers as specified in the args.
1082
+
1083
+ Args:
1084
+ args: Command line arguments from argparse
1085
+
1086
+ Returns:
1087
+ int: Exit code
1088
+ """
1089
+ # Get path to custom classifiers directory
1090
+ custom_dir = Path.home() / ".balancr" / "custom_classifiers"
1091
+ metadata_file = custom_dir / "classifiers_metadata.json"
1092
+
1093
+ # Check if metadata file exists
1094
+ if not metadata_file.exists():
1095
+ logging.error("No custom classifiers have been registered.")
1096
+ return 1
1097
+
1098
+ # Load metadata
1099
+ with open(metadata_file, "r") as f:
1100
+ metadata = json.load(f)
1101
+
1102
+ # If no custom classifiers
1103
+ if not metadata:
1104
+ logging.error("No custom classifiers have been registered.")
1105
+ return 1
1106
+
1107
+ # Remove all custom classifiers
1108
+ if args.remove_all:
1109
+ logging.info("Removing all custom classifiers...")
1110
+
1111
+ # Remove all classifier files
1112
+ file_paths = set(info["file"] for info in metadata.values())
1113
+ for file_path in file_paths:
1114
+ try:
1115
+ Path(file_path).unlink(missing_ok=True)
1116
+ except Exception as e:
1117
+ logging.warning(f"Error removing file {file_path}: {e}")
1118
+
1119
+ # Clear metadata
1120
+ metadata = {}
1121
+ with open(metadata_file, "w") as f:
1122
+ json.dump(metadata, f, indent=2)
1123
+
1124
+ print("All custom classifiers have been removed.")
1125
+ return 0
1126
+
1127
+ # Remove specific classifiers
1128
+ removed_classifiers = []
1129
+ for classifier_name in args.remove:
1130
+ if classifier_name in metadata:
1131
+ # Note the file path (we'll check if it's used by other classifiers)
1132
+ file_path = metadata[classifier_name]["file"]
1133
+
1134
+ # Remove from metadata
1135
+ del metadata[classifier_name]
1136
+ removed_classifiers.append(classifier_name)
1137
+
1138
+ # Check if the file is still used by other classifiers
1139
+ file_still_used = any(
1140
+ info["file"] == file_path for info in metadata.values()
1141
+ )
1142
+
1143
+ # If not used, remove the file
1144
+ if not file_still_used:
1145
+ try:
1146
+ Path(file_path).unlink(missing_ok=True)
1147
+ except Exception as e:
1148
+ logging.warning(f"Error removing file {file_path}: {e}")
1149
+ else:
1150
+ logging.warning(f"Classifier '{classifier_name}' not found.")
1151
+
1152
+ # Save updated metadata
1153
+ with open(metadata_file, "w") as f:
1154
+ json.dump(metadata, f, indent=2)
1155
+
1156
+ if removed_classifiers:
1157
+ print("\nRemoved classifiers:")
1158
+ for classifier in removed_classifiers:
1159
+ print(f" - {classifier}")
1160
+ return 0
1161
+ else:
1162
+ logging.error("No matching classifiers were found.")
1163
+ return 1
1164
+
1165
+
1166
+ def configure_metrics(args):
1167
+ """
1168
+ Handle the configure-metrics command.
1169
+
1170
+ Args:
1171
+ args: Command line arguments from argparse
1172
+
1173
+ Returns:
1174
+ int: Exit code
1175
+ """
1176
+ # Define all available metrics
1177
+ all_metrics = [
1178
+ "accuracy",
1179
+ "precision",
1180
+ "recall",
1181
+ "f1",
1182
+ "roc_auc",
1183
+ "specificity",
1184
+ "g_mean",
1185
+ "average_precision",
1186
+ ]
1187
+
1188
+ # If 'all' is specified, use all available metrics
1189
+ if "all" in args.metrics:
1190
+ metrics_to_use = all_metrics
1191
+ metrics_str = "all available metrics"
1192
+ else:
1193
+ metrics_to_use = args.metrics
1194
+ metrics_str = ", ".join(args.metrics)
1195
+
1196
+ logging.info(f"Configuring metrics: {metrics_str}")
1197
+
1198
+ # Update configuration with metrics settings
1199
+ settings = {
1200
+ "output": {"metrics": metrics_to_use, "save_metrics_formats": args.save_formats}
1201
+ }
1202
+
1203
+ try:
1204
+ # Update existing output settings if they exist
1205
+ current_config = config.load_config(args.config_path)
1206
+ if "output" in current_config:
1207
+ current_output = current_config["output"]
1208
+ # Merge with existing output settings without overwriting other output options
1209
+ settings["output"] = {**current_output, **settings["output"]}
1210
+
1211
+ config.update_config(args.config_path, settings)
1212
+
1213
+ # Display confirmation
1214
+ print("\nMetrics Configuration:")
1215
+ print(f" Metrics: {metrics_str}")
1216
+ print(f" Save Formats: {', '.join(args.save_formats)}")
1217
+
1218
+ return 0
1219
+
1220
+ except Exception as e:
1221
+ logging.error(f"Failed to configure metrics: {e}")
1222
+ return 1
1223
+
1224
+
1225
+ def configure_visualisations(args):
1226
+ """
1227
+ Handle the configure-visualisations command.
1228
+
1229
+ Args:
1230
+ args: Command line arguments from argparse
1231
+
1232
+ Returns:
1233
+ int: Exit code
1234
+ """
1235
+ types_str = "all visualisations" if "all" in args.types else ", ".join(args.types)
1236
+ logging.info(f"Configuring visualisations: {types_str}")
1237
+
1238
+ # Update configuration with visualisation settings
1239
+ settings = {
1240
+ "output": {
1241
+ "visualisations": args.types,
1242
+ "display_visualisations": args.display,
1243
+ "save_vis_formats": args.save_formats,
1244
+ }
1245
+ }
1246
+
1247
+ try:
1248
+ # Update existing output settings if they exist
1249
+ current_config = config.load_config(args.config_path)
1250
+ if "output" in current_config:
1251
+ current_output = current_config["output"]
1252
+ # Merge with existing output settings without overwriting other output options
1253
+ settings["output"] = {**current_output, **settings["output"]}
1254
+
1255
+ config.update_config(args.config_path, settings)
1256
+
1257
+ # Display confirmation
1258
+ print("\nVisualisation Configuration:")
1259
+ print(f" Types: {types_str}")
1260
+ print(f" Display During Execution: {'Yes' if args.display else 'No'}")
1261
+ print(f" Save Formats: {', '.join(args.save_formats)}")
1262
+
1263
+ return 0
1264
+
1265
+ except Exception as e:
1266
+ logging.error(f"Failed to configure visualisations: {e}")
1267
+ return 1
1268
+
1269
+
1270
+ def configure_evaluation(args):
1271
+ """
1272
+ Handle the configure-evaluation command.
1273
+
1274
+ Args:
1275
+ args: Command line arguments from argparse
1276
+
1277
+ Returns:
1278
+ int: Exit code
1279
+ """
1280
+ logging.info("Configuring evaluation settings")
1281
+
1282
+ # Update configuration with evaluation settings
1283
+ settings = {
1284
+ "evaluation": {
1285
+ "test_size": args.test_size,
1286
+ "cross_validation": args.cross_validation,
1287
+ "random_state": args.random_state,
1288
+ "learning_curve_folds": args.learning_curve_folds,
1289
+ "learning_curve_points": args.learning_curve_points,
1290
+ }
1291
+ }
1292
+
1293
+ try:
1294
+ config.update_config(args.config_path, settings)
1295
+
1296
+ # Display confirmation
1297
+ print("\nEvaluation Configuration:")
1298
+ print(f" Test Size: {args.test_size}")
1299
+ print(f" Cross-Validation Folds: {args.cross_validation}")
1300
+ print(f" Random State: {args.random_state}")
1301
+ print(f" Learning Curve Folds: {args.learning_curve_folds}")
1302
+ print(f" Learning Curve Points: {args.learning_curve_points}")
1303
+
1304
+ return 0
1305
+
1306
+ except Exception as e:
1307
+ logging.error(f"Failed to configure evaluation: {e}")
1308
+ return 1
1309
+
1310
+
1311
+ def format_time(seconds):
1312
+ """Format time in seconds to minutes and seconds"""
1313
+ minutes = int(seconds // 60)
1314
+ remaining_seconds = seconds % 60
1315
+ return f"{minutes}mins, {remaining_seconds:.2f}secs"
1316
+
1317
+
1318
+ def run_comparison(args):
1319
+ """
1320
+ Handle the run command.
1321
+
1322
+ Args:
1323
+ args: Command line arguments from argparse
1324
+
1325
+ Returns:
1326
+ int: Exit code
1327
+ """
1328
+ start_time_total = time.time()
1329
+
1330
+ # Load current configuration
1331
+ try:
1332
+ current_config = config.load_config(args.config_path)
1333
+ except Exception as e:
1334
+ logging.error(f"Failed to load configuration: {e}")
1335
+ return 1
1336
+
1337
+ # Check if all required settings are configured
1338
+ required_settings = ["data_file", "target_column", "balancing_techniques"]
1339
+ missing_settings = [s for s in required_settings if s not in current_config]
1340
+
1341
+ if missing_settings:
1342
+ logging.error(f"Missing required configuration: {', '.join(missing_settings)}")
1343
+ logging.info("Please configure all required settings before running comparison")
1344
+ return 1
1345
+
1346
+ # Ensure balancing framework is available
1347
+ if BalancingFramework is None:
1348
+ logging.error("Balancing framework not available. Please check installation")
1349
+ return 1
1350
+
1351
+ # Prepare output directory
1352
+ output_dir = Path(args.output_dir)
1353
+ output_dir.mkdir(parents=True, exist_ok=True)
1354
+
1355
+ # Get output and evaluation settings with defaults
1356
+ output_config = current_config.get("output", {})
1357
+ metrics = output_config.get("metrics", ["precision", "recall", "f1", "roc_auc"])
1358
+ visualisations = output_config.get("visualisations", ["all"])
1359
+ display_visualisations = output_config.get("display_visualisations", False)
1360
+ save_metrics_formats = output_config.get("save_metrics_formats", ["csv"])
1361
+ save_vis_formats = output_config.get("save_vis_formats", ["png"])
1362
+
1363
+ eval_config = current_config.get("evaluation", {})
1364
+ test_size = eval_config.get("test_size", 0.2)
1365
+ cv_enabled = eval_config.get("cross_validation", 0) > 0
1366
+ cv_folds = eval_config.get("cross_validation", 5)
1367
+ random_state = eval_config.get("random_state", 42)
1368
+ include_original = current_config.get("include_original_data", False)
1369
+
1370
+ balancing_techniques = current_config.get("balancing_techniques", {})
1371
+ technique_names = list(balancing_techniques.keys())
1372
+ logging.info(f"Running comparison with techniques: {', '.join(technique_names)}")
1373
+ logging.info(f"Results will be saved to: {output_dir}")
1374
+
1375
+ try:
1376
+ # Initialise the framework
1377
+ framework = BalancingFramework()
1378
+
1379
+ # Load data
1380
+ start_time = time.time()
1381
+ logging.info(f"Loading data from {current_config['data_file']}")
1382
+ feature_columns = current_config.get("feature_columns", None)
1383
+ framework.load_data(
1384
+ current_config["data_file"],
1385
+ current_config["target_column"],
1386
+ feature_columns,
1387
+ )
1388
+ load_time = time.time() - start_time
1389
+ logging.info(f"Data loading completed (Time Taken: {format_time(load_time)})")
1390
+
1391
+ # Apply preprocessing if configured
1392
+ if "preprocessing" in current_config:
1393
+ logging.info("Applying preprocessing...")
1394
+ preproc = current_config["preprocessing"]
1395
+
1396
+ handle_missing = preproc.get("handle_missing", "mean")
1397
+ scale = preproc.get("scale", "standard")
1398
+ categorical_features = preproc.get("categorical_features", {})
1399
+ handle_constant_features = preproc.get("handle_constant_features", None)
1400
+ handle_correlations = preproc.get("handle_correlations", None)
1401
+ save_preprocessed_file = preproc.get("save_preprocessed", True)
1402
+
1403
+ # Extract hash components information
1404
+ hash_components_dict = {}
1405
+ for feature, encoding in categorical_features.items():
1406
+ if isinstance(encoding, list) and encoding[0] == "hash":
1407
+ hash_components_dict[feature] = encoding[
1408
+ 1
1409
+ ] # Store the n_components value
1410
+
1411
+ framework.preprocess_data(
1412
+ handle_missing=handle_missing,
1413
+ scale=scale,
1414
+ categorical_features=categorical_features,
1415
+ hash_components_dict=hash_components_dict,
1416
+ handle_constant_features=handle_constant_features,
1417
+ handle_correlations=handle_correlations,
1418
+ )
1419
+ logging.info("Data preprocessing applied")
1420
+
1421
+ # Save preprocessed dataset to a new file
1422
+ if save_preprocessed_file:
1423
+ try:
1424
+ if "data_file" in current_config and current_config["data_file"]:
1425
+ original_path = Path(current_config["data_file"])
1426
+ preprocessed_path = (
1427
+ original_path.parent
1428
+ / f"{original_path.stem}_preprocessed{original_path.suffix}"
1429
+ )
1430
+
1431
+ # Copy DataFrame with the preprocessed data
1432
+ preprocessed_df = framework.X.copy()
1433
+
1434
+ # Add the target column
1435
+ target_column = current_config.get("target_column")
1436
+ if target_column:
1437
+ preprocessed_df[target_column] = framework.y
1438
+
1439
+ # Save to the new file
1440
+ if original_path.suffix.lower() == ".csv":
1441
+ preprocessed_df.to_csv(preprocessed_path, index=False)
1442
+ elif original_path.suffix.lower() in [".xlsx", ".xls"]:
1443
+ preprocessed_df.to_excel(preprocessed_path, index=False)
1444
+
1445
+ logging.info(
1446
+ f"Saved preprocessed dataset to: {preprocessed_path}"
1447
+ )
1448
+ except Exception as e:
1449
+ logging.warning(f"Could not save preprocessed dataset: {e}")
1450
+
1451
+ # Apply balancing techniques
1452
+ start_time = time.time()
1453
+ logging.info("Applying balancing techniques...")
1454
+ framework.apply_balancing_techniques(
1455
+ technique_names,
1456
+ test_size=test_size,
1457
+ random_state=random_state,
1458
+ technique_params=balancing_techniques,
1459
+ include_original=include_original,
1460
+ )
1461
+ balancing_time = time.time() - start_time
1462
+ logging.info(
1463
+ f"Balancing techniques applied successfully (Time Taken: {format_time(balancing_time)})"
1464
+ )
1465
+
1466
+ # Save balanced datasets at the root level
1467
+ balanced_dir = output_dir / "balanced_datasets"
1468
+ balanced_dir.mkdir(exist_ok=True)
1469
+ logging.info(f"Saving balanced datasets to {balanced_dir}")
1470
+ framework.generate_balanced_data(
1471
+ folder_path=str(balanced_dir),
1472
+ techniques=technique_names,
1473
+ file_format="csv",
1474
+ )
1475
+
1476
+ # Determine which visualisation types to generate
1477
+ vis_types_to_generate = []
1478
+ if "all" in visualisations:
1479
+ vis_types_to_generate = ["metrics", "distribution", "learning_curves"]
1480
+ else:
1481
+ vis_types_to_generate = visualisations
1482
+
1483
+ # Save class distribution visualisations at the root level
1484
+ for format_type in save_vis_formats:
1485
+ if format_type == "none":
1486
+ continue
1487
+
1488
+ if "distribution" in vis_types_to_generate or "all" in visualisations:
1489
+ # Original (imbalanced) class distribution
1490
+ logging.info(
1491
+ f"Generating imbalanced class distribution in {format_type} format..."
1492
+ )
1493
+ imbalanced_plot_path = (
1494
+ output_dir / f"imbalanced_class_distribution.{format_type}"
1495
+ )
1496
+ framework.inspect_class_distribution(
1497
+ save_path=str(imbalanced_plot_path), display=display_visualisations
1498
+ )
1499
+ logging.info(
1500
+ f"Imbalanced class distribution saved to {imbalanced_plot_path}"
1501
+ )
1502
+
1503
+ # Balanced class distributions comparison
1504
+ logging.info(
1505
+ f"Generating balanced class distribution comparison in {format_type} format..."
1506
+ )
1507
+ balanced_plot_path = (
1508
+ output_dir / f"balanced_class_distribution.{format_type}"
1509
+ )
1510
+ framework.compare_balanced_class_distributions(
1511
+ save_path=str(balanced_plot_path),
1512
+ display=display_visualisations,
1513
+ )
1514
+ logging.info(
1515
+ f"Balanaced class distribution comparison saved to {balanced_plot_path}"
1516
+ )
1517
+
1518
+ # Train classifiers
1519
+ start_time = time.time()
1520
+ logging.info("Training classifiers on balanced datasets...")
1521
+ classifiers = current_config.get("classifiers", {})
1522
+ if not classifiers:
1523
+ logging.warning(
1524
+ "No classifiers configured. Using default RandomForestClassifier."
1525
+ )
1526
+
1527
+ # Train classifiers with the balanced datasets
1528
+ results = framework.train_classifiers(
1529
+ classifier_configs=classifiers, enable_cv=cv_enabled, cv_folds=cv_folds
1530
+ )
1531
+
1532
+ training_time = time.time() - start_time
1533
+ logging.info(
1534
+ f"Training classifiers complete (Time Taken: {format_time(training_time)})"
1535
+ )
1536
+
1537
+ # Process each classifier and save its results in a separate directory
1538
+ standard_start_time = time.time()
1539
+ for classifier_name in current_config.get("classifiers", {}):
1540
+ logging.info(f"Processing results for classifier: {classifier_name}")
1541
+
1542
+ # Create classifier-specific directory
1543
+ classifier_dir = output_dir / classifier_name
1544
+ classifier_dir.mkdir(exist_ok=True)
1545
+
1546
+ # Create standard metrics directory
1547
+ std_metrics_dir = classifier_dir / "standard_metrics"
1548
+ std_metrics_dir.mkdir(exist_ok=True)
1549
+
1550
+ # Save standard metrics in requested formats
1551
+ for format_type in save_metrics_formats:
1552
+ if format_type == "none":
1553
+ continue
1554
+
1555
+ results_file = std_metrics_dir / f"comparison_results.{format_type}"
1556
+ logging.info(
1557
+ f"Saving standard metrics for {classifier_name} to {results_file}"
1558
+ )
1559
+
1560
+ # We need a modified save_results method that can extract a specific classifier's results
1561
+ framework.save_classifier_results(
1562
+ results_file,
1563
+ classifier_name=classifier_name,
1564
+ metric_type="standard_metrics",
1565
+ file_type=format_type,
1566
+ )
1567
+
1568
+ # Generate and save standard metrics visualisations
1569
+ for format_type in save_vis_formats:
1570
+ if format_type == "none":
1571
+ continue
1572
+
1573
+ if "metrics" in vis_types_to_generate or "all" in visualisations:
1574
+ metrics_path = std_metrics_dir / f"metrics_comparison.{format_type}"
1575
+ logging.info(
1576
+ f"Generating metrics comparison for {classifier_name} in {format_type} format..."
1577
+ )
1578
+
1579
+ metrics_to_plot = current_config.get("output", {}).get(
1580
+ "metrics", ["precision", "recall", "f1", "roc_auc"]
1581
+ )
1582
+ # Call a modified plot_comparison_results that can handle specific classifier data
1583
+ plot_comparison_results(
1584
+ results,
1585
+ classifier_name=classifier_name,
1586
+ metric_type="standard_metrics",
1587
+ metrics_to_plot=metrics_to_plot,
1588
+ save_path=str(metrics_path),
1589
+ display=display_visualisations,
1590
+ )
1591
+
1592
+ if "radar" in vis_types_to_generate or "all" in visualisations:
1593
+ std_radar_path = (
1594
+ classifier_dir / f"standard_metrics_radar.{format_type}"
1595
+ )
1596
+ plot_radar_chart(
1597
+ results,
1598
+ classifier_name=classifier_name,
1599
+ metric_type="standard_metrics",
1600
+ metrics_to_plot=metrics_to_plot,
1601
+ save_path=std_radar_path,
1602
+ display=display_visualisations,
1603
+ )
1604
+
1605
+ if "3d" in vis_types_to_generate or "all" in visualisations:
1606
+ std_3d_path = output_dir / "standard_metrics_3d.html"
1607
+ plot_3d_scatter(
1608
+ results,
1609
+ metric_type="standard_metrics",
1610
+ metrics_to_plot=metrics_to_plot,
1611
+ save_path=std_3d_path,
1612
+ display=display_visualisations,
1613
+ )
1614
+
1615
+ standard_total_time = time.time() - standard_start_time
1616
+ logging.info(
1617
+ f"Standard metrics evaluation total time: {format_time(standard_total_time)}"
1618
+ )
1619
+
1620
+ # If cross-validation is enabled, create CV metrics directory and save results
1621
+ if cv_enabled:
1622
+ cv_start_time = time.time()
1623
+ for classifier_name in current_config.get("classifiers", {}):
1624
+ # Create classifier-specific directory
1625
+ classifier_dir = output_dir / classifier_name
1626
+ classifier_dir.mkdir(exist_ok=True)
1627
+
1628
+ cv_metrics_dir = classifier_dir / "cv_metrics"
1629
+ cv_metrics_dir.mkdir(exist_ok=True)
1630
+
1631
+ # Save CV metrics in requested formats
1632
+ for format_type in save_metrics_formats:
1633
+ if format_type == "none":
1634
+ continue
1635
+
1636
+ cv_results_file = (
1637
+ cv_metrics_dir / f"comparison_results.{format_type}"
1638
+ )
1639
+ logging.info(
1640
+ f"Saving CV metrics for {classifier_name} to {cv_results_file}"
1641
+ )
1642
+
1643
+ framework.save_classifier_results(
1644
+ cv_results_file,
1645
+ classifier_name=classifier_name,
1646
+ metric_type="cv_metrics",
1647
+ file_type=format_type,
1648
+ )
1649
+
1650
+ # Generate and save CV metrics visualisations
1651
+ for format_type in save_vis_formats:
1652
+ if format_type == "none":
1653
+ continue
1654
+
1655
+ if "metrics" in vis_types_to_generate or "all" in visualisations:
1656
+ cv_metrics_path = (
1657
+ cv_metrics_dir / f"metrics_comparison.{format_type}"
1658
+ )
1659
+ logging.info(
1660
+ f"Generating CV metrics comparison for {classifier_name} in {format_type} format..."
1661
+ )
1662
+
1663
+ metrics_to_plot = current_config.get("output", {}).get(
1664
+ "metrics", ["precision", "recall", "f1", "roc_auc"]
1665
+ )
1666
+ plot_comparison_results(
1667
+ results,
1668
+ classifier_name=classifier_name,
1669
+ metric_type="cv_metrics",
1670
+ metrics_to_plot=metrics_to_plot,
1671
+ save_path=str(cv_metrics_path),
1672
+ display=display_visualisations,
1673
+ )
1674
+
1675
+ if "radar" in vis_types_to_generate or "all" in visualisations:
1676
+ cv_radar_path = (
1677
+ classifier_dir / f"cv_metrics_radar.{format_type}"
1678
+ )
1679
+ plot_radar_chart(
1680
+ results,
1681
+ classifier_name=classifier_name,
1682
+ metric_type="cv_metrics",
1683
+ metrics_to_plot=metrics_to_plot,
1684
+ save_path=cv_radar_path,
1685
+ display=display_visualisations,
1686
+ )
1687
+
1688
+ if "3d" in vis_types_to_generate or "all" in visualisations:
1689
+ cv_3d_path = output_dir / "cv_metrics_3d.html"
1690
+ plot_3d_scatter(
1691
+ results,
1692
+ metric_type="cv_metrics",
1693
+ metrics_to_plot=metrics_to_plot,
1694
+ save_path=cv_3d_path,
1695
+ display=display_visualisations,
1696
+ )
1697
+
1698
+ if (
1699
+ "learning_curves" in vis_types_to_generate
1700
+ or "all" in visualisations
1701
+ ):
1702
+ cv_learning_curve_path = (
1703
+ cv_metrics_dir / f"learning_curves.{format_type}"
1704
+ )
1705
+
1706
+ start_time = time.time()
1707
+ logging.info(
1708
+ f"Generating CV learning curves for {classifier_name} in {format_type} format..."
1709
+ )
1710
+
1711
+ # Get learning curve parameters from config
1712
+ learning_curve_points = eval_config.get(
1713
+ "learning_curve_points", 10
1714
+ )
1715
+ learning_curve_folds = eval_config.get(
1716
+ "learning_curve_folds", 5
1717
+ )
1718
+ train_sizes = np.linspace(0.1, 1.0, learning_curve_points)
1719
+
1720
+ framework.generate_learning_curves(
1721
+ classifier_name=classifier_name,
1722
+ train_sizes=train_sizes,
1723
+ n_folds=learning_curve_folds,
1724
+ save_path=str(cv_learning_curve_path),
1725
+ display=display_visualisations,
1726
+ )
1727
+ cv_learning_curves_time = time.time() - start_time
1728
+ logging.info(
1729
+ f"Successfully generated cv learning curves for {classifier_name}"
1730
+ f"(Time Taken: {format_time(cv_learning_curves_time)})"
1731
+ )
1732
+ cv_total_time = time.time() - cv_start_time
1733
+ logging.info(
1734
+ f"Cross validation metrics evaluation total time: {format_time(cv_total_time)}"
1735
+ )
1736
+
1737
+ total_time = time.time() - start_time_total
1738
+ logging.info(f"Total execution time: {format_time(total_time)}")
1739
+
1740
+ # Print summary of timing results
1741
+ print("\nExecution Time Summary:\n")
1742
+ print(f" Data Loading: {format_time(load_time)}")
1743
+ print(f" Balancing: {format_time(balancing_time)}")
1744
+ print(f" Training Classifiers: {format_time(training_time)}")
1745
+ print(f" Standard Metrics Evaluation: {format_time(standard_total_time)}")
1746
+ if cv_enabled:
1747
+ print(f" CV Metrics Evaluation: {format_time(cv_total_time)}")
1748
+ print(f" Total Time: {format_time(total_time)}")
1749
+
1750
+ print("\nResults Summary:")
1751
+
1752
+ # Check and print Standard Metrics if available
1753
+ has_standard_metrics = any(
1754
+ "standard_metrics" in technique_metrics
1755
+ and any(m in metrics for m in technique_metrics["standard_metrics"])
1756
+ for classifier_results in results.values()
1757
+ for technique_metrics in classifier_results.values()
1758
+ )
1759
+
1760
+ if has_standard_metrics:
1761
+ print("\nStandard Metrics:")
1762
+ for classifier_name, classifier_results in results.items():
1763
+ print(f"\n{classifier_name}:")
1764
+ for technique_name, technique_metrics in classifier_results.items():
1765
+ if "standard_metrics" in technique_metrics:
1766
+ std_metrics = technique_metrics["standard_metrics"]
1767
+ if any(m in metrics for m in std_metrics):
1768
+ print(f" {technique_name}:")
1769
+ for metric_name, value in std_metrics.items():
1770
+ if metric_name in metrics:
1771
+ print(f" {metric_name}: {value:.4f}")
1772
+
1773
+ # Check and print Cross Validation Metrics if available
1774
+ has_cv_metrics = any(
1775
+ "cv_metrics" in technique_metrics
1776
+ and any(
1777
+ metric_name.startswith("cv_")
1778
+ and metric_name[len("cv_"):].rsplit("_", 1)[0] in metrics
1779
+ for metric_name in technique_metrics["cv_metrics"]
1780
+ )
1781
+ for classifier_results in results.values()
1782
+ for technique_metrics in classifier_results.values()
1783
+ )
1784
+
1785
+ if has_cv_metrics:
1786
+ print("\nCross Validation Metrics:")
1787
+ for classifier_name, classifier_results in results.items():
1788
+ print(f"\n{classifier_name}:")
1789
+ for technique_name, technique_metrics in classifier_results.items():
1790
+ if "cv_metrics" in technique_metrics:
1791
+ cv_metrics = technique_metrics["cv_metrics"]
1792
+
1793
+ # Check if any relevant cv metric exists for this technique
1794
+ if any(
1795
+ metric_name.startswith("cv_")
1796
+ and metric_name[len("cv_"):].rsplit("_", 1)[0] in metrics
1797
+ for metric_name in cv_metrics
1798
+ ):
1799
+ print(f" {technique_name}:")
1800
+
1801
+ # Now print only relevant metrics
1802
+ for metric_name, value in cv_metrics.items():
1803
+ if metric_name.startswith("cv_"):
1804
+ base_name = metric_name[len("cv_"):].rsplit(
1805
+ "_", 1
1806
+ )[0]
1807
+ if base_name in metrics:
1808
+ print(f" {metric_name}: {value:.4f}")
1809
+
1810
+ print(f"\nDetailed results saved to: {output_dir}")
1811
+ return 0
1812
+
1813
+ except Exception as e:
1814
+ logging.error(f"Error during comparison: {e}")
1815
+ if args.verbose:
1816
+ import traceback
1817
+
1818
+ traceback.print_exc()
1819
+ return 1
1820
+
1821
+
1822
+ def reset_config(args):
1823
+ """
1824
+ Handle the reset command.
1825
+
1826
+ Args:
1827
+ args: Command line arguments from argparse
1828
+
1829
+ Returns:
1830
+ int: Exit code
1831
+ """
1832
+ try:
1833
+ config.initialise_config(args.config_path, force=True)
1834
+ logging.info("Configuration has been reset to defaults")
1835
+ return 0
1836
+ except Exception as e:
1837
+ logging.error(f"Failed to reset configuration: {e}")
1838
+ return 1