neural-feature-importance 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- neural_feature_importance/__init__.py +22 -0
- neural_feature_importance/callbacks.py +142 -0
- neural_feature_importance/utils/__init__.py +5 -0
- neural_feature_importance/utils/monitors.py +54 -0
- neural_feature_importance-0.5.0.dist-info/METADATA +10 -0
- neural_feature_importance-0.5.0.dist-info/RECORD +8 -0
- neural_feature_importance-0.5.0.dist-info/WHEEL +5 -0
- neural_feature_importance-0.5.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,22 @@
|
|
1
|
+
"""Utilities for variance-based feature importance in neural networks."""
|
2
|
+
|
3
|
+
from importlib import metadata
|
4
|
+
|
5
|
+
from .callbacks import (
|
6
|
+
VarianceImportanceBase,
|
7
|
+
VarianceImportanceKeras,
|
8
|
+
VarianceImportanceTorch,
|
9
|
+
)
|
10
|
+
from .utils import MetricThreshold
|
11
|
+
|
12
|
+
try:
|
13
|
+
__version__ = metadata.version("neural-feature-importance")
|
14
|
+
except metadata.PackageNotFoundError: # pragma: no cover - package not installed
|
15
|
+
__version__ = "0.0.dev0"
|
16
|
+
|
17
|
+
__all__ = [
|
18
|
+
"VarianceImportanceBase",
|
19
|
+
"VarianceImportanceKeras",
|
20
|
+
"VarianceImportanceTorch",
|
21
|
+
"MetricThreshold",
|
22
|
+
]
|
@@ -0,0 +1,142 @@
|
|
1
|
+
"""Variance-based feature importance utilities."""
|
2
|
+
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
import logging
|
6
|
+
from typing import Optional
|
7
|
+
|
8
|
+
import numpy as np
|
9
|
+
from tensorflow.keras.callbacks import Callback
|
10
|
+
from tensorflow.keras.layers import Layer
|
11
|
+
|
12
|
+
logger = logging.getLogger(__name__)
|
13
|
+
|
14
|
+
|
15
|
+
class VarianceImportanceBase:
|
16
|
+
"""Compute feature importance using Welford's algorithm."""
|
17
|
+
|
18
|
+
def __init__(self) -> None:
|
19
|
+
self._n = 0
|
20
|
+
self._mean: np.ndarray | None = None
|
21
|
+
self._m2: np.ndarray | None = None
|
22
|
+
self._last_weights: np.ndarray | None = None
|
23
|
+
self.var_scores: np.ndarray | None = None
|
24
|
+
|
25
|
+
def start(self, weights: np.ndarray) -> None:
|
26
|
+
"""Initialize statistics for the given weight matrix."""
|
27
|
+
self._mean = weights.astype(np.float64)
|
28
|
+
self._m2 = np.zeros_like(self._mean)
|
29
|
+
self._n = 0
|
30
|
+
|
31
|
+
def update(self, weights: np.ndarray) -> None:
|
32
|
+
"""Update running statistics with new weights."""
|
33
|
+
if self._mean is None or self._m2 is None:
|
34
|
+
return
|
35
|
+
self._n += 1
|
36
|
+
delta = weights - self._mean
|
37
|
+
self._mean += delta / self._n
|
38
|
+
delta2 = weights - self._mean
|
39
|
+
self._m2 += delta * delta2
|
40
|
+
self._last_weights = weights
|
41
|
+
|
42
|
+
def finalize(self) -> None:
|
43
|
+
"""Finalize statistics and compute normalized scores."""
|
44
|
+
if self._last_weights is None or self._m2 is None:
|
45
|
+
logger.warning(
|
46
|
+
"%s was not fully initialized; no scores computed", self.__class__.__name__
|
47
|
+
)
|
48
|
+
return
|
49
|
+
|
50
|
+
if self._n < 2:
|
51
|
+
variance = np.full_like(self._m2, np.nan)
|
52
|
+
else:
|
53
|
+
variance = self._m2 / (self._n - 1)
|
54
|
+
|
55
|
+
scores = np.sum(variance * np.abs(self._last_weights), axis=1)
|
56
|
+
min_val = float(np.min(scores))
|
57
|
+
max_val = float(np.max(scores))
|
58
|
+
denom = max_val - min_val if max_val != min_val else 1.0
|
59
|
+
self.var_scores = (scores - min_val) / denom
|
60
|
+
|
61
|
+
top = np.argsort(self.var_scores)[-10:][::-1]
|
62
|
+
logger.info("Most important variables: %s", top)
|
63
|
+
|
64
|
+
@property
|
65
|
+
def feature_importances_(self) -> np.ndarray | None:
|
66
|
+
"""Normalized importance scores for each input feature."""
|
67
|
+
return self.var_scores
|
68
|
+
|
69
|
+
|
70
|
+
class VarianceImportanceKeras(Callback, VarianceImportanceBase):
|
71
|
+
"""Keras callback implementing variance-based feature importance."""
|
72
|
+
|
73
|
+
def __init__(self) -> None:
|
74
|
+
Callback.__init__(self)
|
75
|
+
VarianceImportanceBase.__init__(self)
|
76
|
+
self._layer: Optional[Layer] = None
|
77
|
+
|
78
|
+
def on_train_begin(self, logs: Optional[dict] = None) -> None: # type: ignore[override]
|
79
|
+
self._layer = None
|
80
|
+
for layer in self.model.layers:
|
81
|
+
if layer.get_weights():
|
82
|
+
self._layer = layer
|
83
|
+
break
|
84
|
+
if self._layer is None:
|
85
|
+
raise ValueError("Model does not contain trainable weights.")
|
86
|
+
weights = self._layer.get_weights()[0]
|
87
|
+
logger.info(
|
88
|
+
"Tracking variance for layer '%s' with %d features",
|
89
|
+
self._layer.name,
|
90
|
+
weights.shape[0],
|
91
|
+
)
|
92
|
+
self.start(weights)
|
93
|
+
|
94
|
+
def on_epoch_end(self, epoch: int, logs: Optional[dict] = None) -> None: # type: ignore[override]
|
95
|
+
if self._layer is None:
|
96
|
+
return
|
97
|
+
weights = self._layer.get_weights()[0]
|
98
|
+
self.update(weights)
|
99
|
+
|
100
|
+
def on_train_end(self, logs: Optional[dict] = None) -> None: # type: ignore[override]
|
101
|
+
self.finalize()
|
102
|
+
|
103
|
+
def get_config(self) -> dict[str, int]:
|
104
|
+
"""Return configuration for serialization."""
|
105
|
+
return {}
|
106
|
+
|
107
|
+
|
108
|
+
class VarianceImportanceTorch(VarianceImportanceBase):
|
109
|
+
"""Track variance-based feature importance for PyTorch models."""
|
110
|
+
|
111
|
+
def __init__(self, model: "nn.Module") -> None:
|
112
|
+
from torch import nn # Local import to avoid hard dependency
|
113
|
+
|
114
|
+
super().__init__()
|
115
|
+
self.model = model
|
116
|
+
self._param: nn.Parameter | None = None
|
117
|
+
|
118
|
+
def on_train_begin(self) -> None:
|
119
|
+
from torch import nn
|
120
|
+
|
121
|
+
for name, param in self.model.named_parameters():
|
122
|
+
if param.requires_grad and param.dim() >= 2:
|
123
|
+
self._param = param
|
124
|
+
weights = param.detach().cpu().numpy()
|
125
|
+
logger.info(
|
126
|
+
"Tracking variance for parameter '%s' with %d features",
|
127
|
+
name,
|
128
|
+
weights.shape[1],
|
129
|
+
)
|
130
|
+
self.start(weights)
|
131
|
+
break
|
132
|
+
if self._param is None:
|
133
|
+
raise ValueError("Model does not contain trainable parameters")
|
134
|
+
|
135
|
+
def on_epoch_end(self) -> None:
|
136
|
+
if self._param is None:
|
137
|
+
return
|
138
|
+
weights = self._param.detach().cpu().numpy()
|
139
|
+
self.update(weights)
|
140
|
+
|
141
|
+
def on_train_end(self) -> None:
|
142
|
+
self.finalize()
|
@@ -0,0 +1,54 @@
|
|
1
|
+
"""Training utilities for keras models."""
|
2
|
+
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
import logging
|
6
|
+
from typing import Optional
|
7
|
+
|
8
|
+
from tensorflow.keras.callbacks import Callback
|
9
|
+
|
10
|
+
logger = logging.getLogger(__name__)
|
11
|
+
|
12
|
+
|
13
|
+
class MetricThreshold(Callback):
|
14
|
+
"""Stop training when a metric exceeds a given threshold.
|
15
|
+
|
16
|
+
Parameters
|
17
|
+
----------
|
18
|
+
monitor:
|
19
|
+
Name of the metric to monitor (e.g. ``"val_accuracy"`` or ``"loss"``).
|
20
|
+
threshold:
|
21
|
+
Value that the metric must reach to trigger early stopping.
|
22
|
+
min_epochs:
|
23
|
+
Minimum number of epochs before stopping is allowed.
|
24
|
+
"""
|
25
|
+
|
26
|
+
def __init__(self, monitor: str = "val_accuracy", threshold: float | None = None, min_epochs: int = 5) -> None:
|
27
|
+
super().__init__()
|
28
|
+
self.monitor = monitor
|
29
|
+
self.threshold = threshold
|
30
|
+
self.min_epochs = min_epochs
|
31
|
+
self.stopped_epoch = 0
|
32
|
+
|
33
|
+
def on_epoch_end(self, epoch: int, logs: Optional[dict] = None) -> None: # type: ignore[override]
|
34
|
+
logs = logs or {}
|
35
|
+
metric = logs.get(self.monitor)
|
36
|
+
if (
|
37
|
+
metric is not None
|
38
|
+
and self.threshold is not None
|
39
|
+
and epoch + 1 >= self.min_epochs
|
40
|
+
and metric >= self.threshold
|
41
|
+
):
|
42
|
+
self.stopped_epoch = epoch + 1
|
43
|
+
self.model.stop_training = True
|
44
|
+
logger.info(
|
45
|
+
"MetricThreshold: stopped at epoch %d with %s=%.4f (threshold %.4f)",
|
46
|
+
self.stopped_epoch,
|
47
|
+
self.monitor,
|
48
|
+
metric,
|
49
|
+
self.threshold,
|
50
|
+
)
|
51
|
+
|
52
|
+
def on_train_end(self, logs: Optional[dict] = None) -> None: # type: ignore[override]
|
53
|
+
if self.stopped_epoch:
|
54
|
+
logger.info("Training stopped at epoch %d", self.stopped_epoch)
|
@@ -0,0 +1,10 @@
|
|
1
|
+
Metadata-Version: 2.4
|
2
|
+
Name: neural-feature-importance
|
3
|
+
Version: 0.5.0
|
4
|
+
Summary: Variance-based feature importance for Neural Networks using callbacks for Keras and PyTorch
|
5
|
+
Author: CR de Sá
|
6
|
+
Requires-Dist: numpy
|
7
|
+
Provides-Extra: tensorflow
|
8
|
+
Requires-Dist: tensorflow; extra == "tensorflow"
|
9
|
+
Provides-Extra: torch
|
10
|
+
Requires-Dist: torch; extra == "torch"
|
@@ -0,0 +1,8 @@
|
|
1
|
+
neural_feature_importance/__init__.py,sha256=z3Rve0a7QTAhEpCesDhSdbkOwfSsNZgiwDVep1Is_c0,566
|
2
|
+
neural_feature_importance/callbacks.py,sha256=HMHsmVaqZOzy5NSbxN-8CWvq82vzgZgZD53zqp2nAz0,4811
|
3
|
+
neural_feature_importance/utils/__init__.py,sha256=dMjBUCx8DCoJKAEAnjj_daXfEu9Q5va1k8XupmWdZiE,114
|
4
|
+
neural_feature_importance/utils/monitors.py,sha256=LTz7oE0-WgZ50DHyHDnTwfzWSSWMnjWd0xlwt7BWKuU,1763
|
5
|
+
neural_feature_importance-0.5.0.dist-info/METADATA,sha256=Uo5R11NK7P-XXpeu3VWIl7I1w5K0l_30y0gp_J5GR-k,346
|
6
|
+
neural_feature_importance-0.5.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
7
|
+
neural_feature_importance-0.5.0.dist-info/top_level.txt,sha256=yP0Q-BG7hDLLu1H1_x5bGEKwkCso5NxxvScnlmICb-o,26
|
8
|
+
neural_feature_importance-0.5.0.dist-info/RECORD,,
|
@@ -0,0 +1 @@
|
|
1
|
+
neural_feature_importance
|