balancr 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- balancr/__init__.py +13 -0
- balancr/base.py +14 -0
- balancr/classifier_registry.py +300 -0
- balancr/cli/__init__.py +0 -0
- balancr/cli/commands.py +1838 -0
- balancr/cli/config.py +165 -0
- balancr/cli/main.py +778 -0
- balancr/cli/utils.py +101 -0
- balancr/data/__init__.py +5 -0
- balancr/data/loader.py +59 -0
- balancr/data/preprocessor.py +556 -0
- balancr/evaluation/__init__.py +19 -0
- balancr/evaluation/metrics.py +442 -0
- balancr/evaluation/visualisation.py +660 -0
- balancr/imbalance_analyser.py +677 -0
- balancr/technique_registry.py +284 -0
- balancr/techniques/__init__.py +4 -0
- balancr/techniques/custom/__init__.py +0 -0
- balancr/techniques/custom/example_custom_technique.py +27 -0
- balancr-0.1.0.dist-info/LICENSE +21 -0
- balancr-0.1.0.dist-info/METADATA +536 -0
- balancr-0.1.0.dist-info/RECORD +25 -0
- balancr-0.1.0.dist-info/WHEEL +5 -0
- balancr-0.1.0.dist-info/entry_points.txt +2 -0
- balancr-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,660 @@
|
|
1
|
+
"""Visualisation utilities for imbalanced data analysis."""
|
2
|
+
|
3
|
+
from typing import Dict, Any, Optional
|
4
|
+
import numpy as np
|
5
|
+
import matplotlib.pyplot as plt
|
6
|
+
import seaborn as sns
|
7
|
+
import logging
|
8
|
+
import pandas as pd
|
9
|
+
|
10
|
+
|
11
|
+
def plot_class_distribution(
|
12
|
+
distribution: Dict[Any, int],
|
13
|
+
title: str = "Class Distribution",
|
14
|
+
save_path: Optional[str] = None,
|
15
|
+
display: bool = False,
|
16
|
+
) -> None:
|
17
|
+
"""Plot the distribution of classes in the dataset."""
|
18
|
+
if distribution is None:
|
19
|
+
raise TypeError("Distribution cannot be None")
|
20
|
+
if not distribution or not isinstance(distribution, dict):
|
21
|
+
raise ValueError("Distribution must be a non-empty dictionary")
|
22
|
+
|
23
|
+
plt.figure(figsize=(10, 6))
|
24
|
+
sns.barplot(x=list(distribution.keys()), y=list(distribution.values()))
|
25
|
+
plt.title(title)
|
26
|
+
plt.xlabel("Class")
|
27
|
+
plt.ylabel("Count")
|
28
|
+
|
29
|
+
# Add percentage labels on top of bars
|
30
|
+
total = sum(distribution.values())
|
31
|
+
for i, count in enumerate(distribution.values()):
|
32
|
+
percentage = (count / total) * 100
|
33
|
+
plt.text(i, count, f"{percentage:.1f}%", ha="center", va="bottom")
|
34
|
+
|
35
|
+
if save_path:
|
36
|
+
plt.savefig(save_path)
|
37
|
+
|
38
|
+
if display:
|
39
|
+
plt.show()
|
40
|
+
|
41
|
+
plt.close()
|
42
|
+
|
43
|
+
|
44
|
+
def plot_class_distributions_comparison(
|
45
|
+
distributions: Dict[str, Dict[Any, int]],
|
46
|
+
title: str = "Class Distribution Comparison",
|
47
|
+
save_path: Optional[str] = None,
|
48
|
+
display: bool = False,
|
49
|
+
) -> None:
|
50
|
+
"""
|
51
|
+
Compare class distributions across multiple techniques using bar plots.
|
52
|
+
|
53
|
+
Args:
|
54
|
+
distributions: Dictionary where keys are technique names and values are class distributions.
|
55
|
+
title: Title for the plot.
|
56
|
+
save_path: Path to save the plot (optional).
|
57
|
+
|
58
|
+
Example input:
|
59
|
+
{
|
60
|
+
"SMOTE": {0: 500, 1: 500},
|
61
|
+
"RandomUnderSampler": {0: 400, 1: 400},
|
62
|
+
}
|
63
|
+
"""
|
64
|
+
if distributions is None:
|
65
|
+
raise TypeError("Distributions dictionary cannot be None")
|
66
|
+
if not distributions or not isinstance(distributions, dict):
|
67
|
+
raise ValueError("Distributions must be a non-empty dictionary")
|
68
|
+
if not all(isinstance(d, dict) for d in distributions.values()):
|
69
|
+
raise ValueError("Each distribution must be a dictionary")
|
70
|
+
|
71
|
+
# Prepare data for visualisation
|
72
|
+
techniques = []
|
73
|
+
classes = []
|
74
|
+
counts = []
|
75
|
+
|
76
|
+
# Process each technique
|
77
|
+
for technique, dist in distributions.items():
|
78
|
+
for cls, count in dist.items():
|
79
|
+
techniques.append(technique)
|
80
|
+
classes.append(str(cls)) # Convert class label to string for plotting
|
81
|
+
counts.append(count)
|
82
|
+
|
83
|
+
# Create DataFrame for seaborn plotting
|
84
|
+
import pandas as pd
|
85
|
+
|
86
|
+
plot_data = pd.DataFrame(
|
87
|
+
{"Technique": techniques, "Class": classes, "Count": counts}
|
88
|
+
)
|
89
|
+
|
90
|
+
# Plot grouped bar chart
|
91
|
+
plt.figure(figsize=(12, 8))
|
92
|
+
ax = sns.barplot(x="Class", y="Count", hue="Technique", data=plot_data)
|
93
|
+
|
94
|
+
# Values for each bar
|
95
|
+
for p in ax.patches:
|
96
|
+
ax.annotate(
|
97
|
+
f"{int(p.get_height())}",
|
98
|
+
(p.get_x() + p.get_width() / 2.0, p.get_height()),
|
99
|
+
ha="center",
|
100
|
+
va="bottom",
|
101
|
+
fontsize=10,
|
102
|
+
color="black",
|
103
|
+
)
|
104
|
+
|
105
|
+
# Title and labels
|
106
|
+
plt.title(title)
|
107
|
+
plt.xlabel("Class")
|
108
|
+
plt.ylabel("Count")
|
109
|
+
plt.legend(title="Technique")
|
110
|
+
plt.grid(True)
|
111
|
+
|
112
|
+
if save_path:
|
113
|
+
plt.savefig(save_path)
|
114
|
+
|
115
|
+
if display:
|
116
|
+
plt.show()
|
117
|
+
|
118
|
+
plt.close()
|
119
|
+
|
120
|
+
|
121
|
+
def plot_comparison_results(
|
122
|
+
results: Dict[str, Dict[str, Dict[str, Dict[str, float]]]],
|
123
|
+
classifier_name: str,
|
124
|
+
metric_type: str = "standard_metrics",
|
125
|
+
metrics_to_plot: Optional[list] = None,
|
126
|
+
save_path: Optional[str] = None,
|
127
|
+
display: bool = False,
|
128
|
+
) -> None:
|
129
|
+
"""
|
130
|
+
Plot comparison of different techniques for a specific classifier and metric type.
|
131
|
+
|
132
|
+
Args:
|
133
|
+
results: Dictionary containing nested results structure
|
134
|
+
classifier_name: Name of the classifier to visualise
|
135
|
+
metric_type: Type of metrics to plot ('standard_metrics' or 'cv_metrics')
|
136
|
+
metrics_to_plot: List of metrics to include in plot
|
137
|
+
save_path: Path to save the plot
|
138
|
+
display: Whether to display the plot
|
139
|
+
"""
|
140
|
+
if results is None:
|
141
|
+
raise TypeError("Results dictionary cannot be None")
|
142
|
+
|
143
|
+
if classifier_name not in results:
|
144
|
+
raise ValueError(f"Classifier '{classifier_name}' not found in results")
|
145
|
+
|
146
|
+
# Extract the classifier's results
|
147
|
+
classifier_results = results[classifier_name]
|
148
|
+
|
149
|
+
# Create a structure for plotting with techniques as keys and metric dictionaries as values
|
150
|
+
plot_data = {}
|
151
|
+
for technique_name, technique_data in classifier_results.items():
|
152
|
+
if metric_type in technique_data:
|
153
|
+
plot_data[technique_name] = technique_data[metric_type]
|
154
|
+
|
155
|
+
if not plot_data:
|
156
|
+
raise ValueError(
|
157
|
+
f"No {metric_type} data found for classifier '{classifier_name}'"
|
158
|
+
)
|
159
|
+
|
160
|
+
# Default metrics to plot
|
161
|
+
if metrics_to_plot is None:
|
162
|
+
if metric_type == "standard_metrics":
|
163
|
+
metrics_to_plot = ["precision", "recall", "f1", "roc_auc"]
|
164
|
+
elif metric_type == "cv_metrics":
|
165
|
+
# For CV metrics, we want to look for metrics with "cv_" prefix and "_mean" suffix
|
166
|
+
# But we want to use the same base metric names as configured
|
167
|
+
metrics_to_plot = [
|
168
|
+
"cv_accuracy_mean",
|
169
|
+
"cv_precision_mean",
|
170
|
+
"cv_recall_mean",
|
171
|
+
"cv_f1_mean",
|
172
|
+
]
|
173
|
+
else:
|
174
|
+
metrics_to_plot = ["precision", "recall", "f1", "roc_auc"]
|
175
|
+
elif (
|
176
|
+
metric_type == "cv_metrics"
|
177
|
+
and metrics_to_plot
|
178
|
+
and not all(m.startswith("cv_") for m in metrics_to_plot)
|
179
|
+
):
|
180
|
+
# If user provided standard metric names but we're plotting CV metrics,
|
181
|
+
# convert them to CV metric names
|
182
|
+
metrics_to_plot = [f"cv_{m}_mean" for m in metrics_to_plot]
|
183
|
+
|
184
|
+
# Filter metrics to only include those that exist in all techniques
|
185
|
+
common_metrics = set.intersection(
|
186
|
+
*[set(metrics.keys()) for metrics in plot_data.values()]
|
187
|
+
)
|
188
|
+
available_metrics = [m for m in metrics_to_plot if m in common_metrics]
|
189
|
+
|
190
|
+
if not available_metrics:
|
191
|
+
# Show all available metrics in the error message
|
192
|
+
raise ValueError(
|
193
|
+
f"No common metrics found across techniques for metric type '{metric_type}'. "
|
194
|
+
f"Requested metrics: {metrics_to_plot}, Available metrics: {sorted(common_metrics)}"
|
195
|
+
)
|
196
|
+
|
197
|
+
# Convert results to suitable format for plotting
|
198
|
+
techniques = list(plot_data.keys())
|
199
|
+
metrics_data = {
|
200
|
+
metric: [plot_data[tech][metric] for tech in techniques]
|
201
|
+
for metric in available_metrics
|
202
|
+
}
|
203
|
+
|
204
|
+
# Create subplot grid that accommodates all metrics
|
205
|
+
n_metrics = len(available_metrics)
|
206
|
+
n_cols = min(3, n_metrics) # Maximum 3 columns to ensure readability
|
207
|
+
n_rows = (n_metrics + n_cols - 1) // n_cols # Ceiling division
|
208
|
+
|
209
|
+
fig, axes = plt.subplots(
|
210
|
+
n_rows, n_cols, figsize=(5 * n_cols, 5 * n_rows), squeeze=False
|
211
|
+
)
|
212
|
+
fig.suptitle(
|
213
|
+
f"{classifier_name} - Comparison of Balancing Techniques ({metric_type.replace('_', ' ').title()})",
|
214
|
+
size=16,
|
215
|
+
)
|
216
|
+
|
217
|
+
# Plot each metric
|
218
|
+
for idx, (metric, values) in enumerate(metrics_data.items()):
|
219
|
+
row = idx // n_cols
|
220
|
+
col = idx % n_cols
|
221
|
+
ax = axes[row, col]
|
222
|
+
|
223
|
+
sns.barplot(x=techniques, y=values, ax=ax)
|
224
|
+
|
225
|
+
# Set appropriate title based on metric type
|
226
|
+
if metric.startswith("cv_") and metric.endswith("_mean"):
|
227
|
+
# For CV metrics, show "Metric Mean" format
|
228
|
+
base_metric = metric[3:-5] # Remove 'cv_' prefix and '_mean' suffix
|
229
|
+
display_title = f'{base_metric.replace("_", " ").title()} Mean'
|
230
|
+
else:
|
231
|
+
display_title = metric.replace("_", " ").title()
|
232
|
+
|
233
|
+
ax.set_title(display_title)
|
234
|
+
plt.setp(ax.get_xticklabels(), rotation=45, ha="right")
|
235
|
+
|
236
|
+
# Add value labels on top of bars
|
237
|
+
for i, v in enumerate(values):
|
238
|
+
ax.text(i, v, f"{v:.3f}", ha="center", va="bottom")
|
239
|
+
|
240
|
+
# Remove any empty subplots
|
241
|
+
for idx in range(len(metrics_data), axes.shape[0] * axes.shape[1]):
|
242
|
+
row = idx // n_cols
|
243
|
+
col = idx % n_cols
|
244
|
+
fig.delaxes(axes[row, col])
|
245
|
+
|
246
|
+
plt.tight_layout()
|
247
|
+
|
248
|
+
if save_path:
|
249
|
+
plt.savefig(save_path)
|
250
|
+
|
251
|
+
if display:
|
252
|
+
plt.show()
|
253
|
+
|
254
|
+
plt.close()
|
255
|
+
|
256
|
+
|
257
|
+
def plot_learning_curves(
|
258
|
+
learning_curve_data: dict,
|
259
|
+
title: str = "Learning Curves",
|
260
|
+
save_path: Optional[str] = None,
|
261
|
+
display: bool = False,
|
262
|
+
):
|
263
|
+
"""
|
264
|
+
Plot learning curves for multiple techniques in subplots.
|
265
|
+
|
266
|
+
Args:
|
267
|
+
learning_curve_data: Dictionary with technique names as keys and corresponding learning curve data as values
|
268
|
+
title: Title of the plot
|
269
|
+
save_path: Optional path to save the figure
|
270
|
+
"""
|
271
|
+
# Error handling
|
272
|
+
if learning_curve_data is None:
|
273
|
+
raise TypeError("Learning curve data cannot be None")
|
274
|
+
if not learning_curve_data or not isinstance(learning_curve_data, dict):
|
275
|
+
raise ValueError("Learning curve data must be a non-empty dictionary")
|
276
|
+
|
277
|
+
for technique, data in learning_curve_data.items():
|
278
|
+
required_keys = {"train_sizes", "train_scores", "val_scores"}
|
279
|
+
if not all(key in data for key in required_keys):
|
280
|
+
raise ValueError(
|
281
|
+
f"Learning curve data for technique '{technique}' must contain "
|
282
|
+
f"'train_sizes', 'train_scores', and 'val_scores'"
|
283
|
+
)
|
284
|
+
|
285
|
+
num_techniques = len(learning_curve_data)
|
286
|
+
# Set up a grid of subplots, one for each technique
|
287
|
+
fig, axes = plt.subplots(num_techniques, 1, figsize=(10, 6 * num_techniques))
|
288
|
+
|
289
|
+
# Ensure axes is iterable even if there's only one technique
|
290
|
+
if num_techniques == 1:
|
291
|
+
axes = [axes]
|
292
|
+
|
293
|
+
for idx, (technique_name, data) in enumerate(learning_curve_data.items()):
|
294
|
+
# Extract the train_sizes, train_scores, and val_scores from the dictionary
|
295
|
+
train_sizes = data["train_sizes"]
|
296
|
+
train_scores = data["train_scores"]
|
297
|
+
val_scores = data["val_scores"]
|
298
|
+
|
299
|
+
# Calculate mean and std of scores
|
300
|
+
train_mean = np.mean(train_scores, axis=1)
|
301
|
+
train_std = np.std(train_scores, axis=1)
|
302
|
+
val_mean = np.mean(val_scores, axis=1)
|
303
|
+
val_std = np.std(val_scores, axis=1)
|
304
|
+
|
305
|
+
ax = axes[idx]
|
306
|
+
|
307
|
+
ax.plot(train_sizes, train_mean, label="Training score", color="blue")
|
308
|
+
ax.plot(train_sizes, val_mean, label="Validation score", color="red")
|
309
|
+
ax.fill_between(
|
310
|
+
train_sizes,
|
311
|
+
train_mean - train_std,
|
312
|
+
train_mean + train_std,
|
313
|
+
alpha=0.1,
|
314
|
+
color="blue",
|
315
|
+
)
|
316
|
+
ax.fill_between(
|
317
|
+
train_sizes, val_mean - val_std, val_mean + val_std, alpha=0.1, color="red"
|
318
|
+
)
|
319
|
+
|
320
|
+
ax.set_title(f"{technique_name} - Learning Curves")
|
321
|
+
ax.set_xlabel("Training Examples")
|
322
|
+
ax.set_ylabel("Score")
|
323
|
+
ax.legend(loc="best")
|
324
|
+
ax.grid(True)
|
325
|
+
|
326
|
+
plt.tight_layout()
|
327
|
+
|
328
|
+
if save_path:
|
329
|
+
plt.savefig(save_path)
|
330
|
+
|
331
|
+
if display:
|
332
|
+
plt.show()
|
333
|
+
|
334
|
+
plt.close()
|
335
|
+
|
336
|
+
|
337
|
+
def plot_radar_chart(
|
338
|
+
results: Dict[str, Dict[str, Dict[str, Dict[str, float]]]],
|
339
|
+
classifier_name: str,
|
340
|
+
metric_type: str = "standard_metrics",
|
341
|
+
metrics_to_plot: Optional[list] = None,
|
342
|
+
save_path: Optional[str] = None,
|
343
|
+
display: bool = False,
|
344
|
+
) -> None:
|
345
|
+
"""
|
346
|
+
Plot radar (spider) chart comparing balancing techniques for a specific classifier.
|
347
|
+
|
348
|
+
Args:
|
349
|
+
results: Dictionary containing nested results structure
|
350
|
+
{classifier_name: {technique_name: {metric_type: {metric_name: value}}}}
|
351
|
+
classifier_name: Name of the classifier to visualise
|
352
|
+
metric_type: Type of metrics to plot ('standard_metrics' or 'cv_metrics')
|
353
|
+
metrics_to_plot: List of metrics to include in the radar chart
|
354
|
+
save_path: Path to save the plot
|
355
|
+
display: Whether to display the plot
|
356
|
+
"""
|
357
|
+
if results is None:
|
358
|
+
raise TypeError("Results dictionary cannot be None")
|
359
|
+
|
360
|
+
if classifier_name not in results:
|
361
|
+
raise ValueError(f"Classifier '{classifier_name}' not found in results")
|
362
|
+
|
363
|
+
# Extract the classifier's results
|
364
|
+
classifier_results = results[classifier_name]
|
365
|
+
|
366
|
+
# Create a structure for plotting with techniques as keys and metric dictionaries as values
|
367
|
+
plot_data = {}
|
368
|
+
for technique_name, technique_data in classifier_results.items():
|
369
|
+
if metric_type in technique_data:
|
370
|
+
plot_data[technique_name] = technique_data[metric_type]
|
371
|
+
|
372
|
+
if not plot_data:
|
373
|
+
raise ValueError(
|
374
|
+
f"No {metric_type} data found for classifier '{classifier_name}'"
|
375
|
+
)
|
376
|
+
|
377
|
+
# Default metrics to plot
|
378
|
+
if metrics_to_plot is None:
|
379
|
+
if metric_type == "standard_metrics":
|
380
|
+
metrics_to_plot = ["precision", "recall", "f1", "roc_auc"]
|
381
|
+
elif metric_type == "cv_metrics":
|
382
|
+
# For CV metrics, we're interested in mean values, not std
|
383
|
+
metrics_to_plot = [
|
384
|
+
"cv_accuracy_mean",
|
385
|
+
"cv_precision_mean",
|
386
|
+
"cv_recall_mean",
|
387
|
+
"cv_f1_mean",
|
388
|
+
]
|
389
|
+
else:
|
390
|
+
metrics_to_plot = ["precision", "recall", "f1", "roc_auc"]
|
391
|
+
elif (
|
392
|
+
metric_type == "cv_metrics"
|
393
|
+
and metrics_to_plot
|
394
|
+
and not all(m.startswith("cv_") for m in metrics_to_plot)
|
395
|
+
):
|
396
|
+
# If user provided standard metric names but we're plotting CV metrics,
|
397
|
+
# convert them to CV metric names with _mean suffix
|
398
|
+
metrics_to_plot = [f"cv_{m}_mean" for m in metrics_to_plot]
|
399
|
+
|
400
|
+
# Filter metrics to only include those that exist in all techniques
|
401
|
+
common_metrics = set.intersection(
|
402
|
+
*[set(metrics.keys()) for metrics in plot_data.values()]
|
403
|
+
)
|
404
|
+
available_metrics = [m for m in metrics_to_plot if m in common_metrics]
|
405
|
+
|
406
|
+
if not available_metrics:
|
407
|
+
# Show all available metrics in the error message
|
408
|
+
raise ValueError(
|
409
|
+
f"No common metrics found across techniques for metric type '{metric_type}'. "
|
410
|
+
f"Requested metrics: {metrics_to_plot}, Available metrics: {sorted(common_metrics)}"
|
411
|
+
)
|
412
|
+
|
413
|
+
# Create the figure and polar axis
|
414
|
+
fig, ax = plt.subplots(figsize=(10, 10), subplot_kw=dict(polar=True))
|
415
|
+
|
416
|
+
# Calculate angles for each metric (evenly spaced around the circle)
|
417
|
+
angles = np.linspace(0, 2 * np.pi, len(available_metrics), endpoint=False).tolist()
|
418
|
+
|
419
|
+
# Make the plot circular by appending the start point at the end
|
420
|
+
angles += angles[:1]
|
421
|
+
|
422
|
+
# Add metric labels
|
423
|
+
metric_labels = []
|
424
|
+
for metric in available_metrics:
|
425
|
+
# Format the metric name for display
|
426
|
+
if (
|
427
|
+
metric_type == "cv_metrics"
|
428
|
+
and metric.startswith("cv_")
|
429
|
+
and metric.endswith("_mean")
|
430
|
+
):
|
431
|
+
# For CV metrics, show "Metric Mean" format
|
432
|
+
base_metric = metric[3:-5] # Remove 'cv_' prefix and '_mean' suffix
|
433
|
+
display_label = f'{base_metric.replace("_", " ").title()}'
|
434
|
+
else:
|
435
|
+
display_label = metric.replace("_", " ").title()
|
436
|
+
metric_labels.append(display_label)
|
437
|
+
|
438
|
+
# Set radar chart labels at appropriate angles
|
439
|
+
ax.set_xticks(angles[:-1])
|
440
|
+
ax.set_xticklabels(metric_labels)
|
441
|
+
|
442
|
+
# Get color map for different techniques
|
443
|
+
# Use matplotlib.colormaps instead of deprecated plt.cm.get_cmap
|
444
|
+
import matplotlib as mpl
|
445
|
+
|
446
|
+
if hasattr(mpl, "colormaps"): # Matplotlib 3.7+
|
447
|
+
cmap = mpl.colormaps["tab10"]
|
448
|
+
else: # Fallback for older versions
|
449
|
+
cmap = plt.get_cmap("tab10")
|
450
|
+
|
451
|
+
# Plot each technique
|
452
|
+
for i, (technique_name, metrics) in enumerate(plot_data.items()):
|
453
|
+
# Extract values for the metrics we want to plot
|
454
|
+
values = [metrics[metric] for metric in available_metrics]
|
455
|
+
|
456
|
+
# Make values circular
|
457
|
+
values += values[:1]
|
458
|
+
|
459
|
+
# Plot the technique
|
460
|
+
color = cmap(i)
|
461
|
+
ax.plot(angles, values, "o-", linewidth=2, label=technique_name, color=color)
|
462
|
+
ax.fill(angles, values, alpha=0.1, color=color)
|
463
|
+
|
464
|
+
# Set title
|
465
|
+
title_type = (
|
466
|
+
"Cross-Validation Metrics"
|
467
|
+
if metric_type == "cv_metrics"
|
468
|
+
else "Standard Metrics"
|
469
|
+
)
|
470
|
+
title = f"{classifier_name} - {title_type}"
|
471
|
+
ax.set_title(title, size=15)
|
472
|
+
|
473
|
+
# Add a legend
|
474
|
+
ax.legend(loc="upper right", bbox_to_anchor=(1.3, 1.0))
|
475
|
+
|
476
|
+
# Adjust layout
|
477
|
+
plt.tight_layout()
|
478
|
+
|
479
|
+
if save_path:
|
480
|
+
plt.savefig(save_path)
|
481
|
+
|
482
|
+
if display:
|
483
|
+
plt.show()
|
484
|
+
|
485
|
+
plt.close()
|
486
|
+
|
487
|
+
|
488
|
+
def plot_3d_scatter(
|
489
|
+
results: Dict[str, Dict[str, Dict[str, Dict[str, float]]]],
|
490
|
+
metric_type: str = "standard_metrics",
|
491
|
+
metrics_to_plot: Optional[list] = None,
|
492
|
+
save_path: Optional[str] = None,
|
493
|
+
display: bool = False,
|
494
|
+
) -> None:
|
495
|
+
"""
|
496
|
+
Create a 3D scatter plot showing the relationship between F1-score, ROC-AUC, and G-mean
|
497
|
+
for various classifier and balancing technique combinations.
|
498
|
+
|
499
|
+
Args:
|
500
|
+
results: Dictionary containing nested results structure
|
501
|
+
metric_type: Type of metrics to plot ('standard_metrics' or 'cv_metrics')
|
502
|
+
metrics_to_plot: List of metrics that were chosen to be plotted
|
503
|
+
save_path: Path to save the plot
|
504
|
+
display: Whether to display the plot
|
505
|
+
"""
|
506
|
+
try:
|
507
|
+
import plotly.graph_objects as go
|
508
|
+
except ImportError:
|
509
|
+
logging.error(
|
510
|
+
"Plotly is required for 3D scatter plots. Install with: pip install plotly"
|
511
|
+
)
|
512
|
+
return
|
513
|
+
|
514
|
+
# Determine metric keys based on the metric_type
|
515
|
+
if metric_type == "standard_metrics":
|
516
|
+
f1_key = "f1"
|
517
|
+
roc_auc_key = "roc_auc"
|
518
|
+
g_mean_key = "g_mean"
|
519
|
+
title_prefix = "Standard Metrics"
|
520
|
+
else: # cv_metrics
|
521
|
+
f1_key = "cv_f1_mean"
|
522
|
+
roc_auc_key = "cv_roc_auc_mean"
|
523
|
+
g_mean_key = "cv_g_mean_mean"
|
524
|
+
title_prefix = "Cross-Validation Metrics"
|
525
|
+
|
526
|
+
# Check if required metrics are in metrics_to_plot
|
527
|
+
required_metrics = ["f1", "roc_auc", "g_mean"]
|
528
|
+
if metrics_to_plot is not None:
|
529
|
+
missing_metrics = [m for m in required_metrics if m not in metrics_to_plot]
|
530
|
+
if missing_metrics:
|
531
|
+
logging.warning(
|
532
|
+
f"Cannot create 3D scatter plot. Required metrics {missing_metrics} are not in metrics_to_plot. "
|
533
|
+
"Please include these metrics using 'configure-metrics'."
|
534
|
+
)
|
535
|
+
return
|
536
|
+
|
537
|
+
# Prepare data for plotting
|
538
|
+
plot_data = []
|
539
|
+
|
540
|
+
# Assign a unique color to each classifier
|
541
|
+
classifiers = list(results.keys())
|
542
|
+
|
543
|
+
# Skip if no classifiers
|
544
|
+
if not classifiers:
|
545
|
+
logging.warning("No classifiers found in results.")
|
546
|
+
return
|
547
|
+
|
548
|
+
# Build the data structure for plotting
|
549
|
+
for classifier_name in classifiers:
|
550
|
+
classifier_results = results[classifier_name]
|
551
|
+
|
552
|
+
for technique_name, technique_metrics in classifier_results.items():
|
553
|
+
if metric_type in technique_metrics:
|
554
|
+
metrics = technique_metrics[metric_type]
|
555
|
+
|
556
|
+
# Check if all required metrics are available
|
557
|
+
if all(key in metrics for key in [f1_key, roc_auc_key, g_mean_key]):
|
558
|
+
f1 = metrics[f1_key]
|
559
|
+
roc_auc = metrics[roc_auc_key]
|
560
|
+
g_mean = metrics[g_mean_key]
|
561
|
+
|
562
|
+
# Only add valid data points
|
563
|
+
if (
|
564
|
+
isinstance(f1, (int, float))
|
565
|
+
and isinstance(roc_auc, (int, float))
|
566
|
+
and isinstance(g_mean, (int, float))
|
567
|
+
):
|
568
|
+
|
569
|
+
plot_data.append(
|
570
|
+
{
|
571
|
+
"classifier": classifier_name,
|
572
|
+
"technique": technique_name,
|
573
|
+
"f1": f1,
|
574
|
+
"roc_auc": roc_auc,
|
575
|
+
"g_mean": g_mean,
|
576
|
+
}
|
577
|
+
)
|
578
|
+
else:
|
579
|
+
missing = [
|
580
|
+
k for k in [f1_key, roc_auc_key, g_mean_key] if k not in metrics
|
581
|
+
]
|
582
|
+
logging.warning(
|
583
|
+
f"Cannot plot data point for {classifier_name}/{technique_name}. "
|
584
|
+
f"Missing metrics: {missing}"
|
585
|
+
)
|
586
|
+
|
587
|
+
# Skip if no valid data points
|
588
|
+
if not plot_data:
|
589
|
+
logging.warning("No valid data points found for 3D scatter plot.")
|
590
|
+
return
|
591
|
+
|
592
|
+
# Convert to pandas DataFrame for easier processing
|
593
|
+
df = pd.DataFrame(plot_data)
|
594
|
+
|
595
|
+
# Create interactive 3D scatter plot
|
596
|
+
fig = go.Figure()
|
597
|
+
|
598
|
+
# Add traces for each classifier
|
599
|
+
for classifier_name in df["classifier"].unique():
|
600
|
+
subset = df[df["classifier"] == classifier_name]
|
601
|
+
|
602
|
+
fig.add_trace(
|
603
|
+
go.Scatter3d(
|
604
|
+
x=subset["roc_auc"], # X-axis: ROC-AUC
|
605
|
+
y=subset["g_mean"], # Y-axis: G-mean
|
606
|
+
z=subset["f1"], # Z-axis: F1-score
|
607
|
+
mode="markers",
|
608
|
+
marker=dict(
|
609
|
+
size=10,
|
610
|
+
opacity=0.8,
|
611
|
+
),
|
612
|
+
name=classifier_name,
|
613
|
+
text=[
|
614
|
+
f"Classifier: {row['classifier']}<br>Technique: {row['technique']}<br>"
|
615
|
+
f"F1: {row['f1']:.4f}<br>ROC-AUC: {row['roc_auc']:.4f}<br>G-mean: {row['g_mean']:.4f}"
|
616
|
+
for _, row in subset.iterrows()
|
617
|
+
],
|
618
|
+
hoverinfo="text",
|
619
|
+
)
|
620
|
+
)
|
621
|
+
|
622
|
+
# Update layout
|
623
|
+
fig.update_layout(
|
624
|
+
title=f"{title_prefix}: F1-score vs ROC-AUC vs G-mean",
|
625
|
+
scene=dict(
|
626
|
+
xaxis_title="ROC-AUC",
|
627
|
+
yaxis_title="G-mean",
|
628
|
+
zaxis_title="F1-score",
|
629
|
+
xaxis=dict(range=[0, 1]),
|
630
|
+
yaxis=dict(range=[0, 1]),
|
631
|
+
zaxis=dict(range=[0, 1]),
|
632
|
+
),
|
633
|
+
margin=dict(l=0, r=0, b=0, t=50),
|
634
|
+
legend=dict(
|
635
|
+
x=0.01,
|
636
|
+
y=0.99,
|
637
|
+
traceorder="normal",
|
638
|
+
font=dict(size=12),
|
639
|
+
),
|
640
|
+
autosize=True,
|
641
|
+
height=700,
|
642
|
+
)
|
643
|
+
|
644
|
+
# Save or display the figure
|
645
|
+
if save_path:
|
646
|
+
# Convert Path object to string if needed
|
647
|
+
save_path_str = str(save_path)
|
648
|
+
|
649
|
+
# Ensure .html extension for interactive plot
|
650
|
+
if not save_path_str.endswith(".html"):
|
651
|
+
save_path_str += ".html"
|
652
|
+
|
653
|
+
# Save as interactive HTML
|
654
|
+
fig.write_html(save_path_str)
|
655
|
+
logging.info(f"3D scatter plot saved to {save_path_str}")
|
656
|
+
|
657
|
+
if display:
|
658
|
+
fig.show()
|
659
|
+
|
660
|
+
return fig # Return the figure object for potential further customisation
|