neural-feature-importance 0.5.0__tar.gz → 0.9.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. neural_feature_importance-0.9.1/LICENSE +21 -0
  2. neural_feature_importance-0.9.1/PKG-INFO +158 -0
  3. neural_feature_importance-0.9.1/README.md +143 -0
  4. neural_feature_importance-0.9.1/conv_visualization_example.py +60 -0
  5. neural_feature_importance-0.9.1/conv_viz_utils.py +123 -0
  6. {neural_feature_importance-0.5.0 → neural_feature_importance-0.9.1}/neural_feature_importance/__init__.py +9 -0
  7. {neural_feature_importance-0.5.0 → neural_feature_importance-0.9.1}/neural_feature_importance/callbacks.py +26 -5
  8. neural_feature_importance-0.9.1/neural_feature_importance/conv_callbacks.py +104 -0
  9. neural_feature_importance-0.9.1/neural_feature_importance/embedding_callbacks.py +85 -0
  10. neural_feature_importance-0.9.1/neural_feature_importance.egg-info/PKG-INFO +158 -0
  11. neural_feature_importance-0.9.1/neural_feature_importance.egg-info/SOURCES.txt +27 -0
  12. neural_feature_importance-0.9.1/notebooks/conv_visualization_example.ipynb +1 -0
  13. neural_feature_importance-0.9.1/notebooks/token_importance_topk_example.ipynb +9 -0
  14. {neural_feature_importance-0.5.0 → neural_feature_importance-0.9.1}/pyproject.toml +5 -0
  15. neural_feature_importance-0.9.1/scripts/conv_visualization_example.py +114 -0
  16. neural_feature_importance-0.9.1/scripts/token_importance_topk_example.py +50 -0
  17. neural_feature_importance-0.9.1/text_classification_example.py +66 -0
  18. neural_feature_importance-0.9.1/token_topk_utils.py +124 -0
  19. neural_feature_importance-0.5.0/PKG-INFO +0 -10
  20. neural_feature_importance-0.5.0/README.md +0 -95
  21. neural_feature_importance-0.5.0/neural_feature_importance.egg-info/PKG-INFO +0 -10
  22. neural_feature_importance-0.5.0/neural_feature_importance.egg-info/SOURCES.txt +0 -16
  23. {neural_feature_importance-0.5.0 → neural_feature_importance-0.9.1}/.github/workflows/python-publish.yml +0 -0
  24. {neural_feature_importance-0.5.0 → neural_feature_importance-0.9.1}/AGENTS.md +0 -0
  25. {neural_feature_importance-0.5.0 → neural_feature_importance-0.9.1}/neural_feature_importance/utils/__init__.py +0 -0
  26. {neural_feature_importance-0.5.0 → neural_feature_importance-0.9.1}/neural_feature_importance/utils/monitors.py +0 -0
  27. {neural_feature_importance-0.5.0 → neural_feature_importance-0.9.1}/neural_feature_importance.egg-info/dependency_links.txt +0 -0
  28. {neural_feature_importance-0.5.0 → neural_feature_importance-0.9.1}/neural_feature_importance.egg-info/requires.txt +0 -0
  29. {neural_feature_importance-0.5.0 → neural_feature_importance-0.9.1}/neural_feature_importance.egg-info/top_level.txt +0 -0
  30. {neural_feature_importance-0.5.0 → neural_feature_importance-0.9.1/notebooks}/variance-based feature importance in artificial neural networks.ipynb +0 -0
  31. {neural_feature_importance-0.5.0 → neural_feature_importance-0.9.1/scripts}/compare_feature_importance.py +0 -0
  32. {neural_feature_importance-0.5.0 → neural_feature_importance-0.9.1/scripts}/full_experiment.py +0 -0
  33. {neural_feature_importance-0.5.0 → neural_feature_importance-0.9.1}/setup.cfg +0 -0
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2024 CR de Sá
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,158 @@
1
+ Metadata-Version: 2.4
2
+ Name: neural-feature-importance
3
+ Version: 0.9.1
4
+ Summary: Variance-based feature importance for Neural Networks using callbacks for Keras and PyTorch
5
+ Author: CR de Sá
6
+ Requires-Python: >=3.10
7
+ Description-Content-Type: text/markdown
8
+ License-File: LICENSE
9
+ Requires-Dist: numpy
10
+ Provides-Extra: tensorflow
11
+ Requires-Dist: tensorflow; extra == "tensorflow"
12
+ Provides-Extra: torch
13
+ Requires-Dist: torch; extra == "torch"
14
+ Dynamic: license-file
15
+
16
+ # neural-feature-importance
17
+
18
+ [![PyPI version](https://img.shields.io/pypi/v/neural-feature-importance.svg)](https://pypi.org/project/neural-feature-importance/)
19
+ [![Python 3.10+](https://img.shields.io/badge/python-3.10%2B-blue)](https://www.python.org/downloads/)
20
+ [![License: MIT](https://img.shields.io/badge/license-MIT-green)](LICENSE)
21
+
22
+ Variance-based feature importance for deep learning models.
23
+
24
+ `neural-feature-importance` implements the method described in
25
+ [CR de Sá, *Variance-based Feature Importance in Neural Networks*](https://doi.org/10.1007/978-3-030-33778-0_24).
26
+ It tracks the variance of the first trainable layer using Welford's algorithm
27
+ and produces normalized importance scores for each feature.
28
+
29
+ ## Features
30
+
31
+ - `VarianceImportanceKeras` — drop-in callback for TensorFlow/Keras models
32
+ - `VarianceImportanceTorch` — helper class for PyTorch training loops
33
+ - `MetricThreshold` — early-stopping callback based on a monitored metric
34
+ - Example scripts to reproduce the experiments from the paper
35
+
36
+ ## Installation
37
+
38
+ ```bash
39
+ pip install "neural-feature-importance[tensorflow]" # for Keras
40
+ pip install "neural-feature-importance[torch]" # for PyTorch
41
+ ```
42
+
43
+ Retrieve the package version via:
44
+
45
+ ```python
46
+ from neural_feature_importance import __version__
47
+ print(__version__)
48
+ ```
49
+
50
+ ## Quick start
51
+
52
+ ### Keras
53
+
54
+ ```python
55
+ from neural_feature_importance import VarianceImportanceKeras
56
+ from neural_feature_importance.utils import MetricThreshold
57
+
58
+ viann = VarianceImportanceKeras()
59
+ monitor = MetricThreshold(monitor="val_accuracy", threshold=0.95)
60
+ model.fit(X, y, validation_split=0.05, epochs=30, callbacks=[viann, monitor])
61
+ print(viann.feature_importances_)
62
+ ```
63
+
64
+ ### PyTorch
65
+
66
+ ```python
67
+ from neural_feature_importance import VarianceImportanceTorch
68
+
69
+ tracker = VarianceImportanceTorch(model)
70
+ tracker.on_train_begin()
71
+ for epoch in range(num_epochs):
72
+ train_one_epoch(model, optimizer, dataloader)
73
+ tracker.on_epoch_end()
74
+ tracker.on_train_end()
75
+ print(tracker.feature_importances_)
76
+ ```
77
+
78
+ ## Example scripts
79
+
80
+ Run `scripts/compare_feature_importance.py` to train a small network on the Iris dataset
81
+ and compare the scores with a random forest baseline:
82
+
83
+ ```bash
84
+ python compare_feature_importance.py
85
+ ```
86
+
87
+ Run `scripts/full_experiment.py` to reproduce the experiments from the paper:
88
+
89
+ ```bash
90
+ python full_experiment.py
91
+ ```
92
+
93
+ ### Convolutional models
94
+
95
+ To compute importances for convolutional networks, use
96
+ `ConvVarianceImportanceKeras` from `neural_feature_importance.conv_callbacks`.
97
+ `scripts/conv_visualization_example.py` trains small Conv2D models on the MNIST
98
+ and scikit‑learn digits datasets and displays per-filter heatmaps. An equivalent
99
+ notebook is available in ``notebooks/conv_visualization_example.ipynb``:
100
+
101
+ ```bash
102
+ python scripts/conv_visualization_example.py
103
+ ```
104
+
105
+ ### Embedding layers
106
+
107
+ To compute token importances from embedding weights, use
108
+ `EmbeddingVarianceImportanceKeras` or `EmbeddingVarianceImportanceTorch` from
109
+ `neural_feature_importance.embedding_callbacks`.
110
+ Run `scripts/token_importance_topk_example.py` to train a small text classifier
111
+ on IMDB and display the most important tokens. A matching notebook lives in
112
+ ``notebooks/token_importance_topk_example.ipynb``:
113
+
114
+ ```bash
115
+ python scripts/token_importance_topk_example.py
116
+ ```
117
+
118
+ ## Development
119
+
120
+ After making changes, run the following checks:
121
+
122
+ ```bash
123
+ python -m py_compile neural_feature_importance/callbacks.py
124
+ python -m py_compile "variance-based feature importance in artificial neural networks.ipynb" 2>&1 | head
125
+ jupyter nbconvert --to script "variance-based feature importance in artificial neural networks.ipynb" --stdout | head
126
+ ```
127
+
128
+ ## Citation
129
+
130
+ If you use this package in your research, please cite:
131
+
132
+ ```bibtex
133
+ @inproceedings{DBLP:conf/dis/Sa19,
134
+ author = {Cl{\'a}udio Rebelo de S{\'a}},
135
+ editor = {Petra Kralj Novak and
136
+ Tomislav Smuc and
137
+ Saso Dzeroski},
138
+ title = {Variance-Based Feature Importance in Neural Networks},
139
+ booktitle = {Discovery Science - 22nd International Conference, {DS} 2019, Split,
140
+ Croatia, October 28-30, 2019, Proceedings},
141
+ series = {Lecture Notes in Computer Science},
142
+ volume = {11828},
143
+ pages = {306--315},
144
+ publisher = {Springer},
145
+ year = {2019},
146
+ url = {https://doi.org/10.1007/978-3-030-33778-0\_24},
147
+ doi = {10.1007/978-3-030-33778-0\_24},
148
+ timestamp = {Thu, 07 Nov 2019 09:20:36 +0100},
149
+ biburl = {https://dblp.org/rec/conf/dis/Sa19.bib},
150
+ bibsource = {dblp computer science bibliography, https://dblp.org}
151
+ }
152
+ ```
153
+
154
+ We appreciate citations as they help the community discover this work.
155
+
156
+ ## License
157
+
158
+ This project is licensed under the [MIT License](LICENSE).
@@ -0,0 +1,143 @@
1
+ # neural-feature-importance
2
+
3
+ [![PyPI version](https://img.shields.io/pypi/v/neural-feature-importance.svg)](https://pypi.org/project/neural-feature-importance/)
4
+ [![Python 3.10+](https://img.shields.io/badge/python-3.10%2B-blue)](https://www.python.org/downloads/)
5
+ [![License: MIT](https://img.shields.io/badge/license-MIT-green)](LICENSE)
6
+
7
+ Variance-based feature importance for deep learning models.
8
+
9
+ `neural-feature-importance` implements the method described in
10
+ [CR de Sá, *Variance-based Feature Importance in Neural Networks*](https://doi.org/10.1007/978-3-030-33778-0_24).
11
+ It tracks the variance of the first trainable layer using Welford's algorithm
12
+ and produces normalized importance scores for each feature.
13
+
14
+ ## Features
15
+
16
+ - `VarianceImportanceKeras` — drop-in callback for TensorFlow/Keras models
17
+ - `VarianceImportanceTorch` — helper class for PyTorch training loops
18
+ - `MetricThreshold` — early-stopping callback based on a monitored metric
19
+ - Example scripts to reproduce the experiments from the paper
20
+
21
+ ## Installation
22
+
23
+ ```bash
24
+ pip install "neural-feature-importance[tensorflow]" # for Keras
25
+ pip install "neural-feature-importance[torch]" # for PyTorch
26
+ ```
27
+
28
+ Retrieve the package version via:
29
+
30
+ ```python
31
+ from neural_feature_importance import __version__
32
+ print(__version__)
33
+ ```
34
+
35
+ ## Quick start
36
+
37
+ ### Keras
38
+
39
+ ```python
40
+ from neural_feature_importance import VarianceImportanceKeras
41
+ from neural_feature_importance.utils import MetricThreshold
42
+
43
+ viann = VarianceImportanceKeras()
44
+ monitor = MetricThreshold(monitor="val_accuracy", threshold=0.95)
45
+ model.fit(X, y, validation_split=0.05, epochs=30, callbacks=[viann, monitor])
46
+ print(viann.feature_importances_)
47
+ ```
48
+
49
+ ### PyTorch
50
+
51
+ ```python
52
+ from neural_feature_importance import VarianceImportanceTorch
53
+
54
+ tracker = VarianceImportanceTorch(model)
55
+ tracker.on_train_begin()
56
+ for epoch in range(num_epochs):
57
+ train_one_epoch(model, optimizer, dataloader)
58
+ tracker.on_epoch_end()
59
+ tracker.on_train_end()
60
+ print(tracker.feature_importances_)
61
+ ```
62
+
63
+ ## Example scripts
64
+
65
+ Run `scripts/compare_feature_importance.py` to train a small network on the Iris dataset
66
+ and compare the scores with a random forest baseline:
67
+
68
+ ```bash
69
+ python compare_feature_importance.py
70
+ ```
71
+
72
+ Run `scripts/full_experiment.py` to reproduce the experiments from the paper:
73
+
74
+ ```bash
75
+ python full_experiment.py
76
+ ```
77
+
78
+ ### Convolutional models
79
+
80
+ To compute importances for convolutional networks, use
81
+ `ConvVarianceImportanceKeras` from `neural_feature_importance.conv_callbacks`.
82
+ `scripts/conv_visualization_example.py` trains small Conv2D models on the MNIST
83
+ and scikit‑learn digits datasets and displays per-filter heatmaps. An equivalent
84
+ notebook is available in ``notebooks/conv_visualization_example.ipynb``:
85
+
86
+ ```bash
87
+ python scripts/conv_visualization_example.py
88
+ ```
89
+
90
+ ### Embedding layers
91
+
92
+ To compute token importances from embedding weights, use
93
+ `EmbeddingVarianceImportanceKeras` or `EmbeddingVarianceImportanceTorch` from
94
+ `neural_feature_importance.embedding_callbacks`.
95
+ Run `scripts/token_importance_topk_example.py` to train a small text classifier
96
+ on IMDB and display the most important tokens. A matching notebook lives in
97
+ ``notebooks/token_importance_topk_example.ipynb``:
98
+
99
+ ```bash
100
+ python scripts/token_importance_topk_example.py
101
+ ```
102
+
103
+ ## Development
104
+
105
+ After making changes, run the following checks:
106
+
107
+ ```bash
108
+ python -m py_compile neural_feature_importance/callbacks.py
109
+ python -m py_compile "variance-based feature importance in artificial neural networks.ipynb" 2>&1 | head
110
+ jupyter nbconvert --to script "variance-based feature importance in artificial neural networks.ipynb" --stdout | head
111
+ ```
112
+
113
+ ## Citation
114
+
115
+ If you use this package in your research, please cite:
116
+
117
+ ```bibtex
118
+ @inproceedings{DBLP:conf/dis/Sa19,
119
+ author = {Cl{\'a}udio Rebelo de S{\'a}},
120
+ editor = {Petra Kralj Novak and
121
+ Tomislav Smuc and
122
+ Saso Dzeroski},
123
+ title = {Variance-Based Feature Importance in Neural Networks},
124
+ booktitle = {Discovery Science - 22nd International Conference, {DS} 2019, Split,
125
+ Croatia, October 28-30, 2019, Proceedings},
126
+ series = {Lecture Notes in Computer Science},
127
+ volume = {11828},
128
+ pages = {306--315},
129
+ publisher = {Springer},
130
+ year = {2019},
131
+ url = {https://doi.org/10.1007/978-3-030-33778-0\_24},
132
+ doi = {10.1007/978-3-030-33778-0\_24},
133
+ timestamp = {Thu, 07 Nov 2019 09:20:36 +0100},
134
+ biburl = {https://dblp.org/rec/conf/dis/Sa19.bib},
135
+ bibsource = {dblp computer science bibliography, https://dblp.org}
136
+ }
137
+ ```
138
+
139
+ We appreciate citations as they help the community discover this work.
140
+
141
+ ## License
142
+
143
+ This project is licensed under the [MIT License](LICENSE).
@@ -0,0 +1,60 @@
1
+ """Example of variance-based importance with a Conv2D model."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+
7
+ import matplotlib.pyplot as plt
8
+ from tensorflow.keras.datasets import mnist
9
+ from tensorflow.keras.layers import Conv2D, Dense, Flatten, MaxPooling2D
10
+ from tensorflow.keras.models import Sequential
11
+ from tensorflow.keras.utils import to_categorical
12
+
13
+ from neural_feature_importance.conv_callbacks import ConvVarianceImportanceKeras
14
+
15
+ logging.basicConfig(level=logging.INFO)
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ def build_model() -> Sequential:
20
+ """Return a minimal Conv2D model."""
21
+ model = Sequential(
22
+ [
23
+ Conv2D(8, (3, 3), activation="relu", input_shape=(28, 28, 1)),
24
+ MaxPooling2D((2, 2)),
25
+ Flatten(),
26
+ Dense(10, activation="softmax"),
27
+ ]
28
+ )
29
+ model.compile(
30
+ optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"]
31
+ )
32
+ return model
33
+
34
+
35
+ def main() -> None:
36
+ """Train model on MNIST and display a heatmap of importances."""
37
+ (x_train, y_train), _ = mnist.load_data()
38
+ x_train = x_train.astype("float32") / 255.0
39
+ x_train = x_train[..., None]
40
+ y_train = to_categorical(y_train, 10)
41
+
42
+ model = build_model()
43
+ callback = ConvVarianceImportanceKeras()
44
+ model.fit(x_train, y_train, epochs=2, batch_size=128, callbacks=[callback], verbose=0)
45
+
46
+ scores = callback.feature_importances_
47
+ if scores is None:
48
+ logger.warning("No importance scores computed.")
49
+ return
50
+
51
+ weights = model.layers[0].get_weights()[0]
52
+ heatmap = scores.reshape(weights.shape[0], weights.shape[1], weights.shape[2]).mean(axis=-1)
53
+ plt.imshow(heatmap, cmap="hot")
54
+ plt.colorbar()
55
+ plt.title("Feature importance heatmap")
56
+ plt.show()
57
+
58
+
59
+ if __name__ == "__main__":
60
+ main()
@@ -0,0 +1,123 @@
1
+ """Helper utilities for convolutional importance visualizations.
2
+
3
+ These functions build small convolutional models, compute filter scores using
4
+ variance-based importances, and plot the resulting weight maps and activations.
5
+ They are intended for exploratory analysis and example scripts.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import logging
11
+ from typing import Iterable
12
+
13
+ import matplotlib.pyplot as plt
14
+ import matplotlib.colors as mcolors
15
+ import numpy as np
16
+ from tensorflow.keras.layers import Conv2D, Dense, Flatten, MaxPooling2D
17
+ from tensorflow.keras.models import Sequential
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+ CMAP = mcolors.LinearSegmentedColormap.from_list("white_blue_black", ["white", "blue", "black"])
22
+ WEIGHT_CMAP = plt.cm.seismic
23
+
24
+
25
+ def build_model(input_shape: tuple[int, int, int], kernel_size: tuple[int, int]) -> Sequential:
26
+ """Return a simple Conv2D model for visualization experiments.
27
+
28
+ Parameters
29
+ ----------
30
+ input_shape:
31
+ Shape of the input images, e.g. ``(28, 28, 1)``.
32
+ kernel_size:
33
+ Size of the convolution kernel.
34
+ """
35
+ model = Sequential(
36
+ [
37
+ Conv2D(8, kernel_size, activation="relu", input_shape=input_shape),
38
+ MaxPooling2D((2, 2)),
39
+ Flatten(),
40
+ Dense(10, activation="softmax"),
41
+ ]
42
+ )
43
+ model.compile(optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"])
44
+ return model
45
+
46
+
47
+ def _threshold_filters(weights: np.ndarray, threshold: float) -> np.ndarray:
48
+ """Binarize filter weights using a threshold."""
49
+ mask = np.abs(weights) >= threshold
50
+ return np.where(mask, weights, 0.0)
51
+
52
+
53
+ def compute_filter_scores(
54
+ weights: np.ndarray, heatmap: np.ndarray, threshold: float
55
+ ) -> tuple[np.ndarray, np.ndarray]:
56
+ """Compute filter importance scores.
57
+
58
+ Filters are thresholded and multiplied by the variance heatmap before
59
+ summing over all spatial dimensions.
60
+ """
61
+ thr_weights = _threshold_filters(weights, threshold)
62
+ scores = np.sum(np.abs(thr_weights) * heatmap[..., None], axis=(0, 1, 2))
63
+ return scores.astype(float), thr_weights
64
+
65
+
66
+ def rank_filters(weights: np.ndarray, heatmap: np.ndarray, threshold: float) -> np.ndarray:
67
+ """Return filter indices sorted by descending importance."""
68
+ scores, _ = compute_filter_scores(weights, heatmap, threshold)
69
+ order = np.argsort(scores)[::-1]
70
+ logger.info("Filter scores: %s", scores.tolist())
71
+ return order
72
+
73
+
74
+ def accuracy_with_filters(
75
+ model: Sequential,
76
+ x: np.ndarray,
77
+ y: np.ndarray,
78
+ indices: Iterable[int],
79
+ ) -> float:
80
+ """Return accuracy when only selected filters are active."""
81
+ conv = model.layers[0]
82
+ original = conv.get_weights()
83
+ weights = original[0].copy()
84
+ bias = original[1].copy()
85
+ mask = np.zeros(weights.shape[-1], dtype=bool)
86
+ mask[list(indices)] = True
87
+ weights[..., ~mask] = 0.0
88
+ bias[~mask] = 0.0
89
+ conv.set_weights([weights, bias])
90
+ preds = np.argmax(model.predict(x, verbose=0), axis=1)
91
+ acc = float(np.mean(preds == np.argmax(y, axis=1)))
92
+ conv.set_weights(original)
93
+ return acc
94
+
95
+
96
+ def plot_filters(
97
+ weights: np.ndarray,
98
+ heatmap: np.ndarray,
99
+ example_out: np.ndarray,
100
+ order: Iterable[int],
101
+ ) -> None:
102
+ """Display filter weights, importances, and outputs."""
103
+ n_filters = weights.shape[-1]
104
+ vmax = float(np.max(np.abs(weights)))
105
+ fig, axes = plt.subplots(n_filters, 3, figsize=(9, 3 * n_filters))
106
+ for row, idx in enumerate(order):
107
+ ax_w = axes[row, 0]
108
+ ax_i = axes[row, 1]
109
+ ax_o = axes[row, 2]
110
+ im_w = ax_w.imshow(weights[:, :, 0, idx], cmap=WEIGHT_CMAP, vmin=-vmax, vmax=vmax)
111
+ ax_w.set_title(f"Filter {idx} weights")
112
+ ax_w.axis("off")
113
+ fig.colorbar(im_w, ax=ax_w)
114
+ im_i = ax_i.imshow(heatmap[:, :, 0], cmap=CMAP, vmin=0.0, vmax=1.0)
115
+ ax_i.set_title("Importance")
116
+ ax_i.axis("off")
117
+ fig.colorbar(im_i, ax=ax_i)
118
+ im_o = ax_o.imshow(example_out[:, :, idx], cmap="gray", vmin=0.0, vmax=np.max(example_out))
119
+ ax_o.set_title("Filter output")
120
+ ax_o.axis("off")
121
+ fig.colorbar(im_o, ax=ax_o)
122
+ plt.tight_layout()
123
+ plt.show()
@@ -7,6 +7,11 @@ from .callbacks import (
7
7
  VarianceImportanceKeras,
8
8
  VarianceImportanceTorch,
9
9
  )
10
+ from .conv_callbacks import ConvVarianceImportanceKeras, ConvVarianceImportanceTorch
11
+ from .embedding_callbacks import (
12
+ EmbeddingVarianceImportanceKeras,
13
+ EmbeddingVarianceImportanceTorch,
14
+ )
10
15
  from .utils import MetricThreshold
11
16
 
12
17
  try:
@@ -19,4 +24,8 @@ __all__ = [
19
24
  "VarianceImportanceKeras",
20
25
  "VarianceImportanceTorch",
21
26
  "MetricThreshold",
27
+ "ConvVarianceImportanceKeras",
28
+ "ConvVarianceImportanceTorch",
29
+ "EmbeddingVarianceImportanceKeras",
30
+ "EmbeddingVarianceImportanceTorch",
22
31
  ]
@@ -1,4 +1,10 @@
1
- """Variance-based feature importance utilities."""
1
+ """Utilities for computing variance-based feature importances.
2
+
3
+ These classes track the weights of the first trainable layer during training
4
+ and estimate feature importances by accumulating the variance of each weight
5
+ value. After training, the variances are combined with the last observed
6
+ weights to produce a normalized importance score for every input feature.
7
+ """
2
8
 
3
9
  from __future__ import annotations
4
10
 
@@ -13,7 +19,13 @@ logger = logging.getLogger(__name__)
13
19
 
14
20
 
15
21
  class VarianceImportanceBase:
16
- """Compute feature importance using Welford's algorithm."""
22
+ """Compute feature importances using running variance statistics.
23
+
24
+ The class implements Welford's algorithm to accumulate the variance of
25
+ weight values over training iterations. Feature importances are derived by
26
+ combining the final variance estimates with the absolute value of the last
27
+ observed weights.
28
+ """
17
29
 
18
30
  def __init__(self) -> None:
19
31
  self._n = 0
@@ -23,13 +35,21 @@ class VarianceImportanceBase:
23
35
  self.var_scores: np.ndarray | None = None
24
36
 
25
37
  def start(self, weights: np.ndarray) -> None:
26
- """Initialize statistics for the given weight matrix."""
38
+ """Initialize running statistics.
39
+
40
+ Parameters
41
+ ----------
42
+ weights:
43
+ Initial weight matrix of shape ``(features, outputs)``. The values
44
+ are converted to ``float64`` for numerical stability and the running
45
+ mean and variance buffers are reset.
46
+ """
27
47
  self._mean = weights.astype(np.float64)
28
48
  self._m2 = np.zeros_like(self._mean)
29
49
  self._n = 0
30
50
 
31
51
  def update(self, weights: np.ndarray) -> None:
32
- """Update running statistics with new weights."""
52
+ """Update running mean and variance using new weights."""
33
53
  if self._mean is None or self._m2 is None:
34
54
  return
35
55
  self._n += 1
@@ -40,7 +60,7 @@ class VarianceImportanceBase:
40
60
  self._last_weights = weights
41
61
 
42
62
  def finalize(self) -> None:
43
- """Finalize statistics and compute normalized scores."""
63
+ """Compute normalized importance scores from accumulated statistics."""
44
64
  if self._last_weights is None or self._m2 is None:
45
65
  logger.warning(
46
66
  "%s was not fully initialized; no scores computed", self.__class__.__name__
@@ -67,6 +87,7 @@ class VarianceImportanceBase:
67
87
  return self.var_scores
68
88
 
69
89
 
90
+
70
91
  class VarianceImportanceKeras(Callback, VarianceImportanceBase):
71
92
  """Keras callback implementing variance-based feature importance."""
72
93