dragon-ml-toolbox 13.3.0__py3-none-any.whl → 16.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/METADATA +20 -6
- dragon_ml_toolbox-16.2.0.dist-info/RECORD +51 -0
- {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +10 -0
- ml_tools/ETL_cleaning.py +20 -20
- ml_tools/ETL_engineering.py +23 -25
- ml_tools/GUI_tools.py +20 -20
- ml_tools/MICE_imputation.py +207 -5
- ml_tools/ML_callbacks.py +43 -26
- ml_tools/ML_configuration.py +788 -0
- ml_tools/ML_datasetmaster.py +303 -448
- ml_tools/ML_evaluation.py +351 -93
- ml_tools/ML_evaluation_multi.py +139 -42
- ml_tools/ML_inference.py +290 -209
- ml_tools/ML_models.py +33 -106
- ml_tools/ML_models_advanced.py +323 -0
- ml_tools/ML_optimization.py +12 -12
- ml_tools/ML_scaler.py +11 -11
- ml_tools/ML_sequence_datasetmaster.py +341 -0
- ml_tools/ML_sequence_evaluation.py +219 -0
- ml_tools/ML_sequence_inference.py +391 -0
- ml_tools/ML_sequence_models.py +139 -0
- ml_tools/ML_trainer.py +1604 -179
- ml_tools/ML_utilities.py +351 -4
- ml_tools/ML_vision_datasetmaster.py +1540 -0
- ml_tools/ML_vision_evaluation.py +284 -0
- ml_tools/ML_vision_inference.py +405 -0
- ml_tools/ML_vision_models.py +641 -0
- ml_tools/ML_vision_transformers.py +284 -0
- ml_tools/PSO_optimization.py +6 -6
- ml_tools/SQL.py +4 -4
- ml_tools/_keys.py +171 -0
- ml_tools/_schema.py +1 -1
- ml_tools/custom_logger.py +37 -14
- ml_tools/data_exploration.py +502 -93
- ml_tools/ensemble_evaluation.py +54 -11
- ml_tools/ensemble_inference.py +7 -33
- ml_tools/ensemble_learning.py +1 -1
- ml_tools/math_utilities.py +1 -1
- ml_tools/optimization_tools.py +2 -2
- ml_tools/path_manager.py +5 -5
- ml_tools/serde.py +2 -2
- ml_tools/utilities.py +192 -4
- dragon_ml_toolbox-13.3.0.dist-info/RECORD +0 -41
- ml_tools/RNN_forecast.py +0 -56
- ml_tools/keys.py +0 -87
- {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/top_level.txt +0 -0
ml_tools/ML_evaluation_multi.py
CHANGED
|
@@ -19,13 +19,17 @@ from sklearn.metrics import (
|
|
|
19
19
|
jaccard_score
|
|
20
20
|
)
|
|
21
21
|
from pathlib import Path
|
|
22
|
-
from typing import Union, List, Literal
|
|
22
|
+
from typing import Union, List, Literal, Optional
|
|
23
23
|
import warnings
|
|
24
24
|
|
|
25
25
|
from .path_manager import make_fullpath, sanitize_filename
|
|
26
26
|
from ._logger import _LOGGER
|
|
27
27
|
from ._script_info import _script_info
|
|
28
|
-
from .
|
|
28
|
+
from ._keys import SHAPKeys
|
|
29
|
+
from .ML_configuration import (MultiTargetRegressionMetricsFormat,
|
|
30
|
+
_BaseRegressionFormat,
|
|
31
|
+
MultiLabelBinaryClassificationMetricsFormat,
|
|
32
|
+
_BaseMultiLabelFormat)
|
|
29
33
|
|
|
30
34
|
|
|
31
35
|
__all__ = [
|
|
@@ -34,12 +38,15 @@ __all__ = [
|
|
|
34
38
|
"multi_target_shap_summary_plot",
|
|
35
39
|
]
|
|
36
40
|
|
|
41
|
+
DPI_value = 250
|
|
42
|
+
|
|
37
43
|
|
|
38
44
|
def multi_target_regression_metrics(
|
|
39
45
|
y_true: np.ndarray,
|
|
40
46
|
y_pred: np.ndarray,
|
|
41
47
|
target_names: List[str],
|
|
42
|
-
save_dir: Union[str, Path]
|
|
48
|
+
save_dir: Union[str, Path],
|
|
49
|
+
config: Optional[MultiTargetRegressionMetricsFormat] = None
|
|
43
50
|
):
|
|
44
51
|
"""
|
|
45
52
|
Calculates and saves regression metrics for each target individually.
|
|
@@ -53,6 +60,7 @@ def multi_target_regression_metrics(
|
|
|
53
60
|
y_pred (np.ndarray): Predicted values, shape (n_samples, n_targets).
|
|
54
61
|
target_names (List[str]): A list of names for the target variables.
|
|
55
62
|
save_dir (str | Path): Directory to save plots and the report.
|
|
63
|
+
config (object): Formatting configuration object.
|
|
56
64
|
"""
|
|
57
65
|
if y_true.ndim != 2 or y_pred.ndim != 2:
|
|
58
66
|
_LOGGER.error("y_true and y_pred must be 2D arrays for multi-target regression.")
|
|
@@ -66,8 +74,19 @@ def multi_target_regression_metrics(
|
|
|
66
74
|
|
|
67
75
|
save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
68
76
|
metrics_summary = []
|
|
77
|
+
|
|
78
|
+
# --- Parse Config or use defaults ---
|
|
79
|
+
if config is None:
|
|
80
|
+
# Create a default config if one wasn't provided
|
|
81
|
+
format_config = _BaseRegressionFormat()
|
|
82
|
+
else:
|
|
83
|
+
format_config = config
|
|
84
|
+
|
|
85
|
+
# --- Set Matplotlib font size ---
|
|
86
|
+
original_rc_params = plt.rcParams.copy()
|
|
87
|
+
plt.rcParams.update({'font.size': format_config.font_size})
|
|
69
88
|
|
|
70
|
-
_LOGGER.
|
|
89
|
+
_LOGGER.debug("--- Multi-Target Regression Evaluation ---")
|
|
71
90
|
|
|
72
91
|
for i, name in enumerate(target_names):
|
|
73
92
|
print(f" -> Evaluating target: '{name}'")
|
|
@@ -90,9 +109,13 @@ def multi_target_regression_metrics(
|
|
|
90
109
|
|
|
91
110
|
# --- Save Residual Plot ---
|
|
92
111
|
residuals = true_i - pred_i
|
|
93
|
-
fig_res, ax_res = plt.subplots(figsize=(8, 6), dpi=
|
|
94
|
-
ax_res.scatter(pred_i, residuals,
|
|
95
|
-
|
|
112
|
+
fig_res, ax_res = plt.subplots(figsize=(8, 6), dpi=DPI_value)
|
|
113
|
+
ax_res.scatter(pred_i, residuals,
|
|
114
|
+
alpha=format_config.scatter_alpha,
|
|
115
|
+
edgecolors='k',
|
|
116
|
+
s=50,
|
|
117
|
+
color=format_config.scatter_color) # Use config color
|
|
118
|
+
ax_res.axhline(0, color=format_config.residual_line_color, linestyle='--') # Use config color
|
|
96
119
|
ax_res.set_xlabel("Predicted Values")
|
|
97
120
|
ax_res.set_ylabel("Residuals (True - Predicted)")
|
|
98
121
|
ax_res.set_title(f"Residual Plot for '{name}'")
|
|
@@ -103,9 +126,16 @@ def multi_target_regression_metrics(
|
|
|
103
126
|
plt.close(fig_res)
|
|
104
127
|
|
|
105
128
|
# --- Save True vs. Predicted Plot ---
|
|
106
|
-
fig_tvp, ax_tvp = plt.subplots(figsize=(8, 6), dpi=
|
|
107
|
-
ax_tvp.scatter(true_i, pred_i,
|
|
108
|
-
|
|
129
|
+
fig_tvp, ax_tvp = plt.subplots(figsize=(8, 6), dpi=DPI_value)
|
|
130
|
+
ax_tvp.scatter(true_i, pred_i,
|
|
131
|
+
alpha=format_config.scatter_alpha,
|
|
132
|
+
edgecolors='k',
|
|
133
|
+
s=50,
|
|
134
|
+
color=format_config.scatter_color) # Use config color
|
|
135
|
+
ax_tvp.plot([true_i.min(), true_i.max()], [true_i.min(), true_i.max()],
|
|
136
|
+
linestyle='--',
|
|
137
|
+
lw=2,
|
|
138
|
+
color=format_config.ideal_line_color) # Use config color
|
|
109
139
|
ax_tvp.set_xlabel('True Values')
|
|
110
140
|
ax_tvp.set_ylabel('Predicted Values')
|
|
111
141
|
ax_tvp.set_title(f'True vs. Predicted Values for "{name}"')
|
|
@@ -120,14 +150,18 @@ def multi_target_regression_metrics(
|
|
|
120
150
|
report_path = save_dir_path / "regression_report_multi.csv"
|
|
121
151
|
summary_df.to_csv(report_path, index=False)
|
|
122
152
|
_LOGGER.info(f"Full regression report saved to '{report_path.name}'")
|
|
153
|
+
|
|
154
|
+
# --- Restore RC params ---
|
|
155
|
+
plt.rcParams.update(original_rc_params)
|
|
123
156
|
|
|
124
157
|
|
|
125
158
|
def multi_label_classification_metrics(
|
|
126
159
|
y_true: np.ndarray,
|
|
160
|
+
y_pred: np.ndarray,
|
|
127
161
|
y_prob: np.ndarray,
|
|
128
162
|
target_names: List[str],
|
|
129
163
|
save_dir: Union[str, Path],
|
|
130
|
-
|
|
164
|
+
config: Optional[MultiLabelBinaryClassificationMetricsFormat] = None
|
|
131
165
|
):
|
|
132
166
|
"""
|
|
133
167
|
Calculates and saves classification metrics for each label individually.
|
|
@@ -138,17 +172,17 @@ def multi_label_classification_metrics(
|
|
|
138
172
|
|
|
139
173
|
Args:
|
|
140
174
|
y_true (np.ndarray): Ground truth binary labels, shape (n_samples, n_labels).
|
|
175
|
+
y_pred (np.ndarray): Predicted binary labels, shape (n_samples, n_labels).
|
|
141
176
|
y_prob (np.ndarray): Predicted probabilities, shape (n_samples, n_labels).
|
|
142
177
|
target_names (List[str]): A list of names for the labels.
|
|
143
178
|
save_dir (str | Path): Directory to save plots and reports.
|
|
144
|
-
|
|
145
|
-
binary predictions for metrics like the confusion matrix.
|
|
179
|
+
config (object): Formatting configuration object.
|
|
146
180
|
"""
|
|
147
|
-
if y_true.ndim != 2 or y_prob.ndim != 2:
|
|
148
|
-
_LOGGER.error("y_true and y_prob must be 2D arrays for multi-label classification.")
|
|
181
|
+
if y_true.ndim != 2 or y_prob.ndim != 2 or y_pred.ndim != 2:
|
|
182
|
+
_LOGGER.error("y_true, y_pred, and y_prob must be 2D arrays for multi-label classification.")
|
|
149
183
|
raise ValueError()
|
|
150
|
-
if y_true.shape != y_prob.shape:
|
|
151
|
-
_LOGGER.error("Shapes of y_true and y_prob must match.")
|
|
184
|
+
if y_true.shape != y_prob.shape or y_true.shape != y_pred.shape:
|
|
185
|
+
_LOGGER.error("Shapes of y_true, y_pred, and y_prob must match.")
|
|
152
186
|
raise ValueError()
|
|
153
187
|
if y_true.shape[1] != len(target_names):
|
|
154
188
|
_LOGGER.error("Number of target names must match the number of columns in y_true.")
|
|
@@ -156,25 +190,35 @@ def multi_label_classification_metrics(
|
|
|
156
190
|
|
|
157
191
|
save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
158
192
|
|
|
159
|
-
#
|
|
160
|
-
|
|
193
|
+
# --- Parse Config or use defaults ---
|
|
194
|
+
if config is None:
|
|
195
|
+
# Create a default config if one wasn't provided
|
|
196
|
+
format_config = _BaseMultiLabelFormat()
|
|
197
|
+
else:
|
|
198
|
+
format_config = config
|
|
199
|
+
|
|
200
|
+
# y_pred is now passed in directly, no threshold needed.
|
|
201
|
+
|
|
202
|
+
# --- Save current RC params and update font size ---
|
|
203
|
+
original_rc_params = plt.rcParams.copy()
|
|
204
|
+
plt.rcParams.update({'font.size': format_config.font_size})
|
|
161
205
|
|
|
162
|
-
_LOGGER.info("--- Multi-Label Classification Evaluation ---")
|
|
206
|
+
# _LOGGER.info("--- Multi-Label Classification Evaluation ---")
|
|
163
207
|
|
|
164
|
-
# --- Calculate and Save Overall Metrics ---
|
|
208
|
+
# --- Calculate and Save Overall Metrics (using y_pred) ---
|
|
165
209
|
h_loss = hamming_loss(y_true, y_pred)
|
|
166
210
|
j_score_micro = jaccard_score(y_true, y_pred, average='micro')
|
|
167
211
|
j_score_macro = jaccard_score(y_true, y_pred, average='macro')
|
|
168
212
|
|
|
169
213
|
overall_report = (
|
|
170
|
-
f"Overall Multi-Label Metrics
|
|
214
|
+
f"Overall Multi-Label Metrics:\n" # No threshold to report here
|
|
171
215
|
f"--------------------------------------------------\n"
|
|
172
216
|
f"Hamming Loss: {h_loss:.4f}\n"
|
|
173
217
|
f"Jaccard Score (micro): {j_score_micro:.4f}\n"
|
|
174
218
|
f"Jaccard Score (macro): {j_score_macro:.4f}\n"
|
|
175
219
|
f"--------------------------------------------------\n"
|
|
176
220
|
)
|
|
177
|
-
print(overall_report)
|
|
221
|
+
# print(overall_report)
|
|
178
222
|
overall_report_path = save_dir_path / "classification_report_overall.txt"
|
|
179
223
|
overall_report_path.write_text(overall_report)
|
|
180
224
|
|
|
@@ -182,28 +226,79 @@ def multi_label_classification_metrics(
|
|
|
182
226
|
for i, name in enumerate(target_names):
|
|
183
227
|
print(f" -> Evaluating label: '{name}'")
|
|
184
228
|
true_i = y_true[:, i]
|
|
185
|
-
pred_i = y_pred[:, i]
|
|
186
|
-
prob_i = y_prob[:, i]
|
|
229
|
+
pred_i = y_pred[:, i] # Use passed-in y_pred
|
|
230
|
+
prob_i = y_prob[:, i] # Use passed-in y_prob
|
|
187
231
|
sanitized_name = sanitize_filename(name)
|
|
188
232
|
|
|
189
|
-
# --- Save Classification Report for the label ---
|
|
233
|
+
# --- Save Classification Report for the label (uses y_pred) ---
|
|
190
234
|
report_text = classification_report(true_i, pred_i)
|
|
191
235
|
report_path = save_dir_path / f"classification_report_{sanitized_name}.txt"
|
|
192
236
|
report_path.write_text(report_text) # type: ignore
|
|
193
237
|
|
|
194
|
-
# --- Save Confusion Matrix ---
|
|
195
|
-
fig_cm, ax_cm = plt.subplots(figsize=(6, 6), dpi=
|
|
196
|
-
ConfusionMatrixDisplay.from_predictions(true_i,
|
|
238
|
+
# --- Save Confusion Matrix (uses y_pred) ---
|
|
239
|
+
fig_cm, ax_cm = plt.subplots(figsize=(6, 6), dpi=DPI_value)
|
|
240
|
+
disp_ = ConfusionMatrixDisplay.from_predictions(true_i,
|
|
241
|
+
pred_i,
|
|
242
|
+
cmap=format_config.cmap, # Use config cmap
|
|
243
|
+
ax=ax_cm,
|
|
244
|
+
normalize='true',
|
|
245
|
+
labels=[0, 1],
|
|
246
|
+
display_labels=["Negative", "Positive"])
|
|
247
|
+
|
|
248
|
+
disp_.im_.set_clim(vmin=0.0, vmax=1.0)
|
|
249
|
+
|
|
250
|
+
# Turn off gridlines
|
|
251
|
+
ax_cm.grid(False)
|
|
252
|
+
|
|
253
|
+
# Manually update font size of cell texts
|
|
254
|
+
for text in ax_cm.texts:
|
|
255
|
+
text.set_fontsize(format_config.font_size) # Use config font_size
|
|
256
|
+
|
|
257
|
+
fig_cm.tight_layout()
|
|
258
|
+
|
|
197
259
|
ax_cm.set_title(f"Confusion Matrix for '{name}'")
|
|
198
260
|
cm_path = save_dir_path / f"confusion_matrix_{sanitized_name}.svg"
|
|
199
261
|
plt.savefig(cm_path)
|
|
200
262
|
plt.close(fig_cm)
|
|
201
263
|
|
|
202
|
-
# --- Save ROC Curve ---
|
|
203
|
-
fpr, tpr,
|
|
264
|
+
# --- Save ROC Curve (uses y_prob) ---
|
|
265
|
+
fpr, tpr, thresholds = roc_curve(true_i, prob_i)
|
|
266
|
+
|
|
267
|
+
try:
|
|
268
|
+
# Calculate Youden's J statistic (tpr - fpr)
|
|
269
|
+
J = tpr - fpr
|
|
270
|
+
# Find the index of the best threshold
|
|
271
|
+
best_index = np.argmax(J)
|
|
272
|
+
optimal_threshold = thresholds[best_index]
|
|
273
|
+
best_tpr = tpr[best_index]
|
|
274
|
+
best_fpr = fpr[best_index]
|
|
275
|
+
|
|
276
|
+
# Define the filename
|
|
277
|
+
threshold_filename = f"best_threshold_{sanitized_name}.txt"
|
|
278
|
+
threshold_path = save_dir_path / threshold_filename
|
|
279
|
+
|
|
280
|
+
# The class name is the target_name for this label
|
|
281
|
+
class_name = name
|
|
282
|
+
|
|
283
|
+
# Create content for the file
|
|
284
|
+
file_content = (
|
|
285
|
+
f"Optimal Classification Threshold (Youden's J Statistic)\n"
|
|
286
|
+
f"Class/Label: {class_name}\n"
|
|
287
|
+
f"--------------------------------------------------\n"
|
|
288
|
+
f"Threshold: {optimal_threshold:.6f}\n"
|
|
289
|
+
f"True Positive Rate (TPR): {best_tpr:.6f}\n"
|
|
290
|
+
f"False Positive Rate (FPR): {best_fpr:.6f}\n"
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
threshold_path.write_text(file_content, encoding="utf-8")
|
|
294
|
+
_LOGGER.info(f"💾 Optimal threshold for '{name}' saved to '{threshold_path.name}'")
|
|
295
|
+
|
|
296
|
+
except Exception as e:
|
|
297
|
+
_LOGGER.warning(f"Could not calculate or save optimal threshold for '{name}': {e}")
|
|
298
|
+
|
|
204
299
|
auc = roc_auc_score(true_i, prob_i)
|
|
205
|
-
fig_roc, ax_roc = plt.subplots(figsize=(6, 6), dpi=
|
|
206
|
-
ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}')
|
|
300
|
+
fig_roc, ax_roc = plt.subplots(figsize=(6, 6), dpi=DPI_value)
|
|
301
|
+
ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}', color=format_config.ROC_PR_line) # Use config color
|
|
207
302
|
ax_roc.plot([0, 1], [0, 1], 'k--')
|
|
208
303
|
ax_roc.set_title(f'ROC Curve for "{name}"')
|
|
209
304
|
ax_roc.set_xlabel('False Positive Rate'); ax_roc.set_ylabel('True Positive Rate')
|
|
@@ -212,17 +307,20 @@ def multi_label_classification_metrics(
|
|
|
212
307
|
plt.savefig(roc_path)
|
|
213
308
|
plt.close(fig_roc)
|
|
214
309
|
|
|
215
|
-
# --- Save Precision-Recall Curve ---
|
|
310
|
+
# --- Save Precision-Recall Curve (uses y_prob) ---
|
|
216
311
|
precision, recall, _ = precision_recall_curve(true_i, prob_i)
|
|
217
312
|
ap_score = average_precision_score(true_i, prob_i)
|
|
218
|
-
fig_pr, ax_pr = plt.subplots(figsize=(6, 6), dpi=
|
|
219
|
-
ax_pr.plot(recall, precision, label=f'AP = {ap_score:.2f}')
|
|
313
|
+
fig_pr, ax_pr = plt.subplots(figsize=(6, 6), dpi=DPI_value)
|
|
314
|
+
ax_pr.plot(recall, precision, label=f'AP = {ap_score:.2f}', color=format_config.ROC_PR_line) # Use config color
|
|
220
315
|
ax_pr.set_title(f'Precision-Recall Curve for "{name}"')
|
|
221
316
|
ax_pr.set_xlabel('Recall'); ax_pr.set_ylabel('Precision')
|
|
222
317
|
ax_pr.legend(loc='lower left'); ax_pr.grid(True, linestyle='--', alpha=0.6)
|
|
223
318
|
pr_path = save_dir_path / f"pr_curve_{sanitized_name}.svg"
|
|
224
319
|
plt.savefig(pr_path)
|
|
225
320
|
plt.close(fig_pr)
|
|
321
|
+
|
|
322
|
+
# restore RC params
|
|
323
|
+
plt.rcParams.update(original_rc_params)
|
|
226
324
|
|
|
227
325
|
_LOGGER.info(f"All individual label reports and plots saved to '{save_dir_path.name}'")
|
|
228
326
|
|
|
@@ -235,7 +333,7 @@ def multi_target_shap_summary_plot(
|
|
|
235
333
|
target_names: List[str],
|
|
236
334
|
save_dir: Union[str, Path],
|
|
237
335
|
device: torch.device = torch.device('cpu'),
|
|
238
|
-
explainer_type: Literal['deep', 'kernel'] = '
|
|
336
|
+
explainer_type: Literal['deep', 'kernel'] = 'kernel'
|
|
239
337
|
):
|
|
240
338
|
"""
|
|
241
339
|
Calculates SHAP values for a multi-target model and saves summary plots and data for each target.
|
|
@@ -249,7 +347,7 @@ def multi_target_shap_summary_plot(
|
|
|
249
347
|
save_dir (str | Path): Directory to save SHAP artifacts.
|
|
250
348
|
device (torch.device): The torch device for SHAP calculations.
|
|
251
349
|
explainer_type (Literal['deep', 'kernel']): The explainer to use.
|
|
252
|
-
- 'deep':
|
|
350
|
+
- 'deep': Uses shap.DeepExplainer. Fast and efficient.
|
|
253
351
|
- 'kernel': Uses shap.KernelExplainer. Model-agnostic but slow and memory-intensive.
|
|
254
352
|
"""
|
|
255
353
|
_LOGGER.info(f"--- Multi-Target SHAP Value Explanation (Using: {explainer_type.upper()}Explainer) ---")
|
|
@@ -260,7 +358,7 @@ def multi_target_shap_summary_plot(
|
|
|
260
358
|
instances_to_explain_np = None
|
|
261
359
|
|
|
262
360
|
if explainer_type == 'deep':
|
|
263
|
-
# --- 1. Use DeepExplainer
|
|
361
|
+
# --- 1. Use DeepExplainer ---
|
|
264
362
|
|
|
265
363
|
# Ensure data is torch.Tensor
|
|
266
364
|
if isinstance(background_data, np.ndarray):
|
|
@@ -285,10 +383,9 @@ def multi_target_shap_summary_plot(
|
|
|
285
383
|
instances_to_explain_np = instances_to_explain.cpu().numpy()
|
|
286
384
|
|
|
287
385
|
elif explainer_type == 'kernel':
|
|
288
|
-
# --- 2. Use KernelExplainer
|
|
386
|
+
# --- 2. Use KernelExplainer ---
|
|
289
387
|
_LOGGER.warning(
|
|
290
|
-
"
|
|
291
|
-
"Consider reducing 'n_samples' if the process terminates."
|
|
388
|
+
"KernelExplainer is memory-intensive and slow. Consider reducing the number of instances to explain if the process terminates unexpectedly."
|
|
292
389
|
)
|
|
293
390
|
|
|
294
391
|
# Convert all data to numpy
|