neural-feature-importance 0.5.2__tar.gz → 0.9.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. neural_feature_importance-0.9.1/LICENSE +21 -0
  2. {neural_feature_importance-0.5.2/neural_feature_importance.egg-info → neural_feature_importance-0.9.1}/PKG-INFO +36 -4
  3. {neural_feature_importance-0.5.2 → neural_feature_importance-0.9.1}/README.md +33 -3
  4. neural_feature_importance-0.9.1/conv_visualization_example.py +60 -0
  5. neural_feature_importance-0.9.1/conv_viz_utils.py +123 -0
  6. {neural_feature_importance-0.5.2 → neural_feature_importance-0.9.1}/neural_feature_importance/__init__.py +9 -0
  7. {neural_feature_importance-0.5.2 → neural_feature_importance-0.9.1}/neural_feature_importance/callbacks.py +26 -5
  8. neural_feature_importance-0.9.1/neural_feature_importance/conv_callbacks.py +104 -0
  9. neural_feature_importance-0.9.1/neural_feature_importance/embedding_callbacks.py +85 -0
  10. {neural_feature_importance-0.5.2 → neural_feature_importance-0.9.1/neural_feature_importance.egg-info}/PKG-INFO +36 -4
  11. neural_feature_importance-0.9.1/neural_feature_importance.egg-info/SOURCES.txt +27 -0
  12. neural_feature_importance-0.9.1/notebooks/conv_visualization_example.ipynb +1 -0
  13. neural_feature_importance-0.9.1/notebooks/token_importance_topk_example.ipynb +9 -0
  14. {neural_feature_importance-0.5.2 → neural_feature_importance-0.9.1}/pyproject.toml +3 -0
  15. neural_feature_importance-0.9.1/scripts/conv_visualization_example.py +114 -0
  16. neural_feature_importance-0.9.1/scripts/token_importance_topk_example.py +50 -0
  17. neural_feature_importance-0.9.1/text_classification_example.py +66 -0
  18. neural_feature_importance-0.9.1/token_topk_utils.py +124 -0
  19. neural_feature_importance-0.5.2/neural_feature_importance.egg-info/SOURCES.txt +0 -16
  20. {neural_feature_importance-0.5.2 → neural_feature_importance-0.9.1}/.github/workflows/python-publish.yml +0 -0
  21. {neural_feature_importance-0.5.2 → neural_feature_importance-0.9.1}/AGENTS.md +0 -0
  22. {neural_feature_importance-0.5.2 → neural_feature_importance-0.9.1}/neural_feature_importance/utils/__init__.py +0 -0
  23. {neural_feature_importance-0.5.2 → neural_feature_importance-0.9.1}/neural_feature_importance/utils/monitors.py +0 -0
  24. {neural_feature_importance-0.5.2 → neural_feature_importance-0.9.1}/neural_feature_importance.egg-info/dependency_links.txt +0 -0
  25. {neural_feature_importance-0.5.2 → neural_feature_importance-0.9.1}/neural_feature_importance.egg-info/requires.txt +0 -0
  26. {neural_feature_importance-0.5.2 → neural_feature_importance-0.9.1}/neural_feature_importance.egg-info/top_level.txt +0 -0
  27. {neural_feature_importance-0.5.2 → neural_feature_importance-0.9.1/notebooks}/variance-based feature importance in artificial neural networks.ipynb +0 -0
  28. {neural_feature_importance-0.5.2 → neural_feature_importance-0.9.1/scripts}/compare_feature_importance.py +0 -0
  29. {neural_feature_importance-0.5.2 → neural_feature_importance-0.9.1/scripts}/full_experiment.py +0 -0
  30. {neural_feature_importance-0.5.2 → neural_feature_importance-0.9.1}/setup.cfg +0 -0
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2024 CR de Sá
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -1,20 +1,23 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: neural-feature-importance
3
- Version: 0.5.2
3
+ Version: 0.9.1
4
4
  Summary: Variance-based feature importance for Neural Networks using callbacks for Keras and PyTorch
5
5
  Author: CR de Sá
6
6
  Requires-Python: >=3.10
7
7
  Description-Content-Type: text/markdown
8
+ License-File: LICENSE
8
9
  Requires-Dist: numpy
9
10
  Provides-Extra: tensorflow
10
11
  Requires-Dist: tensorflow; extra == "tensorflow"
11
12
  Provides-Extra: torch
12
13
  Requires-Dist: torch; extra == "torch"
14
+ Dynamic: license-file
13
15
 
14
16
  # neural-feature-importance
15
17
 
16
18
  [![PyPI version](https://img.shields.io/pypi/v/neural-feature-importance.svg)](https://pypi.org/project/neural-feature-importance/)
17
- [![Python versions](https://img.shields.io/pypi/pyversions/neural-feature-importance.svg)](https://pypi.org/project/neural-feature-importance/)
19
+ [![Python 3.10+](https://img.shields.io/badge/python-3.10%2B-blue)](https://www.python.org/downloads/)
20
+ [![License: MIT](https://img.shields.io/badge/license-MIT-green)](LICENSE)
18
21
 
19
22
  Variance-based feature importance for deep learning models.
20
23
 
@@ -74,19 +77,44 @@ print(tracker.feature_importances_)
74
77
 
75
78
  ## Example scripts
76
79
 
77
- Run `compare_feature_importance.py` to train a small network on the Iris dataset
80
+ Run `scripts/compare_feature_importance.py` to train a small network on the Iris dataset
78
81
  and compare the scores with a random forest baseline:
79
82
 
80
83
  ```bash
81
84
  python compare_feature_importance.py
82
85
  ```
83
86
 
84
- Run `full_experiment.py` to reproduce the experiments from the paper:
87
+ Run `scripts/full_experiment.py` to reproduce the experiments from the paper:
85
88
 
86
89
  ```bash
87
90
  python full_experiment.py
88
91
  ```
89
92
 
93
+ ### Convolutional models
94
+
95
+ To compute importances for convolutional networks, use
96
+ `ConvVarianceImportanceKeras` from `neural_feature_importance.conv_callbacks`.
97
+ `scripts/conv_visualization_example.py` trains small Conv2D models on the MNIST
98
+ and scikit‑learn digits datasets and displays per-filter heatmaps. An equivalent
99
+ notebook is available in ``notebooks/conv_visualization_example.ipynb``:
100
+
101
+ ```bash
102
+ python scripts/conv_visualization_example.py
103
+ ```
104
+
105
+ ### Embedding layers
106
+
107
+ To compute token importances from embedding weights, use
108
+ `EmbeddingVarianceImportanceKeras` or `EmbeddingVarianceImportanceTorch` from
109
+ `neural_feature_importance.embedding_callbacks`.
110
+ Run `scripts/token_importance_topk_example.py` to train a small text classifier
111
+ on IMDB and display the most important tokens. A matching notebook lives in
112
+ ``notebooks/token_importance_topk_example.ipynb``:
113
+
114
+ ```bash
115
+ python scripts/token_importance_topk_example.py
116
+ ```
117
+
90
118
  ## Development
91
119
 
92
120
  After making changes, run the following checks:
@@ -124,3 +152,7 @@ If you use this package in your research, please cite:
124
152
  ```
125
153
 
126
154
  We appreciate citations as they help the community discover this work.
155
+
156
+ ## License
157
+
158
+ This project is licensed under the [MIT License](LICENSE).
@@ -1,7 +1,8 @@
1
1
  # neural-feature-importance
2
2
 
3
3
  [![PyPI version](https://img.shields.io/pypi/v/neural-feature-importance.svg)](https://pypi.org/project/neural-feature-importance/)
4
- [![Python versions](https://img.shields.io/pypi/pyversions/neural-feature-importance.svg)](https://pypi.org/project/neural-feature-importance/)
4
+ [![Python 3.10+](https://img.shields.io/badge/python-3.10%2B-blue)](https://www.python.org/downloads/)
5
+ [![License: MIT](https://img.shields.io/badge/license-MIT-green)](LICENSE)
5
6
 
6
7
  Variance-based feature importance for deep learning models.
7
8
 
@@ -61,19 +62,44 @@ print(tracker.feature_importances_)
61
62
 
62
63
  ## Example scripts
63
64
 
64
- Run `compare_feature_importance.py` to train a small network on the Iris dataset
65
+ Run `scripts/compare_feature_importance.py` to train a small network on the Iris dataset
65
66
  and compare the scores with a random forest baseline:
66
67
 
67
68
  ```bash
68
69
  python compare_feature_importance.py
69
70
  ```
70
71
 
71
- Run `full_experiment.py` to reproduce the experiments from the paper:
72
+ Run `scripts/full_experiment.py` to reproduce the experiments from the paper:
72
73
 
73
74
  ```bash
74
75
  python full_experiment.py
75
76
  ```
76
77
 
78
+ ### Convolutional models
79
+
80
+ To compute importances for convolutional networks, use
81
+ `ConvVarianceImportanceKeras` from `neural_feature_importance.conv_callbacks`.
82
+ `scripts/conv_visualization_example.py` trains small Conv2D models on the MNIST
83
+ and scikit‑learn digits datasets and displays per-filter heatmaps. An equivalent
84
+ notebook is available in ``notebooks/conv_visualization_example.ipynb``:
85
+
86
+ ```bash
87
+ python scripts/conv_visualization_example.py
88
+ ```
89
+
90
+ ### Embedding layers
91
+
92
+ To compute token importances from embedding weights, use
93
+ `EmbeddingVarianceImportanceKeras` or `EmbeddingVarianceImportanceTorch` from
94
+ `neural_feature_importance.embedding_callbacks`.
95
+ Run `scripts/token_importance_topk_example.py` to train a small text classifier
96
+ on IMDB and display the most important tokens. A matching notebook lives in
97
+ ``notebooks/token_importance_topk_example.ipynb``:
98
+
99
+ ```bash
100
+ python scripts/token_importance_topk_example.py
101
+ ```
102
+
77
103
  ## Development
78
104
 
79
105
  After making changes, run the following checks:
@@ -111,3 +137,7 @@ If you use this package in your research, please cite:
111
137
  ```
112
138
 
113
139
  We appreciate citations as they help the community discover this work.
140
+
141
+ ## License
142
+
143
+ This project is licensed under the [MIT License](LICENSE).
@@ -0,0 +1,60 @@
1
+ """Example of variance-based importance with a Conv2D model."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+
7
+ import matplotlib.pyplot as plt
8
+ from tensorflow.keras.datasets import mnist
9
+ from tensorflow.keras.layers import Conv2D, Dense, Flatten, MaxPooling2D
10
+ from tensorflow.keras.models import Sequential
11
+ from tensorflow.keras.utils import to_categorical
12
+
13
+ from neural_feature_importance.conv_callbacks import ConvVarianceImportanceKeras
14
+
15
+ logging.basicConfig(level=logging.INFO)
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ def build_model() -> Sequential:
20
+ """Return a minimal Conv2D model."""
21
+ model = Sequential(
22
+ [
23
+ Conv2D(8, (3, 3), activation="relu", input_shape=(28, 28, 1)),
24
+ MaxPooling2D((2, 2)),
25
+ Flatten(),
26
+ Dense(10, activation="softmax"),
27
+ ]
28
+ )
29
+ model.compile(
30
+ optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"]
31
+ )
32
+ return model
33
+
34
+
35
+ def main() -> None:
36
+ """Train model on MNIST and display a heatmap of importances."""
37
+ (x_train, y_train), _ = mnist.load_data()
38
+ x_train = x_train.astype("float32") / 255.0
39
+ x_train = x_train[..., None]
40
+ y_train = to_categorical(y_train, 10)
41
+
42
+ model = build_model()
43
+ callback = ConvVarianceImportanceKeras()
44
+ model.fit(x_train, y_train, epochs=2, batch_size=128, callbacks=[callback], verbose=0)
45
+
46
+ scores = callback.feature_importances_
47
+ if scores is None:
48
+ logger.warning("No importance scores computed.")
49
+ return
50
+
51
+ weights = model.layers[0].get_weights()[0]
52
+ heatmap = scores.reshape(weights.shape[0], weights.shape[1], weights.shape[2]).mean(axis=-1)
53
+ plt.imshow(heatmap, cmap="hot")
54
+ plt.colorbar()
55
+ plt.title("Feature importance heatmap")
56
+ plt.show()
57
+
58
+
59
+ if __name__ == "__main__":
60
+ main()
@@ -0,0 +1,123 @@
1
+ """Helper utilities for convolutional importance visualizations.
2
+
3
+ These functions build small convolutional models, compute filter scores using
4
+ variance-based importances, and plot the resulting weight maps and activations.
5
+ They are intended for exploratory analysis and example scripts.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import logging
11
+ from typing import Iterable
12
+
13
+ import matplotlib.pyplot as plt
14
+ import matplotlib.colors as mcolors
15
+ import numpy as np
16
+ from tensorflow.keras.layers import Conv2D, Dense, Flatten, MaxPooling2D
17
+ from tensorflow.keras.models import Sequential
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+ CMAP = mcolors.LinearSegmentedColormap.from_list("white_blue_black", ["white", "blue", "black"])
22
+ WEIGHT_CMAP = plt.cm.seismic
23
+
24
+
25
+ def build_model(input_shape: tuple[int, int, int], kernel_size: tuple[int, int]) -> Sequential:
26
+ """Return a simple Conv2D model for visualization experiments.
27
+
28
+ Parameters
29
+ ----------
30
+ input_shape:
31
+ Shape of the input images, e.g. ``(28, 28, 1)``.
32
+ kernel_size:
33
+ Size of the convolution kernel.
34
+ """
35
+ model = Sequential(
36
+ [
37
+ Conv2D(8, kernel_size, activation="relu", input_shape=input_shape),
38
+ MaxPooling2D((2, 2)),
39
+ Flatten(),
40
+ Dense(10, activation="softmax"),
41
+ ]
42
+ )
43
+ model.compile(optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"])
44
+ return model
45
+
46
+
47
+ def _threshold_filters(weights: np.ndarray, threshold: float) -> np.ndarray:
48
+ """Binarize filter weights using a threshold."""
49
+ mask = np.abs(weights) >= threshold
50
+ return np.where(mask, weights, 0.0)
51
+
52
+
53
+ def compute_filter_scores(
54
+ weights: np.ndarray, heatmap: np.ndarray, threshold: float
55
+ ) -> tuple[np.ndarray, np.ndarray]:
56
+ """Compute filter importance scores.
57
+
58
+ Filters are thresholded and multiplied by the variance heatmap before
59
+ summing over all spatial dimensions.
60
+ """
61
+ thr_weights = _threshold_filters(weights, threshold)
62
+ scores = np.sum(np.abs(thr_weights) * heatmap[..., None], axis=(0, 1, 2))
63
+ return scores.astype(float), thr_weights
64
+
65
+
66
+ def rank_filters(weights: np.ndarray, heatmap: np.ndarray, threshold: float) -> np.ndarray:
67
+ """Return filter indices sorted by descending importance."""
68
+ scores, _ = compute_filter_scores(weights, heatmap, threshold)
69
+ order = np.argsort(scores)[::-1]
70
+ logger.info("Filter scores: %s", scores.tolist())
71
+ return order
72
+
73
+
74
+ def accuracy_with_filters(
75
+ model: Sequential,
76
+ x: np.ndarray,
77
+ y: np.ndarray,
78
+ indices: Iterable[int],
79
+ ) -> float:
80
+ """Return accuracy when only selected filters are active."""
81
+ conv = model.layers[0]
82
+ original = conv.get_weights()
83
+ weights = original[0].copy()
84
+ bias = original[1].copy()
85
+ mask = np.zeros(weights.shape[-1], dtype=bool)
86
+ mask[list(indices)] = True
87
+ weights[..., ~mask] = 0.0
88
+ bias[~mask] = 0.0
89
+ conv.set_weights([weights, bias])
90
+ preds = np.argmax(model.predict(x, verbose=0), axis=1)
91
+ acc = float(np.mean(preds == np.argmax(y, axis=1)))
92
+ conv.set_weights(original)
93
+ return acc
94
+
95
+
96
+ def plot_filters(
97
+ weights: np.ndarray,
98
+ heatmap: np.ndarray,
99
+ example_out: np.ndarray,
100
+ order: Iterable[int],
101
+ ) -> None:
102
+ """Display filter weights, importances, and outputs."""
103
+ n_filters = weights.shape[-1]
104
+ vmax = float(np.max(np.abs(weights)))
105
+ fig, axes = plt.subplots(n_filters, 3, figsize=(9, 3 * n_filters))
106
+ for row, idx in enumerate(order):
107
+ ax_w = axes[row, 0]
108
+ ax_i = axes[row, 1]
109
+ ax_o = axes[row, 2]
110
+ im_w = ax_w.imshow(weights[:, :, 0, idx], cmap=WEIGHT_CMAP, vmin=-vmax, vmax=vmax)
111
+ ax_w.set_title(f"Filter {idx} weights")
112
+ ax_w.axis("off")
113
+ fig.colorbar(im_w, ax=ax_w)
114
+ im_i = ax_i.imshow(heatmap[:, :, 0], cmap=CMAP, vmin=0.0, vmax=1.0)
115
+ ax_i.set_title("Importance")
116
+ ax_i.axis("off")
117
+ fig.colorbar(im_i, ax=ax_i)
118
+ im_o = ax_o.imshow(example_out[:, :, idx], cmap="gray", vmin=0.0, vmax=np.max(example_out))
119
+ ax_o.set_title("Filter output")
120
+ ax_o.axis("off")
121
+ fig.colorbar(im_o, ax=ax_o)
122
+ plt.tight_layout()
123
+ plt.show()
@@ -7,6 +7,11 @@ from .callbacks import (
7
7
  VarianceImportanceKeras,
8
8
  VarianceImportanceTorch,
9
9
  )
10
+ from .conv_callbacks import ConvVarianceImportanceKeras, ConvVarianceImportanceTorch
11
+ from .embedding_callbacks import (
12
+ EmbeddingVarianceImportanceKeras,
13
+ EmbeddingVarianceImportanceTorch,
14
+ )
10
15
  from .utils import MetricThreshold
11
16
 
12
17
  try:
@@ -19,4 +24,8 @@ __all__ = [
19
24
  "VarianceImportanceKeras",
20
25
  "VarianceImportanceTorch",
21
26
  "MetricThreshold",
27
+ "ConvVarianceImportanceKeras",
28
+ "ConvVarianceImportanceTorch",
29
+ "EmbeddingVarianceImportanceKeras",
30
+ "EmbeddingVarianceImportanceTorch",
22
31
  ]
@@ -1,4 +1,10 @@
1
- """Variance-based feature importance utilities."""
1
+ """Utilities for computing variance-based feature importances.
2
+
3
+ These classes track the weights of the first trainable layer during training
4
+ and estimate feature importances by accumulating the variance of each weight
5
+ value. After training, the variances are combined with the last observed
6
+ weights to produce a normalized importance score for every input feature.
7
+ """
2
8
 
3
9
  from __future__ import annotations
4
10
 
@@ -13,7 +19,13 @@ logger = logging.getLogger(__name__)
13
19
 
14
20
 
15
21
  class VarianceImportanceBase:
16
- """Compute feature importance using Welford's algorithm."""
22
+ """Compute feature importances using running variance statistics.
23
+
24
+ The class implements Welford's algorithm to accumulate the variance of
25
+ weight values over training iterations. Feature importances are derived by
26
+ combining the final variance estimates with the absolute value of the last
27
+ observed weights.
28
+ """
17
29
 
18
30
  def __init__(self) -> None:
19
31
  self._n = 0
@@ -23,13 +35,21 @@ class VarianceImportanceBase:
23
35
  self.var_scores: np.ndarray | None = None
24
36
 
25
37
  def start(self, weights: np.ndarray) -> None:
26
- """Initialize statistics for the given weight matrix."""
38
+ """Initialize running statistics.
39
+
40
+ Parameters
41
+ ----------
42
+ weights:
43
+ Initial weight matrix of shape ``(features, outputs)``. The values
44
+ are converted to ``float64`` for numerical stability and the running
45
+ mean and variance buffers are reset.
46
+ """
27
47
  self._mean = weights.astype(np.float64)
28
48
  self._m2 = np.zeros_like(self._mean)
29
49
  self._n = 0
30
50
 
31
51
  def update(self, weights: np.ndarray) -> None:
32
- """Update running statistics with new weights."""
52
+ """Update running mean and variance using new weights."""
33
53
  if self._mean is None or self._m2 is None:
34
54
  return
35
55
  self._n += 1
@@ -40,7 +60,7 @@ class VarianceImportanceBase:
40
60
  self._last_weights = weights
41
61
 
42
62
  def finalize(self) -> None:
43
- """Finalize statistics and compute normalized scores."""
63
+ """Compute normalized importance scores from accumulated statistics."""
44
64
  if self._last_weights is None or self._m2 is None:
45
65
  logger.warning(
46
66
  "%s was not fully initialized; no scores computed", self.__class__.__name__
@@ -67,6 +87,7 @@ class VarianceImportanceBase:
67
87
  return self.var_scores
68
88
 
69
89
 
90
+
70
91
  class VarianceImportanceKeras(Callback, VarianceImportanceBase):
71
92
  """Keras callback implementing variance-based feature importance."""
72
93
 
@@ -0,0 +1,104 @@
1
+ """Callbacks that extend variance tracking to convolutional layers."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ from typing import Optional
7
+
8
+ import numpy as np
9
+
10
+ from .callbacks import VarianceImportanceKeras, VarianceImportanceTorch
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ def _flatten_weights(weights: np.ndarray, outputs_last: bool) -> np.ndarray:
16
+ """Return a two-dimensional view of convolutional kernels.
17
+
18
+ Parameters
19
+ ----------
20
+ weights:
21
+ Weight tensor from a convolutional layer. Expected shape is
22
+ ``(H, W, in_channels, out_channels)`` when ``outputs_last`` is ``True``
23
+ and ``(out_channels, in_channels, H, W)`` otherwise.
24
+ outputs_last:
25
+ Whether the output dimension is the last axis of ``weights``.
26
+
27
+ Returns
28
+ -------
29
+ np.ndarray
30
+ Array of shape ``(features, outputs)`` suitable for variance tracking.
31
+ """
32
+ if weights.ndim > 2:
33
+ if outputs_last:
34
+ return weights.reshape(-1, weights.shape[-1])
35
+ return weights.reshape(weights.shape[0], -1).T
36
+ return weights
37
+
38
+
39
+ class ConvVarianceImportanceKeras(VarianceImportanceKeras):
40
+ """Keras callback that tracks convolutional kernels.
41
+
42
+ The first trainable layer is inspected and, if its weights have more than
43
+ two dimensions, they are flattened so that each spatial location and input
44
+ channel is treated as a separate feature. Variances are accumulated during
45
+ training and converted to per-filter importance scores.
46
+ """
47
+ def on_train_begin(self, logs: Optional[dict] = None) -> None:
48
+ self._layer = None
49
+ for layer in self.model.layers:
50
+ has_vars = bool(layer.trainable_weights)
51
+ has_data = bool(layer.get_weights())
52
+ if has_vars and has_data:
53
+ self._layer = layer
54
+ break
55
+ if self._layer is None:
56
+ raise ValueError("Model does not contain trainable weights.")
57
+ weights = self._layer.get_weights()[0]
58
+ weights = _flatten_weights(weights, outputs_last=True)
59
+ logger.info(
60
+ "Tracking variance for layer '%s' with %d features",
61
+ self._layer.name,
62
+ weights.shape[0],
63
+ )
64
+ self.start(weights)
65
+
66
+ def on_epoch_end(self, epoch: int, logs: Optional[dict] = None) -> None:
67
+ if self._layer is None:
68
+ return
69
+ weights = self._layer.get_weights()[0]
70
+ weights = _flatten_weights(weights, outputs_last=True)
71
+ self.update(weights)
72
+
73
+
74
+ class ConvVarianceImportanceTorch(VarianceImportanceTorch):
75
+ """PyTorch helper with convolutional support.
76
+
77
+ Works analogously to :class:`ConvVarianceImportanceKeras` but for models
78
+ built with :mod:`torch.nn`. The first trainable parameter with two or more
79
+ dimensions is flattened so each spatial position becomes a tracked feature.
80
+ """
81
+ def on_train_begin(self) -> None:
82
+ from torch import nn
83
+
84
+ for name, param in self.model.named_parameters():
85
+ if param.requires_grad and param.dim() >= 2:
86
+ self._param = param
87
+ weights = param.detach().cpu().numpy()
88
+ weights = _flatten_weights(weights, outputs_last=False)
89
+ logger.info(
90
+ "Tracking variance for parameter '%s' with %d features",
91
+ name,
92
+ weights.shape[0],
93
+ )
94
+ self.start(weights)
95
+ break
96
+ if self._param is None:
97
+ raise ValueError("Model does not contain trainable parameters")
98
+
99
+ def on_epoch_end(self) -> None:
100
+ if self._param is None:
101
+ return
102
+ weights = self._param.detach().cpu().numpy()
103
+ weights = _flatten_weights(weights, outputs_last=False)
104
+ self.update(weights)
@@ -0,0 +1,85 @@
1
+ """Callbacks that compute variance-based importance for embedding layers.
2
+
3
+ These callbacks extend :class:`~neural_feature_importance.callbacks.VarianceImportanceBase`
4
+ to operate on 2-D embedding matrices. The variance of each embedding vector is
5
+ accumulated over training and the resulting per-token scores are normalized
6
+ between 0 and 1.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import logging
12
+
13
+ import numpy as np
14
+
15
+ from .callbacks import VarianceImportanceKeras, VarianceImportanceTorch
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ class EmbeddingVarianceImportanceKeras(VarianceImportanceKeras):
21
+ """Variance-based importance callback for Keras embedding layers.
22
+
23
+ During training this callback monitors the weights of the first trainable
24
+ layer (expected to be an :class:`~tensorflow.keras.layers.Embedding`) and
25
+ accumulates the running variance of each embedding vector. After training the
26
+ variances are summed across the embedding dimension to yield a single score
27
+ per token.
28
+ """
29
+
30
+ def finalize(self) -> None: # type: ignore[override]
31
+ if self._last_weights is None or self._m2 is None:
32
+ logger.warning(
33
+ "%s was not fully initialized; no scores computed",
34
+ self.__class__.__name__,
35
+ )
36
+ return
37
+
38
+ if self._n < 2:
39
+ variance = np.full_like(self._m2, np.nan)
40
+ else:
41
+ variance = self._m2 / (self._n - 1)
42
+
43
+ scores = np.sum(variance, axis=1)
44
+ min_val = float(np.nanmin(scores))
45
+ max_val = float(np.nanmax(scores))
46
+ denom = max_val - min_val if max_val != min_val else 1.0
47
+ self.var_scores = (scores - min_val) / denom
48
+
49
+ top = np.argsort(self.var_scores)[-10:][::-1]
50
+ logger.info("Most important tokens: %s", top)
51
+
52
+
53
+ class EmbeddingVarianceImportanceTorch(VarianceImportanceTorch):
54
+ """Variance-based importance for PyTorch embedding layers.
55
+
56
+ Parameters
57
+ ----------
58
+ model:
59
+ Neural network containing an :class:`torch.nn.Embedding` layer whose
60
+ weights will be monitored.
61
+ """
62
+
63
+ def finalize(self) -> None: # type: ignore[override]
64
+ if self._last_weights is None or self._m2 is None:
65
+ logger.warning(
66
+ "%s was not fully initialized; no scores computed",
67
+ self.__class__.__name__,
68
+ )
69
+ return
70
+
71
+ if self._n < 2:
72
+ variance = np.full_like(self._m2, np.nan)
73
+ else:
74
+ variance = self._m2 / (self._n - 1)
75
+
76
+ scores = np.sum(variance, axis=1)
77
+ min_val = float(np.nanmin(scores))
78
+ max_val = float(np.nanmax(scores))
79
+ denom = max_val - min_val if max_val != min_val else 1.0
80
+ self.var_scores = (scores - min_val) / denom
81
+
82
+ top = np.argsort(self.var_scores)[-10:][::-1]
83
+ logger.info("Most important tokens: %s", top)
84
+
85
+
@@ -1,20 +1,23 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: neural-feature-importance
3
- Version: 0.5.2
3
+ Version: 0.9.1
4
4
  Summary: Variance-based feature importance for Neural Networks using callbacks for Keras and PyTorch
5
5
  Author: CR de Sá
6
6
  Requires-Python: >=3.10
7
7
  Description-Content-Type: text/markdown
8
+ License-File: LICENSE
8
9
  Requires-Dist: numpy
9
10
  Provides-Extra: tensorflow
10
11
  Requires-Dist: tensorflow; extra == "tensorflow"
11
12
  Provides-Extra: torch
12
13
  Requires-Dist: torch; extra == "torch"
14
+ Dynamic: license-file
13
15
 
14
16
  # neural-feature-importance
15
17
 
16
18
  [![PyPI version](https://img.shields.io/pypi/v/neural-feature-importance.svg)](https://pypi.org/project/neural-feature-importance/)
17
- [![Python versions](https://img.shields.io/pypi/pyversions/neural-feature-importance.svg)](https://pypi.org/project/neural-feature-importance/)
19
+ [![Python 3.10+](https://img.shields.io/badge/python-3.10%2B-blue)](https://www.python.org/downloads/)
20
+ [![License: MIT](https://img.shields.io/badge/license-MIT-green)](LICENSE)
18
21
 
19
22
  Variance-based feature importance for deep learning models.
20
23
 
@@ -74,19 +77,44 @@ print(tracker.feature_importances_)
74
77
 
75
78
  ## Example scripts
76
79
 
77
- Run `compare_feature_importance.py` to train a small network on the Iris dataset
80
+ Run `scripts/compare_feature_importance.py` to train a small network on the Iris dataset
78
81
  and compare the scores with a random forest baseline:
79
82
 
80
83
  ```bash
81
84
  python compare_feature_importance.py
82
85
  ```
83
86
 
84
- Run `full_experiment.py` to reproduce the experiments from the paper:
87
+ Run `scripts/full_experiment.py` to reproduce the experiments from the paper:
85
88
 
86
89
  ```bash
87
90
  python full_experiment.py
88
91
  ```
89
92
 
93
+ ### Convolutional models
94
+
95
+ To compute importances for convolutional networks, use
96
+ `ConvVarianceImportanceKeras` from `neural_feature_importance.conv_callbacks`.
97
+ `scripts/conv_visualization_example.py` trains small Conv2D models on the MNIST
98
+ and scikit‑learn digits datasets and displays per-filter heatmaps. An equivalent
99
+ notebook is available in ``notebooks/conv_visualization_example.ipynb``:
100
+
101
+ ```bash
102
+ python scripts/conv_visualization_example.py
103
+ ```
104
+
105
+ ### Embedding layers
106
+
107
+ To compute token importances from embedding weights, use
108
+ `EmbeddingVarianceImportanceKeras` or `EmbeddingVarianceImportanceTorch` from
109
+ `neural_feature_importance.embedding_callbacks`.
110
+ Run `scripts/token_importance_topk_example.py` to train a small text classifier
111
+ on IMDB and display the most important tokens. A matching notebook lives in
112
+ ``notebooks/token_importance_topk_example.ipynb``:
113
+
114
+ ```bash
115
+ python scripts/token_importance_topk_example.py
116
+ ```
117
+
90
118
  ## Development
91
119
 
92
120
  After making changes, run the following checks:
@@ -124,3 +152,7 @@ If you use this package in your research, please cite:
124
152
  ```
125
153
 
126
154
  We appreciate citations as they help the community discover this work.
155
+
156
+ ## License
157
+
158
+ This project is licensed under the [MIT License](LICENSE).
@@ -0,0 +1,27 @@
1
+ AGENTS.md
2
+ LICENSE
3
+ README.md
4
+ conv_visualization_example.py
5
+ conv_viz_utils.py
6
+ pyproject.toml
7
+ text_classification_example.py
8
+ token_topk_utils.py
9
+ .github/workflows/python-publish.yml
10
+ neural_feature_importance/__init__.py
11
+ neural_feature_importance/callbacks.py
12
+ neural_feature_importance/conv_callbacks.py
13
+ neural_feature_importance/embedding_callbacks.py
14
+ neural_feature_importance.egg-info/PKG-INFO
15
+ neural_feature_importance.egg-info/SOURCES.txt
16
+ neural_feature_importance.egg-info/dependency_links.txt
17
+ neural_feature_importance.egg-info/requires.txt
18
+ neural_feature_importance.egg-info/top_level.txt
19
+ neural_feature_importance/utils/__init__.py
20
+ neural_feature_importance/utils/monitors.py
21
+ notebooks/conv_visualization_example.ipynb
22
+ notebooks/token_importance_topk_example.ipynb
23
+ notebooks/variance-based feature importance in artificial neural networks.ipynb
24
+ scripts/compare_feature_importance.py
25
+ scripts/conv_visualization_example.py
26
+ scripts/full_experiment.py
27
+ scripts/token_importance_topk_example.py
@@ -0,0 +1 @@
1
+ {"cells": [{"cell_type": "markdown", "metadata": {}, "source": "# Convolutional Filter Visualization"}, {"cell_type": "code", "metadata": {}, "source": "\nimport numpy as np\nfrom tensorflow.keras.datasets import mnist\nfrom tensorflow.keras.utils import to_categorical\nfrom sklearn.datasets import load_digits\nfrom sklearn.model_selection import train_test_split\nfrom tensorflow.keras.models import Sequential\n\nfrom neural_feature_importance.conv_callbacks import ConvVarianceImportanceKeras\nfrom conv_viz_utils import build_model, rank_filters, plot_filters, accuracy_with_filters\n\ndef load_mnist():\n (x_train, y_train), (x_test, y_test) = mnist.load_data()\n x_train = x_train.astype(\"float32\") / 255.0\n x_test = x_test.astype(\"float32\") / 255.0\n x_train = x_train[..., None]\n x_test = x_test[..., None]\n y_train = to_categorical(y_train, 10)\n y_test = to_categorical(y_test, 10)\n return (x_train, y_train), (x_test, y_test), (28, 28, 1), (8, 8)\n\ndef load_digits_data():\n digits = load_digits()\n x = digits.images[..., None] / 16.0\n y = to_categorical(digits.target, 10)\n x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=42)\n return (x_train, y_train), (x_test, y_test), (8, 8, 1), (3, 3)\n\nDATASETS = {\n 'mnist': load_mnist,\n 'digits': load_digits_data,\n}\n\ndef run_dataset(loader):\n (x_train, y_train), (x_test, y_test), input_shape, kernel_size = loader()\n model = build_model(input_shape, kernel_size)\n callback = ConvVarianceImportanceKeras()\n model.fit(x_train, y_train, epochs=5, batch_size=32, callbacks=[callback], verbose=1)\n\n scores = callback.feature_importances_\n weights = model.layers[0].get_weights()[0]\n heatmap = scores.reshape(weights.shape[:3])\n order = rank_filters(weights, heatmap, 0.0)\n\n conv_model = Sequential([model.layers[0]])\n example_out = conv_model.predict(x_test[:1], verbose=0)[0]\n plot_filters(weights, heatmap, example_out, order)\n\n for k in (2, 4, 6):\n acc = accuracy_with_filters(model, x_test, y_test, order[:k])\n print(f'Accuracy with top {k} filters:', acc)\n\nfor name, loader in DATASETS.items():\n print('Running', name)\n run_dataset(loader)\n"}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
@@ -0,0 +1,9 @@
1
+ {
2
+ "cells": [
3
+ {"cell_type": "markdown", "metadata": {}, "source": "# Token Importance Top-k Example"},
4
+ {"cell_type": "code", "metadata": {}, "source": "from tensorflow.keras.datasets import imdb\nfrom tensorflow.keras.preprocessing.sequence import pad_sequences\n\nfrom neural_feature_importance.embedding_callbacks import EmbeddingVarianceImportanceKeras\nfrom token_topk_utils import analyze_samples, build_model\n\nMAX_FEATURES = 5000\nMAX_LEN = 400\nTOP_K = 5\nNUM_SAMPLES = 5\n\n(x_train, y_train), _ = imdb.load_data(num_words=MAX_FEATURES)\nx_train = pad_sequences(x_train, maxlen=MAX_LEN)\n\nmodel = build_model()\ncallback = EmbeddingVarianceImportanceKeras()\nmodel.fit(x_train, y_train, epochs=2, batch_size=128, callbacks=[callback], verbose=0)\n\nscores = callback.feature_importances_\nword_index = imdb.get_word_index()\nindex_word = {v + 3: k for k, v in word_index.items()}\nindex_word[0] = \"<PAD>\"\nindex_word[1] = \"<START>\"\nindex_word[2] = \"<UNK>\"\nindex_word[3] = \"<UNUSED>\"\n\nanalyze_samples(x_train, y_train, scores, index_word, NUM_SAMPLES, TOP_K)"}
5
+ ],
6
+ "metadata": {},
7
+ "nbformat": 4,
8
+ "nbformat_minor": 5
9
+ }
@@ -15,3 +15,6 @@ requires-python = ">=3.10"
15
15
  tensorflow = ["tensorflow"]
16
16
  torch = ["torch"]
17
17
  [tool.setuptools_scm]
18
+
19
+ [tool.setuptools]
20
+ packages = ["neural_feature_importance"]
@@ -0,0 +1,114 @@
1
+ """Visualize variance-based filter importances on several datasets."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ from typing import Callable, Tuple
7
+
8
+ import numpy as np
9
+ from tensorflow.keras.datasets import mnist
10
+ from tensorflow.keras.utils import to_categorical
11
+ from sklearn.datasets import load_digits
12
+ from sklearn.model_selection import train_test_split
13
+ from tensorflow.keras.models import Sequential
14
+
15
+ from neural_feature_importance.conv_callbacks import ConvVarianceImportanceKeras
16
+ from conv_viz_utils import build_model, rank_filters, plot_filters, accuracy_with_filters
17
+
18
+ logging.basicConfig(level=logging.INFO)
19
+ logger = logging.getLogger(__name__)
20
+
21
+ DatasetLoader = Callable[
22
+ [],
23
+ Tuple[
24
+ Tuple[np.ndarray, np.ndarray],
25
+ Tuple[np.ndarray, np.ndarray],
26
+ tuple[int, int, int],
27
+ tuple[int, int],
28
+ ],
29
+ ]
30
+
31
+
32
+ def load_mnist() -> Tuple[
33
+ Tuple[np.ndarray, np.ndarray],
34
+ Tuple[np.ndarray, np.ndarray],
35
+ tuple[int, int, int],
36
+ tuple[int, int],
37
+ ]:
38
+ """Return MNIST data and model parameters."""
39
+ (x_train, y_train), (x_test, y_test) = mnist.load_data()
40
+ x_train = x_train.astype("float32") / 255.0
41
+ x_test = x_test.astype("float32") / 255.0
42
+ x_train = x_train[..., None]
43
+ x_test = x_test[..., None]
44
+ y_train = to_categorical(y_train, 10)
45
+ y_test = to_categorical(y_test, 10)
46
+ return (x_train, y_train), (x_test, y_test), (28, 28, 1), (8, 8)
47
+
48
+
49
+ def load_digits_data() -> Tuple[
50
+ Tuple[np.ndarray, np.ndarray],
51
+ Tuple[np.ndarray, np.ndarray],
52
+ tuple[int, int, int],
53
+ tuple[int, int],
54
+ ]:
55
+ """Return scikit-learn digits data and model parameters."""
56
+ digits = load_digits()
57
+ x = digits.images[..., None] / 16.0
58
+ y = to_categorical(digits.target, 10)
59
+ x_train, x_test, y_train, y_test = train_test_split(
60
+ x, y, test_size=0.2, random_state=42
61
+ )
62
+ return (x_train, y_train), (x_test, y_test), (8, 8, 1), (3, 3)
63
+
64
+
65
+ DATASETS: dict[str, DatasetLoader] = {
66
+ "mnist": load_mnist,
67
+ "digits": load_digits_data,
68
+ }
69
+
70
+
71
+ def run_dataset(name: str, loader: DatasetLoader, threshold: float = 0.0) -> None:
72
+ """Train a model on the given dataset and display filter importances."""
73
+ (x_train, y_train), (x_test, y_test), input_shape, kernel_size = loader()
74
+
75
+ model = build_model(input_shape, kernel_size)
76
+ callback = ConvVarianceImportanceKeras()
77
+ model.fit(
78
+ x_train,
79
+ y_train,
80
+ epochs=5,
81
+ batch_size=32,
82
+ callbacks=[callback],
83
+ verbose=1,
84
+ )
85
+
86
+ scores = callback.feature_importances_
87
+ if scores is None:
88
+ logger.warning("No importance scores computed for %s", name)
89
+ return
90
+
91
+ weights = model.layers[0].get_weights()[0]
92
+ heatmap = scores.reshape(weights.shape[:3])
93
+ order = rank_filters(weights, heatmap, threshold)
94
+
95
+ conv_model = Sequential([model.layers[0]])
96
+ example_out = conv_model.predict(x_test[:1], verbose=0)[0]
97
+ plot_filters(weights, heatmap, example_out, order)
98
+
99
+ results = {}
100
+ for k in (2, 4, 6):
101
+ acc = accuracy_with_filters(model, x_test, y_test, order[:k])
102
+ results[k] = acc
103
+ logger.info("Accuracy with top filters on %s: %s", name, results)
104
+
105
+
106
+ def main() -> None:
107
+ """Run visualization on all configured datasets."""
108
+ for name, loader in DATASETS.items():
109
+ logger.info("Running visualization for %s", name)
110
+ run_dataset(name, loader)
111
+
112
+
113
+ if __name__ == "__main__":
114
+ main()
@@ -0,0 +1,50 @@
1
+ """Display top-k token importances for several IMDB samples."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+
7
+ from tensorflow.keras.datasets import imdb
8
+ from tensorflow.keras.preprocessing.sequence import pad_sequences
9
+
10
+ from neural_feature_importance.embedding_callbacks import (
11
+ EmbeddingVarianceImportanceKeras,
12
+ )
13
+
14
+ from token_topk_utils import analyze_samples, build_model
15
+
16
+ logging.basicConfig(level=logging.INFO)
17
+ logger = logging.getLogger(__name__)
18
+
19
+ MAX_FEATURES = 5000
20
+ MAX_LEN = 400
21
+ TOP_K = 5
22
+ NUM_SAMPLES = 5
23
+
24
+
25
+ def main() -> None:
26
+ """Train a model and log top tokens for a few samples."""
27
+ (x_train, y_train), _ = imdb.load_data(num_words=MAX_FEATURES)
28
+ x_train = pad_sequences(x_train, maxlen=MAX_LEN)
29
+
30
+ model = build_model()
31
+ callback = EmbeddingVarianceImportanceKeras()
32
+ model.fit(x_train, y_train, epochs=2, batch_size=128, callbacks=[callback], verbose=0)
33
+
34
+ scores = callback.feature_importances_
35
+ if scores is None:
36
+ logger.warning("No importance scores computed.")
37
+ return
38
+
39
+ word_index = imdb.get_word_index()
40
+ index_word = {v + 3: k for k, v in word_index.items()}
41
+ index_word[0] = "<PAD>"
42
+ index_word[1] = "<START>"
43
+ index_word[2] = "<UNK>"
44
+ index_word[3] = "<UNUSED>"
45
+
46
+ analyze_samples(x_train, y_train, scores, index_word, NUM_SAMPLES, TOP_K)
47
+
48
+
49
+ if __name__ == "__main__":
50
+ main()
@@ -0,0 +1,66 @@
1
+ """Example of variance-based importance with text classification."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ from typing import Tuple
7
+
8
+ import matplotlib.pyplot as plt
9
+ from tensorflow.keras.datasets import imdb
10
+ from tensorflow.keras.layers import Conv1D, Dense, Embedding, GlobalMaxPooling1D
11
+ from tensorflow.keras.models import Sequential
12
+ from tensorflow.keras.preprocessing.sequence import pad_sequences
13
+
14
+ from neural_feature_importance.conv_callbacks import ConvVarianceImportanceKeras
15
+
16
+ logging.basicConfig(level=logging.INFO)
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ MAX_FEATURES = 5000
21
+ MAX_LEN = 400
22
+
23
+
24
+ def load_data() -> Tuple[tuple, tuple]:
25
+ """Return padded IMDB data."""
26
+ (x_train, y_train), _ = imdb.load_data(num_words=MAX_FEATURES)
27
+ x_train = pad_sequences(x_train, maxlen=MAX_LEN)
28
+ return (x_train, y_train), _
29
+
30
+
31
+ def build_model() -> Sequential:
32
+ """Return a small Conv1D model."""
33
+ model = Sequential(
34
+ [
35
+ Embedding(MAX_FEATURES, 32, input_length=MAX_LEN, trainable=False),
36
+ Conv1D(16, 5, activation="relu"),
37
+ GlobalMaxPooling1D(),
38
+ Dense(1, activation="sigmoid"),
39
+ ]
40
+ )
41
+ model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"])
42
+ return model
43
+
44
+
45
+ def main() -> None:
46
+ """Train the model and plot a heatmap of importances."""
47
+ (x_train, y_train), _ = load_data()
48
+ model = build_model()
49
+ callback = ConvVarianceImportanceKeras()
50
+ model.fit(x_train, y_train, epochs=2, batch_size=128, callbacks=[callback], verbose=0)
51
+
52
+ scores = callback.feature_importances_
53
+ if scores is None:
54
+ logger.warning("No importance scores computed.")
55
+ return
56
+
57
+ weights = model.layers[1].get_weights()[0]
58
+ heatmap = scores.reshape(weights.shape[0], weights.shape[1])
59
+ plt.imshow(heatmap, aspect="auto", cmap="hot")
60
+ plt.colorbar()
61
+ plt.title("Conv1D feature importance")
62
+ plt.show()
63
+
64
+
65
+ if __name__ == "__main__":
66
+ main()
@@ -0,0 +1,124 @@
1
+ """Utilities for analyzing top-k token importances.
2
+
3
+ These helpers build a small text classification model and provide functions to
4
+ decode token sequences and print tables of the most important tokens according
5
+ to variance-based scores.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import logging
11
+ from collections import Counter
12
+ from typing import Iterable, List
13
+
14
+ import numpy as np
15
+ from tensorflow.keras.layers import (
16
+ Conv1D,
17
+ Dense,
18
+ Embedding,
19
+ GlobalMaxPooling1D,
20
+ )
21
+ from tensorflow.keras.models import Sequential
22
+
23
+ logging.basicConfig(level=logging.INFO)
24
+ logger = logging.getLogger(__name__)
25
+
26
+ MAX_FEATURES = 5000
27
+ MAX_LEN = 400
28
+
29
+
30
+ def build_model() -> Sequential:
31
+ """Return a small text classification model.
32
+
33
+ The model consists of an embedding layer followed by a single convolution
34
+ and a global max pooling operation. It is intentionally tiny so it can be
35
+ trained quickly on the IMDB dataset for demonstration purposes.
36
+ """
37
+ model = Sequential(
38
+ [
39
+ Embedding(MAX_FEATURES, 32, input_length=MAX_LEN, trainable=True),
40
+ Conv1D(1, 5, activation="relu"),
41
+ GlobalMaxPooling1D(),
42
+ Dense(1, activation="sigmoid"),
43
+ ]
44
+ )
45
+ model.compile(
46
+ optimizer="adam",
47
+ loss="binary_crossentropy",
48
+ metrics=["accuracy"],
49
+ )
50
+ model.build((None, MAX_LEN))
51
+ return model
52
+
53
+
54
+ def decode_review(tokens: Iterable[int], index_word: dict[int, str]) -> str:
55
+ """Return a readable string for the given token sequence."""
56
+ words = [index_word.get(t, "?") for t in tokens if t]
57
+ return " ".join(words)
58
+
59
+
60
+ def summarize_top_tokens(
61
+ tokens: Iterable[int],
62
+ scores: np.ndarray,
63
+ index_word: dict[int, str],
64
+ k: int,
65
+ ) -> str:
66
+ """Return a table with the top ``k`` unique tokens and their counts.
67
+
68
+ Parameters
69
+ ----------
70
+ tokens:
71
+ Token ids representing a review.
72
+ scores:
73
+ Array of token importance scores obtained from the embedding callback.
74
+ index_word:
75
+ Mapping from token id to the corresponding word.
76
+ k:
77
+ Number of tokens to include in the table.
78
+ """
79
+ ignore = {"<PAD>", "<START>", "<UNK>", "<UNUSED>"}
80
+ totals: dict[int, float] = {}
81
+ counts: Counter[int] = Counter()
82
+ for t in tokens:
83
+ if index_word.get(t) in ignore:
84
+ continue
85
+ if t < len(scores):
86
+ totals[t] = totals.get(t, 0.0) + float(scores[t])
87
+ counts[t] += 1
88
+ ordered = sorted(totals.items(), key=lambda kv: kv[1], reverse=True)[:k]
89
+ headers = ["Token", "Count", "Score"]
90
+ rows: List[tuple[str, int, str]] = []
91
+ for token_id, score in ordered:
92
+ token = index_word.get(token_id, "?")
93
+ if token in ignore:
94
+ continue
95
+ rows.append((token, counts[token_id], f"{score:.3f}"))
96
+ table_lines = [" | ".join(headers)]
97
+ table_lines.append("-|-".join("-" * len(h) for h in headers))
98
+ for row in rows:
99
+ table_lines.append(" | ".join(str(x) for x in row))
100
+ return "\n".join(table_lines)
101
+
102
+
103
+ def analyze_samples(
104
+ x: np.ndarray,
105
+ y: np.ndarray,
106
+ scores: np.ndarray,
107
+ index_word: dict[int, str],
108
+ num_samples: int = 5,
109
+ k: int = 5,
110
+ ) -> None:
111
+ """Log original text and the most important tokens for several samples.
112
+
113
+ Each selected sample is decoded to text, its label is printed, and a table
114
+ of the top ``k`` tokens by importance score is logged.
115
+ """
116
+ for i in range(min(num_samples, len(x))):
117
+ tokens = x[i].tolist()
118
+ label = "positive" if y[i] == 1 else "negative"
119
+ text = decode_review(tokens, index_word)
120
+ table = summarize_top_tokens(tokens, scores, index_word, k)
121
+ logger.info("Example %d - class: %s", i, label)
122
+ logger.info("%s", text)
123
+ logger.info("Top tokens:\n%s", table)
124
+
@@ -1,16 +0,0 @@
1
- AGENTS.md
2
- README.md
3
- compare_feature_importance.py
4
- full_experiment.py
5
- pyproject.toml
6
- variance-based feature importance in artificial neural networks.ipynb
7
- .github/workflows/python-publish.yml
8
- neural_feature_importance/__init__.py
9
- neural_feature_importance/callbacks.py
10
- neural_feature_importance.egg-info/PKG-INFO
11
- neural_feature_importance.egg-info/SOURCES.txt
12
- neural_feature_importance.egg-info/dependency_links.txt
13
- neural_feature_importance.egg-info/requires.txt
14
- neural_feature_importance.egg-info/top_level.txt
15
- neural_feature_importance/utils/__init__.py
16
- neural_feature_importance/utils/monitors.py