dragon-ml-toolbox 10.8.0__py3-none-any.whl → 10.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dragon-ml-toolbox might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dragon-ml-toolbox
3
- Version: 10.8.0
3
+ Version: 10.10.0
4
4
  Summary: A collection of tools for data science and machine learning projects.
5
5
  Author-email: Karl Loza <luigiloza@gmail.com>
6
6
  License-Expression: MIT
@@ -1,36 +1,36 @@
1
- dragon_ml_toolbox-10.8.0.dist-info/licenses/LICENSE,sha256=2uUFNy7D0TLgHim1K5s3DIJ4q_KvxEXVilnU20cWliY,1066
2
- dragon_ml_toolbox-10.8.0.dist-info/licenses/LICENSE-THIRD-PARTY.md,sha256=lY4_rJPnLnMu7YBQaY-_iz1JRDcLdQzNCyeLAF1glJY,1837
1
+ dragon_ml_toolbox-10.10.0.dist-info/licenses/LICENSE,sha256=2uUFNy7D0TLgHim1K5s3DIJ4q_KvxEXVilnU20cWliY,1066
2
+ dragon_ml_toolbox-10.10.0.dist-info/licenses/LICENSE-THIRD-PARTY.md,sha256=lY4_rJPnLnMu7YBQaY-_iz1JRDcLdQzNCyeLAF1glJY,1837
3
3
  ml_tools/ETL_cleaning.py,sha256=lSP5q6-ukGhJBPV8dlsqJvPXAzj4du_0J-SbtEd0Pjg,19292
4
4
  ml_tools/ETL_engineering.py,sha256=a6KCWH6kRatZtjaFEF_o917ApPMK5_vRD-BjfCDAl-E,49400
5
5
  ml_tools/GUI_tools.py,sha256=kEQWg-bog3pB5tI22gMGKWaCGHnz9TB2Lvvfhf5F2CI,45412
6
6
  ml_tools/MICE_imputation.py,sha256=kVSythWfxJFR4-2mtcYCWQaQ1Oz5yyx_SJu5gjnS7H8,11670
7
7
  ml_tools/ML_callbacks.py,sha256=JPvEw_cW5tYNJ2rMSgnNrKLuni_UrmuhDFaOw-u2SvA,13926
8
- ml_tools/ML_datasetmaster.py,sha256=BMmdCVAZ-HSnnSPLzKla2TdZKvHkHj4t9A0V1Ba3i-I,30821
9
- ml_tools/ML_evaluation.py,sha256=28JJ2M71p4pxniwav2Hv3b1a5dsvaoIYNLm-UJQuXvY,16002
8
+ ml_tools/ML_datasetmaster.py,sha256=vqKZhCXsvN5yeRJdOKqMPh5OhY1xe6xlNjM3WoH5lys,30821
9
+ ml_tools/ML_evaluation.py,sha256=6FB6S-aDDpFzQdrp3flBVECzEsHhMbQknYVGhHooEFs,16207
10
10
  ml_tools/ML_evaluation_multi.py,sha256=2jTSNFCu8cz5C05EusnrDyffs59M2Fq3UXSHxo2TR1A,12515
11
11
  ml_tools/ML_inference.py,sha256=SGDPiPxs_OYDKKRZziFMyaWcC8A37c70W9t-dMP5niI,23066
12
- ml_tools/ML_models.py,sha256=FliuqGhxP7AWHCweTLlfssXFOjwvFhIYJsgj_w_-EI4,27901
12
+ ml_tools/ML_models.py,sha256=8UOMg9Qn8qtecUGfgnLRedX-lCWYwEs-C5RJ2m8mZM4,27544
13
13
  ml_tools/ML_optimization.py,sha256=a2Uxe1g-y4I-gFa8ENIM8QDS-Pz3hoPRRaVXAWMbyQA,13491
14
- ml_tools/ML_scaler.py,sha256=IrZsAr1xjvuLi8s5IKR-qbk2mS_awl3mn_xoXg5TJyA,7535
15
- ml_tools/ML_trainer.py,sha256=xw1zMgYpdqwsTt604xe3GTQNvpg6z6Ze-avmitGBFeU,23539
14
+ ml_tools/ML_scaler.py,sha256=h2ymq5u953Lx60Qb38Y0mAWj85x9PbnP0xYNQ3pd8-w,7535
15
+ ml_tools/ML_trainer.py,sha256=_g48w5Ak-wQr5fGHdJqlcpnzv3gWyL1ghkOhy9VOZbo,23930
16
16
  ml_tools/PSO_optimization.py,sha256=q0VYpssQGbPum7xdnkDXlJQKhZMYZo8acHpKhajPK3c,22954
17
17
  ml_tools/RNN_forecast.py,sha256=8rNZr-eWOBXMiDQV22e_tQTPM5LM2IFggEAa1FaoXaI,1965
18
- ml_tools/SQL.py,sha256=WDgdZUYuLBUpv-4Am9XjVY_Aq_jxBWdLrbcgAIEwefI,10704
18
+ ml_tools/SQL.py,sha256=givoz6CGWRUdqnBem3VGZxzGdo3ZbX00kyHNjzI8kWE,10803
19
19
  ml_tools/VIF_factor.py,sha256=MkMh_RIdsN2XUPzKNGRiEcmB17R_MmvGV4ezpL5zD2E,10403
20
20
  ml_tools/__init__.py,sha256=q0y9faQ6e17XCQ7eUiCZ1FJ4Bg5EQqLjZ9f_l5REUUY,41
21
21
  ml_tools/_logger.py,sha256=wcImAiXEZKPNcwM30qBh3t7HvoPURonJY0nrgMGF0sM,4719
22
22
  ml_tools/_script_info.py,sha256=21r83LV3RubsNZ_RTEUON6RbDf7Mh4_udweNcvdF_Fk,212
23
23
  ml_tools/custom_logger.py,sha256=ry43hk54K6xKo8jRAgq1sFxUpOA9T0LIJ7sw0so2BW0,5880
24
24
  ml_tools/data_exploration.py,sha256=4McT2BR9muK4JVVTKUfvRyThe0m_o2vpy9RJ1f_1FeY,28692
25
- ml_tools/ensemble_evaluation.py,sha256=xMEMfXJ5MjTkTfr1LkFOeD7iUtnVDCW3S9lm3zT-6tY,24778
25
+ ml_tools/ensemble_evaluation.py,sha256=FGHSe8LBI8_w8LjNeJWOcYQ1UK_mc6fVah8gmSvNVGg,26853
26
26
  ml_tools/ensemble_inference.py,sha256=EFHnbjbu31fcVp88NBx8lWAVdu2Gpg9MY9huVZJHFfM,9350
27
27
  ml_tools/ensemble_learning.py,sha256=3s0kH4i_naj0IVl_T4knst-Hwg4TScWjEdsXX5KAi7I,21929
28
28
  ml_tools/handle_excel.py,sha256=He4UT15sCGhaG-JKfs7uYVAubxWjrqgJ6U7OhMR2fuE,14005
29
- ml_tools/keys.py,sha256=sZANLHvp_93pPigviMOz7AhampGlpokcop_llzsjWBw,1689
29
+ ml_tools/keys.py,sha256=FDpbS3Jb0pjrVvvp2_8nZi919mbob_-xwuy5OOtKM_A,1848
30
30
  ml_tools/optimization_tools.py,sha256=P3I6lIpvZ8Xf2kX5FvvBKBmrK2pB6idBpkTzfUJxTeE,5073
31
31
  ml_tools/path_manager.py,sha256=wLJlz3Y9_1-LB9em4B2VYDCVuTOX2eOc7D6hbbebjgM,14990
32
- ml_tools/utilities.py,sha256=xddY0uASKQWSuUsYJEcfDUkeC-ccbYlkycqHKdkPnhk,25105
33
- dragon_ml_toolbox-10.8.0.dist-info/METADATA,sha256=Ly11G7vOgCFbYwEYXQXa8RBgvWof9thiBxVjlk9DZu4,6968
34
- dragon_ml_toolbox-10.8.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
35
- dragon_ml_toolbox-10.8.0.dist-info/top_level.txt,sha256=wm-oxax3ciyez6VoO4zsFd-gSok2VipYXnbg3TH9PtU,9
36
- dragon_ml_toolbox-10.8.0.dist-info/RECORD,,
32
+ ml_tools/utilities.py,sha256=30z0x1aDLyBGzF98_tgSaxwFafYwQS-GTFzXHopBSGc,29105
33
+ dragon_ml_toolbox-10.10.0.dist-info/METADATA,sha256=hSrcYAuoE1H0uF77-8TClwrcdlQwg0f1BGixlh_Q0Wo,6969
34
+ dragon_ml_toolbox-10.10.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
35
+ dragon_ml_toolbox-10.10.0.dist-info/top_level.txt,sha256=wm-oxax3ciyez6VoO4zsFd-gSok2VipYXnbg3TH9PtU,9
36
+ dragon_ml_toolbox-10.10.0.dist-info/RECORD,,
@@ -200,7 +200,7 @@ class _BaseDatasetMaker(ABC):
200
200
  filepath = save_path / filename
201
201
  self.scaler.save(filepath, verbose=False)
202
202
  if verbose:
203
- _LOGGER.info(f"Scaler for dataset '{self.id}' saved to '{filepath.name}'.")
203
+ _LOGGER.info(f"Scaler for dataset '{self.id}' saved as '{filepath.name}'.")
204
204
 
205
205
 
206
206
  # Single target dataset
ml_tools/ML_evaluation.py CHANGED
@@ -22,6 +22,7 @@ from .path_manager import make_fullpath
22
22
  from ._logger import _LOGGER
23
23
  from typing import Union, Optional, List
24
24
  from ._script_info import _script_info
25
+ from .keys import SHAPKeys
25
26
 
26
27
 
27
28
  __all__ = [
@@ -333,7 +334,8 @@ def shap_summary_plot(model,
333
334
  plt.close()
334
335
 
335
336
  # Save Summary Data to CSV
336
- summary_path = save_dir_path / "shap_summary.csv"
337
+ shap_summary_filename = SHAPKeys.SAVENAME + ".csv"
338
+ summary_path = save_dir_path / shap_summary_filename
337
339
  # Ensure the array is 1D before creating the DataFrame
338
340
  mean_abs_shap = np.abs(shap_values).mean(axis=0).flatten()
339
341
 
@@ -341,9 +343,9 @@ def shap_summary_plot(model,
341
343
  feature_names = [f'feature_{i}' for i in range(len(mean_abs_shap))]
342
344
 
343
345
  summary_df = pd.DataFrame({
344
- 'feature': feature_names,
345
- 'mean_abs_shap_value': mean_abs_shap
346
- }).sort_values('mean_abs_shap_value', ascending=False)
346
+ SHAPKeys.FEATURE_COLUMN: feature_names,
347
+ SHAPKeys.SHAP_VALUE_COLUMN: mean_abs_shap
348
+ }).sort_values(SHAPKeys.SHAP_VALUE_COLUMN, ascending=False)
347
349
 
348
350
  summary_df.to_csv(summary_path, index=False)
349
351
 
@@ -351,7 +353,7 @@ def shap_summary_plot(model,
351
353
  plt.ion()
352
354
 
353
355
 
354
- def plot_attention_importance(weights: List[torch.Tensor], feature_names: Optional[List[str]], save_dir: Union[str, Path]):
356
+ def plot_attention_importance(weights: List[torch.Tensor], feature_names: Optional[List[str]], save_dir: Union[str, Path], top_n: int = 10):
355
357
  """
356
358
  Aggregates attention weights and plots global feature importance.
357
359
 
@@ -362,6 +364,7 @@ def plot_attention_importance(weights: List[torch.Tensor], feature_names: Option
362
364
  weights (List[torch.Tensor]): A list of attention weight tensors from each batch.
363
365
  feature_names (List[str] | None): Names of the features for plot labeling.
364
366
  save_dir (str | Path): Directory to save the plot and summary CSV.
367
+ top_n (int): The number of top features to display in the plot.
365
368
  """
366
369
  if not weights:
367
370
  _LOGGER.error("Attention weights list is empty. Skipping importance plot.")
@@ -390,11 +393,10 @@ def plot_attention_importance(weights: List[torch.Tensor], feature_names: Option
390
393
  summary_df.to_csv(summary_path, index=False)
391
394
  _LOGGER.info(f"📝 Attention summary data saved as '{summary_path.name}'")
392
395
 
393
- # --- Step 3: Create and save the plot ---
394
- plt.figure(figsize=(10, 8), dpi=100)
396
+ # --- Step 3: Create and save the plot for top N features ---
397
+ plot_df = summary_df.head(top_n).sort_values('mean_attention', ascending=True)
395
398
 
396
- # Sort for plotting
397
- plot_df = summary_df.sort_values('mean_attention', ascending=True)
399
+ plt.figure(figsize=(10, 8), dpi=100)
398
400
 
399
401
  # Create horizontal bar plot with error bars
400
402
  plt.barh(
@@ -408,7 +410,7 @@ def plot_attention_importance(weights: List[torch.Tensor], feature_names: Option
408
410
  color='cornflowerblue'
409
411
  )
410
412
 
411
- plt.title('Global Feature Importance')
413
+ plt.title('Top Features by Attention')
412
414
  plt.xlabel('Average Attention Weight')
413
415
  plt.ylabel('Feature')
414
416
  plt.grid(axis='x', linestyle='--', alpha=0.6)
ml_tools/ML_models.py CHANGED
@@ -43,7 +43,7 @@ class _ArchitectureHandlerMixin:
43
43
  json.dump(config, f, indent=4)
44
44
 
45
45
  if verbose:
46
- _LOGGER.info(f"Architecture for '{self.__class__.__name__}' saved to '{path_dir.name}'")
46
+ _LOGGER.info(f"Architecture for '{self.__class__.__name__}' saved as '{full_path.name}'")
47
47
 
48
48
  @classmethod
49
49
  def load(cls: type, file_or_dir: Union[str, Path], verbose: bool = True) -> nn.Module:
@@ -147,6 +147,30 @@ class _BaseMLP(nn.Module, _ArchitectureHandlerMixin):
147
147
  return f"{name}(arch: {arch_str})"
148
148
 
149
149
 
150
+ class _BaseAttention(_BaseMLP):
151
+ """
152
+ Abstract base class for MLP models that incorporate an attention mechanism
153
+ before the main MLP layers.
154
+ """
155
+ def __init__(self, *args, **kwargs):
156
+ super().__init__(*args, **kwargs)
157
+ # By default, models inheriting this do not have the flag.
158
+ self.has_interpretable_attention = False
159
+
160
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
161
+ """Defines the standard forward pass."""
162
+ logits, _attention_weights = self.forward_attention(x)
163
+ return logits
164
+
165
+ def forward_attention(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
166
+ """Returns logits and attention weights."""
167
+ # This logic is now shared and defined in one place
168
+ x, attention_weights = self.attention(x)
169
+ x = self.mlp(x)
170
+ logits = self.output_layer(x)
171
+ return logits, attention_weights
172
+
173
+
150
174
  class MultilayerPerceptron(_BaseMLP):
151
175
  """
152
176
  Creates a versatile Multilayer Perceptron (MLP) for regression or classification tasks.
@@ -184,7 +208,7 @@ class MultilayerPerceptron(_BaseMLP):
184
208
  return self._repr_helper(name="MultilayerPerceptron", mlp_layers=layer_sizes)
185
209
 
186
210
 
187
- class AttentionMLP(_BaseMLP):
211
+ class AttentionMLP(_BaseAttention):
188
212
  """
189
213
  A Multilayer Perceptron (MLP) that incorporates an Attention layer to dynamically weigh input features.
190
214
 
@@ -205,25 +229,7 @@ class AttentionMLP(_BaseMLP):
205
229
  super().__init__(in_features, out_targets, hidden_layers, drop_out)
206
230
  # Attention
207
231
  self.attention = _AttentionLayer(in_features)
208
-
209
- def forward(self, x: torch.Tensor) -> torch.Tensor:
210
- """
211
- Defines the standard forward pass.
212
- """
213
- logits, _attention_weights = self.forward_attention(x)
214
- return logits
215
-
216
- def forward_attention(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
217
- """
218
- Returns logits and attention weights
219
- """
220
- # The attention layer returns the processed x and the weights
221
- x, attention_weights = self.attention(x)
222
-
223
- # Pass the attention-modified tensor through the MLP
224
- logits = self.mlp(x)
225
-
226
- return logits, attention_weights
232
+ self.has_interpretable_attention = True
227
233
 
228
234
  def __repr__(self) -> str:
229
235
  """Returns the developer-friendly string representation of the model."""
@@ -238,7 +244,7 @@ class AttentionMLP(_BaseMLP):
238
244
  return self._repr_helper(name="AttentionMLP", mlp_layers=arch)
239
245
 
240
246
 
241
- class MultiHeadAttentionMLP(_BaseMLP):
247
+ class MultiHeadAttentionMLP(_BaseAttention):
242
248
  """
243
249
  An MLP that incorporates a standard `nn.MultiheadAttention` layer to process
244
250
  the input features.
@@ -267,24 +273,6 @@ class MultiHeadAttentionMLP(_BaseMLP):
267
273
  dropout=attention_dropout
268
274
  )
269
275
 
270
- def forward(self, x: torch.Tensor) -> torch.Tensor:
271
- """Defines the standard forward pass of the model."""
272
- logits, _attention_weights = self.forward_attention(x)
273
- return logits
274
-
275
- def forward_attention(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
276
- """
277
- Returns logits and attention weights.
278
- """
279
- # The attention layer returns the processed x and the weights
280
- x, attention_weights = self.attention(x)
281
-
282
- # Pass the attention-modified tensor through the MLP and prediction head
283
- x = self.mlp(x)
284
- logits = self.output_layer(x)
285
-
286
- return logits, attention_weights
287
-
288
276
  def get_architecture_config(self) -> Dict[str, Any]:
289
277
  """Returns the full configuration of the model."""
290
278
  config = super().get_architecture_config()
ml_tools/ML_scaler.py CHANGED
@@ -164,7 +164,7 @@ class PytorchScaler:
164
164
  }
165
165
  torch.save(state, path_obj)
166
166
  if verbose:
167
- _LOGGER.info(f"PytorchScaler state saved to '{path_obj.name}'.")
167
+ _LOGGER.info(f"PytorchScaler state saved as '{path_obj.name}'.")
168
168
 
169
169
  @staticmethod
170
170
  def load(filepath: Union[str, Path], verbose: bool=True) -> 'PytorchScaler':
ml_tools/ML_trainer.py CHANGED
@@ -472,23 +472,30 @@ class MLTrainer:
472
472
 
473
473
  yield attention_weights
474
474
 
475
- def explain_attention(self, save_dir: Union[str, Path], feature_names: Optional[List[str]], explain_dataset: Optional[Dataset] = None):
475
+ def explain_attention(self, save_dir: Union[str, Path],
476
+ feature_names: Optional[List[str]],
477
+ explain_dataset: Optional[Dataset] = None,
478
+ plot_n_features: int = 10):
476
479
  """
477
480
  Generates and saves a feature importance plot based on attention weights.
478
481
 
479
- This method only works for models with a `forward_attention` method.
482
+ This method only works for models with models with 'has_interpretable_attention'.
480
483
 
481
484
  Args:
482
485
  save_dir (str | Path): Directory to save the plot and summary data.
483
- feature_names (List[str] | None): Names for the features for plot labeling.
486
+ feature_names (List[str] | None): Names for the features for plot labeling. If not given, generic names will be used.
484
487
  explain_dataset (Dataset, optional): A specific dataset to explain. If None, the trainer's test dataset is used.
488
+ plot_n_features (int): Number of top features to plot.
485
489
  """
486
490
 
487
491
  print("\n--- Attention Analysis ---")
488
492
 
489
493
  # --- Step 1: Check if the model supports this explanation ---
490
- if not hasattr(self.model, 'forward_attention'):
491
- _LOGGER.error("Model does not have a `forward_attention` method. Skipping attention explanation.")
494
+ if not getattr(self.model, 'has_interpretable_attention', False):
495
+ _LOGGER.warning(
496
+ "Model is not flagged for interpretable attention analysis. "
497
+ "Skipping. This is the correct behavior for models like MultiHeadAttentionMLP."
498
+ )
492
499
  return
493
500
 
494
501
  # --- Step 2: Set up the dataloader ---
@@ -514,7 +521,8 @@ class MLTrainer:
514
521
  plot_attention_importance(
515
522
  weights=all_weights,
516
523
  feature_names=feature_names,
517
- save_dir=save_dir
524
+ save_dir=save_dir,
525
+ top_n=plot_n_features
518
526
  )
519
527
  else:
520
528
  _LOGGER.error("No attention weights were collected from the model.")
ml_tools/SQL.py CHANGED
@@ -4,7 +4,7 @@ from pathlib import Path
4
4
  from typing import Union, Dict, Any, Optional, List, Literal
5
5
  from ._logger import _LOGGER
6
6
  from ._script_info import _script_info
7
- from .path_manager import make_fullpath
7
+ from .path_manager import make_fullpath, sanitize_filename
8
8
 
9
9
 
10
10
  __all__ = [
@@ -94,11 +94,13 @@ class DatabaseManager:
94
94
  if not self.cursor:
95
95
  _LOGGER.error("Database connection is not open.")
96
96
  raise sqlite3.Error()
97
+
98
+ sanitized_table_name = sanitize_filename(table_name)
97
99
 
98
100
  columns_def = ", ".join([f'"{col_name}" {col_type}' for col_name, col_type in schema.items()])
99
101
  exists_clause = "IF NOT EXISTS" if if_not_exists else ""
100
102
 
101
- query = f"CREATE TABLE {exists_clause} {table_name} ({columns_def})"
103
+ query = f"CREATE TABLE {exists_clause} {sanitized_table_name} ({columns_def})"
102
104
 
103
105
  _LOGGER.info(f"➡️ Executing: {query}")
104
106
  self.cursor.execute(query)
@@ -25,6 +25,7 @@ from typing import Union, Optional, Literal
25
25
  from .path_manager import sanitize_filename, make_fullpath
26
26
  from ._script_info import _script_info
27
27
  from ._logger import _LOGGER
28
+ from .keys import SHAPKeys
28
29
 
29
30
 
30
31
  __all__ = [
@@ -472,7 +473,7 @@ def get_shap_values(
472
473
  save_dir: Directory to save visualizations.
473
474
  """
474
475
  sanitized_target_name = sanitize_filename(target_name)
475
- global_save_path = make_fullpath(save_dir, make=True)
476
+ global_save_path = make_fullpath(save_dir, make=True, enforce="directory")
476
477
 
477
478
  def _apply_plot_style():
478
479
  styles = ['seaborn', 'seaborn-v0_8-darkgrid', 'seaborn-v0_8', 'default']
@@ -539,6 +540,15 @@ def get_shap_values(
539
540
  plot_type=plot_type,
540
541
  title=f"{model_name} - {target_name} (Class {class_name})"
541
542
  )
543
+
544
+ # Save the summary data for the current class
545
+ summary_save_path = global_save_path / f"SHAP_{sanitized_target_name}_{class_name}.csv"
546
+ _save_summary_csv(
547
+ shap_values_for_summary=class_shap,
548
+ feature_names=feature_names,
549
+ save_path=summary_save_path
550
+ )
551
+
542
552
  else:
543
553
  values = shap_values[1] if isinstance(shap_values, list) else shap_values
544
554
  for plot_type in ["bar", "dot"]:
@@ -549,6 +559,15 @@ def get_shap_values(
549
559
  plot_type=plot_type,
550
560
  title=f"{model_name} - {target_name}"
551
561
  )
562
+
563
+ # Save the summary data for the positive class
564
+ shap_summary_filename = SHAPKeys.SAVENAME + ".csv"
565
+ summary_save_path = global_save_path / shap_summary_filename
566
+ _save_summary_csv(
567
+ shap_values_for_summary=values,
568
+ feature_names=feature_names,
569
+ save_path=summary_save_path
570
+ )
552
571
 
553
572
  def _plot_for_regression(shap_values):
554
573
  for plot_type in ["bar", "dot"]:
@@ -559,6 +578,34 @@ def get_shap_values(
559
578
  plot_type=plot_type,
560
579
  title=f"{model_name} - {target_name}"
561
580
  )
581
+
582
+ # Save the summary data to a CSV file
583
+ shap_summary_filename = SHAPKeys.SAVENAME + ".csv"
584
+ summary_save_path = global_save_path / shap_summary_filename
585
+ _save_summary_csv(
586
+ shap_values_for_summary=shap_values,
587
+ feature_names=feature_names,
588
+ save_path=summary_save_path
589
+ )
590
+
591
+ def _save_summary_csv(shap_values_for_summary: np.ndarray, feature_names: list[str], save_path: Path):
592
+ """Calculates and saves the SHAP summary data to a CSV file."""
593
+ mean_abs_shap = np.abs(shap_values_for_summary).mean(axis=0)
594
+
595
+ # Create default feature names if none are provided
596
+ current_feature_names = feature_names
597
+ if current_feature_names is None:
598
+ current_feature_names = [f'feature_{i}' for i in range(len(mean_abs_shap))]
599
+
600
+ summary_df = pd.DataFrame({
601
+ SHAPKeys.FEATURE_COLUMN: feature_names,
602
+ SHAPKeys.SHAP_VALUE_COLUMN: mean_abs_shap
603
+ }).sort_values(SHAPKeys.SHAP_VALUE_COLUMN, ascending=False)
604
+
605
+ summary_df.to_csv(save_path, index=False)
606
+ # print(f"📝 SHAP summary data saved as '{save_path.name}'")
607
+
608
+
562
609
  #START_O
563
610
 
564
611
  explainer = shap.TreeExplainer(model)
ml_tools/keys.py CHANGED
@@ -61,6 +61,13 @@ class DatasetKeys:
61
61
  SCALER_PREFIX = "scaler_"
62
62
 
63
63
 
64
+ class SHAPKeys:
65
+ """Keys for SHAP functions"""
66
+ FEATURE_COLUMN = "feature"
67
+ SHAP_VALUE_COLUMN = "mean_abs_shap_value"
68
+ SAVENAME = "shap_summary"
69
+
70
+
64
71
  class _OneHotOtherPlaceholder:
65
72
  """Used internally by GUI_tools."""
66
73
  OTHER_GUI = "OTHER"
ml_tools/utilities.py CHANGED
@@ -9,7 +9,7 @@ from joblib.externals.loky.process_executor import TerminatedWorkerError
9
9
  from .path_manager import sanitize_filename, make_fullpath, list_csv_paths, list_files_by_extension, list_subdirectories
10
10
  from ._script_info import _script_info
11
11
  from ._logger import _LOGGER
12
- from .keys import DatasetKeys, PytorchModelArchitectureKeys, PytorchArtifactPathKeys
12
+ from .keys import DatasetKeys, PytorchModelArchitectureKeys, PytorchArtifactPathKeys, SHAPKeys
13
13
 
14
14
 
15
15
  # Keep track of available tools
@@ -26,7 +26,8 @@ __all__ = [
26
26
  "distribute_dataset_by_target",
27
27
  "train_dataset_orchestrator",
28
28
  "train_dataset_yielder",
29
- "find_model_artifacts"
29
+ "find_model_artifacts",
30
+ "select_features_by_shap"
30
31
  ]
31
32
 
32
33
 
@@ -34,6 +35,7 @@ __all__ = [
34
35
  @overload
35
36
  def load_dataframe(
36
37
  df_path: Union[str, Path],
38
+ use_columns: Optional[list[str]] = None,
37
39
  kind: Literal["pandas"] = "pandas",
38
40
  all_strings: bool = False,
39
41
  verbose: bool = True
@@ -44,7 +46,8 @@ def load_dataframe(
44
46
  @overload
45
47
  def load_dataframe(
46
48
  df_path: Union[str, Path],
47
- kind: Literal["polars"],
49
+ use_columns: Optional[list[str]] = None,
50
+ kind: Literal["polars"] = "polars",
48
51
  all_strings: bool = False,
49
52
  verbose: bool = True
50
53
  ) -> Tuple[pl.DataFrame, str]:
@@ -52,6 +55,7 @@ def load_dataframe(
52
55
 
53
56
  def load_dataframe(
54
57
  df_path: Union[str, Path],
58
+ use_columns: Optional[list[str]] = None,
55
59
  kind: Literal["pandas", "polars"] = "pandas",
56
60
  all_strings: bool = False,
57
61
  verbose: bool = True
@@ -60,11 +64,13 @@ def load_dataframe(
60
64
  Load a CSV file into a DataFrame and extract its base name.
61
65
 
62
66
  Can load data as either a pandas or a polars DataFrame. Allows for loading all
63
- columns as string types to prevent type inference errors.
67
+ columns or a subset of columns as string types to prevent type inference errors.
64
68
 
65
69
  Args:
66
70
  df_path (str, Path):
67
71
  The path to the CSV file.
72
+ use_columns (list[str] | None):
73
+ If provided, only these columns will be loaded from the CSV.
68
74
  kind ("pandas", "polars"):
69
75
  The type of DataFrame to load. Defaults to "pandas".
70
76
  all_strings (bool):
@@ -78,28 +84,43 @@ def load_dataframe(
78
84
 
79
85
  Raises:
80
86
  FileNotFoundError: If the file does not exist at the given path.
81
- ValueError: If the DataFrame is empty or an invalid 'kind' is provided.
87
+ ValueError: If the DataFrame is empty, an invalid 'kind' is provided, or a column in 'use_columns' is not found in the file.
82
88
  """
83
89
  path = make_fullpath(df_path)
84
90
 
85
91
  df_name = path.stem
86
92
 
87
- if kind == "pandas":
88
- if all_strings:
89
- df = pd.read_csv(path, encoding='utf-8', dtype=str)
90
- else:
91
- df = pd.read_csv(path, encoding='utf-8')
92
-
93
- elif kind == "polars":
94
- if all_strings:
95
- df = pl.read_csv(path, infer_schema=False)
93
+ try:
94
+ if kind == "pandas":
95
+ pd_kwargs: dict[str,Any]
96
+ pd_kwargs = {'encoding': 'utf-8'}
97
+ if use_columns:
98
+ pd_kwargs['usecols'] = use_columns
99
+ if all_strings:
100
+ pd_kwargs['dtype'] = str
101
+
102
+ df = pd.read_csv(path, **pd_kwargs)
103
+
104
+ elif kind == "polars":
105
+ pl_kwargs: dict[str,Any]
106
+ pl_kwargs = {}
107
+ if use_columns:
108
+ pl_kwargs['columns'] = use_columns
109
+
110
+ if all_strings:
111
+ pl_kwargs['infer_schema'] = False
112
+ else:
113
+ pl_kwargs['infer_schema_length'] = 1000
114
+
115
+ df = pl.read_csv(path, **pl_kwargs)
116
+
96
117
  else:
97
- # Default behavior: infer the schema.
98
- df = pl.read_csv(path, infer_schema_length=1000)
118
+ _LOGGER.error(f"Invalid kind '{kind}'. Must be one of 'pandas' or 'polars'.")
119
+ raise ValueError()
99
120
 
100
- else:
101
- _LOGGER.error(f"Invalid kind '{kind}'. Must be one of 'pandas' or 'polars'.")
102
- raise ValueError()
121
+ except (ValueError, pl.exceptions.ColumnNotFoundError) as e:
122
+ _LOGGER.error(f"Failed to load '{df_name}'. A specified column may not exist in the file.")
123
+ raise e
103
124
 
104
125
  # This check works for both pandas and polars DataFrames
105
126
  if df.shape[0] == 0:
@@ -111,7 +132,6 @@ def load_dataframe(
111
132
 
112
133
  return df, df_name # type: ignore
113
134
 
114
-
115
135
  def yield_dataframes_from_dir(datasets_dir: Union[str,Path], verbose: bool=True):
116
136
  """
117
137
  Iterates over all CSV files in a given directory, loading each into a Pandas DataFrame.
@@ -683,5 +703,84 @@ def find_model_artifacts(target_directory: Union[str,Path], load_scaler: bool, v
683
703
  return all_artifacts
684
704
 
685
705
 
706
+ def select_features_by_shap(
707
+ root_directory: Union[str, Path],
708
+ shap_threshold: float = 1.0,
709
+ verbose: bool = True) -> list[str]:
710
+ """
711
+ Scans subdirectories to find SHAP summary CSVs, then extracts feature
712
+ names whose mean absolute SHAP value meets a specified threshold.
713
+
714
+ This function is useful for automated feature selection based on feature
715
+ importance scores aggregated from multiple models.
716
+
717
+ Args:
718
+ root_directory (Union[str, Path]):
719
+ The path to the root directory that contains model subdirectories.
720
+ shap_threshold (float):
721
+ The minimum mean absolute SHAP value for a feature to be included
722
+ in the final list.
723
+
724
+ Returns:
725
+ list[str]:
726
+ A single, sorted list of unique feature names that meet the
727
+ threshold criteria across all found files.
728
+ """
729
+ if verbose:
730
+ _LOGGER.info(f"Starting feature selection with SHAP threshold >= {shap_threshold}")
731
+ root_path = make_fullpath(root_directory, enforce="directory")
732
+
733
+ # --- Step 2: Directory and File Discovery ---
734
+ subdirectories = list_subdirectories(root_dir=root_path, verbose=False)
735
+
736
+ shap_filename = SHAPKeys.SAVENAME + ".csv"
737
+
738
+ valid_csv_paths = []
739
+ for dir_name, dir_path in subdirectories.items():
740
+ expected_path = dir_path / shap_filename
741
+ if expected_path.is_file():
742
+ valid_csv_paths.append(expected_path)
743
+ else:
744
+ _LOGGER.warning(f"No '{shap_filename}' found in subdirectory '{dir_name}'.")
745
+
746
+ if not valid_csv_paths:
747
+ _LOGGER.error(f"Process halted: No '{shap_filename}' files were found in any subdirectory.")
748
+ return []
749
+
750
+ if verbose:
751
+ _LOGGER.info(f"Found {len(valid_csv_paths)} SHAP summary files to process.")
752
+
753
+ # --- Step 3: Data Processing and Feature Extraction ---
754
+ master_feature_set = set()
755
+ for csv_path in valid_csv_paths:
756
+ try:
757
+ df, _ = load_dataframe(csv_path, kind="pandas", verbose=False)
758
+
759
+ # Validate required columns
760
+ required_cols = {SHAPKeys.FEATURE_COLUMN, SHAPKeys.SHAP_VALUE_COLUMN}
761
+ if not required_cols.issubset(df.columns):
762
+ _LOGGER.warning(f"Skipping '{csv_path}': missing required columns.")
763
+ continue
764
+
765
+ # Filter by threshold and extract features
766
+ filtered_df = df[df[SHAPKeys.SHAP_VALUE_COLUMN] >= shap_threshold]
767
+ features = filtered_df[SHAPKeys.FEATURE_COLUMN].tolist()
768
+ master_feature_set.update(features)
769
+
770
+ except (ValueError, pd.errors.EmptyDataError):
771
+ _LOGGER.warning(f"Skipping '{csv_path}' because it is empty or malformed.")
772
+ continue
773
+ except Exception as e:
774
+ _LOGGER.error(f"An unexpected error occurred while processing '{csv_path}': {e}")
775
+ continue
776
+
777
+ # --- Step 4: Finalize and Return ---
778
+ final_features = sorted(list(master_feature_set))
779
+ if verbose:
780
+ _LOGGER.info(f"Selected {len(final_features)} unique features across all files.")
781
+
782
+ return final_features
783
+
784
+
686
785
  def info():
687
786
  _script_info(__all__)