dragon-ml-toolbox 13.3.2__py3-none-any.whl → 13.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dragon-ml-toolbox might be problematic. Click here for more details.
- {dragon_ml_toolbox-13.3.2.dist-info → dragon_ml_toolbox-13.5.0.dist-info}/METADATA +1 -1
- {dragon_ml_toolbox-13.3.2.dist-info → dragon_ml_toolbox-13.5.0.dist-info}/RECORD +10 -10
- ml_tools/ML_datasetmaster.py +61 -20
- ml_tools/ML_evaluation.py +20 -12
- ml_tools/ML_evaluation_multi.py +5 -6
- ml_tools/ML_trainer.py +17 -9
- {dragon_ml_toolbox-13.3.2.dist-info → dragon_ml_toolbox-13.5.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-13.3.2.dist-info → dragon_ml_toolbox-13.5.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-13.3.2.dist-info → dragon_ml_toolbox-13.5.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-13.3.2.dist-info → dragon_ml_toolbox-13.5.0.dist-info}/top_level.txt +0 -0
|
@@ -1,18 +1,18 @@
|
|
|
1
|
-
dragon_ml_toolbox-13.
|
|
2
|
-
dragon_ml_toolbox-13.
|
|
1
|
+
dragon_ml_toolbox-13.5.0.dist-info/licenses/LICENSE,sha256=L35WDmmLZNTlJvxF6Vy7Uy4SYNi6rCfWUqlTHpoRMoU,1081
|
|
2
|
+
dragon_ml_toolbox-13.5.0.dist-info/licenses/LICENSE-THIRD-PARTY.md,sha256=iy2r_R7wjzsCbz_Q_jMsp_jfZ6oP8XW9QhwzRBH0mGY,1904
|
|
3
3
|
ml_tools/ETL_cleaning.py,sha256=2VBRllV8F-ZiPylPp8Az2gwn5ztgazN0BH5OKnRUhV0,20402
|
|
4
4
|
ml_tools/ETL_engineering.py,sha256=KfYqgsxupAx6e_TxwO1LZXeu5mFkIhVXJrNjP3CzIZc,54927
|
|
5
5
|
ml_tools/GUI_tools.py,sha256=Va6ig-dHULPVRwQYYtH3fvY5XPIoqRcJpRW8oXC55Hw,45413
|
|
6
6
|
ml_tools/MICE_imputation.py,sha256=X273Qlgoqqg7KTmoKd75YDyAPB0UIbTzGP3xsCmRh3E,11717
|
|
7
7
|
ml_tools/ML_callbacks.py,sha256=elD2Yr030sv_6gX_m9GVd6HTyrbmt34nFS8lrgS4HtM,15808
|
|
8
|
-
ml_tools/ML_datasetmaster.py,sha256=
|
|
9
|
-
ml_tools/ML_evaluation.py,sha256=
|
|
10
|
-
ml_tools/ML_evaluation_multi.py,sha256=
|
|
8
|
+
ml_tools/ML_datasetmaster.py,sha256=6caWbq6eu1RE9V51gmceD71PtMctJRjFuLvkkK5ChiY,36271
|
|
9
|
+
ml_tools/ML_evaluation.py,sha256=li77AuP53pCzgrj6p-jTCNtPFgS9Y9XnMWIZn1ulTBM,18946
|
|
10
|
+
ml_tools/ML_evaluation_multi.py,sha256=rJKdgtq-9I7oaI7PRzq7aIZ84XdNV0xzlVePZW4nj0k,16095
|
|
11
11
|
ml_tools/ML_inference.py,sha256=yq2gdN6s_OUYC5ZLQrIJC5BA5H33q8UKODXwb-_0M2c,23549
|
|
12
12
|
ml_tools/ML_models.py,sha256=UVWJHPLVIvFno_csCHH1FwBfTwQ5nX0V8F1TbOByZ4I,31388
|
|
13
13
|
ml_tools/ML_optimization.py,sha256=P0zkhKAwTpkorIBtR0AOIDcyexo5ngmvFUzo3DfNO-E,22692
|
|
14
14
|
ml_tools/ML_scaler.py,sha256=tw6onj9o8_kk3FQYb930HUzvv1zsFZe2YZJdF3LtHkU,7538
|
|
15
|
-
ml_tools/ML_trainer.py,sha256=
|
|
15
|
+
ml_tools/ML_trainer.py,sha256=ZxeOagXW5adFhYIH-oMTlcrLU6VHe4R1EROI7yypNwQ,29665
|
|
16
16
|
ml_tools/ML_utilities.py,sha256=EnKpPTnJ2qjZmz7kvows4Uu5CfSA7ByRmI1v2-KarKw,9337
|
|
17
17
|
ml_tools/PSO_optimization.py,sha256=T-HWHMRJUnPvPwixdU5jif3_rnnI36TzcL8u3oSCwuA,22960
|
|
18
18
|
ml_tools/RNN_forecast.py,sha256=Qa2KoZfdAvSjZ4yE78N4BFXtr3tTr0Gx7tQJZPotsh0,1967
|
|
@@ -35,7 +35,7 @@ ml_tools/optimization_tools.py,sha256=TYFQ2nSnp7xxs-VyoZISWgnGJghFbsWasHjruegyJR
|
|
|
35
35
|
ml_tools/path_manager.py,sha256=CyDU16pOKmC82jPubqJPT6EBt-u-3rGVbxyPIZCvDDY,18432
|
|
36
36
|
ml_tools/serde.py,sha256=c8uDYjYry_VrLvoG4ixqDj5pij88lVn6Tu4NHcPkwDU,6943
|
|
37
37
|
ml_tools/utilities.py,sha256=OcAyV1tEcYAfOWlGjRgopsjDLxU3DcI5EynzvWV4q3A,15754
|
|
38
|
-
dragon_ml_toolbox-13.
|
|
39
|
-
dragon_ml_toolbox-13.
|
|
40
|
-
dragon_ml_toolbox-13.
|
|
41
|
-
dragon_ml_toolbox-13.
|
|
38
|
+
dragon_ml_toolbox-13.5.0.dist-info/METADATA,sha256=EwOjL8T9Vnk1cg7vsDY4JaK9ovZtIkeIN2LcAiN-nvg,6166
|
|
39
|
+
dragon_ml_toolbox-13.5.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
40
|
+
dragon_ml_toolbox-13.5.0.dist-info/top_level.txt,sha256=wm-oxax3ciyez6VoO4zsFd-gSok2VipYXnbg3TH9PtU,9
|
|
41
|
+
dragon_ml_toolbox-13.5.0.dist-info/RECORD,,
|
ml_tools/ML_datasetmaster.py
CHANGED
|
@@ -126,8 +126,8 @@ class _BaseDatasetMaker(ABC):
|
|
|
126
126
|
else:
|
|
127
127
|
_LOGGER.info("No continuous features listed in schema. Scaler will not be fitted.")
|
|
128
128
|
|
|
129
|
-
X_train_values = X_train.
|
|
130
|
-
X_test_values = X_test.
|
|
129
|
+
X_train_values = X_train.to_numpy()
|
|
130
|
+
X_test_values = X_test.to_numpy()
|
|
131
131
|
|
|
132
132
|
# continuous_feature_indices is derived
|
|
133
133
|
if self.scaler is None and continuous_feature_indices:
|
|
@@ -253,26 +253,42 @@ class DatasetMaker(_BaseDatasetMaker):
|
|
|
253
253
|
pandas_df: pandas.DataFrame,
|
|
254
254
|
schema: FeatureSchema,
|
|
255
255
|
kind: Literal["regression", "classification"],
|
|
256
|
+
scaler: Union[Literal["fit"], Literal["none"], PytorchScaler],
|
|
256
257
|
test_size: float = 0.2,
|
|
257
|
-
random_state: int = 42
|
|
258
|
-
scaler: Optional[PytorchScaler] = None):
|
|
258
|
+
random_state: int = 42):
|
|
259
259
|
"""
|
|
260
260
|
Args:
|
|
261
261
|
pandas_df (pandas.DataFrame):
|
|
262
262
|
The pre-processed input DataFrame containing all columns. (features and single target).
|
|
263
263
|
schema (FeatureSchema):
|
|
264
264
|
The definitive schema object from data_exploration.
|
|
265
|
-
kind (
|
|
265
|
+
kind ("regression" | "classification"):
|
|
266
266
|
The type of ML task. This determines the data type of the labels.
|
|
267
|
+
scaler ("fit" | "none" | PytorchScaler):
|
|
268
|
+
Strategy for data scaling:
|
|
269
|
+
- "fit": Fit a new PytorchScaler on continuous features.
|
|
270
|
+
- "none": Do not scale data (e.g., for TabularTransformer).
|
|
271
|
+
- PytorchScaler instance: Use a pre-fitted scaler to transform data.
|
|
267
272
|
test_size (float):
|
|
268
273
|
The proportion of the dataset to allocate to the test split.
|
|
269
274
|
random_state (int):
|
|
270
275
|
The seed for the random number of generator for reproducibility.
|
|
271
|
-
|
|
272
|
-
A pre-fitted PytorchScaler instance, if None a new scaler will be created.
|
|
276
|
+
|
|
273
277
|
"""
|
|
274
278
|
super().__init__()
|
|
275
|
-
|
|
279
|
+
|
|
280
|
+
_apply_scaling: bool = False
|
|
281
|
+
if scaler == "fit":
|
|
282
|
+
self.scaler = None # To be created
|
|
283
|
+
_apply_scaling = True
|
|
284
|
+
elif scaler == "none":
|
|
285
|
+
self.scaler = None
|
|
286
|
+
elif isinstance(scaler, PytorchScaler):
|
|
287
|
+
self.scaler = scaler # Use the provided one
|
|
288
|
+
_apply_scaling = True
|
|
289
|
+
else:
|
|
290
|
+
_LOGGER.error(f"Invalid 'scaler' argument. Must be 'fit', 'none', or a PytorchScaler instance.")
|
|
291
|
+
raise ValueError()
|
|
276
292
|
|
|
277
293
|
# --- 1. Identify features (from schema) ---
|
|
278
294
|
self._feature_names = list(schema.feature_names)
|
|
@@ -310,9 +326,14 @@ class DatasetMaker(_BaseDatasetMaker):
|
|
|
310
326
|
label_dtype = torch.float32 if kind == "regression" else torch.int64
|
|
311
327
|
|
|
312
328
|
# --- 4. Scale (using the schema) ---
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
329
|
+
if _apply_scaling:
|
|
330
|
+
X_train_final, X_test_final = self._prepare_scaler(
|
|
331
|
+
X_train, y_train, X_test, label_dtype, schema
|
|
332
|
+
)
|
|
333
|
+
else:
|
|
334
|
+
_LOGGER.info("Features have not been scaled as specified.")
|
|
335
|
+
X_train_final = X_train.to_numpy()
|
|
336
|
+
X_test_final = X_test.to_numpy()
|
|
316
337
|
|
|
317
338
|
# --- 5. Create Datasets ---
|
|
318
339
|
self._train_ds = _PytorchDataset(X_train_final, y_train, labels_dtype=label_dtype, feature_names=self._feature_names, target_names=self._target_names)
|
|
@@ -336,9 +357,9 @@ class DatasetMakerMulti(_BaseDatasetMaker):
|
|
|
336
357
|
pandas_df: pandas.DataFrame,
|
|
337
358
|
target_columns: List[str],
|
|
338
359
|
schema: FeatureSchema,
|
|
360
|
+
scaler: Union[Literal["fit"], Literal["none"], PytorchScaler],
|
|
339
361
|
test_size: float = 0.2,
|
|
340
|
-
random_state: int = 42
|
|
341
|
-
scaler: Optional[PytorchScaler] = None):
|
|
362
|
+
random_state: int = 42):
|
|
342
363
|
"""
|
|
343
364
|
Args:
|
|
344
365
|
pandas_df (pandas.DataFrame):
|
|
@@ -348,20 +369,35 @@ class DatasetMakerMulti(_BaseDatasetMaker):
|
|
|
348
369
|
List of target column names.
|
|
349
370
|
schema (FeatureSchema):
|
|
350
371
|
The definitive schema object from data_exploration.
|
|
372
|
+
scaler ("fit" | "none" | PytorchScaler):
|
|
373
|
+
Strategy for data scaling:
|
|
374
|
+
- "fit": Fit a new PytorchScaler on continuous features.
|
|
375
|
+
- "none": Do not scale data (e.g., for TabularTransformer).
|
|
376
|
+
- PytorchScaler instance: Use a pre-fitted scaler to transform data.
|
|
351
377
|
test_size (float):
|
|
352
378
|
The proportion of the dataset to allocate to the test split.
|
|
353
379
|
random_state (int):
|
|
354
380
|
The seed for the random number generator for reproducibility.
|
|
355
|
-
scaler (PytorchScaler | None):
|
|
356
|
-
A pre-fitted PytorchScaler instance.
|
|
357
381
|
|
|
358
382
|
## Note:
|
|
359
383
|
For multi-binary classification, the most common PyTorch loss function is nn.BCEWithLogitsLoss.
|
|
360
384
|
This loss function requires the labels to be torch.float32 which is the same type required for regression (multi-regression) tasks.
|
|
361
385
|
"""
|
|
362
386
|
super().__init__()
|
|
363
|
-
|
|
364
|
-
|
|
387
|
+
|
|
388
|
+
_apply_scaling: bool = False
|
|
389
|
+
if scaler == "fit":
|
|
390
|
+
self.scaler = None
|
|
391
|
+
_apply_scaling = True
|
|
392
|
+
elif scaler == "none":
|
|
393
|
+
self.scaler = None
|
|
394
|
+
elif isinstance(scaler, PytorchScaler):
|
|
395
|
+
self.scaler = scaler # Use the provided one
|
|
396
|
+
_apply_scaling = True
|
|
397
|
+
else:
|
|
398
|
+
_LOGGER.error(f"Invalid 'scaler' argument. Must be 'fit', 'none', or a PytorchScaler instance.")
|
|
399
|
+
raise ValueError()
|
|
400
|
+
|
|
365
401
|
# --- 1. Get features and targets from schema/args ---
|
|
366
402
|
self._feature_names = list(schema.feature_names)
|
|
367
403
|
self._target_names = target_columns
|
|
@@ -403,9 +439,14 @@ class DatasetMakerMulti(_BaseDatasetMaker):
|
|
|
403
439
|
label_dtype = torch.float32
|
|
404
440
|
|
|
405
441
|
# --- 4. Scale (using the schema) ---
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
442
|
+
if _apply_scaling:
|
|
443
|
+
X_train_final, X_test_final = self._prepare_scaler(
|
|
444
|
+
X_train, y_train, X_test, label_dtype, schema
|
|
445
|
+
)
|
|
446
|
+
else:
|
|
447
|
+
_LOGGER.info("Features have not been scaled as specified.")
|
|
448
|
+
X_train_final = X_train.to_numpy()
|
|
449
|
+
X_test_final = X_test.to_numpy()
|
|
409
450
|
|
|
410
451
|
# --- 5. Create Datasets ---
|
|
411
452
|
# _PytorchDataset now correctly handles y_train (a DataFrame)
|
ml_tools/ML_evaluation.py
CHANGED
|
@@ -258,7 +258,7 @@ def shap_summary_plot(model,
|
|
|
258
258
|
feature_names: Optional[list[str]],
|
|
259
259
|
save_dir: Union[str, Path],
|
|
260
260
|
device: torch.device = torch.device('cpu'),
|
|
261
|
-
explainer_type: Literal['deep', 'kernel'] = '
|
|
261
|
+
explainer_type: Literal['deep', 'kernel'] = 'kernel'):
|
|
262
262
|
"""
|
|
263
263
|
Calculates SHAP values and saves summary plots and data.
|
|
264
264
|
|
|
@@ -270,7 +270,7 @@ def shap_summary_plot(model,
|
|
|
270
270
|
save_dir (str | Path): Directory to save SHAP artifacts.
|
|
271
271
|
device (torch.device): The torch device for SHAP calculations.
|
|
272
272
|
explainer_type (Literal['deep', 'kernel']): The explainer to use.
|
|
273
|
-
- 'deep':
|
|
273
|
+
- 'deep': Uses shap.DeepExplainer. Fast and efficient for
|
|
274
274
|
PyTorch models.
|
|
275
275
|
- 'kernel': Uses shap.KernelExplainer. Model-agnostic but EXTREMELY
|
|
276
276
|
slow and memory-intensive.
|
|
@@ -285,7 +285,7 @@ def shap_summary_plot(model,
|
|
|
285
285
|
instances_to_explain_np = None
|
|
286
286
|
|
|
287
287
|
if explainer_type == 'deep':
|
|
288
|
-
# --- 1. Use DeepExplainer
|
|
288
|
+
# --- 1. Use DeepExplainer ---
|
|
289
289
|
|
|
290
290
|
# Ensure data is torch.Tensor
|
|
291
291
|
if isinstance(background_data, np.ndarray):
|
|
@@ -309,10 +309,9 @@ def shap_summary_plot(model,
|
|
|
309
309
|
instances_to_explain_np = instances_to_explain.cpu().numpy()
|
|
310
310
|
|
|
311
311
|
elif explainer_type == 'kernel':
|
|
312
|
-
# --- 2. Use KernelExplainer
|
|
312
|
+
# --- 2. Use KernelExplainer ---
|
|
313
313
|
_LOGGER.warning(
|
|
314
|
-
"
|
|
315
|
-
"Consider reducing 'n_samples' if the process terminates unexpectedly."
|
|
314
|
+
"KernelExplainer is memory-intensive and slow. Consider reducing the number of instances to explain if the process terminates unexpectedly."
|
|
316
315
|
)
|
|
317
316
|
|
|
318
317
|
# Ensure data is np.ndarray
|
|
@@ -348,14 +347,26 @@ def shap_summary_plot(model,
|
|
|
348
347
|
else:
|
|
349
348
|
_LOGGER.error(f"Invalid explainer_type: '{explainer_type}'. Must be 'deep' or 'kernel'.")
|
|
350
349
|
raise ValueError()
|
|
350
|
+
|
|
351
|
+
if not isinstance(shap_values, list) and shap_values.ndim == 3 and shap_values.shape[2] == 1:
|
|
352
|
+
# _LOGGER.info("Squeezing SHAP values from (N, F, 1) to (N, F) for regression plot.")
|
|
353
|
+
shap_values = shap_values.squeeze(-1)
|
|
351
354
|
|
|
352
355
|
# --- 3. Plotting and Saving ---
|
|
353
356
|
save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
354
357
|
plt.ioff()
|
|
355
358
|
|
|
359
|
+
# Convert instances to a DataFrame. robust way to ensure SHAP correctly maps values to feature names.
|
|
360
|
+
if feature_names is None:
|
|
361
|
+
# Create generic names if none were provided
|
|
362
|
+
num_features = instances_to_explain_np.shape[1]
|
|
363
|
+
feature_names = [f'feature_{i}' for i in range(num_features)]
|
|
364
|
+
|
|
365
|
+
instances_df = pd.DataFrame(instances_to_explain_np, columns=feature_names)
|
|
366
|
+
|
|
356
367
|
# Save Bar Plot
|
|
357
368
|
bar_path = save_dir_path / "shap_bar_plot.svg"
|
|
358
|
-
shap.summary_plot(shap_values,
|
|
369
|
+
shap.summary_plot(shap_values, instances_df, plot_type="bar", show=False)
|
|
359
370
|
ax = plt.gca()
|
|
360
371
|
ax.set_xlabel("SHAP Value Impact", labelpad=10)
|
|
361
372
|
plt.title("SHAP Feature Importance")
|
|
@@ -366,7 +377,7 @@ def shap_summary_plot(model,
|
|
|
366
377
|
|
|
367
378
|
# Save Dot Plot
|
|
368
379
|
dot_path = save_dir_path / "shap_dot_plot.svg"
|
|
369
|
-
shap.summary_plot(shap_values,
|
|
380
|
+
shap.summary_plot(shap_values, instances_df, plot_type="dot", show=False)
|
|
370
381
|
ax = plt.gca()
|
|
371
382
|
ax.set_xlabel("SHAP Value Impact", labelpad=10)
|
|
372
383
|
if plt.gcf().axes and len(plt.gcf().axes) > 1:
|
|
@@ -389,9 +400,6 @@ def shap_summary_plot(model,
|
|
|
389
400
|
mean_abs_shap = np.abs(shap_values).mean(axis=0)
|
|
390
401
|
|
|
391
402
|
mean_abs_shap = mean_abs_shap.flatten()
|
|
392
|
-
|
|
393
|
-
if feature_names is None:
|
|
394
|
-
feature_names = [f'feature_{i}' for i in range(len(mean_abs_shap))]
|
|
395
403
|
|
|
396
404
|
summary_df = pd.DataFrame({
|
|
397
405
|
SHAPKeys.FEATURE_COLUMN: feature_names,
|
|
@@ -401,7 +409,7 @@ def shap_summary_plot(model,
|
|
|
401
409
|
summary_df.to_csv(summary_path, index=False)
|
|
402
410
|
|
|
403
411
|
_LOGGER.info(f"📝 SHAP summary data saved as '{summary_path.name}'")
|
|
404
|
-
plt.ion()
|
|
412
|
+
plt.ion()
|
|
405
413
|
|
|
406
414
|
|
|
407
415
|
def plot_attention_importance(weights: List[torch.Tensor], feature_names: Optional[List[str]], save_dir: Union[str, Path], top_n: int = 10):
|
ml_tools/ML_evaluation_multi.py
CHANGED
|
@@ -235,7 +235,7 @@ def multi_target_shap_summary_plot(
|
|
|
235
235
|
target_names: List[str],
|
|
236
236
|
save_dir: Union[str, Path],
|
|
237
237
|
device: torch.device = torch.device('cpu'),
|
|
238
|
-
explainer_type: Literal['deep', 'kernel'] = '
|
|
238
|
+
explainer_type: Literal['deep', 'kernel'] = 'kernel'
|
|
239
239
|
):
|
|
240
240
|
"""
|
|
241
241
|
Calculates SHAP values for a multi-target model and saves summary plots and data for each target.
|
|
@@ -249,7 +249,7 @@ def multi_target_shap_summary_plot(
|
|
|
249
249
|
save_dir (str | Path): Directory to save SHAP artifacts.
|
|
250
250
|
device (torch.device): The torch device for SHAP calculations.
|
|
251
251
|
explainer_type (Literal['deep', 'kernel']): The explainer to use.
|
|
252
|
-
- 'deep':
|
|
252
|
+
- 'deep': Uses shap.DeepExplainer. Fast and efficient.
|
|
253
253
|
- 'kernel': Uses shap.KernelExplainer. Model-agnostic but slow and memory-intensive.
|
|
254
254
|
"""
|
|
255
255
|
_LOGGER.info(f"--- Multi-Target SHAP Value Explanation (Using: {explainer_type.upper()}Explainer) ---")
|
|
@@ -260,7 +260,7 @@ def multi_target_shap_summary_plot(
|
|
|
260
260
|
instances_to_explain_np = None
|
|
261
261
|
|
|
262
262
|
if explainer_type == 'deep':
|
|
263
|
-
# --- 1. Use DeepExplainer
|
|
263
|
+
# --- 1. Use DeepExplainer ---
|
|
264
264
|
|
|
265
265
|
# Ensure data is torch.Tensor
|
|
266
266
|
if isinstance(background_data, np.ndarray):
|
|
@@ -285,10 +285,9 @@ def multi_target_shap_summary_plot(
|
|
|
285
285
|
instances_to_explain_np = instances_to_explain.cpu().numpy()
|
|
286
286
|
|
|
287
287
|
elif explainer_type == 'kernel':
|
|
288
|
-
# --- 2. Use KernelExplainer
|
|
288
|
+
# --- 2. Use KernelExplainer ---
|
|
289
289
|
_LOGGER.warning(
|
|
290
|
-
"
|
|
291
|
-
"Consider reducing 'n_samples' if the process terminates."
|
|
290
|
+
"KernelExplainer is memory-intensive and slow. Consider reducing the number of instances to explain if the process terminates unexpectedly."
|
|
292
291
|
)
|
|
293
292
|
|
|
294
293
|
# Convert all data to numpy
|
ml_tools/ML_trainer.py
CHANGED
|
@@ -9,7 +9,7 @@ from .ML_callbacks import Callback, History, TqdmProgressBar, ModelCheckpoint
|
|
|
9
9
|
from .ML_evaluation import classification_metrics, regression_metrics, plot_losses, shap_summary_plot, plot_attention_importance
|
|
10
10
|
from .ML_evaluation_multi import multi_target_regression_metrics, multi_label_classification_metrics, multi_target_shap_summary_plot
|
|
11
11
|
from ._script_info import _script_info
|
|
12
|
-
from .keys import PyTorchLogKeys, PyTorchCheckpointKeys
|
|
12
|
+
from .keys import PyTorchLogKeys, PyTorchCheckpointKeys, DatasetKeys
|
|
13
13
|
from ._logger import _LOGGER
|
|
14
14
|
from .path_manager import make_fullpath
|
|
15
15
|
|
|
@@ -408,7 +408,7 @@ class MLTrainer:
|
|
|
408
408
|
n_samples: int = 300,
|
|
409
409
|
feature_names: Optional[List[str]] = None,
|
|
410
410
|
target_names: Optional[List[str]] = None,
|
|
411
|
-
explainer_type: Literal['deep', 'kernel'] = '
|
|
411
|
+
explainer_type: Literal['deep', 'kernel'] = 'kernel'):
|
|
412
412
|
"""
|
|
413
413
|
Explains model predictions using SHAP and saves all artifacts.
|
|
414
414
|
|
|
@@ -422,11 +422,11 @@ class MLTrainer:
|
|
|
422
422
|
explain_dataset (Dataset | None): A specific dataset to explain.
|
|
423
423
|
If None, the trainer's test dataset is used.
|
|
424
424
|
n_samples (int): The number of samples to use for both background and explanation.
|
|
425
|
-
feature_names (list[str] | None): Feature names.
|
|
425
|
+
feature_names (list[str] | None): Feature names. If None, the names will be extracted from the Dataset and raise an error on failure.
|
|
426
426
|
target_names (list[str] | None): Target names for multi-target tasks.
|
|
427
427
|
save_dir (str | Path): Directory to save all SHAP artifacts.
|
|
428
428
|
explainer_type (Literal['deep', 'kernel']): The explainer to use.
|
|
429
|
-
- 'deep':
|
|
429
|
+
- 'deep': Uses shap.DeepExplainer. Fast and efficient for PyTorch models.
|
|
430
430
|
- 'kernel': Uses shap.KernelExplainer. Model-agnostic but EXTREMELY slow and memory-intensive. Use with a very low 'n_samples'< 100.
|
|
431
431
|
"""
|
|
432
432
|
# Internal helper to create a dataloader and get a random sample
|
|
@@ -474,10 +474,10 @@ class MLTrainer:
|
|
|
474
474
|
# attempt to get feature names
|
|
475
475
|
if feature_names is None:
|
|
476
476
|
# _LOGGER.info("`feature_names` not provided. Attempting to extract from dataset...")
|
|
477
|
-
if hasattr(target_dataset,
|
|
477
|
+
if hasattr(target_dataset, DatasetKeys.FEATURE_NAMES):
|
|
478
478
|
feature_names = target_dataset.feature_names # type: ignore
|
|
479
479
|
else:
|
|
480
|
-
_LOGGER.error("Could not extract `feature_names` from the dataset. It must be provided if the dataset object does not have a
|
|
480
|
+
_LOGGER.error(f"Could not extract `feature_names` from the dataset. It must be provided if the dataset object does not have a '{DatasetKeys.FEATURE_NAMES}' attribute.")
|
|
481
481
|
raise ValueError()
|
|
482
482
|
|
|
483
483
|
# move model to device
|
|
@@ -498,7 +498,7 @@ class MLTrainer:
|
|
|
498
498
|
# try to get target names
|
|
499
499
|
if target_names is None:
|
|
500
500
|
target_names = []
|
|
501
|
-
if hasattr(target_dataset,
|
|
501
|
+
if hasattr(target_dataset, DatasetKeys.TARGET_NAMES):
|
|
502
502
|
target_names = target_dataset.target_names # type: ignore
|
|
503
503
|
else:
|
|
504
504
|
# Infer number of targets from the model's output layer
|
|
@@ -549,7 +549,7 @@ class MLTrainer:
|
|
|
549
549
|
yield attention_weights
|
|
550
550
|
|
|
551
551
|
def explain_attention(self, save_dir: Union[str, Path],
|
|
552
|
-
feature_names: Optional[List[str]],
|
|
552
|
+
feature_names: Optional[List[str]] = None,
|
|
553
553
|
explain_dataset: Optional[Dataset] = None,
|
|
554
554
|
plot_n_features: int = 10):
|
|
555
555
|
"""
|
|
@@ -559,7 +559,7 @@ class MLTrainer:
|
|
|
559
559
|
|
|
560
560
|
Args:
|
|
561
561
|
save_dir (str | Path): Directory to save the plot and summary data.
|
|
562
|
-
feature_names (List[str] | None): Names for the features for plot labeling. If
|
|
562
|
+
feature_names (List[str] | None): Names for the features for plot labeling. If None, the names will be extracted from the Dataset and raise an error on failure.
|
|
563
563
|
explain_dataset (Dataset, optional): A specific dataset to explain. If None, the trainer's test dataset is used.
|
|
564
564
|
plot_n_features (int): Number of top features to plot.
|
|
565
565
|
"""
|
|
@@ -580,6 +580,14 @@ class MLTrainer:
|
|
|
580
580
|
_LOGGER.error("The explanation dataset is empty or invalid. Skipping attention analysis.")
|
|
581
581
|
return
|
|
582
582
|
|
|
583
|
+
# Get feature names
|
|
584
|
+
if feature_names is None:
|
|
585
|
+
if hasattr(dataset_to_use, DatasetKeys.FEATURE_NAMES):
|
|
586
|
+
feature_names = dataset_to_use.feature_names # type: ignore
|
|
587
|
+
else:
|
|
588
|
+
_LOGGER.error(f"Could not extract `feature_names` from the dataset for attention plot. It must be provided if the dataset object does not have a '{DatasetKeys.FEATURE_NAMES}' attribute.")
|
|
589
|
+
raise ValueError()
|
|
590
|
+
|
|
583
591
|
explain_loader = DataLoader(
|
|
584
592
|
dataset=dataset_to_use, batch_size=32, shuffle=False,
|
|
585
593
|
num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|