balancr 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- balancr/__init__.py +13 -0
- balancr/base.py +14 -0
- balancr/classifier_registry.py +300 -0
- balancr/cli/__init__.py +0 -0
- balancr/cli/commands.py +1838 -0
- balancr/cli/config.py +165 -0
- balancr/cli/main.py +778 -0
- balancr/cli/utils.py +101 -0
- balancr/data/__init__.py +5 -0
- balancr/data/loader.py +59 -0
- balancr/data/preprocessor.py +556 -0
- balancr/evaluation/__init__.py +19 -0
- balancr/evaluation/metrics.py +442 -0
- balancr/evaluation/visualisation.py +660 -0
- balancr/imbalance_analyser.py +677 -0
- balancr/technique_registry.py +284 -0
- balancr/techniques/__init__.py +4 -0
- balancr/techniques/custom/__init__.py +0 -0
- balancr/techniques/custom/example_custom_technique.py +27 -0
- balancr-0.1.0.dist-info/LICENSE +21 -0
- balancr-0.1.0.dist-info/METADATA +536 -0
- balancr-0.1.0.dist-info/RECORD +25 -0
- balancr-0.1.0.dist-info/WHEEL +5 -0
- balancr-0.1.0.dist-info/entry_points.txt +2 -0
- balancr-0.1.0.dist-info/top_level.txt +1 -0
balancr/cli/commands.py
ADDED
@@ -0,0 +1,1838 @@
|
|
1
|
+
"""
|
2
|
+
commands.py - Command handlers for the balancr CLI.
|
3
|
+
|
4
|
+
This module contains the implementation of all command functions that are
|
5
|
+
registered in main.py
|
6
|
+
"""
|
7
|
+
|
8
|
+
import shutil
|
9
|
+
import time
|
10
|
+
from datetime import datetime
|
11
|
+
import importlib
|
12
|
+
import logging
|
13
|
+
import json
|
14
|
+
import inspect
|
15
|
+
from balancr.data import DataPreprocessor
|
16
|
+
import numpy as np
|
17
|
+
from pathlib import Path
|
18
|
+
|
19
|
+
from balancr.evaluation import (
|
20
|
+
plot_comparison_results,
|
21
|
+
plot_radar_chart,
|
22
|
+
plot_3d_scatter,
|
23
|
+
)
|
24
|
+
import pandas as pd
|
25
|
+
|
26
|
+
from . import config
|
27
|
+
from balancr import BaseBalancer
|
28
|
+
|
29
|
+
# Will be used to interact with the core balancing framework
|
30
|
+
try:
|
31
|
+
from balancr import BalancingFramework
|
32
|
+
from balancr import TechniqueRegistry
|
33
|
+
from balancr import ClassifierRegistry
|
34
|
+
except ImportError as e:
|
35
|
+
logging.error(f"Could not import balancing framework: {str(e)}")
|
36
|
+
logging.error(
|
37
|
+
"Could not import balancing framework. Ensure it's installed correctly."
|
38
|
+
)
|
39
|
+
BalancingFramework = None
|
40
|
+
TechniqueRegistry = None
|
41
|
+
ClassifierRegistry = None
|
42
|
+
|
43
|
+
|
44
|
+
def load_data(args):
|
45
|
+
"""
|
46
|
+
Handle the load-data command.
|
47
|
+
|
48
|
+
Args:
|
49
|
+
args: Command line arguments from argparse
|
50
|
+
|
51
|
+
Returns:
|
52
|
+
int: Exit code
|
53
|
+
"""
|
54
|
+
logging.info(f"Loading data from {args.file_path}")
|
55
|
+
|
56
|
+
# Validate file exists
|
57
|
+
if not Path(args.file_path).exists():
|
58
|
+
logging.error(f"File not found: {args.file_path}")
|
59
|
+
return 1
|
60
|
+
|
61
|
+
# Update configuration with data file settings
|
62
|
+
settings = {
|
63
|
+
"data_file": args.file_path,
|
64
|
+
"target_column": args.target_column,
|
65
|
+
}
|
66
|
+
|
67
|
+
if args.feature_columns:
|
68
|
+
settings["feature_columns"] = args.feature_columns
|
69
|
+
|
70
|
+
try:
|
71
|
+
# Validate that the file can be loaded
|
72
|
+
if BalancingFramework is not None:
|
73
|
+
# This is just a validation check, not storing the framework instance
|
74
|
+
framework = BalancingFramework()
|
75
|
+
|
76
|
+
# Try to get correlation_threshold from config
|
77
|
+
try:
|
78
|
+
current_config = config.load_config(args.config_path)
|
79
|
+
correlation_threshold = current_config["preprocessing"].get(
|
80
|
+
"correlation_threshold", 0.95
|
81
|
+
)
|
82
|
+
except Exception:
|
83
|
+
# Fall back to default if config can't be loaded or doesn't have the value
|
84
|
+
correlation_threshold = 0.95
|
85
|
+
|
86
|
+
print(f"Correlation Threshold: {correlation_threshold}")
|
87
|
+
framework.load_data(
|
88
|
+
args.file_path,
|
89
|
+
args.target_column,
|
90
|
+
args.feature_columns,
|
91
|
+
correlation_threshold=correlation_threshold,
|
92
|
+
)
|
93
|
+
|
94
|
+
# Get and display class distribution
|
95
|
+
distribution = framework.inspect_class_distribution(display=False)
|
96
|
+
print("\nClass Distribution:")
|
97
|
+
for cls, count in distribution.items():
|
98
|
+
print(f" Class {cls}: {count} samples")
|
99
|
+
|
100
|
+
total = sum(distribution.values())
|
101
|
+
for cls, count in distribution.items():
|
102
|
+
pct = (count / total) * 100
|
103
|
+
print(f" Class {cls}: {pct:.2f}%")
|
104
|
+
|
105
|
+
# Update config with new settings
|
106
|
+
config.update_config(args.config_path, settings)
|
107
|
+
logging.info(
|
108
|
+
f"Data configuration saved: {args.file_path}, target: {args.target_column}"
|
109
|
+
)
|
110
|
+
return 0
|
111
|
+
|
112
|
+
except Exception as e:
|
113
|
+
logging.error(f"Failed to load data: {e}")
|
114
|
+
return 1
|
115
|
+
|
116
|
+
|
117
|
+
def preprocess(args):
|
118
|
+
"""
|
119
|
+
Handle the preprocess command.
|
120
|
+
|
121
|
+
Args:
|
122
|
+
args: Command line arguments from argparse
|
123
|
+
|
124
|
+
Returns:
|
125
|
+
int: Exit code
|
126
|
+
"""
|
127
|
+
logging.info("Configuring preprocessing options")
|
128
|
+
|
129
|
+
# Initialise basic preprocessing settings
|
130
|
+
settings = {
|
131
|
+
"preprocessing": {
|
132
|
+
"handle_missing": args.handle_missing,
|
133
|
+
"handle_constant_features": args.handle_constant_features,
|
134
|
+
"handle_correlations": args.handle_correlations,
|
135
|
+
"correlation_threshold": args.correlation_threshold,
|
136
|
+
"scale": args.scale,
|
137
|
+
"encode": args.encode,
|
138
|
+
"save_preprocessed": args.save_preprocessed,
|
139
|
+
}
|
140
|
+
}
|
141
|
+
|
142
|
+
try:
|
143
|
+
current_config = config.load_config(args.config_path)
|
144
|
+
|
145
|
+
# Check if categorical features are specified
|
146
|
+
if args.categorical_features:
|
147
|
+
# Process categorical features if a dataset is available
|
148
|
+
if "data_file" in current_config:
|
149
|
+
data_file = current_config["data_file"]
|
150
|
+
target_column = current_config.get("target_column")
|
151
|
+
|
152
|
+
logging.info(f"Loading dataset from {data_file} for encoding analysis")
|
153
|
+
|
154
|
+
try:
|
155
|
+
# Initialise framework to load data
|
156
|
+
# Read the dataset directly
|
157
|
+
df = pd.read_csv(data_file)
|
158
|
+
|
159
|
+
# Validate target column exists
|
160
|
+
if target_column and target_column not in df.columns:
|
161
|
+
logging.warning(
|
162
|
+
f"Target column '{target_column}' not found in dataset"
|
163
|
+
)
|
164
|
+
|
165
|
+
# Create preprocessor and determine encoding types
|
166
|
+
preprocessor = DataPreprocessor()
|
167
|
+
categorical_encodings = preprocessor.assign_encoding_types(
|
168
|
+
df=df,
|
169
|
+
categorical_columns=args.categorical_features,
|
170
|
+
encoding_type=args.encode,
|
171
|
+
hash_components=args.hash_components,
|
172
|
+
ordinal_columns=args.ordinal_features,
|
173
|
+
)
|
174
|
+
|
175
|
+
# Add categorical feature encodings to settings
|
176
|
+
settings["preprocessing"][
|
177
|
+
"categorical_features"
|
178
|
+
] = categorical_encodings
|
179
|
+
|
180
|
+
# Display encoding recommendations
|
181
|
+
print("\nCategorical feature encoding assignments:")
|
182
|
+
for column, encoding in categorical_encodings.items():
|
183
|
+
print(f" {column}: {encoding}")
|
184
|
+
except Exception as e:
|
185
|
+
logging.error(f"Error analysing dataset for encoding: {e}")
|
186
|
+
logging.info(
|
187
|
+
"Storing categorical features without encoding recommendations"
|
188
|
+
)
|
189
|
+
# If analysis fails, just store the categorical features with the default encoding
|
190
|
+
settings["preprocessing"]["categorical_features"] = {
|
191
|
+
col: args.encode for col in args.categorical_features
|
192
|
+
}
|
193
|
+
else:
|
194
|
+
logging.warning(
|
195
|
+
"No dataset configured. Cannot analyse categorical features."
|
196
|
+
)
|
197
|
+
logging.info("Storing categorical features with default encoding")
|
198
|
+
# Store categorical features with the specified encoding type
|
199
|
+
settings["preprocessing"]["categorical_features"] = {
|
200
|
+
col: args.encode for col in args.categorical_features
|
201
|
+
}
|
202
|
+
|
203
|
+
# Update config
|
204
|
+
config.update_config(args.config_path, settings)
|
205
|
+
logging.info("Preprocessing configuration saved")
|
206
|
+
|
207
|
+
# Display the preprocessing settings
|
208
|
+
print("\nPreprocessing Configuration:")
|
209
|
+
print(f" Handle Missing Values: {args.handle_missing}")
|
210
|
+
print(f" Handle Constant Features: {args.handle_constant_features}")
|
211
|
+
print(f" Handle Feature Correlations: {args.handle_correlations}")
|
212
|
+
print(f" Correlation Threshold: {args.correlation_threshold}")
|
213
|
+
print(f" Feature Scaling: {args.scale}")
|
214
|
+
print(f" Default Categorical Encoding: {args.encode}")
|
215
|
+
print(f" Save Preprocessed Data to File: {args.save_preprocessed}")
|
216
|
+
|
217
|
+
if args.categorical_features:
|
218
|
+
print(f" Categorical Features: {', '.join(args.categorical_features)}")
|
219
|
+
|
220
|
+
if args.ordinal_features:
|
221
|
+
print(f" Ordinal Features: {', '.join(args.ordinal_features)}")
|
222
|
+
|
223
|
+
return 0
|
224
|
+
|
225
|
+
except Exception as e:
|
226
|
+
logging.error(f"Failed to configure preprocessing: {e}")
|
227
|
+
return 1
|
228
|
+
|
229
|
+
|
230
|
+
def select_techniques(args):
|
231
|
+
"""
|
232
|
+
Handle the select-techniques command.
|
233
|
+
|
234
|
+
Args:
|
235
|
+
args: Command line arguments from argparse
|
236
|
+
|
237
|
+
Returns:
|
238
|
+
int: Exit code
|
239
|
+
"""
|
240
|
+
# List available techniques if requested
|
241
|
+
if args.list_available and BalancingFramework is not None:
|
242
|
+
print("Listing available balancing techniques...")
|
243
|
+
try:
|
244
|
+
framework = BalancingFramework()
|
245
|
+
techniques = framework.list_available_techniques()
|
246
|
+
|
247
|
+
print("\nAvailable Techniques:")
|
248
|
+
|
249
|
+
# Print custom techniques
|
250
|
+
if techniques.get("custom"):
|
251
|
+
print("\nCustom Techniques:")
|
252
|
+
for technique in techniques["custom"]:
|
253
|
+
print(f" - {technique}")
|
254
|
+
|
255
|
+
# Print imblearn techniques
|
256
|
+
if techniques.get("imblearn"):
|
257
|
+
print("\nImbalanced-Learn Techniques:")
|
258
|
+
for technique in sorted(techniques["imblearn"]):
|
259
|
+
print(f" - {technique}")
|
260
|
+
|
261
|
+
return 0
|
262
|
+
|
263
|
+
except Exception as e:
|
264
|
+
logging.error(f"Failed to list techniques: {e}")
|
265
|
+
return 1
|
266
|
+
|
267
|
+
# When not listing but selecting techniques
|
268
|
+
logging.info(f"Selecting balancing techniques: {', '.join(args.techniques)}")
|
269
|
+
|
270
|
+
try:
|
271
|
+
# Validate techniques exist if framework is available
|
272
|
+
if BalancingFramework is not None:
|
273
|
+
framework = BalancingFramework()
|
274
|
+
available = framework.list_available_techniques()
|
275
|
+
all_techniques = available.get("custom", []) + available.get("imblearn", [])
|
276
|
+
|
277
|
+
invalid_techniques = [t for t in args.techniques if t not in all_techniques]
|
278
|
+
|
279
|
+
if invalid_techniques:
|
280
|
+
logging.error(f"Invalid techniques: {', '.join(invalid_techniques)}")
|
281
|
+
logging.info(
|
282
|
+
"Use 'balancr select-techniques --list-available' to see available techniques"
|
283
|
+
)
|
284
|
+
return 1
|
285
|
+
|
286
|
+
# Create technique configurations with default parameters
|
287
|
+
balancing_techniques = {}
|
288
|
+
for technique_name in args.techniques:
|
289
|
+
# Get default parameters for this technique
|
290
|
+
params = framework.technique_registry.get_technique_default_params(
|
291
|
+
technique_name
|
292
|
+
)
|
293
|
+
balancing_techniques[technique_name] = params
|
294
|
+
|
295
|
+
# Read existing config
|
296
|
+
current_config = config.load_config(args.config_path)
|
297
|
+
|
298
|
+
if args.append and "balancing_techniques" in current_config:
|
299
|
+
# Append mode: Update existing techniques
|
300
|
+
existing_techniques = current_config.get("balancing_techniques", {})
|
301
|
+
|
302
|
+
# Add new techniques to the existing ones
|
303
|
+
if BalancingFramework is not None:
|
304
|
+
existing_techniques.update(balancing_techniques)
|
305
|
+
|
306
|
+
# Update config with merged values
|
307
|
+
settings = {
|
308
|
+
"balancing_techniques": existing_techniques,
|
309
|
+
"include_original_data": args.include_original_data,
|
310
|
+
}
|
311
|
+
|
312
|
+
config.update_config(args.config_path, settings)
|
313
|
+
|
314
|
+
print(f"\nAdded balancing techniques: {', '.join(args.techniques)}")
|
315
|
+
print(f"Total techniques: {', '.join(existing_techniques.keys())}")
|
316
|
+
else:
|
317
|
+
# Replace mode: Create a completely new config entry
|
318
|
+
# Create a copy of the current config and set the techniques
|
319
|
+
new_config = dict(current_config) # shallow copy
|
320
|
+
new_config["balancing_techniques"] = (
|
321
|
+
balancing_techniques if BalancingFramework is not None else {}
|
322
|
+
)
|
323
|
+
new_config["include_original_data"] = args.include_original_data
|
324
|
+
|
325
|
+
# Directly write the entire config to replace the file
|
326
|
+
config_path = Path(args.config_path)
|
327
|
+
with open(config_path, "w") as f:
|
328
|
+
json.dump(new_config, f, indent=2)
|
329
|
+
|
330
|
+
print(f"\nReplaced balancing techniques with: {', '.join(args.techniques)}")
|
331
|
+
|
332
|
+
# Return early since we've manually written the config
|
333
|
+
print("Default parameters have been added to the configuration file.")
|
334
|
+
print("You can modify them by editing the configuration file.")
|
335
|
+
return 0
|
336
|
+
|
337
|
+
print("Default parameters have been added to the configuration file.")
|
338
|
+
print("You can modify them by editing the configuration file.")
|
339
|
+
return 0
|
340
|
+
|
341
|
+
except Exception as e:
|
342
|
+
logging.error(f"Failed to select techniques: {e}")
|
343
|
+
return 1
|
344
|
+
|
345
|
+
|
346
|
+
def register_techniques(args):
|
347
|
+
"""
|
348
|
+
Handle the register-techniques command.
|
349
|
+
|
350
|
+
This command allows users to register custom balancing techniques from
|
351
|
+
Python files or folders for use in comparisons.
|
352
|
+
|
353
|
+
Args:
|
354
|
+
args: Command line arguments from argparse
|
355
|
+
|
356
|
+
Returns:
|
357
|
+
int: Exit code
|
358
|
+
"""
|
359
|
+
try:
|
360
|
+
# Handle removal operations
|
361
|
+
if args.remove or args.remove_all:
|
362
|
+
return _remove_techniques(args)
|
363
|
+
|
364
|
+
# Ensure framework is available
|
365
|
+
if TechniqueRegistry is None:
|
366
|
+
logging.error(
|
367
|
+
"Technique registry not available. Please check installation."
|
368
|
+
)
|
369
|
+
return 1
|
370
|
+
|
371
|
+
registry = TechniqueRegistry()
|
372
|
+
|
373
|
+
# Track successfully registered techniques
|
374
|
+
registered_techniques = []
|
375
|
+
|
376
|
+
# Process file path
|
377
|
+
if args.file_path:
|
378
|
+
file_path = Path(args.file_path)
|
379
|
+
|
380
|
+
if not file_path.exists():
|
381
|
+
logging.error(f"File not found: {file_path}")
|
382
|
+
return 1
|
383
|
+
|
384
|
+
if not file_path.is_file() or file_path.suffix.lower() != ".py":
|
385
|
+
logging.error(f"Not a Python file: {file_path}")
|
386
|
+
return 1
|
387
|
+
|
388
|
+
# Register techniques from the file
|
389
|
+
registered = _register_from_file(
|
390
|
+
registry, file_path, args.name, args.class_name, args.overwrite
|
391
|
+
)
|
392
|
+
registered_techniques.extend(registered)
|
393
|
+
|
394
|
+
# Process folder path
|
395
|
+
elif args.folder_path:
|
396
|
+
folder_path = Path(args.folder_path)
|
397
|
+
|
398
|
+
if not folder_path.exists():
|
399
|
+
logging.error(f"Folder not found: {folder_path}")
|
400
|
+
return 1
|
401
|
+
|
402
|
+
if not folder_path.is_dir():
|
403
|
+
logging.error(f"Not a directory: {folder_path}")
|
404
|
+
return 1
|
405
|
+
|
406
|
+
# Register techniques from all Python files in the folder
|
407
|
+
for py_file in folder_path.glob("*.py"):
|
408
|
+
registered = _register_from_file(
|
409
|
+
registry,
|
410
|
+
py_file,
|
411
|
+
None, # Don't use custom name for folder scanning
|
412
|
+
None, # Don't use class name for folder scanning
|
413
|
+
args.overwrite,
|
414
|
+
)
|
415
|
+
registered_techniques.extend(registered)
|
416
|
+
|
417
|
+
# Print summary
|
418
|
+
if registered_techniques:
|
419
|
+
print("\nSuccessfully registered techniques:")
|
420
|
+
for technique in registered_techniques:
|
421
|
+
print(f" - {technique}")
|
422
|
+
|
423
|
+
# Suggestion for next steps
|
424
|
+
print("\nYou can now use these techniques in comparisons. For example:")
|
425
|
+
print(f" balancr select-techniques {registered_techniques[0]}")
|
426
|
+
return 0
|
427
|
+
else:
|
428
|
+
logging.warning("No valid balancing techniques found to register.")
|
429
|
+
return 1
|
430
|
+
|
431
|
+
except Exception as e:
|
432
|
+
logging.error(f"Error registering techniques: {e}")
|
433
|
+
if args.verbose:
|
434
|
+
import traceback
|
435
|
+
|
436
|
+
traceback.print_exc()
|
437
|
+
return 1
|
438
|
+
|
439
|
+
|
440
|
+
def _register_from_file(
|
441
|
+
registry, file_path, custom_name=None, class_name=None, overwrite=False
|
442
|
+
):
|
443
|
+
"""
|
444
|
+
Register technique classes from a Python file and copy to custom techniques directory.
|
445
|
+
|
446
|
+
Args:
|
447
|
+
registry: TechniqueRegistry instance
|
448
|
+
file_path: Path to the Python file
|
449
|
+
custom_name: Custom name to register the technique under
|
450
|
+
class_name: Name of specific class to register
|
451
|
+
overwrite: Whether to overwrite existing techniques
|
452
|
+
|
453
|
+
Returns:
|
454
|
+
list: Names of successfully registered techniques
|
455
|
+
"""
|
456
|
+
registered_techniques = []
|
457
|
+
|
458
|
+
try:
|
459
|
+
# Create custom techniques directory if it doesn't exist
|
460
|
+
custom_dir = Path.home() / ".balancr" / "custom_techniques"
|
461
|
+
custom_dir.mkdir(parents=True, exist_ok=True)
|
462
|
+
|
463
|
+
# Import the module dynamically
|
464
|
+
module_name = file_path.stem
|
465
|
+
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
466
|
+
if spec is None or spec.loader is None:
|
467
|
+
logging.error(f"Could not load module from {file_path}")
|
468
|
+
return registered_techniques
|
469
|
+
|
470
|
+
module = importlib.util.module_from_spec(spec)
|
471
|
+
spec.loader.exec_module(module)
|
472
|
+
|
473
|
+
# Find all classes that inherit from BaseBalancer
|
474
|
+
technique_classes = []
|
475
|
+
for name, obj in inspect.getmembers(module, inspect.isclass):
|
476
|
+
if (
|
477
|
+
obj.__module__
|
478
|
+
== module_name # Only consider classes defined in this module
|
479
|
+
and issubclass(obj, BaseBalancer)
|
480
|
+
and obj is not BaseBalancer
|
481
|
+
): # Exclude the base class itself
|
482
|
+
technique_classes.append((name, obj))
|
483
|
+
|
484
|
+
# If no valid classes found
|
485
|
+
if not technique_classes:
|
486
|
+
logging.warning(f"No valid technique classes found in {file_path}")
|
487
|
+
logging.info(
|
488
|
+
"Classes must inherit from balancr.base.BaseBalancer"
|
489
|
+
)
|
490
|
+
return registered_techniques
|
491
|
+
|
492
|
+
# If class_name is specified, filter to just that class
|
493
|
+
if class_name:
|
494
|
+
technique_classes = [
|
495
|
+
(name, cls) for name, cls in technique_classes if name == class_name
|
496
|
+
]
|
497
|
+
if not technique_classes:
|
498
|
+
logging.error(
|
499
|
+
f"Class '{class_name}' not found in {file_path} or doesn't inherit from BaseBalancer"
|
500
|
+
)
|
501
|
+
return registered_techniques
|
502
|
+
|
503
|
+
# If requesting a custom name but multiple classes found and no class_name specified
|
504
|
+
if custom_name and len(technique_classes) > 1 and not class_name:
|
505
|
+
logging.error(
|
506
|
+
f"Multiple technique classes found in {file_path}, but custom name provided. "
|
507
|
+
"Please specify which class to register with --class-name."
|
508
|
+
)
|
509
|
+
return registered_techniques
|
510
|
+
|
511
|
+
# Register techniques
|
512
|
+
for name, cls in technique_classes:
|
513
|
+
# If this is a specifically requested class with a custom name
|
514
|
+
if class_name and name == class_name and custom_name:
|
515
|
+
register_name = custom_name
|
516
|
+
else:
|
517
|
+
register_name = name
|
518
|
+
|
519
|
+
try:
|
520
|
+
# Check if technique already exists
|
521
|
+
existing_techniques = registry.list_available_techniques()
|
522
|
+
if (
|
523
|
+
register_name in existing_techniques.get("custom", [])
|
524
|
+
and not overwrite
|
525
|
+
):
|
526
|
+
logging.warning(
|
527
|
+
f"Technique '{register_name}' already exists. "
|
528
|
+
"Use --overwrite to replace it."
|
529
|
+
)
|
530
|
+
continue
|
531
|
+
|
532
|
+
# Register the technique
|
533
|
+
registry.register_custom_technique(register_name, cls)
|
534
|
+
registered_techniques.append(register_name)
|
535
|
+
logging.info(f"Successfully registered technique: {register_name}")
|
536
|
+
|
537
|
+
except Exception as e:
|
538
|
+
logging.error(f"Error registering technique '{register_name}': {e}")
|
539
|
+
|
540
|
+
# For successfully registered techniques, copy the file
|
541
|
+
if registered_techniques:
|
542
|
+
# Generate a unique filename (in case multiple files have same name)
|
543
|
+
dest_file = custom_dir / f"{file_path.stem}_{hash(str(file_path))}.py"
|
544
|
+
shutil.copy2(file_path, dest_file)
|
545
|
+
logging.debug(f"Copied {file_path} to {dest_file}")
|
546
|
+
|
547
|
+
# Create a metadata file to map technique names to files
|
548
|
+
metadata_file = custom_dir / "techniques_metadata.json"
|
549
|
+
metadata = {}
|
550
|
+
if metadata_file.exists():
|
551
|
+
with open(metadata_file, "r") as f:
|
552
|
+
metadata = json.load(f)
|
553
|
+
|
554
|
+
# Update metadata with new techniques
|
555
|
+
class_mapping = {cls_name: cls_name for cls_name, _ in technique_classes}
|
556
|
+
if class_name and custom_name:
|
557
|
+
class_mapping[custom_name] = class_name
|
558
|
+
|
559
|
+
for technique_name in registered_techniques:
|
560
|
+
original_class = class_mapping.get(technique_name, technique_name)
|
561
|
+
metadata[technique_name] = {
|
562
|
+
"file": str(dest_file),
|
563
|
+
"class_name": original_class,
|
564
|
+
"registered_at": datetime.now().isoformat(),
|
565
|
+
}
|
566
|
+
|
567
|
+
with open(metadata_file, "w") as f:
|
568
|
+
json.dump(metadata, f, indent=2)
|
569
|
+
|
570
|
+
except Exception as e:
|
571
|
+
logging.error(f"Error processing file {file_path}: {e}")
|
572
|
+
|
573
|
+
return registered_techniques
|
574
|
+
|
575
|
+
|
576
|
+
def _remove_techniques(args):
|
577
|
+
"""
|
578
|
+
Remove custom techniques as specified in the args.
|
579
|
+
|
580
|
+
Args:
|
581
|
+
args: Command line arguments from argparse
|
582
|
+
|
583
|
+
Returns:
|
584
|
+
int: Exit code
|
585
|
+
"""
|
586
|
+
# Get path to custom techniques directory
|
587
|
+
custom_dir = Path.home() / ".balancr" / "custom_techniques"
|
588
|
+
metadata_file = custom_dir / "techniques_metadata.json"
|
589
|
+
|
590
|
+
# Check if metadata file exists
|
591
|
+
if not metadata_file.exists():
|
592
|
+
logging.error("No custom techniques have been registered.")
|
593
|
+
return 1
|
594
|
+
|
595
|
+
# Load metadata
|
596
|
+
with open(metadata_file, "r") as f:
|
597
|
+
metadata = json.load(f)
|
598
|
+
|
599
|
+
# If no custom techniques
|
600
|
+
if not metadata:
|
601
|
+
logging.error("No custom techniques have been registered.")
|
602
|
+
return 1
|
603
|
+
|
604
|
+
# Remove all custom techniques
|
605
|
+
if args.remove_all:
|
606
|
+
logging.info("Removing all custom techniques...")
|
607
|
+
|
608
|
+
# Remove all technique files
|
609
|
+
file_paths = set(info["file"] for info in metadata.values())
|
610
|
+
for file_path in file_paths:
|
611
|
+
try:
|
612
|
+
Path(file_path).unlink(missing_ok=True)
|
613
|
+
except Exception as e:
|
614
|
+
logging.warning(f"Error removing file {file_path}: {e}")
|
615
|
+
|
616
|
+
# Clear metadata
|
617
|
+
metadata = {}
|
618
|
+
with open(metadata_file, "w") as f:
|
619
|
+
json.dump(metadata, f, indent=2)
|
620
|
+
|
621
|
+
print("All custom techniques have been removed.")
|
622
|
+
return 0
|
623
|
+
|
624
|
+
# Remove specific techniques
|
625
|
+
removed_techniques = []
|
626
|
+
for technique_name in args.remove:
|
627
|
+
if technique_name in metadata:
|
628
|
+
# Note the file path (we'll check if it's used by other techniques)
|
629
|
+
file_path = metadata[technique_name]["file"]
|
630
|
+
|
631
|
+
# Remove from metadata
|
632
|
+
del metadata[technique_name]
|
633
|
+
removed_techniques.append(technique_name)
|
634
|
+
|
635
|
+
# Check if the file is still used by other techniques
|
636
|
+
file_still_used = any(
|
637
|
+
info["file"] == file_path for info in metadata.values()
|
638
|
+
)
|
639
|
+
|
640
|
+
# If not used, remove the file
|
641
|
+
if not file_still_used:
|
642
|
+
try:
|
643
|
+
Path(file_path).unlink(missing_ok=True)
|
644
|
+
except Exception as e:
|
645
|
+
logging.warning(f"Error removing file {file_path}: {e}")
|
646
|
+
else:
|
647
|
+
logging.warning(f"Technique '{technique_name}' not found.")
|
648
|
+
|
649
|
+
# Save updated metadata
|
650
|
+
with open(metadata_file, "w") as f:
|
651
|
+
json.dump(metadata, f, indent=2)
|
652
|
+
|
653
|
+
if removed_techniques:
|
654
|
+
print("\nRemoved techniques:")
|
655
|
+
for technique in removed_techniques:
|
656
|
+
print(f" - {technique}")
|
657
|
+
return 0
|
658
|
+
else:
|
659
|
+
logging.error("No matching techniques were found.")
|
660
|
+
return 1
|
661
|
+
|
662
|
+
|
663
|
+
def select_classifier(args):
|
664
|
+
"""
|
665
|
+
Handle the select-classifier command.
|
666
|
+
|
667
|
+
Args:
|
668
|
+
args: Command line arguments from argparse
|
669
|
+
|
670
|
+
Returns:
|
671
|
+
int: Exit code
|
672
|
+
"""
|
673
|
+
# Check if we should list available classifiers
|
674
|
+
if args.list_available:
|
675
|
+
return list_available_classifiers(args)
|
676
|
+
|
677
|
+
logging.info(f"Selecting classifiers: {', '.join(args.classifiers)}")
|
678
|
+
|
679
|
+
# Create classifier registry
|
680
|
+
if ClassifierRegistry is None:
|
681
|
+
logging.error("Classifier registry not available. Please check installation.")
|
682
|
+
return 1
|
683
|
+
|
684
|
+
registry = ClassifierRegistry()
|
685
|
+
|
686
|
+
# Get classifier configurations
|
687
|
+
classifier_configs = {}
|
688
|
+
|
689
|
+
for classifier_name in args.classifiers:
|
690
|
+
# Get the classifier class
|
691
|
+
classifier_class = registry.get_classifier_class(classifier_name)
|
692
|
+
|
693
|
+
if classifier_class is None:
|
694
|
+
logging.error(f"Classifier '{classifier_name}' not found.")
|
695
|
+
logging.info(
|
696
|
+
"Use 'balancr select-classifier --list-available' to see available classifiers."
|
697
|
+
)
|
698
|
+
continue
|
699
|
+
|
700
|
+
# Get default parameters
|
701
|
+
params = get_classifier_default_params(classifier_class)
|
702
|
+
classifier_configs[classifier_name] = params
|
703
|
+
|
704
|
+
# If no valid classifiers were found
|
705
|
+
if not classifier_configs:
|
706
|
+
logging.error("No valid classifiers selected.")
|
707
|
+
return 1
|
708
|
+
|
709
|
+
try:
|
710
|
+
# Read existing config (we need this regardless of append mode)
|
711
|
+
current_config = config.load_config(args.config_path)
|
712
|
+
|
713
|
+
if args.append:
|
714
|
+
# Append mode: Update existing classifiers
|
715
|
+
existing_classifiers = current_config.get("classifiers", {})
|
716
|
+
existing_classifiers.update(classifier_configs)
|
717
|
+
settings = {"classifiers": existing_classifiers}
|
718
|
+
|
719
|
+
print(f"\nAdded classifiers: {', '.join(classifier_configs.keys())}")
|
720
|
+
print(f"Total classifiers: {', '.join(existing_classifiers.keys())}")
|
721
|
+
else:
|
722
|
+
# Replace mode: Create a completely new config entry
|
723
|
+
# We'll create a copy of the current config and explicitly set the classifiers
|
724
|
+
new_config = dict(current_config) # shallow copy is sufficient
|
725
|
+
new_config["classifiers"] = classifier_configs
|
726
|
+
|
727
|
+
# Use config.write_config instead of update_config to replace the entire file
|
728
|
+
config_path = Path(args.config_path)
|
729
|
+
with open(config_path, "w") as f:
|
730
|
+
json.dump(new_config, f, indent=2)
|
731
|
+
|
732
|
+
print(
|
733
|
+
f"\nReplaced classifiers with: {', '.join(classifier_configs.keys())}"
|
734
|
+
)
|
735
|
+
|
736
|
+
# Return early since we've manually written the config
|
737
|
+
print("Default parameters have been added to the configuration file.")
|
738
|
+
print("You can modify them by editing the configuration or using the CLI.")
|
739
|
+
return 0
|
740
|
+
|
741
|
+
# Only reach here in append mode
|
742
|
+
config.update_config(args.config_path, settings)
|
743
|
+
|
744
|
+
print("Default parameters have been added to the configuration file.")
|
745
|
+
print("You can modify them by editing the configuration or using the CLI.")
|
746
|
+
|
747
|
+
return 0
|
748
|
+
except Exception as e:
|
749
|
+
logging.error(f"Failed to select classifiers: {e}")
|
750
|
+
return 1
|
751
|
+
|
752
|
+
|
753
|
+
def list_available_classifiers(args):
|
754
|
+
"""
|
755
|
+
List all available classifiers.
|
756
|
+
|
757
|
+
Args:
|
758
|
+
args: Command line arguments from argparse
|
759
|
+
|
760
|
+
Returns:
|
761
|
+
int: Exit code
|
762
|
+
"""
|
763
|
+
if ClassifierRegistry is None:
|
764
|
+
logging.error("Classifier registry not available. Please check installation.")
|
765
|
+
return 1
|
766
|
+
|
767
|
+
registry = ClassifierRegistry()
|
768
|
+
classifiers = registry.list_available_classifiers()
|
769
|
+
|
770
|
+
print("\nAvailable Classifiers:")
|
771
|
+
|
772
|
+
# Print custom classifiers if any
|
773
|
+
if "custom" in classifiers and classifiers["custom"]:
|
774
|
+
print("\nCustom Classifiers:")
|
775
|
+
for module_name, clf_list in classifiers["custom"].items():
|
776
|
+
if clf_list:
|
777
|
+
print(f"\n {module_name.capitalize()}:")
|
778
|
+
for clf in sorted(clf_list):
|
779
|
+
print(f" - {clf}")
|
780
|
+
|
781
|
+
# Print sklearn classifiers by module
|
782
|
+
if "sklearn" in classifiers:
|
783
|
+
print("\nScikit-learn Classifiers:")
|
784
|
+
for module_name, clf_list in classifiers["sklearn"].items():
|
785
|
+
print(f"\n {module_name.capitalize()}:")
|
786
|
+
for clf in sorted(clf_list):
|
787
|
+
print(f" - {clf}")
|
788
|
+
|
789
|
+
return 0
|
790
|
+
|
791
|
+
|
792
|
+
def get_classifier_default_params(classifier_class):
|
793
|
+
"""
|
794
|
+
Extract default parameters from a classifier class.
|
795
|
+
|
796
|
+
Args:
|
797
|
+
classifier_class: The classifier class to inspect
|
798
|
+
|
799
|
+
Returns:
|
800
|
+
Dictionary of parameter names and their default values
|
801
|
+
"""
|
802
|
+
params = {}
|
803
|
+
|
804
|
+
try:
|
805
|
+
# Get the signature of the __init__ method
|
806
|
+
sig = inspect.signature(classifier_class.__init__)
|
807
|
+
|
808
|
+
# Process each parameter
|
809
|
+
for name, param in sig.parameters.items():
|
810
|
+
# Skip 'self' parameter
|
811
|
+
if name == "self":
|
812
|
+
continue
|
813
|
+
|
814
|
+
# Get default value if it exists
|
815
|
+
if param.default is not inspect.Parameter.empty:
|
816
|
+
# Handle special case for None (JSON uses null)
|
817
|
+
if param.default is None:
|
818
|
+
params[name] = None
|
819
|
+
# Handle other types that can be serialised to JSON
|
820
|
+
elif isinstance(param.default, (int, float, str, bool, list, dict)):
|
821
|
+
params[name] = param.default
|
822
|
+
else:
|
823
|
+
# Convert non-JSON-serialisable defaults to string representation
|
824
|
+
params[name] = str(param.default)
|
825
|
+
else:
|
826
|
+
# For parameters without defaults, use None
|
827
|
+
params[name] = None
|
828
|
+
|
829
|
+
except Exception as e:
|
830
|
+
logging.warning(
|
831
|
+
f"Error extracting parameters from {classifier_class.__name__}: {e}"
|
832
|
+
)
|
833
|
+
|
834
|
+
return params
|
835
|
+
|
836
|
+
|
837
|
+
def register_classifiers(args):
|
838
|
+
"""
|
839
|
+
Handle the register-classifiers command.
|
840
|
+
|
841
|
+
This command allows users to register custom classifiers from
|
842
|
+
Python files or folders, or remove existing custom classifiers.
|
843
|
+
|
844
|
+
Args:
|
845
|
+
args: Command line arguments from argparse
|
846
|
+
|
847
|
+
Returns:
|
848
|
+
int: Exit code
|
849
|
+
"""
|
850
|
+
try:
|
851
|
+
# Ensure framework is available
|
852
|
+
if ClassifierRegistry is None:
|
853
|
+
logging.error(
|
854
|
+
"Classifier registry not available. Please check installation."
|
855
|
+
)
|
856
|
+
return 1
|
857
|
+
|
858
|
+
# Handle removal operations
|
859
|
+
if args.remove or args.remove_all:
|
860
|
+
return _remove_classifiers(args)
|
861
|
+
|
862
|
+
registry = ClassifierRegistry()
|
863
|
+
|
864
|
+
# Track successfully registered classifiers
|
865
|
+
registered_classifiers = []
|
866
|
+
|
867
|
+
# Process file path
|
868
|
+
if args.file_path:
|
869
|
+
file_path = Path(args.file_path)
|
870
|
+
|
871
|
+
if not file_path.exists():
|
872
|
+
logging.error(f"File not found: {file_path}")
|
873
|
+
return 1
|
874
|
+
|
875
|
+
if not file_path.is_file() or file_path.suffix.lower() != ".py":
|
876
|
+
logging.error(f"Not a Python file: {file_path}")
|
877
|
+
return 1
|
878
|
+
|
879
|
+
# Register classifiers from the file
|
880
|
+
registered = _register_classifier_from_file(
|
881
|
+
registry, file_path, args.name, args.class_name, args.overwrite
|
882
|
+
)
|
883
|
+
registered_classifiers.extend(registered)
|
884
|
+
|
885
|
+
# Process folder path
|
886
|
+
elif args.folder_path:
|
887
|
+
folder_path = Path(args.folder_path)
|
888
|
+
|
889
|
+
if not folder_path.exists():
|
890
|
+
logging.error(f"Folder not found: {folder_path}")
|
891
|
+
return 1
|
892
|
+
|
893
|
+
if not folder_path.is_dir():
|
894
|
+
logging.error(f"Not a directory: {folder_path}")
|
895
|
+
return 1
|
896
|
+
|
897
|
+
# Register classifiers from all Python files in the folder
|
898
|
+
for py_file in folder_path.glob("*.py"):
|
899
|
+
registered = _register_classifier_from_file(
|
900
|
+
registry,
|
901
|
+
py_file,
|
902
|
+
None, # Don't use custom name for folder scanning
|
903
|
+
None, # Don't use class name for folder scanning
|
904
|
+
args.overwrite,
|
905
|
+
)
|
906
|
+
registered_classifiers.extend(registered)
|
907
|
+
|
908
|
+
# Print summary
|
909
|
+
if registered_classifiers:
|
910
|
+
print("\nSuccessfully registered classifiers:")
|
911
|
+
for classifier in registered_classifiers:
|
912
|
+
print(f" - {classifier}")
|
913
|
+
|
914
|
+
print("\nYou can now use these classifiers in comparisons. For example:")
|
915
|
+
print(f" balancr select-classifiers {registered_classifiers[0]}")
|
916
|
+
return 0
|
917
|
+
else:
|
918
|
+
logging.warning("No valid classifiers found to register.")
|
919
|
+
return 1
|
920
|
+
|
921
|
+
except Exception as e:
|
922
|
+
logging.error(f"Error registering classifiers: {e}")
|
923
|
+
if args.verbose:
|
924
|
+
import traceback
|
925
|
+
|
926
|
+
traceback.print_exc()
|
927
|
+
return 1
|
928
|
+
|
929
|
+
|
930
|
+
def _register_classifier_from_file(
|
931
|
+
registry, file_path, custom_name=None, class_name=None, overwrite=False
|
932
|
+
):
|
933
|
+
"""
|
934
|
+
Register classifier classes from a Python file.
|
935
|
+
|
936
|
+
Args:
|
937
|
+
registry: ClassifierRegistry instance
|
938
|
+
file_path: Path to the Python file
|
939
|
+
custom_name: Custom name to register the classifier under
|
940
|
+
class_name: Name of specific class to register
|
941
|
+
overwrite: Whether to overwrite existing classifiers
|
942
|
+
|
943
|
+
Returns:
|
944
|
+
list: Names of successfully registered classifiers
|
945
|
+
"""
|
946
|
+
registered_classifiers = []
|
947
|
+
|
948
|
+
try:
|
949
|
+
# Create custom classifiers directory if it doesn't exist
|
950
|
+
custom_dir = Path.home() / ".balancr" / "custom_classifiers"
|
951
|
+
custom_dir.mkdir(parents=True, exist_ok=True)
|
952
|
+
|
953
|
+
# Import the module dynamically
|
954
|
+
module_name = file_path.stem
|
955
|
+
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
956
|
+
if spec is None or spec.loader is None:
|
957
|
+
logging.error(f"Could not load module from {file_path}")
|
958
|
+
return registered_classifiers
|
959
|
+
|
960
|
+
module = importlib.util.module_from_spec(spec)
|
961
|
+
spec.loader.exec_module(module)
|
962
|
+
|
963
|
+
# Ensure sklearn BaseEstimator is available
|
964
|
+
try:
|
965
|
+
from sklearn.base import BaseEstimator
|
966
|
+
except ImportError:
|
967
|
+
logging.error(
|
968
|
+
"scikit-learn is not available. Please install it using: pip install scikit-learn"
|
969
|
+
)
|
970
|
+
return registered_classifiers
|
971
|
+
|
972
|
+
# Find all classes that inherit from BaseEstimator and have fit/predict methods
|
973
|
+
classifier_classes = []
|
974
|
+
for name, obj in inspect.getmembers(module, inspect.isclass):
|
975
|
+
if (
|
976
|
+
obj.__module__
|
977
|
+
== module_name # Only consider classes defined in this module
|
978
|
+
and issubclass(obj, BaseEstimator)
|
979
|
+
and hasattr(obj, "fit")
|
980
|
+
and hasattr(obj, "predict")
|
981
|
+
):
|
982
|
+
classifier_classes.append((name, obj))
|
983
|
+
|
984
|
+
# If no valid classes found
|
985
|
+
if not classifier_classes:
|
986
|
+
logging.warning(f"No valid classifier classes found in {file_path}")
|
987
|
+
logging.info(
|
988
|
+
"Classes must inherit from sklearn.base.BaseEstimator and implement fit and predict methods"
|
989
|
+
)
|
990
|
+
return registered_classifiers
|
991
|
+
|
992
|
+
# If class_name is specified, filter to just that class
|
993
|
+
if class_name:
|
994
|
+
classifier_classes = [
|
995
|
+
(name, cls) for name, cls in classifier_classes if name == class_name
|
996
|
+
]
|
997
|
+
if not classifier_classes:
|
998
|
+
logging.error(
|
999
|
+
f"Class '{class_name}' not found in {file_path} or doesn't meet classifier requirements"
|
1000
|
+
)
|
1001
|
+
return registered_classifiers
|
1002
|
+
|
1003
|
+
# If requesting a custom name but multiple classes found and no class_name specified
|
1004
|
+
if custom_name and len(classifier_classes) > 1 and not class_name:
|
1005
|
+
logging.error(
|
1006
|
+
f"Multiple classifier classes found in {file_path}, but custom name provided. "
|
1007
|
+
"Please specify which class to register with --class-name."
|
1008
|
+
)
|
1009
|
+
return registered_classifiers
|
1010
|
+
|
1011
|
+
# Register the classifiers
|
1012
|
+
for name, cls in classifier_classes:
|
1013
|
+
# If this is a specifically requested class with a custom name
|
1014
|
+
if class_name and name == class_name and custom_name:
|
1015
|
+
register_name = custom_name
|
1016
|
+
else:
|
1017
|
+
register_name = name
|
1018
|
+
|
1019
|
+
try:
|
1020
|
+
# Check if classifier already exists
|
1021
|
+
existing_classifiers = registry.list_available_classifiers()
|
1022
|
+
flat_existing = []
|
1023
|
+
for module_classifiers in existing_classifiers.get(
|
1024
|
+
"custom", {}
|
1025
|
+
).values():
|
1026
|
+
flat_existing.extend(module_classifiers)
|
1027
|
+
|
1028
|
+
if register_name in flat_existing and not overwrite:
|
1029
|
+
logging.warning(
|
1030
|
+
f"Classifier '{register_name}' already exists. "
|
1031
|
+
"Use --overwrite to replace it."
|
1032
|
+
)
|
1033
|
+
continue
|
1034
|
+
|
1035
|
+
# Register the classifier
|
1036
|
+
registry.register_custom_classifier(register_name, cls)
|
1037
|
+
registered_classifiers.append(register_name)
|
1038
|
+
logging.info(f"Successfully registered classifier: {register_name}")
|
1039
|
+
|
1040
|
+
except Exception as e:
|
1041
|
+
logging.error(f"Error registering classifier '{register_name}': {e}")
|
1042
|
+
|
1043
|
+
# For successfully registered classifiers, copy the file
|
1044
|
+
if registered_classifiers:
|
1045
|
+
# Generate a unique filename (in case multiple files have same name)
|
1046
|
+
dest_file = custom_dir / f"{file_path.stem}_{hash(str(file_path))}.py"
|
1047
|
+
shutil.copy2(file_path, dest_file)
|
1048
|
+
logging.debug(f"Copied {file_path} to {dest_file}")
|
1049
|
+
|
1050
|
+
# Create a metadata file to map classifier names to files
|
1051
|
+
metadata_file = custom_dir / "classifiers_metadata.json"
|
1052
|
+
metadata = {}
|
1053
|
+
if metadata_file.exists():
|
1054
|
+
with open(metadata_file, "r") as f:
|
1055
|
+
metadata = json.load(f)
|
1056
|
+
|
1057
|
+
# Update metadata with new classifiers
|
1058
|
+
class_mapping = {cls_name: cls_name for cls_name, _ in classifier_classes}
|
1059
|
+
if class_name and custom_name:
|
1060
|
+
class_mapping[custom_name] = class_name
|
1061
|
+
|
1062
|
+
for classifier_name in registered_classifiers:
|
1063
|
+
original_class = class_mapping.get(classifier_name, classifier_name)
|
1064
|
+
metadata[classifier_name] = {
|
1065
|
+
"file": str(dest_file),
|
1066
|
+
"class_name": original_class,
|
1067
|
+
"registered_at": datetime.now().isoformat(),
|
1068
|
+
}
|
1069
|
+
|
1070
|
+
with open(metadata_file, "w") as f:
|
1071
|
+
json.dump(metadata, f, indent=2)
|
1072
|
+
|
1073
|
+
except Exception as e:
|
1074
|
+
logging.error(f"Error processing file {file_path}: {e}")
|
1075
|
+
|
1076
|
+
return registered_classifiers
|
1077
|
+
|
1078
|
+
|
1079
|
+
def _remove_classifiers(args):
|
1080
|
+
"""
|
1081
|
+
Remove custom classifiers as specified in the args.
|
1082
|
+
|
1083
|
+
Args:
|
1084
|
+
args: Command line arguments from argparse
|
1085
|
+
|
1086
|
+
Returns:
|
1087
|
+
int: Exit code
|
1088
|
+
"""
|
1089
|
+
# Get path to custom classifiers directory
|
1090
|
+
custom_dir = Path.home() / ".balancr" / "custom_classifiers"
|
1091
|
+
metadata_file = custom_dir / "classifiers_metadata.json"
|
1092
|
+
|
1093
|
+
# Check if metadata file exists
|
1094
|
+
if not metadata_file.exists():
|
1095
|
+
logging.error("No custom classifiers have been registered.")
|
1096
|
+
return 1
|
1097
|
+
|
1098
|
+
# Load metadata
|
1099
|
+
with open(metadata_file, "r") as f:
|
1100
|
+
metadata = json.load(f)
|
1101
|
+
|
1102
|
+
# If no custom classifiers
|
1103
|
+
if not metadata:
|
1104
|
+
logging.error("No custom classifiers have been registered.")
|
1105
|
+
return 1
|
1106
|
+
|
1107
|
+
# Remove all custom classifiers
|
1108
|
+
if args.remove_all:
|
1109
|
+
logging.info("Removing all custom classifiers...")
|
1110
|
+
|
1111
|
+
# Remove all classifier files
|
1112
|
+
file_paths = set(info["file"] for info in metadata.values())
|
1113
|
+
for file_path in file_paths:
|
1114
|
+
try:
|
1115
|
+
Path(file_path).unlink(missing_ok=True)
|
1116
|
+
except Exception as e:
|
1117
|
+
logging.warning(f"Error removing file {file_path}: {e}")
|
1118
|
+
|
1119
|
+
# Clear metadata
|
1120
|
+
metadata = {}
|
1121
|
+
with open(metadata_file, "w") as f:
|
1122
|
+
json.dump(metadata, f, indent=2)
|
1123
|
+
|
1124
|
+
print("All custom classifiers have been removed.")
|
1125
|
+
return 0
|
1126
|
+
|
1127
|
+
# Remove specific classifiers
|
1128
|
+
removed_classifiers = []
|
1129
|
+
for classifier_name in args.remove:
|
1130
|
+
if classifier_name in metadata:
|
1131
|
+
# Note the file path (we'll check if it's used by other classifiers)
|
1132
|
+
file_path = metadata[classifier_name]["file"]
|
1133
|
+
|
1134
|
+
# Remove from metadata
|
1135
|
+
del metadata[classifier_name]
|
1136
|
+
removed_classifiers.append(classifier_name)
|
1137
|
+
|
1138
|
+
# Check if the file is still used by other classifiers
|
1139
|
+
file_still_used = any(
|
1140
|
+
info["file"] == file_path for info in metadata.values()
|
1141
|
+
)
|
1142
|
+
|
1143
|
+
# If not used, remove the file
|
1144
|
+
if not file_still_used:
|
1145
|
+
try:
|
1146
|
+
Path(file_path).unlink(missing_ok=True)
|
1147
|
+
except Exception as e:
|
1148
|
+
logging.warning(f"Error removing file {file_path}: {e}")
|
1149
|
+
else:
|
1150
|
+
logging.warning(f"Classifier '{classifier_name}' not found.")
|
1151
|
+
|
1152
|
+
# Save updated metadata
|
1153
|
+
with open(metadata_file, "w") as f:
|
1154
|
+
json.dump(metadata, f, indent=2)
|
1155
|
+
|
1156
|
+
if removed_classifiers:
|
1157
|
+
print("\nRemoved classifiers:")
|
1158
|
+
for classifier in removed_classifiers:
|
1159
|
+
print(f" - {classifier}")
|
1160
|
+
return 0
|
1161
|
+
else:
|
1162
|
+
logging.error("No matching classifiers were found.")
|
1163
|
+
return 1
|
1164
|
+
|
1165
|
+
|
1166
|
+
def configure_metrics(args):
|
1167
|
+
"""
|
1168
|
+
Handle the configure-metrics command.
|
1169
|
+
|
1170
|
+
Args:
|
1171
|
+
args: Command line arguments from argparse
|
1172
|
+
|
1173
|
+
Returns:
|
1174
|
+
int: Exit code
|
1175
|
+
"""
|
1176
|
+
# Define all available metrics
|
1177
|
+
all_metrics = [
|
1178
|
+
"accuracy",
|
1179
|
+
"precision",
|
1180
|
+
"recall",
|
1181
|
+
"f1",
|
1182
|
+
"roc_auc",
|
1183
|
+
"specificity",
|
1184
|
+
"g_mean",
|
1185
|
+
"average_precision",
|
1186
|
+
]
|
1187
|
+
|
1188
|
+
# If 'all' is specified, use all available metrics
|
1189
|
+
if "all" in args.metrics:
|
1190
|
+
metrics_to_use = all_metrics
|
1191
|
+
metrics_str = "all available metrics"
|
1192
|
+
else:
|
1193
|
+
metrics_to_use = args.metrics
|
1194
|
+
metrics_str = ", ".join(args.metrics)
|
1195
|
+
|
1196
|
+
logging.info(f"Configuring metrics: {metrics_str}")
|
1197
|
+
|
1198
|
+
# Update configuration with metrics settings
|
1199
|
+
settings = {
|
1200
|
+
"output": {"metrics": metrics_to_use, "save_metrics_formats": args.save_formats}
|
1201
|
+
}
|
1202
|
+
|
1203
|
+
try:
|
1204
|
+
# Update existing output settings if they exist
|
1205
|
+
current_config = config.load_config(args.config_path)
|
1206
|
+
if "output" in current_config:
|
1207
|
+
current_output = current_config["output"]
|
1208
|
+
# Merge with existing output settings without overwriting other output options
|
1209
|
+
settings["output"] = {**current_output, **settings["output"]}
|
1210
|
+
|
1211
|
+
config.update_config(args.config_path, settings)
|
1212
|
+
|
1213
|
+
# Display confirmation
|
1214
|
+
print("\nMetrics Configuration:")
|
1215
|
+
print(f" Metrics: {metrics_str}")
|
1216
|
+
print(f" Save Formats: {', '.join(args.save_formats)}")
|
1217
|
+
|
1218
|
+
return 0
|
1219
|
+
|
1220
|
+
except Exception as e:
|
1221
|
+
logging.error(f"Failed to configure metrics: {e}")
|
1222
|
+
return 1
|
1223
|
+
|
1224
|
+
|
1225
|
+
def configure_visualisations(args):
|
1226
|
+
"""
|
1227
|
+
Handle the configure-visualisations command.
|
1228
|
+
|
1229
|
+
Args:
|
1230
|
+
args: Command line arguments from argparse
|
1231
|
+
|
1232
|
+
Returns:
|
1233
|
+
int: Exit code
|
1234
|
+
"""
|
1235
|
+
types_str = "all visualisations" if "all" in args.types else ", ".join(args.types)
|
1236
|
+
logging.info(f"Configuring visualisations: {types_str}")
|
1237
|
+
|
1238
|
+
# Update configuration with visualisation settings
|
1239
|
+
settings = {
|
1240
|
+
"output": {
|
1241
|
+
"visualisations": args.types,
|
1242
|
+
"display_visualisations": args.display,
|
1243
|
+
"save_vis_formats": args.save_formats,
|
1244
|
+
}
|
1245
|
+
}
|
1246
|
+
|
1247
|
+
try:
|
1248
|
+
# Update existing output settings if they exist
|
1249
|
+
current_config = config.load_config(args.config_path)
|
1250
|
+
if "output" in current_config:
|
1251
|
+
current_output = current_config["output"]
|
1252
|
+
# Merge with existing output settings without overwriting other output options
|
1253
|
+
settings["output"] = {**current_output, **settings["output"]}
|
1254
|
+
|
1255
|
+
config.update_config(args.config_path, settings)
|
1256
|
+
|
1257
|
+
# Display confirmation
|
1258
|
+
print("\nVisualisation Configuration:")
|
1259
|
+
print(f" Types: {types_str}")
|
1260
|
+
print(f" Display During Execution: {'Yes' if args.display else 'No'}")
|
1261
|
+
print(f" Save Formats: {', '.join(args.save_formats)}")
|
1262
|
+
|
1263
|
+
return 0
|
1264
|
+
|
1265
|
+
except Exception as e:
|
1266
|
+
logging.error(f"Failed to configure visualisations: {e}")
|
1267
|
+
return 1
|
1268
|
+
|
1269
|
+
|
1270
|
+
def configure_evaluation(args):
|
1271
|
+
"""
|
1272
|
+
Handle the configure-evaluation command.
|
1273
|
+
|
1274
|
+
Args:
|
1275
|
+
args: Command line arguments from argparse
|
1276
|
+
|
1277
|
+
Returns:
|
1278
|
+
int: Exit code
|
1279
|
+
"""
|
1280
|
+
logging.info("Configuring evaluation settings")
|
1281
|
+
|
1282
|
+
# Update configuration with evaluation settings
|
1283
|
+
settings = {
|
1284
|
+
"evaluation": {
|
1285
|
+
"test_size": args.test_size,
|
1286
|
+
"cross_validation": args.cross_validation,
|
1287
|
+
"random_state": args.random_state,
|
1288
|
+
"learning_curve_folds": args.learning_curve_folds,
|
1289
|
+
"learning_curve_points": args.learning_curve_points,
|
1290
|
+
}
|
1291
|
+
}
|
1292
|
+
|
1293
|
+
try:
|
1294
|
+
config.update_config(args.config_path, settings)
|
1295
|
+
|
1296
|
+
# Display confirmation
|
1297
|
+
print("\nEvaluation Configuration:")
|
1298
|
+
print(f" Test Size: {args.test_size}")
|
1299
|
+
print(f" Cross-Validation Folds: {args.cross_validation}")
|
1300
|
+
print(f" Random State: {args.random_state}")
|
1301
|
+
print(f" Learning Curve Folds: {args.learning_curve_folds}")
|
1302
|
+
print(f" Learning Curve Points: {args.learning_curve_points}")
|
1303
|
+
|
1304
|
+
return 0
|
1305
|
+
|
1306
|
+
except Exception as e:
|
1307
|
+
logging.error(f"Failed to configure evaluation: {e}")
|
1308
|
+
return 1
|
1309
|
+
|
1310
|
+
|
1311
|
+
def format_time(seconds):
|
1312
|
+
"""Format time in seconds to minutes and seconds"""
|
1313
|
+
minutes = int(seconds // 60)
|
1314
|
+
remaining_seconds = seconds % 60
|
1315
|
+
return f"{minutes}mins, {remaining_seconds:.2f}secs"
|
1316
|
+
|
1317
|
+
|
1318
|
+
def run_comparison(args):
|
1319
|
+
"""
|
1320
|
+
Handle the run command.
|
1321
|
+
|
1322
|
+
Args:
|
1323
|
+
args: Command line arguments from argparse
|
1324
|
+
|
1325
|
+
Returns:
|
1326
|
+
int: Exit code
|
1327
|
+
"""
|
1328
|
+
start_time_total = time.time()
|
1329
|
+
|
1330
|
+
# Load current configuration
|
1331
|
+
try:
|
1332
|
+
current_config = config.load_config(args.config_path)
|
1333
|
+
except Exception as e:
|
1334
|
+
logging.error(f"Failed to load configuration: {e}")
|
1335
|
+
return 1
|
1336
|
+
|
1337
|
+
# Check if all required settings are configured
|
1338
|
+
required_settings = ["data_file", "target_column", "balancing_techniques"]
|
1339
|
+
missing_settings = [s for s in required_settings if s not in current_config]
|
1340
|
+
|
1341
|
+
if missing_settings:
|
1342
|
+
logging.error(f"Missing required configuration: {', '.join(missing_settings)}")
|
1343
|
+
logging.info("Please configure all required settings before running comparison")
|
1344
|
+
return 1
|
1345
|
+
|
1346
|
+
# Ensure balancing framework is available
|
1347
|
+
if BalancingFramework is None:
|
1348
|
+
logging.error("Balancing framework not available. Please check installation")
|
1349
|
+
return 1
|
1350
|
+
|
1351
|
+
# Prepare output directory
|
1352
|
+
output_dir = Path(args.output_dir)
|
1353
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
1354
|
+
|
1355
|
+
# Get output and evaluation settings with defaults
|
1356
|
+
output_config = current_config.get("output", {})
|
1357
|
+
metrics = output_config.get("metrics", ["precision", "recall", "f1", "roc_auc"])
|
1358
|
+
visualisations = output_config.get("visualisations", ["all"])
|
1359
|
+
display_visualisations = output_config.get("display_visualisations", False)
|
1360
|
+
save_metrics_formats = output_config.get("save_metrics_formats", ["csv"])
|
1361
|
+
save_vis_formats = output_config.get("save_vis_formats", ["png"])
|
1362
|
+
|
1363
|
+
eval_config = current_config.get("evaluation", {})
|
1364
|
+
test_size = eval_config.get("test_size", 0.2)
|
1365
|
+
cv_enabled = eval_config.get("cross_validation", 0) > 0
|
1366
|
+
cv_folds = eval_config.get("cross_validation", 5)
|
1367
|
+
random_state = eval_config.get("random_state", 42)
|
1368
|
+
include_original = current_config.get("include_original_data", False)
|
1369
|
+
|
1370
|
+
balancing_techniques = current_config.get("balancing_techniques", {})
|
1371
|
+
technique_names = list(balancing_techniques.keys())
|
1372
|
+
logging.info(f"Running comparison with techniques: {', '.join(technique_names)}")
|
1373
|
+
logging.info(f"Results will be saved to: {output_dir}")
|
1374
|
+
|
1375
|
+
try:
|
1376
|
+
# Initialise the framework
|
1377
|
+
framework = BalancingFramework()
|
1378
|
+
|
1379
|
+
# Load data
|
1380
|
+
start_time = time.time()
|
1381
|
+
logging.info(f"Loading data from {current_config['data_file']}")
|
1382
|
+
feature_columns = current_config.get("feature_columns", None)
|
1383
|
+
framework.load_data(
|
1384
|
+
current_config["data_file"],
|
1385
|
+
current_config["target_column"],
|
1386
|
+
feature_columns,
|
1387
|
+
)
|
1388
|
+
load_time = time.time() - start_time
|
1389
|
+
logging.info(f"Data loading completed (Time Taken: {format_time(load_time)})")
|
1390
|
+
|
1391
|
+
# Apply preprocessing if configured
|
1392
|
+
if "preprocessing" in current_config:
|
1393
|
+
logging.info("Applying preprocessing...")
|
1394
|
+
preproc = current_config["preprocessing"]
|
1395
|
+
|
1396
|
+
handle_missing = preproc.get("handle_missing", "mean")
|
1397
|
+
scale = preproc.get("scale", "standard")
|
1398
|
+
categorical_features = preproc.get("categorical_features", {})
|
1399
|
+
handle_constant_features = preproc.get("handle_constant_features", None)
|
1400
|
+
handle_correlations = preproc.get("handle_correlations", None)
|
1401
|
+
save_preprocessed_file = preproc.get("save_preprocessed", True)
|
1402
|
+
|
1403
|
+
# Extract hash components information
|
1404
|
+
hash_components_dict = {}
|
1405
|
+
for feature, encoding in categorical_features.items():
|
1406
|
+
if isinstance(encoding, list) and encoding[0] == "hash":
|
1407
|
+
hash_components_dict[feature] = encoding[
|
1408
|
+
1
|
1409
|
+
] # Store the n_components value
|
1410
|
+
|
1411
|
+
framework.preprocess_data(
|
1412
|
+
handle_missing=handle_missing,
|
1413
|
+
scale=scale,
|
1414
|
+
categorical_features=categorical_features,
|
1415
|
+
hash_components_dict=hash_components_dict,
|
1416
|
+
handle_constant_features=handle_constant_features,
|
1417
|
+
handle_correlations=handle_correlations,
|
1418
|
+
)
|
1419
|
+
logging.info("Data preprocessing applied")
|
1420
|
+
|
1421
|
+
# Save preprocessed dataset to a new file
|
1422
|
+
if save_preprocessed_file:
|
1423
|
+
try:
|
1424
|
+
if "data_file" in current_config and current_config["data_file"]:
|
1425
|
+
original_path = Path(current_config["data_file"])
|
1426
|
+
preprocessed_path = (
|
1427
|
+
original_path.parent
|
1428
|
+
/ f"{original_path.stem}_preprocessed{original_path.suffix}"
|
1429
|
+
)
|
1430
|
+
|
1431
|
+
# Copy DataFrame with the preprocessed data
|
1432
|
+
preprocessed_df = framework.X.copy()
|
1433
|
+
|
1434
|
+
# Add the target column
|
1435
|
+
target_column = current_config.get("target_column")
|
1436
|
+
if target_column:
|
1437
|
+
preprocessed_df[target_column] = framework.y
|
1438
|
+
|
1439
|
+
# Save to the new file
|
1440
|
+
if original_path.suffix.lower() == ".csv":
|
1441
|
+
preprocessed_df.to_csv(preprocessed_path, index=False)
|
1442
|
+
elif original_path.suffix.lower() in [".xlsx", ".xls"]:
|
1443
|
+
preprocessed_df.to_excel(preprocessed_path, index=False)
|
1444
|
+
|
1445
|
+
logging.info(
|
1446
|
+
f"Saved preprocessed dataset to: {preprocessed_path}"
|
1447
|
+
)
|
1448
|
+
except Exception as e:
|
1449
|
+
logging.warning(f"Could not save preprocessed dataset: {e}")
|
1450
|
+
|
1451
|
+
# Apply balancing techniques
|
1452
|
+
start_time = time.time()
|
1453
|
+
logging.info("Applying balancing techniques...")
|
1454
|
+
framework.apply_balancing_techniques(
|
1455
|
+
technique_names,
|
1456
|
+
test_size=test_size,
|
1457
|
+
random_state=random_state,
|
1458
|
+
technique_params=balancing_techniques,
|
1459
|
+
include_original=include_original,
|
1460
|
+
)
|
1461
|
+
balancing_time = time.time() - start_time
|
1462
|
+
logging.info(
|
1463
|
+
f"Balancing techniques applied successfully (Time Taken: {format_time(balancing_time)})"
|
1464
|
+
)
|
1465
|
+
|
1466
|
+
# Save balanced datasets at the root level
|
1467
|
+
balanced_dir = output_dir / "balanced_datasets"
|
1468
|
+
balanced_dir.mkdir(exist_ok=True)
|
1469
|
+
logging.info(f"Saving balanced datasets to {balanced_dir}")
|
1470
|
+
framework.generate_balanced_data(
|
1471
|
+
folder_path=str(balanced_dir),
|
1472
|
+
techniques=technique_names,
|
1473
|
+
file_format="csv",
|
1474
|
+
)
|
1475
|
+
|
1476
|
+
# Determine which visualisation types to generate
|
1477
|
+
vis_types_to_generate = []
|
1478
|
+
if "all" in visualisations:
|
1479
|
+
vis_types_to_generate = ["metrics", "distribution", "learning_curves"]
|
1480
|
+
else:
|
1481
|
+
vis_types_to_generate = visualisations
|
1482
|
+
|
1483
|
+
# Save class distribution visualisations at the root level
|
1484
|
+
for format_type in save_vis_formats:
|
1485
|
+
if format_type == "none":
|
1486
|
+
continue
|
1487
|
+
|
1488
|
+
if "distribution" in vis_types_to_generate or "all" in visualisations:
|
1489
|
+
# Original (imbalanced) class distribution
|
1490
|
+
logging.info(
|
1491
|
+
f"Generating imbalanced class distribution in {format_type} format..."
|
1492
|
+
)
|
1493
|
+
imbalanced_plot_path = (
|
1494
|
+
output_dir / f"imbalanced_class_distribution.{format_type}"
|
1495
|
+
)
|
1496
|
+
framework.inspect_class_distribution(
|
1497
|
+
save_path=str(imbalanced_plot_path), display=display_visualisations
|
1498
|
+
)
|
1499
|
+
logging.info(
|
1500
|
+
f"Imbalanced class distribution saved to {imbalanced_plot_path}"
|
1501
|
+
)
|
1502
|
+
|
1503
|
+
# Balanced class distributions comparison
|
1504
|
+
logging.info(
|
1505
|
+
f"Generating balanced class distribution comparison in {format_type} format..."
|
1506
|
+
)
|
1507
|
+
balanced_plot_path = (
|
1508
|
+
output_dir / f"balanced_class_distribution.{format_type}"
|
1509
|
+
)
|
1510
|
+
framework.compare_balanced_class_distributions(
|
1511
|
+
save_path=str(balanced_plot_path),
|
1512
|
+
display=display_visualisations,
|
1513
|
+
)
|
1514
|
+
logging.info(
|
1515
|
+
f"Balanaced class distribution comparison saved to {balanced_plot_path}"
|
1516
|
+
)
|
1517
|
+
|
1518
|
+
# Train classifiers
|
1519
|
+
start_time = time.time()
|
1520
|
+
logging.info("Training classifiers on balanced datasets...")
|
1521
|
+
classifiers = current_config.get("classifiers", {})
|
1522
|
+
if not classifiers:
|
1523
|
+
logging.warning(
|
1524
|
+
"No classifiers configured. Using default RandomForestClassifier."
|
1525
|
+
)
|
1526
|
+
|
1527
|
+
# Train classifiers with the balanced datasets
|
1528
|
+
results = framework.train_classifiers(
|
1529
|
+
classifier_configs=classifiers, enable_cv=cv_enabled, cv_folds=cv_folds
|
1530
|
+
)
|
1531
|
+
|
1532
|
+
training_time = time.time() - start_time
|
1533
|
+
logging.info(
|
1534
|
+
f"Training classifiers complete (Time Taken: {format_time(training_time)})"
|
1535
|
+
)
|
1536
|
+
|
1537
|
+
# Process each classifier and save its results in a separate directory
|
1538
|
+
standard_start_time = time.time()
|
1539
|
+
for classifier_name in current_config.get("classifiers", {}):
|
1540
|
+
logging.info(f"Processing results for classifier: {classifier_name}")
|
1541
|
+
|
1542
|
+
# Create classifier-specific directory
|
1543
|
+
classifier_dir = output_dir / classifier_name
|
1544
|
+
classifier_dir.mkdir(exist_ok=True)
|
1545
|
+
|
1546
|
+
# Create standard metrics directory
|
1547
|
+
std_metrics_dir = classifier_dir / "standard_metrics"
|
1548
|
+
std_metrics_dir.mkdir(exist_ok=True)
|
1549
|
+
|
1550
|
+
# Save standard metrics in requested formats
|
1551
|
+
for format_type in save_metrics_formats:
|
1552
|
+
if format_type == "none":
|
1553
|
+
continue
|
1554
|
+
|
1555
|
+
results_file = std_metrics_dir / f"comparison_results.{format_type}"
|
1556
|
+
logging.info(
|
1557
|
+
f"Saving standard metrics for {classifier_name} to {results_file}"
|
1558
|
+
)
|
1559
|
+
|
1560
|
+
# We need a modified save_results method that can extract a specific classifier's results
|
1561
|
+
framework.save_classifier_results(
|
1562
|
+
results_file,
|
1563
|
+
classifier_name=classifier_name,
|
1564
|
+
metric_type="standard_metrics",
|
1565
|
+
file_type=format_type,
|
1566
|
+
)
|
1567
|
+
|
1568
|
+
# Generate and save standard metrics visualisations
|
1569
|
+
for format_type in save_vis_formats:
|
1570
|
+
if format_type == "none":
|
1571
|
+
continue
|
1572
|
+
|
1573
|
+
if "metrics" in vis_types_to_generate or "all" in visualisations:
|
1574
|
+
metrics_path = std_metrics_dir / f"metrics_comparison.{format_type}"
|
1575
|
+
logging.info(
|
1576
|
+
f"Generating metrics comparison for {classifier_name} in {format_type} format..."
|
1577
|
+
)
|
1578
|
+
|
1579
|
+
metrics_to_plot = current_config.get("output", {}).get(
|
1580
|
+
"metrics", ["precision", "recall", "f1", "roc_auc"]
|
1581
|
+
)
|
1582
|
+
# Call a modified plot_comparison_results that can handle specific classifier data
|
1583
|
+
plot_comparison_results(
|
1584
|
+
results,
|
1585
|
+
classifier_name=classifier_name,
|
1586
|
+
metric_type="standard_metrics",
|
1587
|
+
metrics_to_plot=metrics_to_plot,
|
1588
|
+
save_path=str(metrics_path),
|
1589
|
+
display=display_visualisations,
|
1590
|
+
)
|
1591
|
+
|
1592
|
+
if "radar" in vis_types_to_generate or "all" in visualisations:
|
1593
|
+
std_radar_path = (
|
1594
|
+
classifier_dir / f"standard_metrics_radar.{format_type}"
|
1595
|
+
)
|
1596
|
+
plot_radar_chart(
|
1597
|
+
results,
|
1598
|
+
classifier_name=classifier_name,
|
1599
|
+
metric_type="standard_metrics",
|
1600
|
+
metrics_to_plot=metrics_to_plot,
|
1601
|
+
save_path=std_radar_path,
|
1602
|
+
display=display_visualisations,
|
1603
|
+
)
|
1604
|
+
|
1605
|
+
if "3d" in vis_types_to_generate or "all" in visualisations:
|
1606
|
+
std_3d_path = output_dir / "standard_metrics_3d.html"
|
1607
|
+
plot_3d_scatter(
|
1608
|
+
results,
|
1609
|
+
metric_type="standard_metrics",
|
1610
|
+
metrics_to_plot=metrics_to_plot,
|
1611
|
+
save_path=std_3d_path,
|
1612
|
+
display=display_visualisations,
|
1613
|
+
)
|
1614
|
+
|
1615
|
+
standard_total_time = time.time() - standard_start_time
|
1616
|
+
logging.info(
|
1617
|
+
f"Standard metrics evaluation total time: {format_time(standard_total_time)}"
|
1618
|
+
)
|
1619
|
+
|
1620
|
+
# If cross-validation is enabled, create CV metrics directory and save results
|
1621
|
+
if cv_enabled:
|
1622
|
+
cv_start_time = time.time()
|
1623
|
+
for classifier_name in current_config.get("classifiers", {}):
|
1624
|
+
# Create classifier-specific directory
|
1625
|
+
classifier_dir = output_dir / classifier_name
|
1626
|
+
classifier_dir.mkdir(exist_ok=True)
|
1627
|
+
|
1628
|
+
cv_metrics_dir = classifier_dir / "cv_metrics"
|
1629
|
+
cv_metrics_dir.mkdir(exist_ok=True)
|
1630
|
+
|
1631
|
+
# Save CV metrics in requested formats
|
1632
|
+
for format_type in save_metrics_formats:
|
1633
|
+
if format_type == "none":
|
1634
|
+
continue
|
1635
|
+
|
1636
|
+
cv_results_file = (
|
1637
|
+
cv_metrics_dir / f"comparison_results.{format_type}"
|
1638
|
+
)
|
1639
|
+
logging.info(
|
1640
|
+
f"Saving CV metrics for {classifier_name} to {cv_results_file}"
|
1641
|
+
)
|
1642
|
+
|
1643
|
+
framework.save_classifier_results(
|
1644
|
+
cv_results_file,
|
1645
|
+
classifier_name=classifier_name,
|
1646
|
+
metric_type="cv_metrics",
|
1647
|
+
file_type=format_type,
|
1648
|
+
)
|
1649
|
+
|
1650
|
+
# Generate and save CV metrics visualisations
|
1651
|
+
for format_type in save_vis_formats:
|
1652
|
+
if format_type == "none":
|
1653
|
+
continue
|
1654
|
+
|
1655
|
+
if "metrics" in vis_types_to_generate or "all" in visualisations:
|
1656
|
+
cv_metrics_path = (
|
1657
|
+
cv_metrics_dir / f"metrics_comparison.{format_type}"
|
1658
|
+
)
|
1659
|
+
logging.info(
|
1660
|
+
f"Generating CV metrics comparison for {classifier_name} in {format_type} format..."
|
1661
|
+
)
|
1662
|
+
|
1663
|
+
metrics_to_plot = current_config.get("output", {}).get(
|
1664
|
+
"metrics", ["precision", "recall", "f1", "roc_auc"]
|
1665
|
+
)
|
1666
|
+
plot_comparison_results(
|
1667
|
+
results,
|
1668
|
+
classifier_name=classifier_name,
|
1669
|
+
metric_type="cv_metrics",
|
1670
|
+
metrics_to_plot=metrics_to_plot,
|
1671
|
+
save_path=str(cv_metrics_path),
|
1672
|
+
display=display_visualisations,
|
1673
|
+
)
|
1674
|
+
|
1675
|
+
if "radar" in vis_types_to_generate or "all" in visualisations:
|
1676
|
+
cv_radar_path = (
|
1677
|
+
classifier_dir / f"cv_metrics_radar.{format_type}"
|
1678
|
+
)
|
1679
|
+
plot_radar_chart(
|
1680
|
+
results,
|
1681
|
+
classifier_name=classifier_name,
|
1682
|
+
metric_type="cv_metrics",
|
1683
|
+
metrics_to_plot=metrics_to_plot,
|
1684
|
+
save_path=cv_radar_path,
|
1685
|
+
display=display_visualisations,
|
1686
|
+
)
|
1687
|
+
|
1688
|
+
if "3d" in vis_types_to_generate or "all" in visualisations:
|
1689
|
+
cv_3d_path = output_dir / "cv_metrics_3d.html"
|
1690
|
+
plot_3d_scatter(
|
1691
|
+
results,
|
1692
|
+
metric_type="cv_metrics",
|
1693
|
+
metrics_to_plot=metrics_to_plot,
|
1694
|
+
save_path=cv_3d_path,
|
1695
|
+
display=display_visualisations,
|
1696
|
+
)
|
1697
|
+
|
1698
|
+
if (
|
1699
|
+
"learning_curves" in vis_types_to_generate
|
1700
|
+
or "all" in visualisations
|
1701
|
+
):
|
1702
|
+
cv_learning_curve_path = (
|
1703
|
+
cv_metrics_dir / f"learning_curves.{format_type}"
|
1704
|
+
)
|
1705
|
+
|
1706
|
+
start_time = time.time()
|
1707
|
+
logging.info(
|
1708
|
+
f"Generating CV learning curves for {classifier_name} in {format_type} format..."
|
1709
|
+
)
|
1710
|
+
|
1711
|
+
# Get learning curve parameters from config
|
1712
|
+
learning_curve_points = eval_config.get(
|
1713
|
+
"learning_curve_points", 10
|
1714
|
+
)
|
1715
|
+
learning_curve_folds = eval_config.get(
|
1716
|
+
"learning_curve_folds", 5
|
1717
|
+
)
|
1718
|
+
train_sizes = np.linspace(0.1, 1.0, learning_curve_points)
|
1719
|
+
|
1720
|
+
framework.generate_learning_curves(
|
1721
|
+
classifier_name=classifier_name,
|
1722
|
+
train_sizes=train_sizes,
|
1723
|
+
n_folds=learning_curve_folds,
|
1724
|
+
save_path=str(cv_learning_curve_path),
|
1725
|
+
display=display_visualisations,
|
1726
|
+
)
|
1727
|
+
cv_learning_curves_time = time.time() - start_time
|
1728
|
+
logging.info(
|
1729
|
+
f"Successfully generated cv learning curves for {classifier_name}"
|
1730
|
+
f"(Time Taken: {format_time(cv_learning_curves_time)})"
|
1731
|
+
)
|
1732
|
+
cv_total_time = time.time() - cv_start_time
|
1733
|
+
logging.info(
|
1734
|
+
f"Cross validation metrics evaluation total time: {format_time(cv_total_time)}"
|
1735
|
+
)
|
1736
|
+
|
1737
|
+
total_time = time.time() - start_time_total
|
1738
|
+
logging.info(f"Total execution time: {format_time(total_time)}")
|
1739
|
+
|
1740
|
+
# Print summary of timing results
|
1741
|
+
print("\nExecution Time Summary:\n")
|
1742
|
+
print(f" Data Loading: {format_time(load_time)}")
|
1743
|
+
print(f" Balancing: {format_time(balancing_time)}")
|
1744
|
+
print(f" Training Classifiers: {format_time(training_time)}")
|
1745
|
+
print(f" Standard Metrics Evaluation: {format_time(standard_total_time)}")
|
1746
|
+
if cv_enabled:
|
1747
|
+
print(f" CV Metrics Evaluation: {format_time(cv_total_time)}")
|
1748
|
+
print(f" Total Time: {format_time(total_time)}")
|
1749
|
+
|
1750
|
+
print("\nResults Summary:")
|
1751
|
+
|
1752
|
+
# Check and print Standard Metrics if available
|
1753
|
+
has_standard_metrics = any(
|
1754
|
+
"standard_metrics" in technique_metrics
|
1755
|
+
and any(m in metrics for m in technique_metrics["standard_metrics"])
|
1756
|
+
for classifier_results in results.values()
|
1757
|
+
for technique_metrics in classifier_results.values()
|
1758
|
+
)
|
1759
|
+
|
1760
|
+
if has_standard_metrics:
|
1761
|
+
print("\nStandard Metrics:")
|
1762
|
+
for classifier_name, classifier_results in results.items():
|
1763
|
+
print(f"\n{classifier_name}:")
|
1764
|
+
for technique_name, technique_metrics in classifier_results.items():
|
1765
|
+
if "standard_metrics" in technique_metrics:
|
1766
|
+
std_metrics = technique_metrics["standard_metrics"]
|
1767
|
+
if any(m in metrics for m in std_metrics):
|
1768
|
+
print(f" {technique_name}:")
|
1769
|
+
for metric_name, value in std_metrics.items():
|
1770
|
+
if metric_name in metrics:
|
1771
|
+
print(f" {metric_name}: {value:.4f}")
|
1772
|
+
|
1773
|
+
# Check and print Cross Validation Metrics if available
|
1774
|
+
has_cv_metrics = any(
|
1775
|
+
"cv_metrics" in technique_metrics
|
1776
|
+
and any(
|
1777
|
+
metric_name.startswith("cv_")
|
1778
|
+
and metric_name[len("cv_"):].rsplit("_", 1)[0] in metrics
|
1779
|
+
for metric_name in technique_metrics["cv_metrics"]
|
1780
|
+
)
|
1781
|
+
for classifier_results in results.values()
|
1782
|
+
for technique_metrics in classifier_results.values()
|
1783
|
+
)
|
1784
|
+
|
1785
|
+
if has_cv_metrics:
|
1786
|
+
print("\nCross Validation Metrics:")
|
1787
|
+
for classifier_name, classifier_results in results.items():
|
1788
|
+
print(f"\n{classifier_name}:")
|
1789
|
+
for technique_name, technique_metrics in classifier_results.items():
|
1790
|
+
if "cv_metrics" in technique_metrics:
|
1791
|
+
cv_metrics = technique_metrics["cv_metrics"]
|
1792
|
+
|
1793
|
+
# Check if any relevant cv metric exists for this technique
|
1794
|
+
if any(
|
1795
|
+
metric_name.startswith("cv_")
|
1796
|
+
and metric_name[len("cv_"):].rsplit("_", 1)[0] in metrics
|
1797
|
+
for metric_name in cv_metrics
|
1798
|
+
):
|
1799
|
+
print(f" {technique_name}:")
|
1800
|
+
|
1801
|
+
# Now print only relevant metrics
|
1802
|
+
for metric_name, value in cv_metrics.items():
|
1803
|
+
if metric_name.startswith("cv_"):
|
1804
|
+
base_name = metric_name[len("cv_"):].rsplit(
|
1805
|
+
"_", 1
|
1806
|
+
)[0]
|
1807
|
+
if base_name in metrics:
|
1808
|
+
print(f" {metric_name}: {value:.4f}")
|
1809
|
+
|
1810
|
+
print(f"\nDetailed results saved to: {output_dir}")
|
1811
|
+
return 0
|
1812
|
+
|
1813
|
+
except Exception as e:
|
1814
|
+
logging.error(f"Error during comparison: {e}")
|
1815
|
+
if args.verbose:
|
1816
|
+
import traceback
|
1817
|
+
|
1818
|
+
traceback.print_exc()
|
1819
|
+
return 1
|
1820
|
+
|
1821
|
+
|
1822
|
+
def reset_config(args):
|
1823
|
+
"""
|
1824
|
+
Handle the reset command.
|
1825
|
+
|
1826
|
+
Args:
|
1827
|
+
args: Command line arguments from argparse
|
1828
|
+
|
1829
|
+
Returns:
|
1830
|
+
int: Exit code
|
1831
|
+
"""
|
1832
|
+
try:
|
1833
|
+
config.initialise_config(args.config_path, force=True)
|
1834
|
+
logging.info("Configuration has been reset to defaults")
|
1835
|
+
return 0
|
1836
|
+
except Exception as e:
|
1837
|
+
logging.error(f"Failed to reset configuration: {e}")
|
1838
|
+
return 1
|