dragon-ml-toolbox 5.3.1__py3-none-any.whl → 6.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dragon-ml-toolbox might be problematic. Click here for more details.

@@ -0,0 +1,639 @@
1
+ import pandas as pd
2
+ import numpy as np
3
+ import seaborn as sns
4
+ import matplotlib.pyplot as plt
5
+ from matplotlib.colors import Colormap
6
+ from matplotlib import rcdefaults
7
+ import shap
8
+ import xgboost as xgb
9
+ import lightgbm as lgb
10
+ from sklearn.model_selection import learning_curve
11
+ from sklearn.calibration import CalibrationDisplay
12
+ from sklearn.metrics import (accuracy_score,
13
+ classification_report,
14
+ ConfusionMatrixDisplay,
15
+ mean_absolute_error,
16
+ mean_squared_error,
17
+ r2_score,
18
+ roc_curve,
19
+ roc_auc_score,
20
+ precision_recall_curve,
21
+ average_precision_score)
22
+ from pathlib import Path
23
+ from typing import Union, Optional, Literal
24
+
25
+ from .path_manager import sanitize_filename, make_fullpath
26
+ from ._script_info import _script_info
27
+ from ._logger import _LOGGER
28
+
29
+
30
+ __all__ = [
31
+ "evaluate_model_classification",
32
+ "plot_roc_curve",
33
+ "plot_precision_recall_curve",
34
+ "plot_calibration_curve",
35
+ "evaluate_model_regression",
36
+ "get_shap_values",
37
+ "plot_learning_curves",
38
+ ]
39
+
40
+
41
+ # function to evaluate the model and save metrics (Classification)
42
+ def evaluate_model_classification(
43
+ model,
44
+ model_name: str,
45
+ save_dir: Union[str,Path],
46
+ x_test_scaled: np.ndarray,
47
+ single_y_test: np.ndarray,
48
+ target_name: str,
49
+ figsize: tuple = (10, 8),
50
+ base_fontsize: int = 24,
51
+ cmap: Colormap = plt.cm.Blues, # type: ignore
52
+ heatmap_cmap: str = "viridis"
53
+ ) -> np.ndarray:
54
+ """
55
+ Evaluates a classification model, saves the classification report (text and heatmap) and the confusion matrix plot.
56
+
57
+ Parameters:
58
+ model: Trained classifier with .predict() method
59
+ model_name: Identifier for the model
60
+ save_dir: Directory where results are saved
61
+ x_test_scaled: Feature matrix for test set
62
+ single_y_test: True targets
63
+ target_name: Target name
64
+ figsize: Size of the confusion matrix figure (width, height)
65
+ fontsize: Font size used for title, axis labels and ticks
66
+ heatmap_cmap: Colormap for the classification report heatmap.
67
+ cmap: Color map for the confusion matrix. Examples include:
68
+ - plt.cm.Blues (default)
69
+ - plt.cm.Greens
70
+ - plt.cm.Oranges
71
+ - plt.cm.Purples
72
+ - plt.cm.Reds
73
+ - plt.cm.cividis
74
+ - plt.cm.inferno
75
+
76
+ Returns:
77
+ y_pred: Predicted class labels
78
+ """
79
+ save_path = make_fullpath(save_dir, make=True)
80
+ sanitized_target_name = sanitize_filename(target_name)
81
+
82
+ y_pred = model.predict(x_test_scaled)
83
+ accuracy = accuracy_score(single_y_test, y_pred)
84
+
85
+ # Generate report as dictionary for the heatmap
86
+ report_dict = classification_report(
87
+ single_y_test,
88
+ y_pred,
89
+ target_names=["Negative", "Positive"],
90
+ output_dict=True
91
+ )
92
+
93
+ # text report to save
94
+ report_text = classification_report(
95
+ single_y_test,
96
+ y_pred,
97
+ target_names=["Negative", "Positive"],
98
+ output_dict=False
99
+ )
100
+
101
+ # Save text report
102
+
103
+ report_path = save_path / f"Classification_Report_{sanitized_target_name}.txt"
104
+ with open(report_path, "w") as f:
105
+ f.write(f"{model_name} - {target_name}\t\tAccuracy: {accuracy:.2f}\n")
106
+ f.write("Classification Report:\n")
107
+ f.write(report_text) # type: ignore
108
+
109
+ # 3. Create and save the classification report heatmap
110
+ try:
111
+ report_df = pd.DataFrame(report_dict).iloc[:-1, :].T
112
+ plt.figure(figsize=figsize)
113
+ sns.heatmap(report_df, annot=True, cmap=heatmap_cmap, fmt='.2f',
114
+ annot_kws={"size": base_fontsize - 4})
115
+ plt.title(f"{model_name} - {target_name}", fontsize=base_fontsize)
116
+ plt.xticks(fontsize=base_fontsize - 2)
117
+ plt.yticks(fontsize=base_fontsize - 2)
118
+
119
+ heatmap_path = save_path / f"Classification_Report_{sanitized_target_name}.svg"
120
+ plt.savefig(heatmap_path, format="svg", bbox_inches="tight")
121
+ plt.close()
122
+ except Exception as e:
123
+ _LOGGER.error(f"❌ Could not generate classification report heatmap for {target_name}: {e}")
124
+
125
+ # Create confusion matrix
126
+ fig, ax = plt.subplots(figsize=figsize)
127
+ disp = ConfusionMatrixDisplay.from_predictions(
128
+ y_true=single_y_test,
129
+ y_pred=y_pred,
130
+ display_labels=["Negative", "Positive"],
131
+ cmap=cmap,
132
+ normalize="true",
133
+ ax=ax
134
+ )
135
+
136
+ ax.set_title(f"{model_name} - {target_name}", fontsize=base_fontsize)
137
+ ax.tick_params(axis='both', labelsize=base_fontsize)
138
+ ax.set_xlabel("Predicted label", fontsize=base_fontsize)
139
+ ax.set_ylabel("True label", fontsize=base_fontsize)
140
+
141
+ # Turn off gridlines
142
+ ax.grid(False)
143
+
144
+ # Manually update font size of cell texts
145
+ for text in ax.texts:
146
+ text.set_fontsize(base_fontsize+4)
147
+
148
+ fig.tight_layout()
149
+ fig_path = save_path / f"Confusion_Matrix_{sanitized_target_name}.svg"
150
+ fig.savefig(fig_path, format="svg", bbox_inches="tight") # type: ignore
151
+ plt.close(fig)
152
+
153
+ return y_pred
154
+
155
+ #Function to save ROC and ROC AUC (Classification)
156
+ def plot_roc_curve(
157
+ true_labels: np.ndarray,
158
+ probabilities_or_model: Union[np.ndarray, xgb.XGBClassifier, lgb.LGBMClassifier, object],
159
+ model_name: str,
160
+ target_name: str,
161
+ save_directory: Union[str,Path],
162
+ color: str = "darkorange",
163
+ figure_size: tuple = (10, 10),
164
+ linewidth: int = 2,
165
+ base_fontsize: int = 24,
166
+ input_features: Optional[np.ndarray] = None,
167
+ ) -> plt.Figure: # type: ignore
168
+ """
169
+ Plots the ROC curve and computes AUC for binary classification. Positive class is assumed to be in the second column of the probabilities array.
170
+
171
+ Parameters:
172
+ true_labels: np.ndarray of shape (n_samples,), ground truth binary labels (0 or 1).
173
+ probabilities_or_model: either predicted probabilities (ndarray), or a trained model with attribute `.predict_proba()`.
174
+ target_name: str, Target name.
175
+ save_directory: str or Path, path to directory where figure is saved.
176
+ color: color of the ROC curve. Accepts any valid Matplotlib color specification. Examples:
177
+ - Named colors: "darkorange", "blue", "red", "green", "black"
178
+ - Hex codes: "#1f77b4", "#ff7f0e"
179
+ - RGB tuples: (0.2, 0.4, 0.6)
180
+ - Colormap value: plt.cm.viridis(0.6)
181
+ figure_size: Tuple for figure size (width, height).
182
+ linewidth: int, width of the plotted ROC line.
183
+ title_fontsize: int, font size of the title.
184
+ label_fontsize: int, font size for axes labels.
185
+ input_features: np.ndarray of shape (n_samples, n_features), required if a model is passed.
186
+
187
+ Returns:
188
+ fig: matplotlib Figure object
189
+ """
190
+
191
+ # Determine predicted probabilities
192
+ if isinstance(probabilities_or_model, np.ndarray):
193
+ # Input is already probabilities
194
+ if probabilities_or_model.ndim == 2: # type: ignore
195
+ y_score = probabilities_or_model[:, 1] # type: ignore
196
+ else:
197
+ y_score = probabilities_or_model
198
+
199
+ elif hasattr(probabilities_or_model, "predict_proba"):
200
+ if input_features is None:
201
+ raise ValueError("input_features must be provided when using a classifier.")
202
+
203
+ try:
204
+ classes = probabilities_or_model.classes_ # type: ignore
205
+ positive_class_index = list(classes).index(1)
206
+ except (AttributeError, ValueError):
207
+ positive_class_index = 1
208
+
209
+ y_score = probabilities_or_model.predict_proba(input_features)[:, positive_class_index] # type: ignore
210
+
211
+ else:
212
+ raise TypeError("Unsupported type for 'probabilities_or_model'. Must be a NumPy array or a model with support for '.predict_proba()'.")
213
+
214
+ # ROC and AUC
215
+ fpr, tpr, _ = roc_curve(true_labels, y_score)
216
+ auc_score = roc_auc_score(true_labels, y_score)
217
+
218
+ # Plot
219
+ fig, ax = plt.subplots(figsize=figure_size)
220
+ ax.plot(fpr, tpr, color=color, lw=linewidth, label=f"AUC = {auc_score:.2f}")
221
+ ax.plot([0, 1], [0, 1], color="gray", linestyle="--", lw=1)
222
+
223
+ ax.set_title(f"{model_name} - {target_name}", fontsize=base_fontsize)
224
+ ax.set_xlabel("False Positive Rate", fontsize=base_fontsize)
225
+ ax.set_ylabel("True Positive Rate", fontsize=base_fontsize)
226
+ ax.tick_params(axis='both', labelsize=base_fontsize)
227
+ ax.legend(loc="lower right", fontsize=base_fontsize)
228
+ ax.grid(True)
229
+
230
+ # Save figure
231
+ save_path = make_fullpath(save_directory, make=True)
232
+ sanitized_target_name = sanitize_filename(target_name)
233
+ full_save_path = save_path / f"ROC_{sanitized_target_name}.svg"
234
+ fig.savefig(full_save_path, bbox_inches="tight", format="svg") # type: ignore
235
+
236
+ return fig
237
+
238
+
239
+ # Precision-Recall curve (Classification)
240
+ def plot_precision_recall_curve(
241
+ true_labels: np.ndarray,
242
+ probabilities_or_model: Union[np.ndarray, xgb.XGBClassifier, lgb.LGBMClassifier, object],
243
+ model_name: str,
244
+ target_name: str,
245
+ save_directory: Union[str, Path],
246
+ color: str = "teal",
247
+ figure_size: tuple = (10, 10),
248
+ linewidth: int = 2,
249
+ base_fontsize: int = 24,
250
+ input_features: Optional[np.ndarray] = None,
251
+ ) -> plt.Figure: # type: ignore
252
+ """
253
+ Plots the Precision-Recall curve and computes Average Precision (AP) for binary classification.
254
+
255
+ Parameters:
256
+ true_labels: np.ndarray of shape (n_samples,), ground truth binary labels (0 or 1).
257
+ probabilities_or_model: either predicted probabilities (ndarray), or a trained model with attribute `.predict_proba()`.
258
+ model_name: Identifier for the model.
259
+ target_name: Name of the target variable.
260
+ save_directory: Path to the directory where the figure will be saved.
261
+ color: str, color of the PR curve.
262
+ figure_size: Tuple for figure size (width, height).
263
+ linewidth: int, width of the plotted PR line.
264
+ base_fontsize: int, base font size for titles and labels.
265
+ input_features: np.ndarray, required if a model object is passed instead of probabilities.
266
+
267
+ Returns:
268
+ fig: matplotlib Figure object
269
+ """
270
+ # Determine predicted probabilities for the positive class
271
+ if isinstance(probabilities_or_model, np.ndarray):
272
+ if probabilities_or_model.ndim == 2:
273
+ y_score = probabilities_or_model[:, 1]
274
+ else:
275
+ y_score = probabilities_or_model
276
+
277
+ elif hasattr(probabilities_or_model, "predict_proba"):
278
+ if input_features is None:
279
+ raise ValueError("input_features must be provided when using a classifier.")
280
+ try:
281
+ classes = probabilities_or_model.classes_ # type: ignore
282
+ positive_class_index = list(classes).index(1)
283
+ except (AttributeError, ValueError):
284
+ positive_class_index = 1
285
+ y_score = probabilities_or_model.predict_proba(input_features)[:, positive_class_index] # type: ignore
286
+ else:
287
+ raise TypeError("Unsupported type for 'probabilities_or_model'. Must be a NumPy array or a model with support for '.predict_proba()'.")
288
+
289
+ # Calculate PR curve and AP score
290
+ precision, recall, _ = precision_recall_curve(true_labels, y_score)
291
+ ap_score = average_precision_score(true_labels, y_score)
292
+
293
+ # Plot
294
+ fig, ax = plt.subplots(figsize=figure_size)
295
+ ax.plot(recall, precision, color=color, lw=linewidth, label=f"AP = {ap_score:.2f}")
296
+
297
+ ax.set_title(f"{model_name} - {target_name}", fontsize=base_fontsize)
298
+ ax.set_xlabel("Recall", fontsize=base_fontsize)
299
+ ax.set_ylabel("Precision", fontsize=base_fontsize)
300
+ ax.tick_params(axis='both', labelsize=base_fontsize)
301
+ ax.legend(loc="lower left", fontsize=base_fontsize)
302
+ ax.grid(True)
303
+ fig.tight_layout()
304
+
305
+ # Save figure
306
+ save_path = make_fullpath(save_directory, make=True)
307
+ sanitized_target_name = sanitize_filename(target_name)
308
+ full_save_path = save_path / f"PR_Curve_{sanitized_target_name}.svg"
309
+ fig.savefig(full_save_path, bbox_inches="tight", format="svg") # type: ignore
310
+ plt.close(fig)
311
+
312
+ return fig
313
+
314
+
315
+ # Calibration curve (classification)
316
+ def plot_calibration_curve(
317
+ model,
318
+ model_name: str,
319
+ save_dir: Union[str, Path],
320
+ x_test: np.ndarray,
321
+ y_test: np.ndarray,
322
+ target_name: str,
323
+ figure_size: tuple = (10, 10),
324
+ base_fontsize: int = 24,
325
+ n_bins: int = 15
326
+ ) -> plt.Figure: # type: ignore
327
+ """
328
+ Plots the calibration curve (reliability diagram) for a classifier.
329
+
330
+ Parameters:
331
+ model: Trained classifier with .predict_proba() method.
332
+ model_name: Identifier for the model.
333
+ save_dir: Directory where the plot will be saved.
334
+ x_test: Feature matrix for the test set.
335
+ y_test: True labels for the test set.
336
+ target_name: Name of the target variable.
337
+ figure_size: Tuple for figure size (width, height).
338
+ base_fontsize: Base font size for titles and labels.
339
+ n_bins: Number of bins to discretize predictions into.
340
+
341
+ Returns:
342
+ fig: matplotlib Figure object
343
+ """
344
+ fig, ax = plt.subplots(figsize=figure_size)
345
+
346
+ disp = CalibrationDisplay.from_estimator(
347
+ model,
348
+ x_test,
349
+ y_test,
350
+ n_bins=n_bins,
351
+ ax=ax
352
+ )
353
+
354
+ ax.set_title(f"{model_name} - Reliability Curve for {target_name}", fontsize=base_fontsize)
355
+ ax.tick_params(axis='both', labelsize=base_fontsize - 2)
356
+ ax.set_xlabel("Mean Predicted Probability", fontsize=base_fontsize)
357
+ ax.set_ylabel("Fraction of Positives", fontsize=base_fontsize)
358
+ ax.legend(fontsize=base_fontsize - 4)
359
+ fig.tight_layout()
360
+
361
+ # Save figure
362
+ save_path = make_fullpath(save_dir, make=True)
363
+ sanitized_target_name = sanitize_filename(target_name)
364
+ full_save_path = save_path / f"Calibration_Plot_{sanitized_target_name}.svg"
365
+ fig.savefig(full_save_path, bbox_inches="tight", format="svg") # type: ignore
366
+ plt.close(fig)
367
+
368
+ return fig
369
+
370
+
371
+ # function to evaluate the model and save metrics (Regression)
372
+ def evaluate_model_regression(model, model_name: str,
373
+ save_dir: Union[str,Path],
374
+ x_test_scaled: np.ndarray, single_y_test: np.ndarray,
375
+ target_name: str,
376
+ figure_size: tuple = (12, 8),
377
+ alpha_transparency: float = 0.5,
378
+ base_fontsize: int = 24,
379
+ hist_bins: int = 30):
380
+ # Generate predictions
381
+ y_pred = model.predict(x_test_scaled)
382
+
383
+ # Calculate regression metrics
384
+ mae = mean_absolute_error(single_y_test, y_pred)
385
+ mse = mean_squared_error(single_y_test, y_pred)
386
+ rmse = np.sqrt(mse)
387
+ r2 = r2_score(single_y_test, y_pred)
388
+
389
+ # Create formatted report
390
+ sanitized_target_name = sanitize_filename(target_name)
391
+ save_path = make_fullpath(save_dir, make=True)
392
+ report_path = save_path / f"Regression_Report_{sanitized_target_name}.txt"
393
+ with open(report_path, "w") as f:
394
+ f.write(f"{model_name} - Regression Performance for '{target_name}'\n\n")
395
+ f.write(f"Mean Absolute Error (MAE): {mae:.4f}\n")
396
+ f.write(f"Mean Squared Error (MSE): {mse:.4f}\n")
397
+ f.write(f"Root Mean Squared Error (RMSE): {rmse:.4f}\n")
398
+ f.write(f"R² Score: {r2:.4f}\n")
399
+
400
+ # Generate and save residual plot
401
+ residuals = single_y_test - y_pred
402
+
403
+ plt.figure(figsize=figure_size)
404
+ plt.scatter(y_pred, residuals, alpha=alpha_transparency)
405
+ plt.axhline(0, color='red', linestyle='--')
406
+ plt.xlabel("Predicted Values", fontsize=base_fontsize)
407
+ plt.ylabel("Residuals", fontsize=base_fontsize)
408
+ plt.title(f"{model_name} - Residual Plot for {target_name}", fontsize=base_fontsize)
409
+ plt.grid(True)
410
+ plt.tight_layout()
411
+ residual_path = save_path / f"Residuals_Plot_{sanitized_target_name}.svg"
412
+ plt.savefig(residual_path, bbox_inches='tight', format="svg")
413
+ plt.close()
414
+
415
+ # Create true vs predicted values plot
416
+ plt.figure(figsize=figure_size)
417
+ plt.scatter(single_y_test, y_pred, alpha=alpha_transparency)
418
+ plt.plot([single_y_test.min(), single_y_test.max()],
419
+ [single_y_test.min(), single_y_test.max()],
420
+ 'k--', lw=2)
421
+ plt.xlabel('True Values', fontsize=base_fontsize)
422
+ plt.ylabel('Predictions', fontsize=base_fontsize)
423
+ plt.title(f"{model_name} - True vs Predicted for {target_name}", fontsize=base_fontsize)
424
+ plt.grid(True)
425
+ plot_path = save_path / f"True_Vs_Predict_Plot_{sanitized_target_name}.svg"
426
+ plt.savefig(plot_path, bbox_inches='tight', format="svg")
427
+ plt.close()
428
+
429
+ # Generate and save histogram of residuals
430
+ plt.figure(figsize=figure_size)
431
+ sns.histplot(residuals, bins=hist_bins, kde=True)
432
+ plt.xlabel("Residual Value", fontsize=base_fontsize)
433
+ plt.ylabel("Frequency", fontsize=base_fontsize)
434
+ plt.title(f"{model_name} - Distribution of Residuals for {target_name}", fontsize=base_fontsize)
435
+ plt.grid(True)
436
+ plt.tight_layout()
437
+ hist_path = save_path / f"Residuals_Distribution_{sanitized_target_name}.svg"
438
+ plt.savefig(hist_path, bbox_inches='tight', format="svg")
439
+ plt.close()
440
+
441
+ return y_pred
442
+
443
+
444
+ # Get SHAP values
445
+ def get_shap_values(
446
+ model,
447
+ model_name: str,
448
+ save_dir: Union[str, Path],
449
+ features_to_explain: np.ndarray,
450
+ feature_names: list[str],
451
+ target_name: str,
452
+ task: Literal["classification", "regression"],
453
+ max_display_features: int = 10,
454
+ figsize: tuple = (16, 20),
455
+ base_fontsize: int = 38,
456
+ ):
457
+ """
458
+ Universal SHAP explainer for regression and classification.
459
+ * Use `X_train` (or a subsample of it) to see how the model explains the data it was trained on.
460
+
461
+ * Use `X_test` (or a hold-out set) to see how the model explains unseen data.
462
+
463
+ * Use the entire dataset to get the global view.
464
+
465
+ Parameters:
466
+ task: 'regression' or 'classification'.
467
+ features_to_explain: Should match the model's training data format, including scaling.
468
+ save_dir: Directory to save visualizations.
469
+ """
470
+ sanitized_target_name = sanitize_filename(target_name)
471
+ global_save_path = make_fullpath(save_dir, make=True)
472
+
473
+ def _apply_plot_style():
474
+ styles = ['seaborn', 'seaborn-v0_8-darkgrid', 'seaborn-v0_8', 'default']
475
+ for style in styles:
476
+ if style in plt.style.available or style == 'default':
477
+ plt.style.use(style)
478
+ break
479
+
480
+ def _configure_rcparams():
481
+ plt.rc('font', size=base_fontsize)
482
+ plt.rc('axes', titlesize=base_fontsize)
483
+ plt.rc('axes', labelsize=base_fontsize)
484
+ plt.rc('xtick', labelsize=base_fontsize)
485
+ plt.rc('ytick', labelsize=base_fontsize + 2)
486
+ plt.rc('legend', fontsize=base_fontsize)
487
+ plt.rc('figure', titlesize=base_fontsize)
488
+
489
+ def _create_shap_plot(shap_values, features, save_path: Path, plot_type: str, title: str):
490
+ _apply_plot_style()
491
+ _configure_rcparams()
492
+ plt.figure(figsize=figsize)
493
+
494
+ shap.summary_plot(
495
+ shap_values=shap_values,
496
+ features=features,
497
+ feature_names=feature_names,
498
+ plot_type=plot_type,
499
+ show=False,
500
+ plot_size=figsize,
501
+ max_display=max_display_features,
502
+ alpha=0.7,
503
+ # color='viridis'
504
+ )
505
+
506
+ ax = plt.gca()
507
+ ax.set_xlabel("SHAP Value Impact", fontsize=base_fontsize + 2, weight='bold', labelpad=20)
508
+ plt.title(title, fontsize=base_fontsize + 2, pad=20, weight='bold')
509
+
510
+ for tick in ax.get_xticklabels():
511
+ tick.set_fontsize(base_fontsize)
512
+ tick.set_rotation(30)
513
+ for tick in ax.get_yticklabels():
514
+ tick.set_fontsize(base_fontsize + 2)
515
+
516
+ if plot_type == "dot":
517
+ cb = plt.gcf().axes[-1]
518
+ cb.set_ylabel("", size=1)
519
+ cb.tick_params(labelsize=base_fontsize - 2)
520
+
521
+ plt.savefig(save_path, bbox_inches='tight', facecolor='white', format="svg")
522
+ plt.close()
523
+ rcdefaults()
524
+
525
+ def _plot_for_classification(shap_values, class_names):
526
+ is_multiclass = isinstance(shap_values, list) and len(shap_values) > 1
527
+
528
+ if is_multiclass:
529
+ for class_shap, class_name in zip(shap_values, class_names):
530
+ for plot_type in ["bar", "dot"]:
531
+ _create_shap_plot(
532
+ shap_values=class_shap,
533
+ features=features_to_explain,
534
+ save_path=global_save_path / f"SHAP_{sanitized_target_name}_Class{class_name}_{plot_type}.svg",
535
+ plot_type=plot_type,
536
+ title=f"{model_name} - {target_name} (Class {class_name})"
537
+ )
538
+ else:
539
+ values = shap_values[1] if isinstance(shap_values, list) else shap_values
540
+ for plot_type in ["bar", "dot"]:
541
+ _create_shap_plot(
542
+ shap_values=values,
543
+ features=features_to_explain,
544
+ save_path=global_save_path / f"SHAP_{sanitized_target_name}_{plot_type}.svg",
545
+ plot_type=plot_type,
546
+ title=f"{model_name} - {target_name}"
547
+ )
548
+
549
+ def _plot_for_regression(shap_values):
550
+ for plot_type in ["bar", "dot"]:
551
+ _create_shap_plot(
552
+ shap_values=shap_values,
553
+ features=features_to_explain,
554
+ save_path=global_save_path / f"SHAP_{sanitized_target_name}_{plot_type}.svg",
555
+ plot_type=plot_type,
556
+ title=f"{model_name} - {target_name}"
557
+ )
558
+ #START_O
559
+
560
+ explainer = shap.TreeExplainer(model)
561
+ shap_values = explainer.shap_values(features_to_explain)
562
+
563
+ if task == 'classification':
564
+ try:
565
+ class_names = model.classes_ if hasattr(model, 'classes_') else list(range(len(shap_values)))
566
+ except Exception:
567
+ class_names = list(range(len(shap_values)))
568
+ _plot_for_classification(shap_values, class_names)
569
+ else:
570
+ _plot_for_regression(shap_values)
571
+
572
+
573
+ # Learning curves for regression and classification
574
+ def plot_learning_curves(
575
+ estimator,
576
+ X: np.ndarray,
577
+ y: np.ndarray,
578
+ task: Literal["classification", "regression"],
579
+ model_name: str,
580
+ target_name: str,
581
+ save_directory: Union[str, Path],
582
+ cv: int = 5,
583
+ n_jobs: int = -1,
584
+ figure_size: tuple = (12, 8),
585
+ base_fontsize: int = 24
586
+ ):
587
+ """
588
+ Generates and saves a plot of the learning curves for a given estimator
589
+ to diagnose bias vs. variance.
590
+
591
+ Computationally expensive, requires a fresh, unfitted instance of the model.
592
+ """
593
+ save_path = make_fullpath(save_directory, make=True)
594
+ sanitized_target_name = sanitize_filename(target_name)
595
+
596
+ # Select scoring metric based on task
597
+ scoring = "accuracy" if task == "classification" else "r2"
598
+
599
+ train_sizes_abs, train_scores, val_scores, *_ = learning_curve(
600
+ estimator, X, y,
601
+ cv=cv,
602
+ n_jobs=n_jobs,
603
+ train_sizes=np.linspace(0.1, 1.0, 10),
604
+ scoring=scoring
605
+ )
606
+
607
+ train_scores_mean = np.mean(train_scores, axis=1)
608
+ train_scores_std = np.std(train_scores, axis=1)
609
+ val_scores_mean = np.mean(val_scores, axis=1)
610
+ val_scores_std = np.std(val_scores, axis=1)
611
+
612
+ fig, ax = plt.subplots(figsize=figure_size)
613
+ ax.grid(True)
614
+
615
+ # Plot the mean scores
616
+ ax.plot(train_sizes_abs, train_scores_mean, 'o-', color="r", label="Training score")
617
+ ax.plot(train_sizes_abs, val_scores_mean, 'o-', color="g", label="Cross-validation score")
618
+
619
+ # Plot the standard deviation bands
620
+ ax.fill_between(train_sizes_abs, train_scores_mean - train_scores_std,
621
+ train_scores_mean + train_scores_std, alpha=0.1, color="r")
622
+ ax.fill_between(train_sizes_abs, val_scores_mean - val_scores_std,
623
+ val_scores_mean + val_scores_std, alpha=0.1, color="g")
624
+
625
+ ax.set_title(f"{model_name} - Learning Curve for {target_name}", fontsize=base_fontsize)
626
+ ax.set_xlabel("Training examples", fontsize=base_fontsize)
627
+ ax.set_ylabel(f"Score ({scoring})", fontsize=base_fontsize)
628
+ ax.legend(loc="best", fontsize=base_fontsize - 4)
629
+ ax.tick_params(axis='both', labelsize=base_fontsize - 4)
630
+ fig.tight_layout()
631
+
632
+ # Save figure
633
+ full_save_path = save_path / f"Learning_Curve_{sanitized_target_name}.svg"
634
+ fig.savefig(full_save_path, bbox_inches="tight", format="svg")
635
+ plt.close(fig)
636
+
637
+
638
+ def info():
639
+ _script_info(__all__)
@@ -1,7 +1,7 @@
1
1
  from ._script_info import _script_info
2
2
  from ._logger import _LOGGER
3
3
  from .path_manager import make_fullpath, list_files_by_extension
4
- from .keys import ModelSaveKeys
4
+ from .keys import EnsembleKeys
5
5
 
6
6
  from typing import Union, Literal, Dict, Any, Optional, List
7
7
  from pathlib import Path
@@ -49,9 +49,9 @@ class InferenceHandler:
49
49
  verbose=self.verbose,
50
50
  raise_on_error=True) # type: ignore
51
51
 
52
- model: Any = full_object[ModelSaveKeys.MODEL]
53
- target_name: str = full_object[ModelSaveKeys.TARGET]
54
- feature_names_list: List[str] = full_object[ModelSaveKeys.FEATURES]
52
+ model: Any = full_object[EnsembleKeys.MODEL]
53
+ target_name: str = full_object[EnsembleKeys.TARGET]
54
+ feature_names_list: List[str] = full_object[EnsembleKeys.FEATURES]
55
55
 
56
56
  # Check that feature names match
57
57
  if self._feature_names is None:
@@ -102,8 +102,8 @@ class InferenceHandler:
102
102
  else: # Classification
103
103
  label = model.predict(features)[0]
104
104
  probabilities = model.predict_proba(features)[0]
105
- results[target_name] = {ModelSaveKeys.CLASSIFICATION_LABEL: label,
106
- ModelSaveKeys.CLASSIFICATION_PROBABILITIES: probabilities}
105
+ results[target_name] = {EnsembleKeys.CLASSIFICATION_LABEL: label,
106
+ EnsembleKeys.CLASSIFICATION_PROBABILITIES: probabilities}
107
107
 
108
108
  if self.verbose:
109
109
  _LOGGER.info("✅ Inference process complete.")
@@ -170,15 +170,15 @@ def model_report(
170
170
  # --- 2. Deserialize and Extract Info ---
171
171
  try:
172
172
  full_object: dict = _deserialize_object(model_p) # type: ignore
173
- model = full_object[ModelSaveKeys.MODEL]
174
- target = full_object[ModelSaveKeys.TARGET]
175
- features = full_object[ModelSaveKeys.FEATURES]
173
+ model = full_object[EnsembleKeys.MODEL]
174
+ target = full_object[EnsembleKeys.TARGET]
175
+ features = full_object[EnsembleKeys.FEATURES]
176
176
  except FileNotFoundError:
177
177
  _LOGGER.error(f"❌ Model file not found at '{model_p}'")
178
178
  raise
179
179
  except (KeyError, TypeError) as e:
180
180
  _LOGGER.error(
181
- f"❌ The serialized object is missing required keys '{ModelSaveKeys.MODEL}', '{ModelSaveKeys.TARGET}', '{ModelSaveKeys.FEATURES}'"
181
+ f"❌ The serialized object is missing required keys '{EnsembleKeys.MODEL}', '{EnsembleKeys.TARGET}', '{EnsembleKeys.FEATURES}'"
182
182
  )
183
183
  raise e
184
184