dragon-ml-toolbox 10.9.0__tar.gz → 10.10.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dragon-ml-toolbox might be problematic. Click here for more details.

Files changed (41) hide show
  1. {dragon_ml_toolbox-10.9.0/dragon_ml_toolbox.egg-info → dragon_ml_toolbox-10.10.0}/PKG-INFO +1 -1
  2. {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0/dragon_ml_toolbox.egg-info}/PKG-INFO +1 -1
  3. {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/ml_tools/ML_datasetmaster.py +1 -1
  4. {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/ml_tools/ML_evaluation.py +6 -6
  5. {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/ml_tools/ML_models.py +28 -40
  6. {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/ml_tools/ML_scaler.py +1 -1
  7. {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/ml_tools/ML_trainer.py +14 -6
  8. {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/pyproject.toml +1 -1
  9. {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/LICENSE +0 -0
  10. {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/LICENSE-THIRD-PARTY.md +0 -0
  11. {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/README.md +0 -0
  12. {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/dragon_ml_toolbox.egg-info/SOURCES.txt +0 -0
  13. {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/dragon_ml_toolbox.egg-info/dependency_links.txt +0 -0
  14. {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/dragon_ml_toolbox.egg-info/requires.txt +0 -0
  15. {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/dragon_ml_toolbox.egg-info/top_level.txt +0 -0
  16. {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/ml_tools/ETL_cleaning.py +0 -0
  17. {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/ml_tools/ETL_engineering.py +0 -0
  18. {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/ml_tools/GUI_tools.py +0 -0
  19. {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/ml_tools/MICE_imputation.py +0 -0
  20. {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/ml_tools/ML_callbacks.py +0 -0
  21. {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/ml_tools/ML_evaluation_multi.py +0 -0
  22. {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/ml_tools/ML_inference.py +0 -0
  23. {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/ml_tools/ML_optimization.py +0 -0
  24. {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/ml_tools/PSO_optimization.py +0 -0
  25. {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/ml_tools/RNN_forecast.py +0 -0
  26. {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/ml_tools/SQL.py +0 -0
  27. {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/ml_tools/VIF_factor.py +0 -0
  28. {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/ml_tools/__init__.py +0 -0
  29. {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/ml_tools/_logger.py +0 -0
  30. {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/ml_tools/_script_info.py +0 -0
  31. {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/ml_tools/custom_logger.py +0 -0
  32. {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/ml_tools/data_exploration.py +0 -0
  33. {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/ml_tools/ensemble_evaluation.py +0 -0
  34. {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/ml_tools/ensemble_inference.py +0 -0
  35. {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/ml_tools/ensemble_learning.py +0 -0
  36. {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/ml_tools/handle_excel.py +0 -0
  37. {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/ml_tools/keys.py +0 -0
  38. {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/ml_tools/optimization_tools.py +0 -0
  39. {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/ml_tools/path_manager.py +0 -0
  40. {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/ml_tools/utilities.py +0 -0
  41. {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dragon-ml-toolbox
3
- Version: 10.9.0
3
+ Version: 10.10.0
4
4
  Summary: A collection of tools for data science and machine learning projects.
5
5
  Author-email: Karl Loza <luigiloza@gmail.com>
6
6
  License-Expression: MIT
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dragon-ml-toolbox
3
- Version: 10.9.0
3
+ Version: 10.10.0
4
4
  Summary: A collection of tools for data science and machine learning projects.
5
5
  Author-email: Karl Loza <luigiloza@gmail.com>
6
6
  License-Expression: MIT
@@ -200,7 +200,7 @@ class _BaseDatasetMaker(ABC):
200
200
  filepath = save_path / filename
201
201
  self.scaler.save(filepath, verbose=False)
202
202
  if verbose:
203
- _LOGGER.info(f"Scaler for dataset '{self.id}' saved to '{filepath.name}'.")
203
+ _LOGGER.info(f"Scaler for dataset '{self.id}' saved as '{filepath.name}'.")
204
204
 
205
205
 
206
206
  # Single target dataset
@@ -353,7 +353,7 @@ def shap_summary_plot(model,
353
353
  plt.ion()
354
354
 
355
355
 
356
- def plot_attention_importance(weights: List[torch.Tensor], feature_names: Optional[List[str]], save_dir: Union[str, Path]):
356
+ def plot_attention_importance(weights: List[torch.Tensor], feature_names: Optional[List[str]], save_dir: Union[str, Path], top_n: int = 10):
357
357
  """
358
358
  Aggregates attention weights and plots global feature importance.
359
359
 
@@ -364,6 +364,7 @@ def plot_attention_importance(weights: List[torch.Tensor], feature_names: Option
364
364
  weights (List[torch.Tensor]): A list of attention weight tensors from each batch.
365
365
  feature_names (List[str] | None): Names of the features for plot labeling.
366
366
  save_dir (str | Path): Directory to save the plot and summary CSV.
367
+ top_n (int): The number of top features to display in the plot.
367
368
  """
368
369
  if not weights:
369
370
  _LOGGER.error("Attention weights list is empty. Skipping importance plot.")
@@ -392,11 +393,10 @@ def plot_attention_importance(weights: List[torch.Tensor], feature_names: Option
392
393
  summary_df.to_csv(summary_path, index=False)
393
394
  _LOGGER.info(f"📝 Attention summary data saved as '{summary_path.name}'")
394
395
 
395
- # --- Step 3: Create and save the plot ---
396
- plt.figure(figsize=(10, 8), dpi=100)
396
+ # --- Step 3: Create and save the plot for top N features ---
397
+ plot_df = summary_df.head(top_n).sort_values('mean_attention', ascending=True)
397
398
 
398
- # Sort for plotting
399
- plot_df = summary_df.sort_values('mean_attention', ascending=True)
399
+ plt.figure(figsize=(10, 8), dpi=100)
400
400
 
401
401
  # Create horizontal bar plot with error bars
402
402
  plt.barh(
@@ -410,7 +410,7 @@ def plot_attention_importance(weights: List[torch.Tensor], feature_names: Option
410
410
  color='cornflowerblue'
411
411
  )
412
412
 
413
- plt.title('Global Feature Importance')
413
+ plt.title('Top Features by Attention')
414
414
  plt.xlabel('Average Attention Weight')
415
415
  plt.ylabel('Feature')
416
416
  plt.grid(axis='x', linestyle='--', alpha=0.6)
@@ -43,7 +43,7 @@ class _ArchitectureHandlerMixin:
43
43
  json.dump(config, f, indent=4)
44
44
 
45
45
  if verbose:
46
- _LOGGER.info(f"Architecture for '{self.__class__.__name__}' saved to '{path_dir.name}'")
46
+ _LOGGER.info(f"Architecture for '{self.__class__.__name__}' saved as '{full_path.name}'")
47
47
 
48
48
  @classmethod
49
49
  def load(cls: type, file_or_dir: Union[str, Path], verbose: bool = True) -> nn.Module:
@@ -147,6 +147,30 @@ class _BaseMLP(nn.Module, _ArchitectureHandlerMixin):
147
147
  return f"{name}(arch: {arch_str})"
148
148
 
149
149
 
150
+ class _BaseAttention(_BaseMLP):
151
+ """
152
+ Abstract base class for MLP models that incorporate an attention mechanism
153
+ before the main MLP layers.
154
+ """
155
+ def __init__(self, *args, **kwargs):
156
+ super().__init__(*args, **kwargs)
157
+ # By default, models inheriting this do not have the flag.
158
+ self.has_interpretable_attention = False
159
+
160
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
161
+ """Defines the standard forward pass."""
162
+ logits, _attention_weights = self.forward_attention(x)
163
+ return logits
164
+
165
+ def forward_attention(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
166
+ """Returns logits and attention weights."""
167
+ # This logic is now shared and defined in one place
168
+ x, attention_weights = self.attention(x)
169
+ x = self.mlp(x)
170
+ logits = self.output_layer(x)
171
+ return logits, attention_weights
172
+
173
+
150
174
  class MultilayerPerceptron(_BaseMLP):
151
175
  """
152
176
  Creates a versatile Multilayer Perceptron (MLP) for regression or classification tasks.
@@ -184,7 +208,7 @@ class MultilayerPerceptron(_BaseMLP):
184
208
  return self._repr_helper(name="MultilayerPerceptron", mlp_layers=layer_sizes)
185
209
 
186
210
 
187
- class AttentionMLP(_BaseMLP):
211
+ class AttentionMLP(_BaseAttention):
188
212
  """
189
213
  A Multilayer Perceptron (MLP) that incorporates an Attention layer to dynamically weigh input features.
190
214
 
@@ -205,25 +229,7 @@ class AttentionMLP(_BaseMLP):
205
229
  super().__init__(in_features, out_targets, hidden_layers, drop_out)
206
230
  # Attention
207
231
  self.attention = _AttentionLayer(in_features)
208
-
209
- def forward(self, x: torch.Tensor) -> torch.Tensor:
210
- """
211
- Defines the standard forward pass.
212
- """
213
- logits, _attention_weights = self.forward_attention(x)
214
- return logits
215
-
216
- def forward_attention(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
217
- """
218
- Returns logits and attention weights
219
- """
220
- # The attention layer returns the processed x and the weights
221
- x, attention_weights = self.attention(x)
222
-
223
- # Pass the attention-modified tensor through the MLP
224
- logits = self.mlp(x)
225
-
226
- return logits, attention_weights
232
+ self.has_interpretable_attention = True
227
233
 
228
234
  def __repr__(self) -> str:
229
235
  """Returns the developer-friendly string representation of the model."""
@@ -238,7 +244,7 @@ class AttentionMLP(_BaseMLP):
238
244
  return self._repr_helper(name="AttentionMLP", mlp_layers=arch)
239
245
 
240
246
 
241
- class MultiHeadAttentionMLP(_BaseMLP):
247
+ class MultiHeadAttentionMLP(_BaseAttention):
242
248
  """
243
249
  An MLP that incorporates a standard `nn.MultiheadAttention` layer to process
244
250
  the input features.
@@ -267,24 +273,6 @@ class MultiHeadAttentionMLP(_BaseMLP):
267
273
  dropout=attention_dropout
268
274
  )
269
275
 
270
- def forward(self, x: torch.Tensor) -> torch.Tensor:
271
- """Defines the standard forward pass of the model."""
272
- logits, _attention_weights = self.forward_attention(x)
273
- return logits
274
-
275
- def forward_attention(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
276
- """
277
- Returns logits and attention weights.
278
- """
279
- # The attention layer returns the processed x and the weights
280
- x, attention_weights = self.attention(x)
281
-
282
- # Pass the attention-modified tensor through the MLP and prediction head
283
- x = self.mlp(x)
284
- logits = self.output_layer(x)
285
-
286
- return logits, attention_weights
287
-
288
276
  def get_architecture_config(self) -> Dict[str, Any]:
289
277
  """Returns the full configuration of the model."""
290
278
  config = super().get_architecture_config()
@@ -164,7 +164,7 @@ class PytorchScaler:
164
164
  }
165
165
  torch.save(state, path_obj)
166
166
  if verbose:
167
- _LOGGER.info(f"PytorchScaler state saved to '{path_obj.name}'.")
167
+ _LOGGER.info(f"PytorchScaler state saved as '{path_obj.name}'.")
168
168
 
169
169
  @staticmethod
170
170
  def load(filepath: Union[str, Path], verbose: bool=True) -> 'PytorchScaler':
@@ -472,23 +472,30 @@ class MLTrainer:
472
472
 
473
473
  yield attention_weights
474
474
 
475
- def explain_attention(self, save_dir: Union[str, Path], feature_names: Optional[List[str]], explain_dataset: Optional[Dataset] = None):
475
+ def explain_attention(self, save_dir: Union[str, Path],
476
+ feature_names: Optional[List[str]],
477
+ explain_dataset: Optional[Dataset] = None,
478
+ plot_n_features: int = 10):
476
479
  """
477
480
  Generates and saves a feature importance plot based on attention weights.
478
481
 
479
- This method only works for models with a `forward_attention` method.
482
+ This method only works for models with models with 'has_interpretable_attention'.
480
483
 
481
484
  Args:
482
485
  save_dir (str | Path): Directory to save the plot and summary data.
483
- feature_names (List[str] | None): Names for the features for plot labeling.
486
+ feature_names (List[str] | None): Names for the features for plot labeling. If not given, generic names will be used.
484
487
  explain_dataset (Dataset, optional): A specific dataset to explain. If None, the trainer's test dataset is used.
488
+ plot_n_features (int): Number of top features to plot.
485
489
  """
486
490
 
487
491
  print("\n--- Attention Analysis ---")
488
492
 
489
493
  # --- Step 1: Check if the model supports this explanation ---
490
- if not hasattr(self.model, 'forward_attention'):
491
- _LOGGER.error("Model does not have a `forward_attention` method. Skipping attention explanation.")
494
+ if not getattr(self.model, 'has_interpretable_attention', False):
495
+ _LOGGER.warning(
496
+ "Model is not flagged for interpretable attention analysis. "
497
+ "Skipping. This is the correct behavior for models like MultiHeadAttentionMLP."
498
+ )
492
499
  return
493
500
 
494
501
  # --- Step 2: Set up the dataloader ---
@@ -514,7 +521,8 @@ class MLTrainer:
514
521
  plot_attention_importance(
515
522
  weights=all_weights,
516
523
  feature_names=feature_names,
517
- save_dir=save_dir
524
+ save_dir=save_dir,
525
+ top_n=plot_n_features
518
526
  )
519
527
  else:
520
528
  _LOGGER.error("No attention weights were collected from the model.")
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "dragon-ml-toolbox"
3
- version = "10.9.0"
3
+ version = "10.10.0"
4
4
  description = "A collection of tools for data science and machine learning projects."
5
5
  authors = [
6
6
  { name = "Karl Loza", email = "luigiloza@gmail.com" }