dragon-ml-toolbox 10.9.0__tar.gz → 10.10.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dragon-ml-toolbox might be problematic. Click here for more details.
- {dragon_ml_toolbox-10.9.0/dragon_ml_toolbox.egg-info → dragon_ml_toolbox-10.10.0}/PKG-INFO +1 -1
- {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0/dragon_ml_toolbox.egg-info}/PKG-INFO +1 -1
- {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/ml_tools/ML_datasetmaster.py +1 -1
- {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/ml_tools/ML_evaluation.py +6 -6
- {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/ml_tools/ML_models.py +28 -40
- {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/ml_tools/ML_scaler.py +1 -1
- {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/ml_tools/ML_trainer.py +14 -6
- {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/pyproject.toml +1 -1
- {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/LICENSE +0 -0
- {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/README.md +0 -0
- {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/dragon_ml_toolbox.egg-info/SOURCES.txt +0 -0
- {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/dragon_ml_toolbox.egg-info/dependency_links.txt +0 -0
- {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/dragon_ml_toolbox.egg-info/requires.txt +0 -0
- {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/dragon_ml_toolbox.egg-info/top_level.txt +0 -0
- {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/ml_tools/ETL_cleaning.py +0 -0
- {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/ml_tools/ETL_engineering.py +0 -0
- {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/ml_tools/GUI_tools.py +0 -0
- {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/ml_tools/MICE_imputation.py +0 -0
- {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/ml_tools/ML_callbacks.py +0 -0
- {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/ml_tools/ML_evaluation_multi.py +0 -0
- {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/ml_tools/ML_inference.py +0 -0
- {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/ml_tools/ML_optimization.py +0 -0
- {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/ml_tools/PSO_optimization.py +0 -0
- {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/ml_tools/RNN_forecast.py +0 -0
- {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/ml_tools/SQL.py +0 -0
- {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/ml_tools/VIF_factor.py +0 -0
- {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/ml_tools/__init__.py +0 -0
- {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/ml_tools/_logger.py +0 -0
- {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/ml_tools/_script_info.py +0 -0
- {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/ml_tools/custom_logger.py +0 -0
- {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/ml_tools/data_exploration.py +0 -0
- {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/ml_tools/ensemble_evaluation.py +0 -0
- {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/ml_tools/ensemble_inference.py +0 -0
- {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/ml_tools/ensemble_learning.py +0 -0
- {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/ml_tools/handle_excel.py +0 -0
- {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/ml_tools/keys.py +0 -0
- {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/ml_tools/optimization_tools.py +0 -0
- {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/ml_tools/path_manager.py +0 -0
- {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/ml_tools/utilities.py +0 -0
- {dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/setup.cfg +0 -0
|
@@ -200,7 +200,7 @@ class _BaseDatasetMaker(ABC):
|
|
|
200
200
|
filepath = save_path / filename
|
|
201
201
|
self.scaler.save(filepath, verbose=False)
|
|
202
202
|
if verbose:
|
|
203
|
-
_LOGGER.info(f"Scaler for dataset '{self.id}' saved
|
|
203
|
+
_LOGGER.info(f"Scaler for dataset '{self.id}' saved as '{filepath.name}'.")
|
|
204
204
|
|
|
205
205
|
|
|
206
206
|
# Single target dataset
|
|
@@ -353,7 +353,7 @@ def shap_summary_plot(model,
|
|
|
353
353
|
plt.ion()
|
|
354
354
|
|
|
355
355
|
|
|
356
|
-
def plot_attention_importance(weights: List[torch.Tensor], feature_names: Optional[List[str]], save_dir: Union[str, Path]):
|
|
356
|
+
def plot_attention_importance(weights: List[torch.Tensor], feature_names: Optional[List[str]], save_dir: Union[str, Path], top_n: int = 10):
|
|
357
357
|
"""
|
|
358
358
|
Aggregates attention weights and plots global feature importance.
|
|
359
359
|
|
|
@@ -364,6 +364,7 @@ def plot_attention_importance(weights: List[torch.Tensor], feature_names: Option
|
|
|
364
364
|
weights (List[torch.Tensor]): A list of attention weight tensors from each batch.
|
|
365
365
|
feature_names (List[str] | None): Names of the features for plot labeling.
|
|
366
366
|
save_dir (str | Path): Directory to save the plot and summary CSV.
|
|
367
|
+
top_n (int): The number of top features to display in the plot.
|
|
367
368
|
"""
|
|
368
369
|
if not weights:
|
|
369
370
|
_LOGGER.error("Attention weights list is empty. Skipping importance plot.")
|
|
@@ -392,11 +393,10 @@ def plot_attention_importance(weights: List[torch.Tensor], feature_names: Option
|
|
|
392
393
|
summary_df.to_csv(summary_path, index=False)
|
|
393
394
|
_LOGGER.info(f"📝 Attention summary data saved as '{summary_path.name}'")
|
|
394
395
|
|
|
395
|
-
# --- Step 3: Create and save the plot ---
|
|
396
|
-
|
|
396
|
+
# --- Step 3: Create and save the plot for top N features ---
|
|
397
|
+
plot_df = summary_df.head(top_n).sort_values('mean_attention', ascending=True)
|
|
397
398
|
|
|
398
|
-
|
|
399
|
-
plot_df = summary_df.sort_values('mean_attention', ascending=True)
|
|
399
|
+
plt.figure(figsize=(10, 8), dpi=100)
|
|
400
400
|
|
|
401
401
|
# Create horizontal bar plot with error bars
|
|
402
402
|
plt.barh(
|
|
@@ -410,7 +410,7 @@ def plot_attention_importance(weights: List[torch.Tensor], feature_names: Option
|
|
|
410
410
|
color='cornflowerblue'
|
|
411
411
|
)
|
|
412
412
|
|
|
413
|
-
plt.title('
|
|
413
|
+
plt.title('Top Features by Attention')
|
|
414
414
|
plt.xlabel('Average Attention Weight')
|
|
415
415
|
plt.ylabel('Feature')
|
|
416
416
|
plt.grid(axis='x', linestyle='--', alpha=0.6)
|
|
@@ -43,7 +43,7 @@ class _ArchitectureHandlerMixin:
|
|
|
43
43
|
json.dump(config, f, indent=4)
|
|
44
44
|
|
|
45
45
|
if verbose:
|
|
46
|
-
_LOGGER.info(f"Architecture for '{self.__class__.__name__}' saved
|
|
46
|
+
_LOGGER.info(f"Architecture for '{self.__class__.__name__}' saved as '{full_path.name}'")
|
|
47
47
|
|
|
48
48
|
@classmethod
|
|
49
49
|
def load(cls: type, file_or_dir: Union[str, Path], verbose: bool = True) -> nn.Module:
|
|
@@ -147,6 +147,30 @@ class _BaseMLP(nn.Module, _ArchitectureHandlerMixin):
|
|
|
147
147
|
return f"{name}(arch: {arch_str})"
|
|
148
148
|
|
|
149
149
|
|
|
150
|
+
class _BaseAttention(_BaseMLP):
|
|
151
|
+
"""
|
|
152
|
+
Abstract base class for MLP models that incorporate an attention mechanism
|
|
153
|
+
before the main MLP layers.
|
|
154
|
+
"""
|
|
155
|
+
def __init__(self, *args, **kwargs):
|
|
156
|
+
super().__init__(*args, **kwargs)
|
|
157
|
+
# By default, models inheriting this do not have the flag.
|
|
158
|
+
self.has_interpretable_attention = False
|
|
159
|
+
|
|
160
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
161
|
+
"""Defines the standard forward pass."""
|
|
162
|
+
logits, _attention_weights = self.forward_attention(x)
|
|
163
|
+
return logits
|
|
164
|
+
|
|
165
|
+
def forward_attention(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
166
|
+
"""Returns logits and attention weights."""
|
|
167
|
+
# This logic is now shared and defined in one place
|
|
168
|
+
x, attention_weights = self.attention(x)
|
|
169
|
+
x = self.mlp(x)
|
|
170
|
+
logits = self.output_layer(x)
|
|
171
|
+
return logits, attention_weights
|
|
172
|
+
|
|
173
|
+
|
|
150
174
|
class MultilayerPerceptron(_BaseMLP):
|
|
151
175
|
"""
|
|
152
176
|
Creates a versatile Multilayer Perceptron (MLP) for regression or classification tasks.
|
|
@@ -184,7 +208,7 @@ class MultilayerPerceptron(_BaseMLP):
|
|
|
184
208
|
return self._repr_helper(name="MultilayerPerceptron", mlp_layers=layer_sizes)
|
|
185
209
|
|
|
186
210
|
|
|
187
|
-
class AttentionMLP(
|
|
211
|
+
class AttentionMLP(_BaseAttention):
|
|
188
212
|
"""
|
|
189
213
|
A Multilayer Perceptron (MLP) that incorporates an Attention layer to dynamically weigh input features.
|
|
190
214
|
|
|
@@ -205,25 +229,7 @@ class AttentionMLP(_BaseMLP):
|
|
|
205
229
|
super().__init__(in_features, out_targets, hidden_layers, drop_out)
|
|
206
230
|
# Attention
|
|
207
231
|
self.attention = _AttentionLayer(in_features)
|
|
208
|
-
|
|
209
|
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
210
|
-
"""
|
|
211
|
-
Defines the standard forward pass.
|
|
212
|
-
"""
|
|
213
|
-
logits, _attention_weights = self.forward_attention(x)
|
|
214
|
-
return logits
|
|
215
|
-
|
|
216
|
-
def forward_attention(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
217
|
-
"""
|
|
218
|
-
Returns logits and attention weights
|
|
219
|
-
"""
|
|
220
|
-
# The attention layer returns the processed x and the weights
|
|
221
|
-
x, attention_weights = self.attention(x)
|
|
222
|
-
|
|
223
|
-
# Pass the attention-modified tensor through the MLP
|
|
224
|
-
logits = self.mlp(x)
|
|
225
|
-
|
|
226
|
-
return logits, attention_weights
|
|
232
|
+
self.has_interpretable_attention = True
|
|
227
233
|
|
|
228
234
|
def __repr__(self) -> str:
|
|
229
235
|
"""Returns the developer-friendly string representation of the model."""
|
|
@@ -238,7 +244,7 @@ class AttentionMLP(_BaseMLP):
|
|
|
238
244
|
return self._repr_helper(name="AttentionMLP", mlp_layers=arch)
|
|
239
245
|
|
|
240
246
|
|
|
241
|
-
class MultiHeadAttentionMLP(
|
|
247
|
+
class MultiHeadAttentionMLP(_BaseAttention):
|
|
242
248
|
"""
|
|
243
249
|
An MLP that incorporates a standard `nn.MultiheadAttention` layer to process
|
|
244
250
|
the input features.
|
|
@@ -267,24 +273,6 @@ class MultiHeadAttentionMLP(_BaseMLP):
|
|
|
267
273
|
dropout=attention_dropout
|
|
268
274
|
)
|
|
269
275
|
|
|
270
|
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
271
|
-
"""Defines the standard forward pass of the model."""
|
|
272
|
-
logits, _attention_weights = self.forward_attention(x)
|
|
273
|
-
return logits
|
|
274
|
-
|
|
275
|
-
def forward_attention(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
276
|
-
"""
|
|
277
|
-
Returns logits and attention weights.
|
|
278
|
-
"""
|
|
279
|
-
# The attention layer returns the processed x and the weights
|
|
280
|
-
x, attention_weights = self.attention(x)
|
|
281
|
-
|
|
282
|
-
# Pass the attention-modified tensor through the MLP and prediction head
|
|
283
|
-
x = self.mlp(x)
|
|
284
|
-
logits = self.output_layer(x)
|
|
285
|
-
|
|
286
|
-
return logits, attention_weights
|
|
287
|
-
|
|
288
276
|
def get_architecture_config(self) -> Dict[str, Any]:
|
|
289
277
|
"""Returns the full configuration of the model."""
|
|
290
278
|
config = super().get_architecture_config()
|
|
@@ -164,7 +164,7 @@ class PytorchScaler:
|
|
|
164
164
|
}
|
|
165
165
|
torch.save(state, path_obj)
|
|
166
166
|
if verbose:
|
|
167
|
-
_LOGGER.info(f"PytorchScaler state saved
|
|
167
|
+
_LOGGER.info(f"PytorchScaler state saved as '{path_obj.name}'.")
|
|
168
168
|
|
|
169
169
|
@staticmethod
|
|
170
170
|
def load(filepath: Union[str, Path], verbose: bool=True) -> 'PytorchScaler':
|
|
@@ -472,23 +472,30 @@ class MLTrainer:
|
|
|
472
472
|
|
|
473
473
|
yield attention_weights
|
|
474
474
|
|
|
475
|
-
def explain_attention(self, save_dir: Union[str, Path],
|
|
475
|
+
def explain_attention(self, save_dir: Union[str, Path],
|
|
476
|
+
feature_names: Optional[List[str]],
|
|
477
|
+
explain_dataset: Optional[Dataset] = None,
|
|
478
|
+
plot_n_features: int = 10):
|
|
476
479
|
"""
|
|
477
480
|
Generates and saves a feature importance plot based on attention weights.
|
|
478
481
|
|
|
479
|
-
This method only works for models with
|
|
482
|
+
This method only works for models with models with 'has_interpretable_attention'.
|
|
480
483
|
|
|
481
484
|
Args:
|
|
482
485
|
save_dir (str | Path): Directory to save the plot and summary data.
|
|
483
|
-
feature_names (List[str] | None): Names for the features for plot labeling.
|
|
486
|
+
feature_names (List[str] | None): Names for the features for plot labeling. If not given, generic names will be used.
|
|
484
487
|
explain_dataset (Dataset, optional): A specific dataset to explain. If None, the trainer's test dataset is used.
|
|
488
|
+
plot_n_features (int): Number of top features to plot.
|
|
485
489
|
"""
|
|
486
490
|
|
|
487
491
|
print("\n--- Attention Analysis ---")
|
|
488
492
|
|
|
489
493
|
# --- Step 1: Check if the model supports this explanation ---
|
|
490
|
-
if not
|
|
491
|
-
_LOGGER.
|
|
494
|
+
if not getattr(self.model, 'has_interpretable_attention', False):
|
|
495
|
+
_LOGGER.warning(
|
|
496
|
+
"Model is not flagged for interpretable attention analysis. "
|
|
497
|
+
"Skipping. This is the correct behavior for models like MultiHeadAttentionMLP."
|
|
498
|
+
)
|
|
492
499
|
return
|
|
493
500
|
|
|
494
501
|
# --- Step 2: Set up the dataloader ---
|
|
@@ -514,7 +521,8 @@ class MLTrainer:
|
|
|
514
521
|
plot_attention_importance(
|
|
515
522
|
weights=all_weights,
|
|
516
523
|
feature_names=feature_names,
|
|
517
|
-
save_dir=save_dir
|
|
524
|
+
save_dir=save_dir,
|
|
525
|
+
top_n=plot_n_features
|
|
518
526
|
)
|
|
519
527
|
else:
|
|
520
528
|
_LOGGER.error("No attention weights were collected from the model.")
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/dragon_ml_toolbox.egg-info/SOURCES.txt
RENAMED
|
File without changes
|
|
File without changes
|
{dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/dragon_ml_toolbox.egg-info/requires.txt
RENAMED
|
File without changes
|
{dragon_ml_toolbox-10.9.0 → dragon_ml_toolbox-10.10.0}/dragon_ml_toolbox.egg-info/top_level.txt
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|