dragon-ml-toolbox 10.8.0__py3-none-any.whl → 10.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dragon-ml-toolbox might be problematic. Click here for more details.
- {dragon_ml_toolbox-10.8.0.dist-info → dragon_ml_toolbox-10.10.0.dist-info}/METADATA +1 -1
- {dragon_ml_toolbox-10.8.0.dist-info → dragon_ml_toolbox-10.10.0.dist-info}/RECORD +15 -15
- ml_tools/ML_datasetmaster.py +1 -1
- ml_tools/ML_evaluation.py +12 -10
- ml_tools/ML_models.py +28 -40
- ml_tools/ML_scaler.py +1 -1
- ml_tools/ML_trainer.py +14 -6
- ml_tools/SQL.py +4 -2
- ml_tools/ensemble_evaluation.py +48 -1
- ml_tools/keys.py +7 -0
- ml_tools/utilities.py +119 -20
- {dragon_ml_toolbox-10.8.0.dist-info → dragon_ml_toolbox-10.10.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-10.8.0.dist-info → dragon_ml_toolbox-10.10.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-10.8.0.dist-info → dragon_ml_toolbox-10.10.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-10.8.0.dist-info → dragon_ml_toolbox-10.10.0.dist-info}/top_level.txt +0 -0
|
@@ -1,36 +1,36 @@
|
|
|
1
|
-
dragon_ml_toolbox-10.
|
|
2
|
-
dragon_ml_toolbox-10.
|
|
1
|
+
dragon_ml_toolbox-10.10.0.dist-info/licenses/LICENSE,sha256=2uUFNy7D0TLgHim1K5s3DIJ4q_KvxEXVilnU20cWliY,1066
|
|
2
|
+
dragon_ml_toolbox-10.10.0.dist-info/licenses/LICENSE-THIRD-PARTY.md,sha256=lY4_rJPnLnMu7YBQaY-_iz1JRDcLdQzNCyeLAF1glJY,1837
|
|
3
3
|
ml_tools/ETL_cleaning.py,sha256=lSP5q6-ukGhJBPV8dlsqJvPXAzj4du_0J-SbtEd0Pjg,19292
|
|
4
4
|
ml_tools/ETL_engineering.py,sha256=a6KCWH6kRatZtjaFEF_o917ApPMK5_vRD-BjfCDAl-E,49400
|
|
5
5
|
ml_tools/GUI_tools.py,sha256=kEQWg-bog3pB5tI22gMGKWaCGHnz9TB2Lvvfhf5F2CI,45412
|
|
6
6
|
ml_tools/MICE_imputation.py,sha256=kVSythWfxJFR4-2mtcYCWQaQ1Oz5yyx_SJu5gjnS7H8,11670
|
|
7
7
|
ml_tools/ML_callbacks.py,sha256=JPvEw_cW5tYNJ2rMSgnNrKLuni_UrmuhDFaOw-u2SvA,13926
|
|
8
|
-
ml_tools/ML_datasetmaster.py,sha256=
|
|
9
|
-
ml_tools/ML_evaluation.py,sha256=
|
|
8
|
+
ml_tools/ML_datasetmaster.py,sha256=vqKZhCXsvN5yeRJdOKqMPh5OhY1xe6xlNjM3WoH5lys,30821
|
|
9
|
+
ml_tools/ML_evaluation.py,sha256=6FB6S-aDDpFzQdrp3flBVECzEsHhMbQknYVGhHooEFs,16207
|
|
10
10
|
ml_tools/ML_evaluation_multi.py,sha256=2jTSNFCu8cz5C05EusnrDyffs59M2Fq3UXSHxo2TR1A,12515
|
|
11
11
|
ml_tools/ML_inference.py,sha256=SGDPiPxs_OYDKKRZziFMyaWcC8A37c70W9t-dMP5niI,23066
|
|
12
|
-
ml_tools/ML_models.py,sha256=
|
|
12
|
+
ml_tools/ML_models.py,sha256=8UOMg9Qn8qtecUGfgnLRedX-lCWYwEs-C5RJ2m8mZM4,27544
|
|
13
13
|
ml_tools/ML_optimization.py,sha256=a2Uxe1g-y4I-gFa8ENIM8QDS-Pz3hoPRRaVXAWMbyQA,13491
|
|
14
|
-
ml_tools/ML_scaler.py,sha256=
|
|
15
|
-
ml_tools/ML_trainer.py,sha256=
|
|
14
|
+
ml_tools/ML_scaler.py,sha256=h2ymq5u953Lx60Qb38Y0mAWj85x9PbnP0xYNQ3pd8-w,7535
|
|
15
|
+
ml_tools/ML_trainer.py,sha256=_g48w5Ak-wQr5fGHdJqlcpnzv3gWyL1ghkOhy9VOZbo,23930
|
|
16
16
|
ml_tools/PSO_optimization.py,sha256=q0VYpssQGbPum7xdnkDXlJQKhZMYZo8acHpKhajPK3c,22954
|
|
17
17
|
ml_tools/RNN_forecast.py,sha256=8rNZr-eWOBXMiDQV22e_tQTPM5LM2IFggEAa1FaoXaI,1965
|
|
18
|
-
ml_tools/SQL.py,sha256=
|
|
18
|
+
ml_tools/SQL.py,sha256=givoz6CGWRUdqnBem3VGZxzGdo3ZbX00kyHNjzI8kWE,10803
|
|
19
19
|
ml_tools/VIF_factor.py,sha256=MkMh_RIdsN2XUPzKNGRiEcmB17R_MmvGV4ezpL5zD2E,10403
|
|
20
20
|
ml_tools/__init__.py,sha256=q0y9faQ6e17XCQ7eUiCZ1FJ4Bg5EQqLjZ9f_l5REUUY,41
|
|
21
21
|
ml_tools/_logger.py,sha256=wcImAiXEZKPNcwM30qBh3t7HvoPURonJY0nrgMGF0sM,4719
|
|
22
22
|
ml_tools/_script_info.py,sha256=21r83LV3RubsNZ_RTEUON6RbDf7Mh4_udweNcvdF_Fk,212
|
|
23
23
|
ml_tools/custom_logger.py,sha256=ry43hk54K6xKo8jRAgq1sFxUpOA9T0LIJ7sw0so2BW0,5880
|
|
24
24
|
ml_tools/data_exploration.py,sha256=4McT2BR9muK4JVVTKUfvRyThe0m_o2vpy9RJ1f_1FeY,28692
|
|
25
|
-
ml_tools/ensemble_evaluation.py,sha256=
|
|
25
|
+
ml_tools/ensemble_evaluation.py,sha256=FGHSe8LBI8_w8LjNeJWOcYQ1UK_mc6fVah8gmSvNVGg,26853
|
|
26
26
|
ml_tools/ensemble_inference.py,sha256=EFHnbjbu31fcVp88NBx8lWAVdu2Gpg9MY9huVZJHFfM,9350
|
|
27
27
|
ml_tools/ensemble_learning.py,sha256=3s0kH4i_naj0IVl_T4knst-Hwg4TScWjEdsXX5KAi7I,21929
|
|
28
28
|
ml_tools/handle_excel.py,sha256=He4UT15sCGhaG-JKfs7uYVAubxWjrqgJ6U7OhMR2fuE,14005
|
|
29
|
-
ml_tools/keys.py,sha256=
|
|
29
|
+
ml_tools/keys.py,sha256=FDpbS3Jb0pjrVvvp2_8nZi919mbob_-xwuy5OOtKM_A,1848
|
|
30
30
|
ml_tools/optimization_tools.py,sha256=P3I6lIpvZ8Xf2kX5FvvBKBmrK2pB6idBpkTzfUJxTeE,5073
|
|
31
31
|
ml_tools/path_manager.py,sha256=wLJlz3Y9_1-LB9em4B2VYDCVuTOX2eOc7D6hbbebjgM,14990
|
|
32
|
-
ml_tools/utilities.py,sha256=
|
|
33
|
-
dragon_ml_toolbox-10.
|
|
34
|
-
dragon_ml_toolbox-10.
|
|
35
|
-
dragon_ml_toolbox-10.
|
|
36
|
-
dragon_ml_toolbox-10.
|
|
32
|
+
ml_tools/utilities.py,sha256=30z0x1aDLyBGzF98_tgSaxwFafYwQS-GTFzXHopBSGc,29105
|
|
33
|
+
dragon_ml_toolbox-10.10.0.dist-info/METADATA,sha256=hSrcYAuoE1H0uF77-8TClwrcdlQwg0f1BGixlh_Q0Wo,6969
|
|
34
|
+
dragon_ml_toolbox-10.10.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
35
|
+
dragon_ml_toolbox-10.10.0.dist-info/top_level.txt,sha256=wm-oxax3ciyez6VoO4zsFd-gSok2VipYXnbg3TH9PtU,9
|
|
36
|
+
dragon_ml_toolbox-10.10.0.dist-info/RECORD,,
|
ml_tools/ML_datasetmaster.py
CHANGED
|
@@ -200,7 +200,7 @@ class _BaseDatasetMaker(ABC):
|
|
|
200
200
|
filepath = save_path / filename
|
|
201
201
|
self.scaler.save(filepath, verbose=False)
|
|
202
202
|
if verbose:
|
|
203
|
-
_LOGGER.info(f"Scaler for dataset '{self.id}' saved
|
|
203
|
+
_LOGGER.info(f"Scaler for dataset '{self.id}' saved as '{filepath.name}'.")
|
|
204
204
|
|
|
205
205
|
|
|
206
206
|
# Single target dataset
|
ml_tools/ML_evaluation.py
CHANGED
|
@@ -22,6 +22,7 @@ from .path_manager import make_fullpath
|
|
|
22
22
|
from ._logger import _LOGGER
|
|
23
23
|
from typing import Union, Optional, List
|
|
24
24
|
from ._script_info import _script_info
|
|
25
|
+
from .keys import SHAPKeys
|
|
25
26
|
|
|
26
27
|
|
|
27
28
|
__all__ = [
|
|
@@ -333,7 +334,8 @@ def shap_summary_plot(model,
|
|
|
333
334
|
plt.close()
|
|
334
335
|
|
|
335
336
|
# Save Summary Data to CSV
|
|
336
|
-
|
|
337
|
+
shap_summary_filename = SHAPKeys.SAVENAME + ".csv"
|
|
338
|
+
summary_path = save_dir_path / shap_summary_filename
|
|
337
339
|
# Ensure the array is 1D before creating the DataFrame
|
|
338
340
|
mean_abs_shap = np.abs(shap_values).mean(axis=0).flatten()
|
|
339
341
|
|
|
@@ -341,9 +343,9 @@ def shap_summary_plot(model,
|
|
|
341
343
|
feature_names = [f'feature_{i}' for i in range(len(mean_abs_shap))]
|
|
342
344
|
|
|
343
345
|
summary_df = pd.DataFrame({
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
}).sort_values(
|
|
346
|
+
SHAPKeys.FEATURE_COLUMN: feature_names,
|
|
347
|
+
SHAPKeys.SHAP_VALUE_COLUMN: mean_abs_shap
|
|
348
|
+
}).sort_values(SHAPKeys.SHAP_VALUE_COLUMN, ascending=False)
|
|
347
349
|
|
|
348
350
|
summary_df.to_csv(summary_path, index=False)
|
|
349
351
|
|
|
@@ -351,7 +353,7 @@ def shap_summary_plot(model,
|
|
|
351
353
|
plt.ion()
|
|
352
354
|
|
|
353
355
|
|
|
354
|
-
def plot_attention_importance(weights: List[torch.Tensor], feature_names: Optional[List[str]], save_dir: Union[str, Path]):
|
|
356
|
+
def plot_attention_importance(weights: List[torch.Tensor], feature_names: Optional[List[str]], save_dir: Union[str, Path], top_n: int = 10):
|
|
355
357
|
"""
|
|
356
358
|
Aggregates attention weights and plots global feature importance.
|
|
357
359
|
|
|
@@ -362,6 +364,7 @@ def plot_attention_importance(weights: List[torch.Tensor], feature_names: Option
|
|
|
362
364
|
weights (List[torch.Tensor]): A list of attention weight tensors from each batch.
|
|
363
365
|
feature_names (List[str] | None): Names of the features for plot labeling.
|
|
364
366
|
save_dir (str | Path): Directory to save the plot and summary CSV.
|
|
367
|
+
top_n (int): The number of top features to display in the plot.
|
|
365
368
|
"""
|
|
366
369
|
if not weights:
|
|
367
370
|
_LOGGER.error("Attention weights list is empty. Skipping importance plot.")
|
|
@@ -390,11 +393,10 @@ def plot_attention_importance(weights: List[torch.Tensor], feature_names: Option
|
|
|
390
393
|
summary_df.to_csv(summary_path, index=False)
|
|
391
394
|
_LOGGER.info(f"📝 Attention summary data saved as '{summary_path.name}'")
|
|
392
395
|
|
|
393
|
-
# --- Step 3: Create and save the plot ---
|
|
394
|
-
|
|
396
|
+
# --- Step 3: Create and save the plot for top N features ---
|
|
397
|
+
plot_df = summary_df.head(top_n).sort_values('mean_attention', ascending=True)
|
|
395
398
|
|
|
396
|
-
|
|
397
|
-
plot_df = summary_df.sort_values('mean_attention', ascending=True)
|
|
399
|
+
plt.figure(figsize=(10, 8), dpi=100)
|
|
398
400
|
|
|
399
401
|
# Create horizontal bar plot with error bars
|
|
400
402
|
plt.barh(
|
|
@@ -408,7 +410,7 @@ def plot_attention_importance(weights: List[torch.Tensor], feature_names: Option
|
|
|
408
410
|
color='cornflowerblue'
|
|
409
411
|
)
|
|
410
412
|
|
|
411
|
-
plt.title('
|
|
413
|
+
plt.title('Top Features by Attention')
|
|
412
414
|
plt.xlabel('Average Attention Weight')
|
|
413
415
|
plt.ylabel('Feature')
|
|
414
416
|
plt.grid(axis='x', linestyle='--', alpha=0.6)
|
ml_tools/ML_models.py
CHANGED
|
@@ -43,7 +43,7 @@ class _ArchitectureHandlerMixin:
|
|
|
43
43
|
json.dump(config, f, indent=4)
|
|
44
44
|
|
|
45
45
|
if verbose:
|
|
46
|
-
_LOGGER.info(f"Architecture for '{self.__class__.__name__}' saved
|
|
46
|
+
_LOGGER.info(f"Architecture for '{self.__class__.__name__}' saved as '{full_path.name}'")
|
|
47
47
|
|
|
48
48
|
@classmethod
|
|
49
49
|
def load(cls: type, file_or_dir: Union[str, Path], verbose: bool = True) -> nn.Module:
|
|
@@ -147,6 +147,30 @@ class _BaseMLP(nn.Module, _ArchitectureHandlerMixin):
|
|
|
147
147
|
return f"{name}(arch: {arch_str})"
|
|
148
148
|
|
|
149
149
|
|
|
150
|
+
class _BaseAttention(_BaseMLP):
|
|
151
|
+
"""
|
|
152
|
+
Abstract base class for MLP models that incorporate an attention mechanism
|
|
153
|
+
before the main MLP layers.
|
|
154
|
+
"""
|
|
155
|
+
def __init__(self, *args, **kwargs):
|
|
156
|
+
super().__init__(*args, **kwargs)
|
|
157
|
+
# By default, models inheriting this do not have the flag.
|
|
158
|
+
self.has_interpretable_attention = False
|
|
159
|
+
|
|
160
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
161
|
+
"""Defines the standard forward pass."""
|
|
162
|
+
logits, _attention_weights = self.forward_attention(x)
|
|
163
|
+
return logits
|
|
164
|
+
|
|
165
|
+
def forward_attention(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
166
|
+
"""Returns logits and attention weights."""
|
|
167
|
+
# This logic is now shared and defined in one place
|
|
168
|
+
x, attention_weights = self.attention(x)
|
|
169
|
+
x = self.mlp(x)
|
|
170
|
+
logits = self.output_layer(x)
|
|
171
|
+
return logits, attention_weights
|
|
172
|
+
|
|
173
|
+
|
|
150
174
|
class MultilayerPerceptron(_BaseMLP):
|
|
151
175
|
"""
|
|
152
176
|
Creates a versatile Multilayer Perceptron (MLP) for regression or classification tasks.
|
|
@@ -184,7 +208,7 @@ class MultilayerPerceptron(_BaseMLP):
|
|
|
184
208
|
return self._repr_helper(name="MultilayerPerceptron", mlp_layers=layer_sizes)
|
|
185
209
|
|
|
186
210
|
|
|
187
|
-
class AttentionMLP(
|
|
211
|
+
class AttentionMLP(_BaseAttention):
|
|
188
212
|
"""
|
|
189
213
|
A Multilayer Perceptron (MLP) that incorporates an Attention layer to dynamically weigh input features.
|
|
190
214
|
|
|
@@ -205,25 +229,7 @@ class AttentionMLP(_BaseMLP):
|
|
|
205
229
|
super().__init__(in_features, out_targets, hidden_layers, drop_out)
|
|
206
230
|
# Attention
|
|
207
231
|
self.attention = _AttentionLayer(in_features)
|
|
208
|
-
|
|
209
|
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
210
|
-
"""
|
|
211
|
-
Defines the standard forward pass.
|
|
212
|
-
"""
|
|
213
|
-
logits, _attention_weights = self.forward_attention(x)
|
|
214
|
-
return logits
|
|
215
|
-
|
|
216
|
-
def forward_attention(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
217
|
-
"""
|
|
218
|
-
Returns logits and attention weights
|
|
219
|
-
"""
|
|
220
|
-
# The attention layer returns the processed x and the weights
|
|
221
|
-
x, attention_weights = self.attention(x)
|
|
222
|
-
|
|
223
|
-
# Pass the attention-modified tensor through the MLP
|
|
224
|
-
logits = self.mlp(x)
|
|
225
|
-
|
|
226
|
-
return logits, attention_weights
|
|
232
|
+
self.has_interpretable_attention = True
|
|
227
233
|
|
|
228
234
|
def __repr__(self) -> str:
|
|
229
235
|
"""Returns the developer-friendly string representation of the model."""
|
|
@@ -238,7 +244,7 @@ class AttentionMLP(_BaseMLP):
|
|
|
238
244
|
return self._repr_helper(name="AttentionMLP", mlp_layers=arch)
|
|
239
245
|
|
|
240
246
|
|
|
241
|
-
class MultiHeadAttentionMLP(
|
|
247
|
+
class MultiHeadAttentionMLP(_BaseAttention):
|
|
242
248
|
"""
|
|
243
249
|
An MLP that incorporates a standard `nn.MultiheadAttention` layer to process
|
|
244
250
|
the input features.
|
|
@@ -267,24 +273,6 @@ class MultiHeadAttentionMLP(_BaseMLP):
|
|
|
267
273
|
dropout=attention_dropout
|
|
268
274
|
)
|
|
269
275
|
|
|
270
|
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
271
|
-
"""Defines the standard forward pass of the model."""
|
|
272
|
-
logits, _attention_weights = self.forward_attention(x)
|
|
273
|
-
return logits
|
|
274
|
-
|
|
275
|
-
def forward_attention(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
276
|
-
"""
|
|
277
|
-
Returns logits and attention weights.
|
|
278
|
-
"""
|
|
279
|
-
# The attention layer returns the processed x and the weights
|
|
280
|
-
x, attention_weights = self.attention(x)
|
|
281
|
-
|
|
282
|
-
# Pass the attention-modified tensor through the MLP and prediction head
|
|
283
|
-
x = self.mlp(x)
|
|
284
|
-
logits = self.output_layer(x)
|
|
285
|
-
|
|
286
|
-
return logits, attention_weights
|
|
287
|
-
|
|
288
276
|
def get_architecture_config(self) -> Dict[str, Any]:
|
|
289
277
|
"""Returns the full configuration of the model."""
|
|
290
278
|
config = super().get_architecture_config()
|
ml_tools/ML_scaler.py
CHANGED
|
@@ -164,7 +164,7 @@ class PytorchScaler:
|
|
|
164
164
|
}
|
|
165
165
|
torch.save(state, path_obj)
|
|
166
166
|
if verbose:
|
|
167
|
-
_LOGGER.info(f"PytorchScaler state saved
|
|
167
|
+
_LOGGER.info(f"PytorchScaler state saved as '{path_obj.name}'.")
|
|
168
168
|
|
|
169
169
|
@staticmethod
|
|
170
170
|
def load(filepath: Union[str, Path], verbose: bool=True) -> 'PytorchScaler':
|
ml_tools/ML_trainer.py
CHANGED
|
@@ -472,23 +472,30 @@ class MLTrainer:
|
|
|
472
472
|
|
|
473
473
|
yield attention_weights
|
|
474
474
|
|
|
475
|
-
def explain_attention(self, save_dir: Union[str, Path],
|
|
475
|
+
def explain_attention(self, save_dir: Union[str, Path],
|
|
476
|
+
feature_names: Optional[List[str]],
|
|
477
|
+
explain_dataset: Optional[Dataset] = None,
|
|
478
|
+
plot_n_features: int = 10):
|
|
476
479
|
"""
|
|
477
480
|
Generates and saves a feature importance plot based on attention weights.
|
|
478
481
|
|
|
479
|
-
This method only works for models with
|
|
482
|
+
This method only works for models with models with 'has_interpretable_attention'.
|
|
480
483
|
|
|
481
484
|
Args:
|
|
482
485
|
save_dir (str | Path): Directory to save the plot and summary data.
|
|
483
|
-
feature_names (List[str] | None): Names for the features for plot labeling.
|
|
486
|
+
feature_names (List[str] | None): Names for the features for plot labeling. If not given, generic names will be used.
|
|
484
487
|
explain_dataset (Dataset, optional): A specific dataset to explain. If None, the trainer's test dataset is used.
|
|
488
|
+
plot_n_features (int): Number of top features to plot.
|
|
485
489
|
"""
|
|
486
490
|
|
|
487
491
|
print("\n--- Attention Analysis ---")
|
|
488
492
|
|
|
489
493
|
# --- Step 1: Check if the model supports this explanation ---
|
|
490
|
-
if not
|
|
491
|
-
_LOGGER.
|
|
494
|
+
if not getattr(self.model, 'has_interpretable_attention', False):
|
|
495
|
+
_LOGGER.warning(
|
|
496
|
+
"Model is not flagged for interpretable attention analysis. "
|
|
497
|
+
"Skipping. This is the correct behavior for models like MultiHeadAttentionMLP."
|
|
498
|
+
)
|
|
492
499
|
return
|
|
493
500
|
|
|
494
501
|
# --- Step 2: Set up the dataloader ---
|
|
@@ -514,7 +521,8 @@ class MLTrainer:
|
|
|
514
521
|
plot_attention_importance(
|
|
515
522
|
weights=all_weights,
|
|
516
523
|
feature_names=feature_names,
|
|
517
|
-
save_dir=save_dir
|
|
524
|
+
save_dir=save_dir,
|
|
525
|
+
top_n=plot_n_features
|
|
518
526
|
)
|
|
519
527
|
else:
|
|
520
528
|
_LOGGER.error("No attention weights were collected from the model.")
|
ml_tools/SQL.py
CHANGED
|
@@ -4,7 +4,7 @@ from pathlib import Path
|
|
|
4
4
|
from typing import Union, Dict, Any, Optional, List, Literal
|
|
5
5
|
from ._logger import _LOGGER
|
|
6
6
|
from ._script_info import _script_info
|
|
7
|
-
from .path_manager import make_fullpath
|
|
7
|
+
from .path_manager import make_fullpath, sanitize_filename
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
__all__ = [
|
|
@@ -94,11 +94,13 @@ class DatabaseManager:
|
|
|
94
94
|
if not self.cursor:
|
|
95
95
|
_LOGGER.error("Database connection is not open.")
|
|
96
96
|
raise sqlite3.Error()
|
|
97
|
+
|
|
98
|
+
sanitized_table_name = sanitize_filename(table_name)
|
|
97
99
|
|
|
98
100
|
columns_def = ", ".join([f'"{col_name}" {col_type}' for col_name, col_type in schema.items()])
|
|
99
101
|
exists_clause = "IF NOT EXISTS" if if_not_exists else ""
|
|
100
102
|
|
|
101
|
-
query = f"CREATE TABLE {exists_clause} {
|
|
103
|
+
query = f"CREATE TABLE {exists_clause} {sanitized_table_name} ({columns_def})"
|
|
102
104
|
|
|
103
105
|
_LOGGER.info(f"➡️ Executing: {query}")
|
|
104
106
|
self.cursor.execute(query)
|
ml_tools/ensemble_evaluation.py
CHANGED
|
@@ -25,6 +25,7 @@ from typing import Union, Optional, Literal
|
|
|
25
25
|
from .path_manager import sanitize_filename, make_fullpath
|
|
26
26
|
from ._script_info import _script_info
|
|
27
27
|
from ._logger import _LOGGER
|
|
28
|
+
from .keys import SHAPKeys
|
|
28
29
|
|
|
29
30
|
|
|
30
31
|
__all__ = [
|
|
@@ -472,7 +473,7 @@ def get_shap_values(
|
|
|
472
473
|
save_dir: Directory to save visualizations.
|
|
473
474
|
"""
|
|
474
475
|
sanitized_target_name = sanitize_filename(target_name)
|
|
475
|
-
global_save_path = make_fullpath(save_dir, make=True)
|
|
476
|
+
global_save_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
476
477
|
|
|
477
478
|
def _apply_plot_style():
|
|
478
479
|
styles = ['seaborn', 'seaborn-v0_8-darkgrid', 'seaborn-v0_8', 'default']
|
|
@@ -539,6 +540,15 @@ def get_shap_values(
|
|
|
539
540
|
plot_type=plot_type,
|
|
540
541
|
title=f"{model_name} - {target_name} (Class {class_name})"
|
|
541
542
|
)
|
|
543
|
+
|
|
544
|
+
# Save the summary data for the current class
|
|
545
|
+
summary_save_path = global_save_path / f"SHAP_{sanitized_target_name}_{class_name}.csv"
|
|
546
|
+
_save_summary_csv(
|
|
547
|
+
shap_values_for_summary=class_shap,
|
|
548
|
+
feature_names=feature_names,
|
|
549
|
+
save_path=summary_save_path
|
|
550
|
+
)
|
|
551
|
+
|
|
542
552
|
else:
|
|
543
553
|
values = shap_values[1] if isinstance(shap_values, list) else shap_values
|
|
544
554
|
for plot_type in ["bar", "dot"]:
|
|
@@ -549,6 +559,15 @@ def get_shap_values(
|
|
|
549
559
|
plot_type=plot_type,
|
|
550
560
|
title=f"{model_name} - {target_name}"
|
|
551
561
|
)
|
|
562
|
+
|
|
563
|
+
# Save the summary data for the positive class
|
|
564
|
+
shap_summary_filename = SHAPKeys.SAVENAME + ".csv"
|
|
565
|
+
summary_save_path = global_save_path / shap_summary_filename
|
|
566
|
+
_save_summary_csv(
|
|
567
|
+
shap_values_for_summary=values,
|
|
568
|
+
feature_names=feature_names,
|
|
569
|
+
save_path=summary_save_path
|
|
570
|
+
)
|
|
552
571
|
|
|
553
572
|
def _plot_for_regression(shap_values):
|
|
554
573
|
for plot_type in ["bar", "dot"]:
|
|
@@ -559,6 +578,34 @@ def get_shap_values(
|
|
|
559
578
|
plot_type=plot_type,
|
|
560
579
|
title=f"{model_name} - {target_name}"
|
|
561
580
|
)
|
|
581
|
+
|
|
582
|
+
# Save the summary data to a CSV file
|
|
583
|
+
shap_summary_filename = SHAPKeys.SAVENAME + ".csv"
|
|
584
|
+
summary_save_path = global_save_path / shap_summary_filename
|
|
585
|
+
_save_summary_csv(
|
|
586
|
+
shap_values_for_summary=shap_values,
|
|
587
|
+
feature_names=feature_names,
|
|
588
|
+
save_path=summary_save_path
|
|
589
|
+
)
|
|
590
|
+
|
|
591
|
+
def _save_summary_csv(shap_values_for_summary: np.ndarray, feature_names: list[str], save_path: Path):
|
|
592
|
+
"""Calculates and saves the SHAP summary data to a CSV file."""
|
|
593
|
+
mean_abs_shap = np.abs(shap_values_for_summary).mean(axis=0)
|
|
594
|
+
|
|
595
|
+
# Create default feature names if none are provided
|
|
596
|
+
current_feature_names = feature_names
|
|
597
|
+
if current_feature_names is None:
|
|
598
|
+
current_feature_names = [f'feature_{i}' for i in range(len(mean_abs_shap))]
|
|
599
|
+
|
|
600
|
+
summary_df = pd.DataFrame({
|
|
601
|
+
SHAPKeys.FEATURE_COLUMN: feature_names,
|
|
602
|
+
SHAPKeys.SHAP_VALUE_COLUMN: mean_abs_shap
|
|
603
|
+
}).sort_values(SHAPKeys.SHAP_VALUE_COLUMN, ascending=False)
|
|
604
|
+
|
|
605
|
+
summary_df.to_csv(save_path, index=False)
|
|
606
|
+
# print(f"📝 SHAP summary data saved as '{save_path.name}'")
|
|
607
|
+
|
|
608
|
+
|
|
562
609
|
#START_O
|
|
563
610
|
|
|
564
611
|
explainer = shap.TreeExplainer(model)
|
ml_tools/keys.py
CHANGED
|
@@ -61,6 +61,13 @@ class DatasetKeys:
|
|
|
61
61
|
SCALER_PREFIX = "scaler_"
|
|
62
62
|
|
|
63
63
|
|
|
64
|
+
class SHAPKeys:
|
|
65
|
+
"""Keys for SHAP functions"""
|
|
66
|
+
FEATURE_COLUMN = "feature"
|
|
67
|
+
SHAP_VALUE_COLUMN = "mean_abs_shap_value"
|
|
68
|
+
SAVENAME = "shap_summary"
|
|
69
|
+
|
|
70
|
+
|
|
64
71
|
class _OneHotOtherPlaceholder:
|
|
65
72
|
"""Used internally by GUI_tools."""
|
|
66
73
|
OTHER_GUI = "OTHER"
|
ml_tools/utilities.py
CHANGED
|
@@ -9,7 +9,7 @@ from joblib.externals.loky.process_executor import TerminatedWorkerError
|
|
|
9
9
|
from .path_manager import sanitize_filename, make_fullpath, list_csv_paths, list_files_by_extension, list_subdirectories
|
|
10
10
|
from ._script_info import _script_info
|
|
11
11
|
from ._logger import _LOGGER
|
|
12
|
-
from .keys import DatasetKeys, PytorchModelArchitectureKeys, PytorchArtifactPathKeys
|
|
12
|
+
from .keys import DatasetKeys, PytorchModelArchitectureKeys, PytorchArtifactPathKeys, SHAPKeys
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
# Keep track of available tools
|
|
@@ -26,7 +26,8 @@ __all__ = [
|
|
|
26
26
|
"distribute_dataset_by_target",
|
|
27
27
|
"train_dataset_orchestrator",
|
|
28
28
|
"train_dataset_yielder",
|
|
29
|
-
"find_model_artifacts"
|
|
29
|
+
"find_model_artifacts",
|
|
30
|
+
"select_features_by_shap"
|
|
30
31
|
]
|
|
31
32
|
|
|
32
33
|
|
|
@@ -34,6 +35,7 @@ __all__ = [
|
|
|
34
35
|
@overload
|
|
35
36
|
def load_dataframe(
|
|
36
37
|
df_path: Union[str, Path],
|
|
38
|
+
use_columns: Optional[list[str]] = None,
|
|
37
39
|
kind: Literal["pandas"] = "pandas",
|
|
38
40
|
all_strings: bool = False,
|
|
39
41
|
verbose: bool = True
|
|
@@ -44,7 +46,8 @@ def load_dataframe(
|
|
|
44
46
|
@overload
|
|
45
47
|
def load_dataframe(
|
|
46
48
|
df_path: Union[str, Path],
|
|
47
|
-
|
|
49
|
+
use_columns: Optional[list[str]] = None,
|
|
50
|
+
kind: Literal["polars"] = "polars",
|
|
48
51
|
all_strings: bool = False,
|
|
49
52
|
verbose: bool = True
|
|
50
53
|
) -> Tuple[pl.DataFrame, str]:
|
|
@@ -52,6 +55,7 @@ def load_dataframe(
|
|
|
52
55
|
|
|
53
56
|
def load_dataframe(
|
|
54
57
|
df_path: Union[str, Path],
|
|
58
|
+
use_columns: Optional[list[str]] = None,
|
|
55
59
|
kind: Literal["pandas", "polars"] = "pandas",
|
|
56
60
|
all_strings: bool = False,
|
|
57
61
|
verbose: bool = True
|
|
@@ -60,11 +64,13 @@ def load_dataframe(
|
|
|
60
64
|
Load a CSV file into a DataFrame and extract its base name.
|
|
61
65
|
|
|
62
66
|
Can load data as either a pandas or a polars DataFrame. Allows for loading all
|
|
63
|
-
columns as string types to prevent type inference errors.
|
|
67
|
+
columns or a subset of columns as string types to prevent type inference errors.
|
|
64
68
|
|
|
65
69
|
Args:
|
|
66
70
|
df_path (str, Path):
|
|
67
71
|
The path to the CSV file.
|
|
72
|
+
use_columns (list[str] | None):
|
|
73
|
+
If provided, only these columns will be loaded from the CSV.
|
|
68
74
|
kind ("pandas", "polars"):
|
|
69
75
|
The type of DataFrame to load. Defaults to "pandas".
|
|
70
76
|
all_strings (bool):
|
|
@@ -78,28 +84,43 @@ def load_dataframe(
|
|
|
78
84
|
|
|
79
85
|
Raises:
|
|
80
86
|
FileNotFoundError: If the file does not exist at the given path.
|
|
81
|
-
ValueError: If the DataFrame is empty
|
|
87
|
+
ValueError: If the DataFrame is empty, an invalid 'kind' is provided, or a column in 'use_columns' is not found in the file.
|
|
82
88
|
"""
|
|
83
89
|
path = make_fullpath(df_path)
|
|
84
90
|
|
|
85
91
|
df_name = path.stem
|
|
86
92
|
|
|
87
|
-
|
|
88
|
-
if
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
93
|
+
try:
|
|
94
|
+
if kind == "pandas":
|
|
95
|
+
pd_kwargs: dict[str,Any]
|
|
96
|
+
pd_kwargs = {'encoding': 'utf-8'}
|
|
97
|
+
if use_columns:
|
|
98
|
+
pd_kwargs['usecols'] = use_columns
|
|
99
|
+
if all_strings:
|
|
100
|
+
pd_kwargs['dtype'] = str
|
|
101
|
+
|
|
102
|
+
df = pd.read_csv(path, **pd_kwargs)
|
|
103
|
+
|
|
104
|
+
elif kind == "polars":
|
|
105
|
+
pl_kwargs: dict[str,Any]
|
|
106
|
+
pl_kwargs = {}
|
|
107
|
+
if use_columns:
|
|
108
|
+
pl_kwargs['columns'] = use_columns
|
|
109
|
+
|
|
110
|
+
if all_strings:
|
|
111
|
+
pl_kwargs['infer_schema'] = False
|
|
112
|
+
else:
|
|
113
|
+
pl_kwargs['infer_schema_length'] = 1000
|
|
114
|
+
|
|
115
|
+
df = pl.read_csv(path, **pl_kwargs)
|
|
116
|
+
|
|
96
117
|
else:
|
|
97
|
-
|
|
98
|
-
|
|
118
|
+
_LOGGER.error(f"Invalid kind '{kind}'. Must be one of 'pandas' or 'polars'.")
|
|
119
|
+
raise ValueError()
|
|
99
120
|
|
|
100
|
-
|
|
101
|
-
_LOGGER.error(f"
|
|
102
|
-
raise
|
|
121
|
+
except (ValueError, pl.exceptions.ColumnNotFoundError) as e:
|
|
122
|
+
_LOGGER.error(f"Failed to load '{df_name}'. A specified column may not exist in the file.")
|
|
123
|
+
raise e
|
|
103
124
|
|
|
104
125
|
# This check works for both pandas and polars DataFrames
|
|
105
126
|
if df.shape[0] == 0:
|
|
@@ -111,7 +132,6 @@ def load_dataframe(
|
|
|
111
132
|
|
|
112
133
|
return df, df_name # type: ignore
|
|
113
134
|
|
|
114
|
-
|
|
115
135
|
def yield_dataframes_from_dir(datasets_dir: Union[str,Path], verbose: bool=True):
|
|
116
136
|
"""
|
|
117
137
|
Iterates over all CSV files in a given directory, loading each into a Pandas DataFrame.
|
|
@@ -683,5 +703,84 @@ def find_model_artifacts(target_directory: Union[str,Path], load_scaler: bool, v
|
|
|
683
703
|
return all_artifacts
|
|
684
704
|
|
|
685
705
|
|
|
706
|
+
def select_features_by_shap(
|
|
707
|
+
root_directory: Union[str, Path],
|
|
708
|
+
shap_threshold: float = 1.0,
|
|
709
|
+
verbose: bool = True) -> list[str]:
|
|
710
|
+
"""
|
|
711
|
+
Scans subdirectories to find SHAP summary CSVs, then extracts feature
|
|
712
|
+
names whose mean absolute SHAP value meets a specified threshold.
|
|
713
|
+
|
|
714
|
+
This function is useful for automated feature selection based on feature
|
|
715
|
+
importance scores aggregated from multiple models.
|
|
716
|
+
|
|
717
|
+
Args:
|
|
718
|
+
root_directory (Union[str, Path]):
|
|
719
|
+
The path to the root directory that contains model subdirectories.
|
|
720
|
+
shap_threshold (float):
|
|
721
|
+
The minimum mean absolute SHAP value for a feature to be included
|
|
722
|
+
in the final list.
|
|
723
|
+
|
|
724
|
+
Returns:
|
|
725
|
+
list[str]:
|
|
726
|
+
A single, sorted list of unique feature names that meet the
|
|
727
|
+
threshold criteria across all found files.
|
|
728
|
+
"""
|
|
729
|
+
if verbose:
|
|
730
|
+
_LOGGER.info(f"Starting feature selection with SHAP threshold >= {shap_threshold}")
|
|
731
|
+
root_path = make_fullpath(root_directory, enforce="directory")
|
|
732
|
+
|
|
733
|
+
# --- Step 2: Directory and File Discovery ---
|
|
734
|
+
subdirectories = list_subdirectories(root_dir=root_path, verbose=False)
|
|
735
|
+
|
|
736
|
+
shap_filename = SHAPKeys.SAVENAME + ".csv"
|
|
737
|
+
|
|
738
|
+
valid_csv_paths = []
|
|
739
|
+
for dir_name, dir_path in subdirectories.items():
|
|
740
|
+
expected_path = dir_path / shap_filename
|
|
741
|
+
if expected_path.is_file():
|
|
742
|
+
valid_csv_paths.append(expected_path)
|
|
743
|
+
else:
|
|
744
|
+
_LOGGER.warning(f"No '{shap_filename}' found in subdirectory '{dir_name}'.")
|
|
745
|
+
|
|
746
|
+
if not valid_csv_paths:
|
|
747
|
+
_LOGGER.error(f"Process halted: No '{shap_filename}' files were found in any subdirectory.")
|
|
748
|
+
return []
|
|
749
|
+
|
|
750
|
+
if verbose:
|
|
751
|
+
_LOGGER.info(f"Found {len(valid_csv_paths)} SHAP summary files to process.")
|
|
752
|
+
|
|
753
|
+
# --- Step 3: Data Processing and Feature Extraction ---
|
|
754
|
+
master_feature_set = set()
|
|
755
|
+
for csv_path in valid_csv_paths:
|
|
756
|
+
try:
|
|
757
|
+
df, _ = load_dataframe(csv_path, kind="pandas", verbose=False)
|
|
758
|
+
|
|
759
|
+
# Validate required columns
|
|
760
|
+
required_cols = {SHAPKeys.FEATURE_COLUMN, SHAPKeys.SHAP_VALUE_COLUMN}
|
|
761
|
+
if not required_cols.issubset(df.columns):
|
|
762
|
+
_LOGGER.warning(f"Skipping '{csv_path}': missing required columns.")
|
|
763
|
+
continue
|
|
764
|
+
|
|
765
|
+
# Filter by threshold and extract features
|
|
766
|
+
filtered_df = df[df[SHAPKeys.SHAP_VALUE_COLUMN] >= shap_threshold]
|
|
767
|
+
features = filtered_df[SHAPKeys.FEATURE_COLUMN].tolist()
|
|
768
|
+
master_feature_set.update(features)
|
|
769
|
+
|
|
770
|
+
except (ValueError, pd.errors.EmptyDataError):
|
|
771
|
+
_LOGGER.warning(f"Skipping '{csv_path}' because it is empty or malformed.")
|
|
772
|
+
continue
|
|
773
|
+
except Exception as e:
|
|
774
|
+
_LOGGER.error(f"An unexpected error occurred while processing '{csv_path}': {e}")
|
|
775
|
+
continue
|
|
776
|
+
|
|
777
|
+
# --- Step 4: Finalize and Return ---
|
|
778
|
+
final_features = sorted(list(master_feature_set))
|
|
779
|
+
if verbose:
|
|
780
|
+
_LOGGER.info(f"Selected {len(final_features)} unique features across all files.")
|
|
781
|
+
|
|
782
|
+
return final_features
|
|
783
|
+
|
|
784
|
+
|
|
686
785
|
def info():
|
|
687
786
|
_script_info(__all__)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|