dragon-ml-toolbox 5.3.0__py3-none-any.whl → 6.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dragon-ml-toolbox might be problematic. Click here for more details.

@@ -12,7 +12,7 @@ from .path_manager import sanitize_filename, make_fullpath, list_files_by_extens
12
12
  import torch
13
13
  from tqdm import trange
14
14
  from ._logger import _LOGGER
15
- from .keys import ModelSaveKeys
15
+ from .keys import EnsembleKeys
16
16
  from ._script_info import _script_info
17
17
  from .SQL import DatabaseManager
18
18
  from contextlib import nullcontext
@@ -48,9 +48,9 @@ class ObjectiveFunction():
48
48
  self.is_hybrid = False if binary_features <= 0 else True
49
49
  self.use_noise = add_noise
50
50
  self._artifact = deserialize_object(trained_model_path, verbose=False, raise_on_error=True)
51
- self.model = self._get_from_artifact(ModelSaveKeys.MODEL)
52
- self.feature_names: Optional[list[str]] = self._get_from_artifact(ModelSaveKeys.FEATURES) # type: ignore
53
- self.target_name: Optional[str] = self._get_from_artifact(ModelSaveKeys.TARGET) # type: ignore
51
+ self.model = self._get_from_artifact(EnsembleKeys.MODEL)
52
+ self.feature_names: Optional[list[str]] = self._get_from_artifact(EnsembleKeys.FEATURES) # type: ignore
53
+ self.target_name: Optional[str] = self._get_from_artifact(EnsembleKeys.TARGET) # type: ignore
54
54
  self.task = task
55
55
  self.check_model() # check for classification models and None values
56
56
 
@@ -126,7 +126,7 @@ class ObjectiveFunction():
126
126
  if self._artifact is None:
127
127
  raise TypeError("Load model error")
128
128
  val = self._artifact.get(key)
129
- if key == ModelSaveKeys.FEATURES:
129
+ if key == EnsembleKeys.FEATURES:
130
130
  result = val if isinstance(val, list) and val else None
131
131
  else:
132
132
  result = val if val else None
@@ -0,0 +1,639 @@
1
+ import pandas as pd
2
+ import numpy as np
3
+ import seaborn as sns
4
+ import matplotlib.pyplot as plt
5
+ from matplotlib.colors import Colormap
6
+ from matplotlib import rcdefaults
7
+ import shap
8
+ import xgboost as xgb
9
+ import lightgbm as lgb
10
+ from sklearn.model_selection import learning_curve
11
+ from sklearn.calibration import CalibrationDisplay
12
+ from sklearn.metrics import (accuracy_score,
13
+ classification_report,
14
+ ConfusionMatrixDisplay,
15
+ mean_absolute_error,
16
+ mean_squared_error,
17
+ r2_score,
18
+ roc_curve,
19
+ roc_auc_score,
20
+ precision_recall_curve,
21
+ average_precision_score)
22
+ from pathlib import Path
23
+ from typing import Union, Optional, Literal
24
+
25
+ from .path_manager import sanitize_filename, make_fullpath
26
+ from ._script_info import _script_info
27
+ from ._logger import _LOGGER
28
+
29
+
30
+ __all__ = [
31
+ "evaluate_model_classification",
32
+ "plot_roc_curve",
33
+ "plot_precision_recall_curve",
34
+ "plot_calibration_curve",
35
+ "evaluate_model_regression",
36
+ "get_shap_values",
37
+ "plot_learning_curves",
38
+ ]
39
+
40
+
41
+ # function to evaluate the model and save metrics (Classification)
42
+ def evaluate_model_classification(
43
+ model,
44
+ model_name: str,
45
+ save_dir: Union[str,Path],
46
+ x_test_scaled: np.ndarray,
47
+ single_y_test: np.ndarray,
48
+ target_name: str,
49
+ figsize: tuple = (10, 8),
50
+ base_fontsize: int = 24,
51
+ cmap: Colormap = plt.cm.Blues, # type: ignore
52
+ heatmap_cmap: str = "viridis"
53
+ ) -> np.ndarray:
54
+ """
55
+ Evaluates a classification model, saves the classification report (text and heatmap) and the confusion matrix plot.
56
+
57
+ Parameters:
58
+ model: Trained classifier with .predict() method
59
+ model_name: Identifier for the model
60
+ save_dir: Directory where results are saved
61
+ x_test_scaled: Feature matrix for test set
62
+ single_y_test: True targets
63
+ target_name: Target name
64
+ figsize: Size of the confusion matrix figure (width, height)
65
+ fontsize: Font size used for title, axis labels and ticks
66
+ heatmap_cmap: Colormap for the classification report heatmap.
67
+ cmap: Color map for the confusion matrix. Examples include:
68
+ - plt.cm.Blues (default)
69
+ - plt.cm.Greens
70
+ - plt.cm.Oranges
71
+ - plt.cm.Purples
72
+ - plt.cm.Reds
73
+ - plt.cm.cividis
74
+ - plt.cm.inferno
75
+
76
+ Returns:
77
+ y_pred: Predicted class labels
78
+ """
79
+ save_path = make_fullpath(save_dir, make=True)
80
+ sanitized_target_name = sanitize_filename(target_name)
81
+
82
+ y_pred = model.predict(x_test_scaled)
83
+ accuracy = accuracy_score(single_y_test, y_pred)
84
+
85
+ # Generate report as dictionary for the heatmap
86
+ report_dict = classification_report(
87
+ single_y_test,
88
+ y_pred,
89
+ target_names=["Negative", "Positive"],
90
+ output_dict=True
91
+ )
92
+
93
+ # text report to save
94
+ report_text = classification_report(
95
+ single_y_test,
96
+ y_pred,
97
+ target_names=["Negative", "Positive"],
98
+ output_dict=False
99
+ )
100
+
101
+ # Save text report
102
+
103
+ report_path = save_path / f"Classification_Report_{sanitized_target_name}.txt"
104
+ with open(report_path, "w") as f:
105
+ f.write(f"{model_name} - {target_name}\t\tAccuracy: {accuracy:.2f}\n")
106
+ f.write("Classification Report:\n")
107
+ f.write(report_text) # type: ignore
108
+
109
+ # 3. Create and save the classification report heatmap
110
+ try:
111
+ report_df = pd.DataFrame(report_dict).iloc[:-1, :].T
112
+ plt.figure(figsize=figsize)
113
+ sns.heatmap(report_df, annot=True, cmap=heatmap_cmap, fmt='.2f',
114
+ annot_kws={"size": base_fontsize - 4})
115
+ plt.title(f"{model_name} - {target_name}", fontsize=base_fontsize)
116
+ plt.xticks(fontsize=base_fontsize - 2)
117
+ plt.yticks(fontsize=base_fontsize - 2)
118
+
119
+ heatmap_path = save_path / f"Classification_Report_{sanitized_target_name}.svg"
120
+ plt.savefig(heatmap_path, format="svg", bbox_inches="tight")
121
+ plt.close()
122
+ except Exception as e:
123
+ _LOGGER.error(f"❌ Could not generate classification report heatmap for {target_name}: {e}")
124
+
125
+ # Create confusion matrix
126
+ fig, ax = plt.subplots(figsize=figsize)
127
+ disp = ConfusionMatrixDisplay.from_predictions(
128
+ y_true=single_y_test,
129
+ y_pred=y_pred,
130
+ display_labels=["Negative", "Positive"],
131
+ cmap=cmap,
132
+ normalize="true",
133
+ ax=ax
134
+ )
135
+
136
+ ax.set_title(f"{model_name} - {target_name}", fontsize=base_fontsize)
137
+ ax.tick_params(axis='both', labelsize=base_fontsize)
138
+ ax.set_xlabel("Predicted label", fontsize=base_fontsize)
139
+ ax.set_ylabel("True label", fontsize=base_fontsize)
140
+
141
+ # Turn off gridlines
142
+ ax.grid(False)
143
+
144
+ # Manually update font size of cell texts
145
+ for text in ax.texts:
146
+ text.set_fontsize(base_fontsize+4)
147
+
148
+ fig.tight_layout()
149
+ fig_path = save_path / f"Confusion_Matrix_{sanitized_target_name}.svg"
150
+ fig.savefig(fig_path, format="svg", bbox_inches="tight") # type: ignore
151
+ plt.close(fig)
152
+
153
+ return y_pred
154
+
155
+ #Function to save ROC and ROC AUC (Classification)
156
+ def plot_roc_curve(
157
+ true_labels: np.ndarray,
158
+ probabilities_or_model: Union[np.ndarray, xgb.XGBClassifier, lgb.LGBMClassifier, object],
159
+ model_name: str,
160
+ target_name: str,
161
+ save_directory: Union[str,Path],
162
+ color: str = "darkorange",
163
+ figure_size: tuple = (10, 10),
164
+ linewidth: int = 2,
165
+ base_fontsize: int = 24,
166
+ input_features: Optional[np.ndarray] = None,
167
+ ) -> plt.Figure: # type: ignore
168
+ """
169
+ Plots the ROC curve and computes AUC for binary classification. Positive class is assumed to be in the second column of the probabilities array.
170
+
171
+ Parameters:
172
+ true_labels: np.ndarray of shape (n_samples,), ground truth binary labels (0 or 1).
173
+ probabilities_or_model: either predicted probabilities (ndarray), or a trained model with attribute `.predict_proba()`.
174
+ target_name: str, Target name.
175
+ save_directory: str or Path, path to directory where figure is saved.
176
+ color: color of the ROC curve. Accepts any valid Matplotlib color specification. Examples:
177
+ - Named colors: "darkorange", "blue", "red", "green", "black"
178
+ - Hex codes: "#1f77b4", "#ff7f0e"
179
+ - RGB tuples: (0.2, 0.4, 0.6)
180
+ - Colormap value: plt.cm.viridis(0.6)
181
+ figure_size: Tuple for figure size (width, height).
182
+ linewidth: int, width of the plotted ROC line.
183
+ title_fontsize: int, font size of the title.
184
+ label_fontsize: int, font size for axes labels.
185
+ input_features: np.ndarray of shape (n_samples, n_features), required if a model is passed.
186
+
187
+ Returns:
188
+ fig: matplotlib Figure object
189
+ """
190
+
191
+ # Determine predicted probabilities
192
+ if isinstance(probabilities_or_model, np.ndarray):
193
+ # Input is already probabilities
194
+ if probabilities_or_model.ndim == 2: # type: ignore
195
+ y_score = probabilities_or_model[:, 1] # type: ignore
196
+ else:
197
+ y_score = probabilities_or_model
198
+
199
+ elif hasattr(probabilities_or_model, "predict_proba"):
200
+ if input_features is None:
201
+ raise ValueError("input_features must be provided when using a classifier.")
202
+
203
+ try:
204
+ classes = probabilities_or_model.classes_ # type: ignore
205
+ positive_class_index = list(classes).index(1)
206
+ except (AttributeError, ValueError):
207
+ positive_class_index = 1
208
+
209
+ y_score = probabilities_or_model.predict_proba(input_features)[:, positive_class_index] # type: ignore
210
+
211
+ else:
212
+ raise TypeError("Unsupported type for 'probabilities_or_model'. Must be a NumPy array or a model with support for '.predict_proba()'.")
213
+
214
+ # ROC and AUC
215
+ fpr, tpr, _ = roc_curve(true_labels, y_score)
216
+ auc_score = roc_auc_score(true_labels, y_score)
217
+
218
+ # Plot
219
+ fig, ax = plt.subplots(figsize=figure_size)
220
+ ax.plot(fpr, tpr, color=color, lw=linewidth, label=f"AUC = {auc_score:.2f}")
221
+ ax.plot([0, 1], [0, 1], color="gray", linestyle="--", lw=1)
222
+
223
+ ax.set_title(f"{model_name} - {target_name}", fontsize=base_fontsize)
224
+ ax.set_xlabel("False Positive Rate", fontsize=base_fontsize)
225
+ ax.set_ylabel("True Positive Rate", fontsize=base_fontsize)
226
+ ax.tick_params(axis='both', labelsize=base_fontsize)
227
+ ax.legend(loc="lower right", fontsize=base_fontsize)
228
+ ax.grid(True)
229
+
230
+ # Save figure
231
+ save_path = make_fullpath(save_directory, make=True)
232
+ sanitized_target_name = sanitize_filename(target_name)
233
+ full_save_path = save_path / f"ROC_{sanitized_target_name}.svg"
234
+ fig.savefig(full_save_path, bbox_inches="tight", format="svg") # type: ignore
235
+
236
+ return fig
237
+
238
+
239
+ # Precision-Recall curve (Classification)
240
+ def plot_precision_recall_curve(
241
+ true_labels: np.ndarray,
242
+ probabilities_or_model: Union[np.ndarray, xgb.XGBClassifier, lgb.LGBMClassifier, object],
243
+ model_name: str,
244
+ target_name: str,
245
+ save_directory: Union[str, Path],
246
+ color: str = "teal",
247
+ figure_size: tuple = (10, 10),
248
+ linewidth: int = 2,
249
+ base_fontsize: int = 24,
250
+ input_features: Optional[np.ndarray] = None,
251
+ ) -> plt.Figure: # type: ignore
252
+ """
253
+ Plots the Precision-Recall curve and computes Average Precision (AP) for binary classification.
254
+
255
+ Parameters:
256
+ true_labels: np.ndarray of shape (n_samples,), ground truth binary labels (0 or 1).
257
+ probabilities_or_model: either predicted probabilities (ndarray), or a trained model with attribute `.predict_proba()`.
258
+ model_name: Identifier for the model.
259
+ target_name: Name of the target variable.
260
+ save_directory: Path to the directory where the figure will be saved.
261
+ color: str, color of the PR curve.
262
+ figure_size: Tuple for figure size (width, height).
263
+ linewidth: int, width of the plotted PR line.
264
+ base_fontsize: int, base font size for titles and labels.
265
+ input_features: np.ndarray, required if a model object is passed instead of probabilities.
266
+
267
+ Returns:
268
+ fig: matplotlib Figure object
269
+ """
270
+ # Determine predicted probabilities for the positive class
271
+ if isinstance(probabilities_or_model, np.ndarray):
272
+ if probabilities_or_model.ndim == 2:
273
+ y_score = probabilities_or_model[:, 1]
274
+ else:
275
+ y_score = probabilities_or_model
276
+
277
+ elif hasattr(probabilities_or_model, "predict_proba"):
278
+ if input_features is None:
279
+ raise ValueError("input_features must be provided when using a classifier.")
280
+ try:
281
+ classes = probabilities_or_model.classes_ # type: ignore
282
+ positive_class_index = list(classes).index(1)
283
+ except (AttributeError, ValueError):
284
+ positive_class_index = 1
285
+ y_score = probabilities_or_model.predict_proba(input_features)[:, positive_class_index] # type: ignore
286
+ else:
287
+ raise TypeError("Unsupported type for 'probabilities_or_model'. Must be a NumPy array or a model with support for '.predict_proba()'.")
288
+
289
+ # Calculate PR curve and AP score
290
+ precision, recall, _ = precision_recall_curve(true_labels, y_score)
291
+ ap_score = average_precision_score(true_labels, y_score)
292
+
293
+ # Plot
294
+ fig, ax = plt.subplots(figsize=figure_size)
295
+ ax.plot(recall, precision, color=color, lw=linewidth, label=f"AP = {ap_score:.2f}")
296
+
297
+ ax.set_title(f"{model_name} - {target_name}", fontsize=base_fontsize)
298
+ ax.set_xlabel("Recall", fontsize=base_fontsize)
299
+ ax.set_ylabel("Precision", fontsize=base_fontsize)
300
+ ax.tick_params(axis='both', labelsize=base_fontsize)
301
+ ax.legend(loc="lower left", fontsize=base_fontsize)
302
+ ax.grid(True)
303
+ fig.tight_layout()
304
+
305
+ # Save figure
306
+ save_path = make_fullpath(save_directory, make=True)
307
+ sanitized_target_name = sanitize_filename(target_name)
308
+ full_save_path = save_path / f"PR_Curve_{sanitized_target_name}.svg"
309
+ fig.savefig(full_save_path, bbox_inches="tight", format="svg") # type: ignore
310
+ plt.close(fig)
311
+
312
+ return fig
313
+
314
+
315
+ # Calibration curve (classification)
316
+ def plot_calibration_curve(
317
+ model,
318
+ model_name: str,
319
+ save_dir: Union[str, Path],
320
+ x_test: np.ndarray,
321
+ y_test: np.ndarray,
322
+ target_name: str,
323
+ figure_size: tuple = (10, 10),
324
+ base_fontsize: int = 24,
325
+ n_bins: int = 15
326
+ ) -> plt.Figure: # type: ignore
327
+ """
328
+ Plots the calibration curve (reliability diagram) for a classifier.
329
+
330
+ Parameters:
331
+ model: Trained classifier with .predict_proba() method.
332
+ model_name: Identifier for the model.
333
+ save_dir: Directory where the plot will be saved.
334
+ x_test: Feature matrix for the test set.
335
+ y_test: True labels for the test set.
336
+ target_name: Name of the target variable.
337
+ figure_size: Tuple for figure size (width, height).
338
+ base_fontsize: Base font size for titles and labels.
339
+ n_bins: Number of bins to discretize predictions into.
340
+
341
+ Returns:
342
+ fig: matplotlib Figure object
343
+ """
344
+ fig, ax = plt.subplots(figsize=figure_size)
345
+
346
+ disp = CalibrationDisplay.from_estimator(
347
+ model,
348
+ x_test,
349
+ y_test,
350
+ n_bins=n_bins,
351
+ ax=ax
352
+ )
353
+
354
+ ax.set_title(f"{model_name} - Calibration Plot for {target_name}", fontsize=base_fontsize)
355
+ ax.tick_params(axis='both', labelsize=base_fontsize - 2)
356
+ ax.set_xlabel("Mean Predicted Probability", fontsize=base_fontsize)
357
+ ax.set_ylabel("Fraction of Positives", fontsize=base_fontsize)
358
+ ax.legend(fontsize=base_fontsize - 4)
359
+ fig.tight_layout()
360
+
361
+ # Save figure
362
+ save_path = make_fullpath(save_dir, make=True)
363
+ sanitized_target_name = sanitize_filename(target_name)
364
+ full_save_path = save_path / f"Calibration_Plot_{sanitized_target_name}.svg"
365
+ fig.savefig(full_save_path, bbox_inches="tight", format="svg") # type: ignore
366
+ plt.close(fig)
367
+
368
+ return fig
369
+
370
+
371
+ # function to evaluate the model and save metrics (Regression)
372
+ def evaluate_model_regression(model, model_name: str,
373
+ save_dir: Union[str,Path],
374
+ x_test_scaled: np.ndarray, single_y_test: np.ndarray,
375
+ target_name: str,
376
+ figure_size: tuple = (12, 8),
377
+ alpha_transparency: float = 0.5,
378
+ base_fontsize: int = 24,
379
+ hist_bins: int = 30):
380
+ # Generate predictions
381
+ y_pred = model.predict(x_test_scaled)
382
+
383
+ # Calculate regression metrics
384
+ mae = mean_absolute_error(single_y_test, y_pred)
385
+ mse = mean_squared_error(single_y_test, y_pred)
386
+ rmse = np.sqrt(mse)
387
+ r2 = r2_score(single_y_test, y_pred)
388
+
389
+ # Create formatted report
390
+ sanitized_target_name = sanitize_filename(target_name)
391
+ save_path = make_fullpath(save_dir, make=True)
392
+ report_path = save_path / f"Regression_Report_{sanitized_target_name}.txt"
393
+ with open(report_path, "w") as f:
394
+ f.write(f"{model_name} - Regression Performance for '{target_name}'\n\n")
395
+ f.write(f"Mean Absolute Error (MAE): {mae:.4f}\n")
396
+ f.write(f"Mean Squared Error (MSE): {mse:.4f}\n")
397
+ f.write(f"Root Mean Squared Error (RMSE): {rmse:.4f}\n")
398
+ f.write(f"R² Score: {r2:.4f}\n")
399
+
400
+ # Generate and save residual plot
401
+ residuals = single_y_test - y_pred
402
+
403
+ plt.figure(figsize=figure_size)
404
+ plt.scatter(y_pred, residuals, alpha=alpha_transparency)
405
+ plt.axhline(0, color='red', linestyle='--')
406
+ plt.xlabel("Predicted Values", fontsize=base_fontsize)
407
+ plt.ylabel("Residuals", fontsize=base_fontsize)
408
+ plt.title(f"{model_name} - Residual Plot for {target_name}", fontsize=base_fontsize)
409
+ plt.grid(True)
410
+ plt.tight_layout()
411
+ residual_path = save_path / f"Residuals_Plot_{sanitized_target_name}.svg"
412
+ plt.savefig(residual_path, bbox_inches='tight', format="svg")
413
+ plt.close()
414
+
415
+ # Create true vs predicted values plot
416
+ plt.figure(figsize=figure_size)
417
+ plt.scatter(single_y_test, y_pred, alpha=alpha_transparency)
418
+ plt.plot([single_y_test.min(), single_y_test.max()],
419
+ [single_y_test.min(), single_y_test.max()],
420
+ 'k--', lw=2)
421
+ plt.xlabel('True Values', fontsize=base_fontsize)
422
+ plt.ylabel('Predictions', fontsize=base_fontsize)
423
+ plt.title(f"{model_name} - True vs Predicted for {target_name}", fontsize=base_fontsize)
424
+ plt.grid(True)
425
+ plot_path = save_path / f"True_Vs_Predict_Plot_{sanitized_target_name}.svg"
426
+ plt.savefig(plot_path, bbox_inches='tight', format="svg")
427
+ plt.close()
428
+
429
+ # Generate and save histogram of residuals
430
+ plt.figure(figsize=figure_size)
431
+ sns.histplot(residuals, bins=hist_bins, kde=True)
432
+ plt.xlabel("Residual Value", fontsize=base_fontsize)
433
+ plt.ylabel("Frequency", fontsize=base_fontsize)
434
+ plt.title(f"{model_name} - Distribution of Residuals for {target_name}", fontsize=base_fontsize)
435
+ plt.grid(True)
436
+ plt.tight_layout()
437
+ hist_path = save_path / f"Residuals_Distribution_{sanitized_target_name}.svg"
438
+ plt.savefig(hist_path, bbox_inches='tight', format="svg")
439
+ plt.close()
440
+
441
+ return y_pred
442
+
443
+
444
+ # Get SHAP values
445
+ def get_shap_values(
446
+ model,
447
+ model_name: str,
448
+ save_dir: Union[str, Path],
449
+ features_to_explain: np.ndarray,
450
+ feature_names: list[str],
451
+ target_name: str,
452
+ task: Literal["classification", "regression"],
453
+ max_display_features: int = 10,
454
+ figsize: tuple = (16, 20),
455
+ base_fontsize: int = 38,
456
+ ):
457
+ """
458
+ Universal SHAP explainer for regression and classification.
459
+ * Use `X_train` (or a subsample of it) to see how the model explains the data it was trained on.
460
+
461
+ * Use `X_test` (or a hold-out set) to see how the model explains unseen data.
462
+
463
+ * Use the entire dataset to get the global view.
464
+
465
+ Parameters:
466
+ task: 'regression' or 'classification'.
467
+ features_to_explain: Should match the model's training data format, including scaling.
468
+ save_dir: Directory to save visualizations.
469
+ """
470
+ sanitized_target_name = sanitize_filename(target_name)
471
+ global_save_path = make_fullpath(save_dir, make=True)
472
+
473
+ def _apply_plot_style():
474
+ styles = ['seaborn', 'seaborn-v0_8-darkgrid', 'seaborn-v0_8', 'default']
475
+ for style in styles:
476
+ if style in plt.style.available or style == 'default':
477
+ plt.style.use(style)
478
+ break
479
+
480
+ def _configure_rcparams():
481
+ plt.rc('font', size=base_fontsize)
482
+ plt.rc('axes', titlesize=base_fontsize)
483
+ plt.rc('axes', labelsize=base_fontsize)
484
+ plt.rc('xtick', labelsize=base_fontsize)
485
+ plt.rc('ytick', labelsize=base_fontsize + 2)
486
+ plt.rc('legend', fontsize=base_fontsize)
487
+ plt.rc('figure', titlesize=base_fontsize)
488
+
489
+ def _create_shap_plot(shap_values, features, save_path: Path, plot_type: str, title: str):
490
+ _apply_plot_style()
491
+ _configure_rcparams()
492
+ plt.figure(figsize=figsize)
493
+
494
+ shap.summary_plot(
495
+ shap_values=shap_values,
496
+ features=features,
497
+ feature_names=feature_names,
498
+ plot_type=plot_type,
499
+ show=False,
500
+ plot_size=figsize,
501
+ max_display=max_display_features,
502
+ alpha=0.7,
503
+ # color='viridis'
504
+ )
505
+
506
+ ax = plt.gca()
507
+ ax.set_xlabel("SHAP Value Impact", fontsize=base_fontsize + 2, weight='bold', labelpad=20)
508
+ plt.title(title, fontsize=base_fontsize + 2, pad=20, weight='bold')
509
+
510
+ for tick in ax.get_xticklabels():
511
+ tick.set_fontsize(base_fontsize)
512
+ tick.set_rotation(30)
513
+ for tick in ax.get_yticklabels():
514
+ tick.set_fontsize(base_fontsize + 2)
515
+
516
+ if plot_type == "dot":
517
+ cb = plt.gcf().axes[-1]
518
+ cb.set_ylabel("", size=1)
519
+ cb.tick_params(labelsize=base_fontsize - 2)
520
+
521
+ plt.savefig(save_path, bbox_inches='tight', facecolor='white', format="svg")
522
+ plt.close()
523
+ rcdefaults()
524
+
525
+ def _plot_for_classification(shap_values, class_names):
526
+ is_multiclass = isinstance(shap_values, list) and len(shap_values) > 1
527
+
528
+ if is_multiclass:
529
+ for class_shap, class_name in zip(shap_values, class_names):
530
+ for plot_type in ["bar", "dot"]:
531
+ _create_shap_plot(
532
+ shap_values=class_shap,
533
+ features=features_to_explain,
534
+ save_path=global_save_path / f"SHAP_{sanitized_target_name}_Class{class_name}_{plot_type}.svg",
535
+ plot_type=plot_type,
536
+ title=f"{model_name} - {target_name} (Class {class_name})"
537
+ )
538
+ else:
539
+ values = shap_values[1] if isinstance(shap_values, list) else shap_values
540
+ for plot_type in ["bar", "dot"]:
541
+ _create_shap_plot(
542
+ shap_values=values,
543
+ features=features_to_explain,
544
+ save_path=global_save_path / f"SHAP_{sanitized_target_name}_{plot_type}.svg",
545
+ plot_type=plot_type,
546
+ title=f"{model_name} - {target_name}"
547
+ )
548
+
549
+ def _plot_for_regression(shap_values):
550
+ for plot_type in ["bar", "dot"]:
551
+ _create_shap_plot(
552
+ shap_values=shap_values,
553
+ features=features_to_explain,
554
+ save_path=global_save_path / f"SHAP_{sanitized_target_name}_{plot_type}.svg",
555
+ plot_type=plot_type,
556
+ title=f"{model_name} - {target_name}"
557
+ )
558
+ #START_O
559
+
560
+ explainer = shap.TreeExplainer(model)
561
+ shap_values = explainer.shap_values(features_to_explain)
562
+
563
+ if task == 'classification':
564
+ try:
565
+ class_names = model.classes_ if hasattr(model, 'classes_') else list(range(len(shap_values)))
566
+ except Exception:
567
+ class_names = list(range(len(shap_values)))
568
+ _plot_for_classification(shap_values, class_names)
569
+ else:
570
+ _plot_for_regression(shap_values)
571
+
572
+
573
+ # Learning curves for regression and classification
574
+ def plot_learning_curves(
575
+ estimator,
576
+ X: np.ndarray,
577
+ y: np.ndarray,
578
+ task: Literal["classification", "regression"],
579
+ model_name: str,
580
+ target_name: str,
581
+ save_directory: Union[str, Path],
582
+ cv: int = 5,
583
+ n_jobs: int = -1,
584
+ figure_size: tuple = (12, 8),
585
+ base_fontsize: int = 24
586
+ ):
587
+ """
588
+ Generates and saves a plot of the learning curves for a given estimator
589
+ to diagnose bias vs. variance.
590
+
591
+ Computationally expensive, requires a fresh, unfitted instance of the model.
592
+ """
593
+ save_path = make_fullpath(save_directory, make=True)
594
+ sanitized_target_name = sanitize_filename(target_name)
595
+
596
+ # Select scoring metric based on task
597
+ scoring = "accuracy" if task == "classification" else "r2"
598
+
599
+ train_sizes_abs, train_scores, val_scores, *_ = learning_curve(
600
+ estimator, X, y,
601
+ cv=cv,
602
+ n_jobs=n_jobs,
603
+ train_sizes=np.linspace(0.1, 1.0, 10),
604
+ scoring=scoring
605
+ )
606
+
607
+ train_scores_mean = np.mean(train_scores, axis=1)
608
+ train_scores_std = np.std(train_scores, axis=1)
609
+ val_scores_mean = np.mean(val_scores, axis=1)
610
+ val_scores_std = np.std(val_scores, axis=1)
611
+
612
+ fig, ax = plt.subplots(figsize=figure_size)
613
+ ax.grid(True)
614
+
615
+ # Plot the mean scores
616
+ ax.plot(train_sizes_abs, train_scores_mean, 'o-', color="r", label="Training score")
617
+ ax.plot(train_sizes_abs, val_scores_mean, 'o-', color="g", label="Cross-validation score")
618
+
619
+ # Plot the standard deviation bands
620
+ ax.fill_between(train_sizes_abs, train_scores_mean - train_scores_std,
621
+ train_scores_mean + train_scores_std, alpha=0.1, color="r")
622
+ ax.fill_between(train_sizes_abs, val_scores_mean - val_scores_std,
623
+ val_scores_mean + val_scores_std, alpha=0.1, color="g")
624
+
625
+ ax.set_title(f"{model_name} - Learning Curve for {target_name}", fontsize=base_fontsize)
626
+ ax.set_xlabel("Training examples", fontsize=base_fontsize)
627
+ ax.set_ylabel(f"Score ({scoring})", fontsize=base_fontsize)
628
+ ax.legend(loc="best", fontsize=base_fontsize - 4)
629
+ ax.tick_params(axis='both', labelsize=base_fontsize - 4)
630
+ fig.tight_layout()
631
+
632
+ # Save figure
633
+ full_save_path = save_path / f"Learning_Curve_{sanitized_target_name}.svg"
634
+ fig.savefig(full_save_path, bbox_inches="tight", format="svg")
635
+ plt.close(fig)
636
+
637
+
638
+ def info():
639
+ _script_info(__all__)