neural-feature-importance 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,22 @@
1
+ """Utilities for variance-based feature importance in neural networks."""
2
+
3
+ from importlib import metadata
4
+
5
+ from .callbacks import (
6
+ VarianceImportanceBase,
7
+ VarianceImportanceKeras,
8
+ VarianceImportanceTorch,
9
+ )
10
+ from .utils import MetricThreshold
11
+
12
+ try:
13
+ __version__ = metadata.version("neural-feature-importance")
14
+ except metadata.PackageNotFoundError: # pragma: no cover - package not installed
15
+ __version__ = "0.0.dev0"
16
+
17
+ __all__ = [
18
+ "VarianceImportanceBase",
19
+ "VarianceImportanceKeras",
20
+ "VarianceImportanceTorch",
21
+ "MetricThreshold",
22
+ ]
@@ -0,0 +1,142 @@
1
+ """Variance-based feature importance utilities."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ from typing import Optional
7
+
8
+ import numpy as np
9
+ from tensorflow.keras.callbacks import Callback
10
+ from tensorflow.keras.layers import Layer
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class VarianceImportanceBase:
16
+ """Compute feature importance using Welford's algorithm."""
17
+
18
+ def __init__(self) -> None:
19
+ self._n = 0
20
+ self._mean: np.ndarray | None = None
21
+ self._m2: np.ndarray | None = None
22
+ self._last_weights: np.ndarray | None = None
23
+ self.var_scores: np.ndarray | None = None
24
+
25
+ def start(self, weights: np.ndarray) -> None:
26
+ """Initialize statistics for the given weight matrix."""
27
+ self._mean = weights.astype(np.float64)
28
+ self._m2 = np.zeros_like(self._mean)
29
+ self._n = 0
30
+
31
+ def update(self, weights: np.ndarray) -> None:
32
+ """Update running statistics with new weights."""
33
+ if self._mean is None or self._m2 is None:
34
+ return
35
+ self._n += 1
36
+ delta = weights - self._mean
37
+ self._mean += delta / self._n
38
+ delta2 = weights - self._mean
39
+ self._m2 += delta * delta2
40
+ self._last_weights = weights
41
+
42
+ def finalize(self) -> None:
43
+ """Finalize statistics and compute normalized scores."""
44
+ if self._last_weights is None or self._m2 is None:
45
+ logger.warning(
46
+ "%s was not fully initialized; no scores computed", self.__class__.__name__
47
+ )
48
+ return
49
+
50
+ if self._n < 2:
51
+ variance = np.full_like(self._m2, np.nan)
52
+ else:
53
+ variance = self._m2 / (self._n - 1)
54
+
55
+ scores = np.sum(variance * np.abs(self._last_weights), axis=1)
56
+ min_val = float(np.min(scores))
57
+ max_val = float(np.max(scores))
58
+ denom = max_val - min_val if max_val != min_val else 1.0
59
+ self.var_scores = (scores - min_val) / denom
60
+
61
+ top = np.argsort(self.var_scores)[-10:][::-1]
62
+ logger.info("Most important variables: %s", top)
63
+
64
+ @property
65
+ def feature_importances_(self) -> np.ndarray | None:
66
+ """Normalized importance scores for each input feature."""
67
+ return self.var_scores
68
+
69
+
70
+ class VarianceImportanceKeras(Callback, VarianceImportanceBase):
71
+ """Keras callback implementing variance-based feature importance."""
72
+
73
+ def __init__(self) -> None:
74
+ Callback.__init__(self)
75
+ VarianceImportanceBase.__init__(self)
76
+ self._layer: Optional[Layer] = None
77
+
78
+ def on_train_begin(self, logs: Optional[dict] = None) -> None: # type: ignore[override]
79
+ self._layer = None
80
+ for layer in self.model.layers:
81
+ if layer.get_weights():
82
+ self._layer = layer
83
+ break
84
+ if self._layer is None:
85
+ raise ValueError("Model does not contain trainable weights.")
86
+ weights = self._layer.get_weights()[0]
87
+ logger.info(
88
+ "Tracking variance for layer '%s' with %d features",
89
+ self._layer.name,
90
+ weights.shape[0],
91
+ )
92
+ self.start(weights)
93
+
94
+ def on_epoch_end(self, epoch: int, logs: Optional[dict] = None) -> None: # type: ignore[override]
95
+ if self._layer is None:
96
+ return
97
+ weights = self._layer.get_weights()[0]
98
+ self.update(weights)
99
+
100
+ def on_train_end(self, logs: Optional[dict] = None) -> None: # type: ignore[override]
101
+ self.finalize()
102
+
103
+ def get_config(self) -> dict[str, int]:
104
+ """Return configuration for serialization."""
105
+ return {}
106
+
107
+
108
+ class VarianceImportanceTorch(VarianceImportanceBase):
109
+ """Track variance-based feature importance for PyTorch models."""
110
+
111
+ def __init__(self, model: "nn.Module") -> None:
112
+ from torch import nn # Local import to avoid hard dependency
113
+
114
+ super().__init__()
115
+ self.model = model
116
+ self._param: nn.Parameter | None = None
117
+
118
+ def on_train_begin(self) -> None:
119
+ from torch import nn
120
+
121
+ for name, param in self.model.named_parameters():
122
+ if param.requires_grad and param.dim() >= 2:
123
+ self._param = param
124
+ weights = param.detach().cpu().numpy()
125
+ logger.info(
126
+ "Tracking variance for parameter '%s' with %d features",
127
+ name,
128
+ weights.shape[1],
129
+ )
130
+ self.start(weights)
131
+ break
132
+ if self._param is None:
133
+ raise ValueError("Model does not contain trainable parameters")
134
+
135
+ def on_epoch_end(self) -> None:
136
+ if self._param is None:
137
+ return
138
+ weights = self._param.detach().cpu().numpy()
139
+ self.update(weights)
140
+
141
+ def on_train_end(self) -> None:
142
+ self.finalize()
@@ -0,0 +1,5 @@
1
+ """Utility callbacks for model training."""
2
+
3
+ from .monitors import MetricThreshold
4
+
5
+ __all__ = ["MetricThreshold"]
@@ -0,0 +1,54 @@
1
+ """Training utilities for keras models."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ from typing import Optional
7
+
8
+ from tensorflow.keras.callbacks import Callback
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class MetricThreshold(Callback):
14
+ """Stop training when a metric exceeds a given threshold.
15
+
16
+ Parameters
17
+ ----------
18
+ monitor:
19
+ Name of the metric to monitor (e.g. ``"val_accuracy"`` or ``"loss"``).
20
+ threshold:
21
+ Value that the metric must reach to trigger early stopping.
22
+ min_epochs:
23
+ Minimum number of epochs before stopping is allowed.
24
+ """
25
+
26
+ def __init__(self, monitor: str = "val_accuracy", threshold: float | None = None, min_epochs: int = 5) -> None:
27
+ super().__init__()
28
+ self.monitor = monitor
29
+ self.threshold = threshold
30
+ self.min_epochs = min_epochs
31
+ self.stopped_epoch = 0
32
+
33
+ def on_epoch_end(self, epoch: int, logs: Optional[dict] = None) -> None: # type: ignore[override]
34
+ logs = logs or {}
35
+ metric = logs.get(self.monitor)
36
+ if (
37
+ metric is not None
38
+ and self.threshold is not None
39
+ and epoch + 1 >= self.min_epochs
40
+ and metric >= self.threshold
41
+ ):
42
+ self.stopped_epoch = epoch + 1
43
+ self.model.stop_training = True
44
+ logger.info(
45
+ "MetricThreshold: stopped at epoch %d with %s=%.4f (threshold %.4f)",
46
+ self.stopped_epoch,
47
+ self.monitor,
48
+ metric,
49
+ self.threshold,
50
+ )
51
+
52
+ def on_train_end(self, logs: Optional[dict] = None) -> None: # type: ignore[override]
53
+ if self.stopped_epoch:
54
+ logger.info("Training stopped at epoch %d", self.stopped_epoch)
@@ -0,0 +1,10 @@
1
+ Metadata-Version: 2.4
2
+ Name: neural-feature-importance
3
+ Version: 0.5.0
4
+ Summary: Variance-based feature importance for Neural Networks using callbacks for Keras and PyTorch
5
+ Author: CR de Sá
6
+ Requires-Dist: numpy
7
+ Provides-Extra: tensorflow
8
+ Requires-Dist: tensorflow; extra == "tensorflow"
9
+ Provides-Extra: torch
10
+ Requires-Dist: torch; extra == "torch"
@@ -0,0 +1,8 @@
1
+ neural_feature_importance/__init__.py,sha256=z3Rve0a7QTAhEpCesDhSdbkOwfSsNZgiwDVep1Is_c0,566
2
+ neural_feature_importance/callbacks.py,sha256=HMHsmVaqZOzy5NSbxN-8CWvq82vzgZgZD53zqp2nAz0,4811
3
+ neural_feature_importance/utils/__init__.py,sha256=dMjBUCx8DCoJKAEAnjj_daXfEu9Q5va1k8XupmWdZiE,114
4
+ neural_feature_importance/utils/monitors.py,sha256=LTz7oE0-WgZ50DHyHDnTwfzWSSWMnjWd0xlwt7BWKuU,1763
5
+ neural_feature_importance-0.5.0.dist-info/METADATA,sha256=Uo5R11NK7P-XXpeu3VWIl7I1w5K0l_30y0gp_J5GR-k,346
6
+ neural_feature_importance-0.5.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
7
+ neural_feature_importance-0.5.0.dist-info/top_level.txt,sha256=yP0Q-BG7hDLLu1H1_x5bGEKwkCso5NxxvScnlmICb-o,26
8
+ neural_feature_importance-0.5.0.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (80.9.0)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1 @@
1
+ neural_feature_importance