balancr 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,660 @@
1
+ """Visualisation utilities for imbalanced data analysis."""
2
+
3
+ from typing import Dict, Any, Optional
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+ import seaborn as sns
7
+ import logging
8
+ import pandas as pd
9
+
10
+
11
+ def plot_class_distribution(
12
+ distribution: Dict[Any, int],
13
+ title: str = "Class Distribution",
14
+ save_path: Optional[str] = None,
15
+ display: bool = False,
16
+ ) -> None:
17
+ """Plot the distribution of classes in the dataset."""
18
+ if distribution is None:
19
+ raise TypeError("Distribution cannot be None")
20
+ if not distribution or not isinstance(distribution, dict):
21
+ raise ValueError("Distribution must be a non-empty dictionary")
22
+
23
+ plt.figure(figsize=(10, 6))
24
+ sns.barplot(x=list(distribution.keys()), y=list(distribution.values()))
25
+ plt.title(title)
26
+ plt.xlabel("Class")
27
+ plt.ylabel("Count")
28
+
29
+ # Add percentage labels on top of bars
30
+ total = sum(distribution.values())
31
+ for i, count in enumerate(distribution.values()):
32
+ percentage = (count / total) * 100
33
+ plt.text(i, count, f"{percentage:.1f}%", ha="center", va="bottom")
34
+
35
+ if save_path:
36
+ plt.savefig(save_path)
37
+
38
+ if display:
39
+ plt.show()
40
+
41
+ plt.close()
42
+
43
+
44
+ def plot_class_distributions_comparison(
45
+ distributions: Dict[str, Dict[Any, int]],
46
+ title: str = "Class Distribution Comparison",
47
+ save_path: Optional[str] = None,
48
+ display: bool = False,
49
+ ) -> None:
50
+ """
51
+ Compare class distributions across multiple techniques using bar plots.
52
+
53
+ Args:
54
+ distributions: Dictionary where keys are technique names and values are class distributions.
55
+ title: Title for the plot.
56
+ save_path: Path to save the plot (optional).
57
+
58
+ Example input:
59
+ {
60
+ "SMOTE": {0: 500, 1: 500},
61
+ "RandomUnderSampler": {0: 400, 1: 400},
62
+ }
63
+ """
64
+ if distributions is None:
65
+ raise TypeError("Distributions dictionary cannot be None")
66
+ if not distributions or not isinstance(distributions, dict):
67
+ raise ValueError("Distributions must be a non-empty dictionary")
68
+ if not all(isinstance(d, dict) for d in distributions.values()):
69
+ raise ValueError("Each distribution must be a dictionary")
70
+
71
+ # Prepare data for visualisation
72
+ techniques = []
73
+ classes = []
74
+ counts = []
75
+
76
+ # Process each technique
77
+ for technique, dist in distributions.items():
78
+ for cls, count in dist.items():
79
+ techniques.append(technique)
80
+ classes.append(str(cls)) # Convert class label to string for plotting
81
+ counts.append(count)
82
+
83
+ # Create DataFrame for seaborn plotting
84
+ import pandas as pd
85
+
86
+ plot_data = pd.DataFrame(
87
+ {"Technique": techniques, "Class": classes, "Count": counts}
88
+ )
89
+
90
+ # Plot grouped bar chart
91
+ plt.figure(figsize=(12, 8))
92
+ ax = sns.barplot(x="Class", y="Count", hue="Technique", data=plot_data)
93
+
94
+ # Values for each bar
95
+ for p in ax.patches:
96
+ ax.annotate(
97
+ f"{int(p.get_height())}",
98
+ (p.get_x() + p.get_width() / 2.0, p.get_height()),
99
+ ha="center",
100
+ va="bottom",
101
+ fontsize=10,
102
+ color="black",
103
+ )
104
+
105
+ # Title and labels
106
+ plt.title(title)
107
+ plt.xlabel("Class")
108
+ plt.ylabel("Count")
109
+ plt.legend(title="Technique")
110
+ plt.grid(True)
111
+
112
+ if save_path:
113
+ plt.savefig(save_path)
114
+
115
+ if display:
116
+ plt.show()
117
+
118
+ plt.close()
119
+
120
+
121
+ def plot_comparison_results(
122
+ results: Dict[str, Dict[str, Dict[str, Dict[str, float]]]],
123
+ classifier_name: str,
124
+ metric_type: str = "standard_metrics",
125
+ metrics_to_plot: Optional[list] = None,
126
+ save_path: Optional[str] = None,
127
+ display: bool = False,
128
+ ) -> None:
129
+ """
130
+ Plot comparison of different techniques for a specific classifier and metric type.
131
+
132
+ Args:
133
+ results: Dictionary containing nested results structure
134
+ classifier_name: Name of the classifier to visualise
135
+ metric_type: Type of metrics to plot ('standard_metrics' or 'cv_metrics')
136
+ metrics_to_plot: List of metrics to include in plot
137
+ save_path: Path to save the plot
138
+ display: Whether to display the plot
139
+ """
140
+ if results is None:
141
+ raise TypeError("Results dictionary cannot be None")
142
+
143
+ if classifier_name not in results:
144
+ raise ValueError(f"Classifier '{classifier_name}' not found in results")
145
+
146
+ # Extract the classifier's results
147
+ classifier_results = results[classifier_name]
148
+
149
+ # Create a structure for plotting with techniques as keys and metric dictionaries as values
150
+ plot_data = {}
151
+ for technique_name, technique_data in classifier_results.items():
152
+ if metric_type in technique_data:
153
+ plot_data[technique_name] = technique_data[metric_type]
154
+
155
+ if not plot_data:
156
+ raise ValueError(
157
+ f"No {metric_type} data found for classifier '{classifier_name}'"
158
+ )
159
+
160
+ # Default metrics to plot
161
+ if metrics_to_plot is None:
162
+ if metric_type == "standard_metrics":
163
+ metrics_to_plot = ["precision", "recall", "f1", "roc_auc"]
164
+ elif metric_type == "cv_metrics":
165
+ # For CV metrics, we want to look for metrics with "cv_" prefix and "_mean" suffix
166
+ # But we want to use the same base metric names as configured
167
+ metrics_to_plot = [
168
+ "cv_accuracy_mean",
169
+ "cv_precision_mean",
170
+ "cv_recall_mean",
171
+ "cv_f1_mean",
172
+ ]
173
+ else:
174
+ metrics_to_plot = ["precision", "recall", "f1", "roc_auc"]
175
+ elif (
176
+ metric_type == "cv_metrics"
177
+ and metrics_to_plot
178
+ and not all(m.startswith("cv_") for m in metrics_to_plot)
179
+ ):
180
+ # If user provided standard metric names but we're plotting CV metrics,
181
+ # convert them to CV metric names
182
+ metrics_to_plot = [f"cv_{m}_mean" for m in metrics_to_plot]
183
+
184
+ # Filter metrics to only include those that exist in all techniques
185
+ common_metrics = set.intersection(
186
+ *[set(metrics.keys()) for metrics in plot_data.values()]
187
+ )
188
+ available_metrics = [m for m in metrics_to_plot if m in common_metrics]
189
+
190
+ if not available_metrics:
191
+ # Show all available metrics in the error message
192
+ raise ValueError(
193
+ f"No common metrics found across techniques for metric type '{metric_type}'. "
194
+ f"Requested metrics: {metrics_to_plot}, Available metrics: {sorted(common_metrics)}"
195
+ )
196
+
197
+ # Convert results to suitable format for plotting
198
+ techniques = list(plot_data.keys())
199
+ metrics_data = {
200
+ metric: [plot_data[tech][metric] for tech in techniques]
201
+ for metric in available_metrics
202
+ }
203
+
204
+ # Create subplot grid that accommodates all metrics
205
+ n_metrics = len(available_metrics)
206
+ n_cols = min(3, n_metrics) # Maximum 3 columns to ensure readability
207
+ n_rows = (n_metrics + n_cols - 1) // n_cols # Ceiling division
208
+
209
+ fig, axes = plt.subplots(
210
+ n_rows, n_cols, figsize=(5 * n_cols, 5 * n_rows), squeeze=False
211
+ )
212
+ fig.suptitle(
213
+ f"{classifier_name} - Comparison of Balancing Techniques ({metric_type.replace('_', ' ').title()})",
214
+ size=16,
215
+ )
216
+
217
+ # Plot each metric
218
+ for idx, (metric, values) in enumerate(metrics_data.items()):
219
+ row = idx // n_cols
220
+ col = idx % n_cols
221
+ ax = axes[row, col]
222
+
223
+ sns.barplot(x=techniques, y=values, ax=ax)
224
+
225
+ # Set appropriate title based on metric type
226
+ if metric.startswith("cv_") and metric.endswith("_mean"):
227
+ # For CV metrics, show "Metric Mean" format
228
+ base_metric = metric[3:-5] # Remove 'cv_' prefix and '_mean' suffix
229
+ display_title = f'{base_metric.replace("_", " ").title()} Mean'
230
+ else:
231
+ display_title = metric.replace("_", " ").title()
232
+
233
+ ax.set_title(display_title)
234
+ plt.setp(ax.get_xticklabels(), rotation=45, ha="right")
235
+
236
+ # Add value labels on top of bars
237
+ for i, v in enumerate(values):
238
+ ax.text(i, v, f"{v:.3f}", ha="center", va="bottom")
239
+
240
+ # Remove any empty subplots
241
+ for idx in range(len(metrics_data), axes.shape[0] * axes.shape[1]):
242
+ row = idx // n_cols
243
+ col = idx % n_cols
244
+ fig.delaxes(axes[row, col])
245
+
246
+ plt.tight_layout()
247
+
248
+ if save_path:
249
+ plt.savefig(save_path)
250
+
251
+ if display:
252
+ plt.show()
253
+
254
+ plt.close()
255
+
256
+
257
+ def plot_learning_curves(
258
+ learning_curve_data: dict,
259
+ title: str = "Learning Curves",
260
+ save_path: Optional[str] = None,
261
+ display: bool = False,
262
+ ):
263
+ """
264
+ Plot learning curves for multiple techniques in subplots.
265
+
266
+ Args:
267
+ learning_curve_data: Dictionary with technique names as keys and corresponding learning curve data as values
268
+ title: Title of the plot
269
+ save_path: Optional path to save the figure
270
+ """
271
+ # Error handling
272
+ if learning_curve_data is None:
273
+ raise TypeError("Learning curve data cannot be None")
274
+ if not learning_curve_data or not isinstance(learning_curve_data, dict):
275
+ raise ValueError("Learning curve data must be a non-empty dictionary")
276
+
277
+ for technique, data in learning_curve_data.items():
278
+ required_keys = {"train_sizes", "train_scores", "val_scores"}
279
+ if not all(key in data for key in required_keys):
280
+ raise ValueError(
281
+ f"Learning curve data for technique '{technique}' must contain "
282
+ f"'train_sizes', 'train_scores', and 'val_scores'"
283
+ )
284
+
285
+ num_techniques = len(learning_curve_data)
286
+ # Set up a grid of subplots, one for each technique
287
+ fig, axes = plt.subplots(num_techniques, 1, figsize=(10, 6 * num_techniques))
288
+
289
+ # Ensure axes is iterable even if there's only one technique
290
+ if num_techniques == 1:
291
+ axes = [axes]
292
+
293
+ for idx, (technique_name, data) in enumerate(learning_curve_data.items()):
294
+ # Extract the train_sizes, train_scores, and val_scores from the dictionary
295
+ train_sizes = data["train_sizes"]
296
+ train_scores = data["train_scores"]
297
+ val_scores = data["val_scores"]
298
+
299
+ # Calculate mean and std of scores
300
+ train_mean = np.mean(train_scores, axis=1)
301
+ train_std = np.std(train_scores, axis=1)
302
+ val_mean = np.mean(val_scores, axis=1)
303
+ val_std = np.std(val_scores, axis=1)
304
+
305
+ ax = axes[idx]
306
+
307
+ ax.plot(train_sizes, train_mean, label="Training score", color="blue")
308
+ ax.plot(train_sizes, val_mean, label="Validation score", color="red")
309
+ ax.fill_between(
310
+ train_sizes,
311
+ train_mean - train_std,
312
+ train_mean + train_std,
313
+ alpha=0.1,
314
+ color="blue",
315
+ )
316
+ ax.fill_between(
317
+ train_sizes, val_mean - val_std, val_mean + val_std, alpha=0.1, color="red"
318
+ )
319
+
320
+ ax.set_title(f"{technique_name} - Learning Curves")
321
+ ax.set_xlabel("Training Examples")
322
+ ax.set_ylabel("Score")
323
+ ax.legend(loc="best")
324
+ ax.grid(True)
325
+
326
+ plt.tight_layout()
327
+
328
+ if save_path:
329
+ plt.savefig(save_path)
330
+
331
+ if display:
332
+ plt.show()
333
+
334
+ plt.close()
335
+
336
+
337
+ def plot_radar_chart(
338
+ results: Dict[str, Dict[str, Dict[str, Dict[str, float]]]],
339
+ classifier_name: str,
340
+ metric_type: str = "standard_metrics",
341
+ metrics_to_plot: Optional[list] = None,
342
+ save_path: Optional[str] = None,
343
+ display: bool = False,
344
+ ) -> None:
345
+ """
346
+ Plot radar (spider) chart comparing balancing techniques for a specific classifier.
347
+
348
+ Args:
349
+ results: Dictionary containing nested results structure
350
+ {classifier_name: {technique_name: {metric_type: {metric_name: value}}}}
351
+ classifier_name: Name of the classifier to visualise
352
+ metric_type: Type of metrics to plot ('standard_metrics' or 'cv_metrics')
353
+ metrics_to_plot: List of metrics to include in the radar chart
354
+ save_path: Path to save the plot
355
+ display: Whether to display the plot
356
+ """
357
+ if results is None:
358
+ raise TypeError("Results dictionary cannot be None")
359
+
360
+ if classifier_name not in results:
361
+ raise ValueError(f"Classifier '{classifier_name}' not found in results")
362
+
363
+ # Extract the classifier's results
364
+ classifier_results = results[classifier_name]
365
+
366
+ # Create a structure for plotting with techniques as keys and metric dictionaries as values
367
+ plot_data = {}
368
+ for technique_name, technique_data in classifier_results.items():
369
+ if metric_type in technique_data:
370
+ plot_data[technique_name] = technique_data[metric_type]
371
+
372
+ if not plot_data:
373
+ raise ValueError(
374
+ f"No {metric_type} data found for classifier '{classifier_name}'"
375
+ )
376
+
377
+ # Default metrics to plot
378
+ if metrics_to_plot is None:
379
+ if metric_type == "standard_metrics":
380
+ metrics_to_plot = ["precision", "recall", "f1", "roc_auc"]
381
+ elif metric_type == "cv_metrics":
382
+ # For CV metrics, we're interested in mean values, not std
383
+ metrics_to_plot = [
384
+ "cv_accuracy_mean",
385
+ "cv_precision_mean",
386
+ "cv_recall_mean",
387
+ "cv_f1_mean",
388
+ ]
389
+ else:
390
+ metrics_to_plot = ["precision", "recall", "f1", "roc_auc"]
391
+ elif (
392
+ metric_type == "cv_metrics"
393
+ and metrics_to_plot
394
+ and not all(m.startswith("cv_") for m in metrics_to_plot)
395
+ ):
396
+ # If user provided standard metric names but we're plotting CV metrics,
397
+ # convert them to CV metric names with _mean suffix
398
+ metrics_to_plot = [f"cv_{m}_mean" for m in metrics_to_plot]
399
+
400
+ # Filter metrics to only include those that exist in all techniques
401
+ common_metrics = set.intersection(
402
+ *[set(metrics.keys()) for metrics in plot_data.values()]
403
+ )
404
+ available_metrics = [m for m in metrics_to_plot if m in common_metrics]
405
+
406
+ if not available_metrics:
407
+ # Show all available metrics in the error message
408
+ raise ValueError(
409
+ f"No common metrics found across techniques for metric type '{metric_type}'. "
410
+ f"Requested metrics: {metrics_to_plot}, Available metrics: {sorted(common_metrics)}"
411
+ )
412
+
413
+ # Create the figure and polar axis
414
+ fig, ax = plt.subplots(figsize=(10, 10), subplot_kw=dict(polar=True))
415
+
416
+ # Calculate angles for each metric (evenly spaced around the circle)
417
+ angles = np.linspace(0, 2 * np.pi, len(available_metrics), endpoint=False).tolist()
418
+
419
+ # Make the plot circular by appending the start point at the end
420
+ angles += angles[:1]
421
+
422
+ # Add metric labels
423
+ metric_labels = []
424
+ for metric in available_metrics:
425
+ # Format the metric name for display
426
+ if (
427
+ metric_type == "cv_metrics"
428
+ and metric.startswith("cv_")
429
+ and metric.endswith("_mean")
430
+ ):
431
+ # For CV metrics, show "Metric Mean" format
432
+ base_metric = metric[3:-5] # Remove 'cv_' prefix and '_mean' suffix
433
+ display_label = f'{base_metric.replace("_", " ").title()}'
434
+ else:
435
+ display_label = metric.replace("_", " ").title()
436
+ metric_labels.append(display_label)
437
+
438
+ # Set radar chart labels at appropriate angles
439
+ ax.set_xticks(angles[:-1])
440
+ ax.set_xticklabels(metric_labels)
441
+
442
+ # Get color map for different techniques
443
+ # Use matplotlib.colormaps instead of deprecated plt.cm.get_cmap
444
+ import matplotlib as mpl
445
+
446
+ if hasattr(mpl, "colormaps"): # Matplotlib 3.7+
447
+ cmap = mpl.colormaps["tab10"]
448
+ else: # Fallback for older versions
449
+ cmap = plt.get_cmap("tab10")
450
+
451
+ # Plot each technique
452
+ for i, (technique_name, metrics) in enumerate(plot_data.items()):
453
+ # Extract values for the metrics we want to plot
454
+ values = [metrics[metric] for metric in available_metrics]
455
+
456
+ # Make values circular
457
+ values += values[:1]
458
+
459
+ # Plot the technique
460
+ color = cmap(i)
461
+ ax.plot(angles, values, "o-", linewidth=2, label=technique_name, color=color)
462
+ ax.fill(angles, values, alpha=0.1, color=color)
463
+
464
+ # Set title
465
+ title_type = (
466
+ "Cross-Validation Metrics"
467
+ if metric_type == "cv_metrics"
468
+ else "Standard Metrics"
469
+ )
470
+ title = f"{classifier_name} - {title_type}"
471
+ ax.set_title(title, size=15)
472
+
473
+ # Add a legend
474
+ ax.legend(loc="upper right", bbox_to_anchor=(1.3, 1.0))
475
+
476
+ # Adjust layout
477
+ plt.tight_layout()
478
+
479
+ if save_path:
480
+ plt.savefig(save_path)
481
+
482
+ if display:
483
+ plt.show()
484
+
485
+ plt.close()
486
+
487
+
488
+ def plot_3d_scatter(
489
+ results: Dict[str, Dict[str, Dict[str, Dict[str, float]]]],
490
+ metric_type: str = "standard_metrics",
491
+ metrics_to_plot: Optional[list] = None,
492
+ save_path: Optional[str] = None,
493
+ display: bool = False,
494
+ ) -> None:
495
+ """
496
+ Create a 3D scatter plot showing the relationship between F1-score, ROC-AUC, and G-mean
497
+ for various classifier and balancing technique combinations.
498
+
499
+ Args:
500
+ results: Dictionary containing nested results structure
501
+ metric_type: Type of metrics to plot ('standard_metrics' or 'cv_metrics')
502
+ metrics_to_plot: List of metrics that were chosen to be plotted
503
+ save_path: Path to save the plot
504
+ display: Whether to display the plot
505
+ """
506
+ try:
507
+ import plotly.graph_objects as go
508
+ except ImportError:
509
+ logging.error(
510
+ "Plotly is required for 3D scatter plots. Install with: pip install plotly"
511
+ )
512
+ return
513
+
514
+ # Determine metric keys based on the metric_type
515
+ if metric_type == "standard_metrics":
516
+ f1_key = "f1"
517
+ roc_auc_key = "roc_auc"
518
+ g_mean_key = "g_mean"
519
+ title_prefix = "Standard Metrics"
520
+ else: # cv_metrics
521
+ f1_key = "cv_f1_mean"
522
+ roc_auc_key = "cv_roc_auc_mean"
523
+ g_mean_key = "cv_g_mean_mean"
524
+ title_prefix = "Cross-Validation Metrics"
525
+
526
+ # Check if required metrics are in metrics_to_plot
527
+ required_metrics = ["f1", "roc_auc", "g_mean"]
528
+ if metrics_to_plot is not None:
529
+ missing_metrics = [m for m in required_metrics if m not in metrics_to_plot]
530
+ if missing_metrics:
531
+ logging.warning(
532
+ f"Cannot create 3D scatter plot. Required metrics {missing_metrics} are not in metrics_to_plot. "
533
+ "Please include these metrics using 'configure-metrics'."
534
+ )
535
+ return
536
+
537
+ # Prepare data for plotting
538
+ plot_data = []
539
+
540
+ # Assign a unique color to each classifier
541
+ classifiers = list(results.keys())
542
+
543
+ # Skip if no classifiers
544
+ if not classifiers:
545
+ logging.warning("No classifiers found in results.")
546
+ return
547
+
548
+ # Build the data structure for plotting
549
+ for classifier_name in classifiers:
550
+ classifier_results = results[classifier_name]
551
+
552
+ for technique_name, technique_metrics in classifier_results.items():
553
+ if metric_type in technique_metrics:
554
+ metrics = technique_metrics[metric_type]
555
+
556
+ # Check if all required metrics are available
557
+ if all(key in metrics for key in [f1_key, roc_auc_key, g_mean_key]):
558
+ f1 = metrics[f1_key]
559
+ roc_auc = metrics[roc_auc_key]
560
+ g_mean = metrics[g_mean_key]
561
+
562
+ # Only add valid data points
563
+ if (
564
+ isinstance(f1, (int, float))
565
+ and isinstance(roc_auc, (int, float))
566
+ and isinstance(g_mean, (int, float))
567
+ ):
568
+
569
+ plot_data.append(
570
+ {
571
+ "classifier": classifier_name,
572
+ "technique": technique_name,
573
+ "f1": f1,
574
+ "roc_auc": roc_auc,
575
+ "g_mean": g_mean,
576
+ }
577
+ )
578
+ else:
579
+ missing = [
580
+ k for k in [f1_key, roc_auc_key, g_mean_key] if k not in metrics
581
+ ]
582
+ logging.warning(
583
+ f"Cannot plot data point for {classifier_name}/{technique_name}. "
584
+ f"Missing metrics: {missing}"
585
+ )
586
+
587
+ # Skip if no valid data points
588
+ if not plot_data:
589
+ logging.warning("No valid data points found for 3D scatter plot.")
590
+ return
591
+
592
+ # Convert to pandas DataFrame for easier processing
593
+ df = pd.DataFrame(plot_data)
594
+
595
+ # Create interactive 3D scatter plot
596
+ fig = go.Figure()
597
+
598
+ # Add traces for each classifier
599
+ for classifier_name in df["classifier"].unique():
600
+ subset = df[df["classifier"] == classifier_name]
601
+
602
+ fig.add_trace(
603
+ go.Scatter3d(
604
+ x=subset["roc_auc"], # X-axis: ROC-AUC
605
+ y=subset["g_mean"], # Y-axis: G-mean
606
+ z=subset["f1"], # Z-axis: F1-score
607
+ mode="markers",
608
+ marker=dict(
609
+ size=10,
610
+ opacity=0.8,
611
+ ),
612
+ name=classifier_name,
613
+ text=[
614
+ f"Classifier: {row['classifier']}<br>Technique: {row['technique']}<br>"
615
+ f"F1: {row['f1']:.4f}<br>ROC-AUC: {row['roc_auc']:.4f}<br>G-mean: {row['g_mean']:.4f}"
616
+ for _, row in subset.iterrows()
617
+ ],
618
+ hoverinfo="text",
619
+ )
620
+ )
621
+
622
+ # Update layout
623
+ fig.update_layout(
624
+ title=f"{title_prefix}: F1-score vs ROC-AUC vs G-mean",
625
+ scene=dict(
626
+ xaxis_title="ROC-AUC",
627
+ yaxis_title="G-mean",
628
+ zaxis_title="F1-score",
629
+ xaxis=dict(range=[0, 1]),
630
+ yaxis=dict(range=[0, 1]),
631
+ zaxis=dict(range=[0, 1]),
632
+ ),
633
+ margin=dict(l=0, r=0, b=0, t=50),
634
+ legend=dict(
635
+ x=0.01,
636
+ y=0.99,
637
+ traceorder="normal",
638
+ font=dict(size=12),
639
+ ),
640
+ autosize=True,
641
+ height=700,
642
+ )
643
+
644
+ # Save or display the figure
645
+ if save_path:
646
+ # Convert Path object to string if needed
647
+ save_path_str = str(save_path)
648
+
649
+ # Ensure .html extension for interactive plot
650
+ if not save_path_str.endswith(".html"):
651
+ save_path_str += ".html"
652
+
653
+ # Save as interactive HTML
654
+ fig.write_html(save_path_str)
655
+ logging.info(f"3D scatter plot saved to {save_path_str}")
656
+
657
+ if display:
658
+ fig.show()
659
+
660
+ return fig # Return the figure object for potential further customisation