oodeel 0.1.1__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of oodeel might be problematic. Click here for more details.
- oodeel/__init__.py +1 -1
- oodeel/datasets/__init__.py +2 -1
- oodeel/datasets/data_handler.py +162 -94
- oodeel/datasets/deprecated/DEPRECATED_data_handler.py +236 -0
- oodeel/datasets/{ooddataset.py → deprecated/DEPRECATED_ooddataset.py} +14 -13
- oodeel/datasets/deprecated/DEPRECATED_tf_data_handler.py +671 -0
- oodeel/datasets/deprecated/DEPRECATED_torch_data_handler.py +769 -0
- oodeel/datasets/deprecated/__init__.py +31 -0
- oodeel/datasets/tf_data_handler.py +105 -167
- oodeel/datasets/torch_data_handler.py +109 -181
- oodeel/eval/metrics.py +7 -2
- oodeel/eval/plots/features.py +2 -2
- oodeel/eval/plots/plotly.py +2 -2
- oodeel/extractor/feature_extractor.py +30 -9
- oodeel/extractor/keras_feature_extractor.py +70 -13
- oodeel/extractor/torch_feature_extractor.py +120 -33
- oodeel/methods/__init__.py +17 -1
- oodeel/methods/base.py +103 -17
- oodeel/methods/dknn.py +22 -9
- oodeel/methods/energy.py +8 -0
- oodeel/methods/entropy.py +8 -0
- oodeel/methods/gen.py +118 -0
- oodeel/methods/gram.py +307 -0
- oodeel/methods/mahalanobis.py +14 -12
- oodeel/methods/mls.py +8 -0
- oodeel/methods/odin.py +8 -0
- oodeel/methods/rmds.py +122 -0
- oodeel/methods/she.py +197 -0
- oodeel/methods/vim.py +5 -5
- oodeel/preprocess/__init__.py +31 -0
- oodeel/preprocess/tf_preprocess.py +95 -0
- oodeel/preprocess/torch_preprocess.py +97 -0
- oodeel/utils/operator.py +72 -2
- oodeel/utils/tf_operator.py +72 -4
- oodeel/utils/tf_training_tools.py +26 -3
- oodeel/utils/torch_operator.py +75 -4
- oodeel/utils/torch_training_tools.py +31 -2
- {oodeel-0.1.1.dist-info → oodeel-0.3.0.dist-info}/METADATA +141 -107
- oodeel-0.3.0.dist-info/RECORD +57 -0
- {oodeel-0.1.1.dist-info → oodeel-0.3.0.dist-info}/WHEEL +1 -1
- tests/tests_tensorflow/tf_methods_utils.py +2 -1
- tests/tests_torch/tools_torch.py +9 -9
- tests/tests_torch/torch_methods_utils.py +34 -27
- tests/tools_operator.py +10 -1
- oodeel-0.1.1.dist-info/RECORD +0 -46
- {oodeel-0.1.1.dist-info → oodeel-0.3.0.dist-info/licenses}/LICENSE +0 -0
- {oodeel-0.1.1.dist-info → oodeel-0.3.0.dist-info}/top_level.txt +0 -0
|
@@ -25,6 +25,7 @@ from abc import abstractmethod
|
|
|
25
25
|
|
|
26
26
|
from ..types import Callable
|
|
27
27
|
from ..types import DatasetType
|
|
28
|
+
from ..types import ItemType
|
|
28
29
|
from ..types import List
|
|
29
30
|
from ..types import Optional
|
|
30
31
|
from ..types import TensorType
|
|
@@ -51,6 +52,11 @@ class FeatureExtractor(ABC):
|
|
|
51
52
|
Defaults to None.
|
|
52
53
|
react_threshold: if not None, penultimate layer activations are clipped under
|
|
53
54
|
this threshold value (useful for ReAct). Defaults to None.
|
|
55
|
+
scale_percentile: if not None, the features are scaled
|
|
56
|
+
following the method of Xu et al., ICLR 2024.
|
|
57
|
+
Defaults to None.
|
|
58
|
+
ash_percentile: if not None, the features are scaled following
|
|
59
|
+
the method of Djurisic et al., ICLR 2023.
|
|
54
60
|
"""
|
|
55
61
|
|
|
56
62
|
def __init__(
|
|
@@ -59,6 +65,8 @@ class FeatureExtractor(ABC):
|
|
|
59
65
|
feature_layers_id: List[Union[int, str]] = [-1],
|
|
60
66
|
input_layer_id: Union[int, str] = [0],
|
|
61
67
|
react_threshold: Optional[float] = None,
|
|
68
|
+
scale_percentile: Optional[float] = None,
|
|
69
|
+
ash_percentile: Optional[float] = None,
|
|
62
70
|
):
|
|
63
71
|
if not isinstance(feature_layers_id, list):
|
|
64
72
|
feature_layers_id = [feature_layers_id]
|
|
@@ -66,6 +74,8 @@ class FeatureExtractor(ABC):
|
|
|
66
74
|
self.feature_layers_id = feature_layers_id
|
|
67
75
|
self.input_layer_id = input_layer_id
|
|
68
76
|
self.react_threshold = react_threshold
|
|
77
|
+
self.scale_percentile = scale_percentile
|
|
78
|
+
self.ash_percentile = ash_percentile
|
|
69
79
|
self.model = model
|
|
70
80
|
self.extractor = self.prepare_extractor()
|
|
71
81
|
|
|
@@ -91,28 +101,39 @@ class FeatureExtractor(ABC):
|
|
|
91
101
|
raise NotImplementedError()
|
|
92
102
|
|
|
93
103
|
@abstractmethod
|
|
94
|
-
def predict_tensor(
|
|
95
|
-
|
|
96
|
-
|
|
104
|
+
def predict_tensor(
|
|
105
|
+
self,
|
|
106
|
+
tensor: TensorType,
|
|
107
|
+
postproc_fns: Optional[List[Callable]] = None,
|
|
108
|
+
) -> Tuple[List[TensorType], TensorType]:
|
|
109
|
+
"""Get the projection of tensor in the feature space of self.model
|
|
97
110
|
|
|
98
111
|
Args:
|
|
99
|
-
tensor (TensorType): input tensor
|
|
112
|
+
tensor (TensorType): input tensor (or dataset elem)
|
|
113
|
+
postproc_fns (Optional[Callable]): postprocessing function to apply to each
|
|
114
|
+
feature immediately after forward. Default to None.
|
|
100
115
|
|
|
101
116
|
Returns:
|
|
102
|
-
List[TensorType], TensorType: features, logits
|
|
117
|
+
Tuple[List[TensorType], TensorType]: features, logits
|
|
103
118
|
"""
|
|
104
119
|
raise NotImplementedError()
|
|
105
120
|
|
|
106
121
|
@abstractmethod
|
|
107
122
|
def predict(
|
|
108
123
|
self,
|
|
109
|
-
dataset: Union[
|
|
124
|
+
dataset: Union[ItemType, DatasetType],
|
|
125
|
+
postproc_fns: Optional[List[Callable]] = None,
|
|
126
|
+
verbose: bool = False,
|
|
127
|
+
**kwargs,
|
|
110
128
|
) -> Tuple[List[TensorType], dict]:
|
|
111
|
-
"""
|
|
112
|
-
Projects input samples "inputs" into the feature space for a batched dataset
|
|
129
|
+
"""Get the projection of the dataset in the feature space of self.model
|
|
113
130
|
|
|
114
131
|
Args:
|
|
115
|
-
dataset (Union[
|
|
132
|
+
dataset (Union[ItemType, DatasetType]): input dataset
|
|
133
|
+
postproc_fns (Optional[Callable]): postprocessing function to apply to each
|
|
134
|
+
feature immediately after forward. Default to None.
|
|
135
|
+
verbose (bool): if True, display a progress bar. Defaults to False.
|
|
136
|
+
kwargs (dict): additional arguments not considered for prediction
|
|
116
137
|
|
|
117
138
|
Returns:
|
|
118
139
|
List[TensorType], dict: features and extra information (logits, labels) as a
|
|
@@ -24,6 +24,8 @@ from typing import get_args
|
|
|
24
24
|
from typing import Optional
|
|
25
25
|
|
|
26
26
|
import tensorflow as tf
|
|
27
|
+
import tensorflow_probability as tfp
|
|
28
|
+
from tqdm import tqdm
|
|
27
29
|
|
|
28
30
|
from ..datasets.tf_data_handler import TFDataHandler
|
|
29
31
|
from ..types import Callable
|
|
@@ -54,6 +56,11 @@ class KerasFeatureExtractor(FeatureExtractor):
|
|
|
54
56
|
Defaults to None.
|
|
55
57
|
react_threshold: if not None, penultimate layer activations are clipped under
|
|
56
58
|
this threshold value (useful for ReAct). Defaults to None.
|
|
59
|
+
scale_percentile: if not None, the features are scaled
|
|
60
|
+
following the method of Xu et al., ICLR 2024.
|
|
61
|
+
Defaults to None.
|
|
62
|
+
ash_percentile: if not None, the features are scaled following
|
|
63
|
+
the method of Djurisic et al., ICLR 2023.
|
|
57
64
|
"""
|
|
58
65
|
|
|
59
66
|
def __init__(
|
|
@@ -62,6 +69,8 @@ class KerasFeatureExtractor(FeatureExtractor):
|
|
|
62
69
|
feature_layers_id: List[Union[int, str]] = [-1],
|
|
63
70
|
input_layer_id: Optional[Union[int, str]] = None,
|
|
64
71
|
react_threshold: Optional[float] = None,
|
|
72
|
+
scale_percentile: Optional[float] = None,
|
|
73
|
+
ash_percentile: Optional[float] = None,
|
|
65
74
|
):
|
|
66
75
|
if input_layer_id is None:
|
|
67
76
|
input_layer_id = 0
|
|
@@ -70,6 +79,8 @@ class KerasFeatureExtractor(FeatureExtractor):
|
|
|
70
79
|
feature_layers_id=feature_layers_id,
|
|
71
80
|
input_layer_id=input_layer_id,
|
|
72
81
|
react_threshold=react_threshold,
|
|
82
|
+
scale_percentile=scale_percentile,
|
|
83
|
+
ash_percentile=ash_percentile,
|
|
73
84
|
)
|
|
74
85
|
|
|
75
86
|
self.backend = "tensorflow"
|
|
@@ -143,6 +154,43 @@ class KerasFeatureExtractor(FeatureExtractor):
|
|
|
143
154
|
)
|
|
144
155
|
# apply ultimate layer on clipped activations
|
|
145
156
|
output_tensors.append(last_layer(x))
|
|
157
|
+
|
|
158
|
+
# === If SCALE method, scale activations from penultimate layer ===
|
|
159
|
+
# === If ASH method, scale and prune activations from penultimate layer ===
|
|
160
|
+
elif (self.scale_percentile is not None) or (self.ash_percentile is not None):
|
|
161
|
+
penultimate_layer = self.find_layer(self.model, -2)
|
|
162
|
+
penult_extractor = tf.keras.models.Model(
|
|
163
|
+
new_input, penultimate_layer.output
|
|
164
|
+
)
|
|
165
|
+
last_layer = self.find_layer(self.model, -1)
|
|
166
|
+
|
|
167
|
+
# apply scaling on penultimate activations
|
|
168
|
+
penultimate = penult_extractor(new_input)
|
|
169
|
+
if self.scale_percentile is not None:
|
|
170
|
+
output_percentile = tfp.stats.percentile(
|
|
171
|
+
penultimate, 100 * self.scale_percentile, axis=1
|
|
172
|
+
)
|
|
173
|
+
else:
|
|
174
|
+
output_percentile = tfp.stats.percentile(
|
|
175
|
+
penultimate, 100 * self.ash_percentile, axis=1
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
mask = penultimate > tf.reshape(output_percentile, (-1, 1))
|
|
179
|
+
filtered_penultimate = tf.where(
|
|
180
|
+
mask, penultimate, tf.zeros_like(penultimate)
|
|
181
|
+
)
|
|
182
|
+
s = tf.math.exp(
|
|
183
|
+
tf.reduce_sum(penultimate, axis=1)
|
|
184
|
+
/ tf.reduce_sum(filtered_penultimate, axis=1)
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
if self.scale_percentile is not None:
|
|
188
|
+
x = penultimate * tf.expand_dims(s, 1)
|
|
189
|
+
else:
|
|
190
|
+
x = filtered_penultimate * tf.expand_dims(s, 1)
|
|
191
|
+
# apply ultimate layer on scaled activations
|
|
192
|
+
output_tensors.append(last_layer(x))
|
|
193
|
+
|
|
146
194
|
else:
|
|
147
195
|
output_tensors.append(self.find_layer(self.model, -1).output)
|
|
148
196
|
|
|
@@ -150,11 +198,17 @@ class KerasFeatureExtractor(FeatureExtractor):
|
|
|
150
198
|
return extractor
|
|
151
199
|
|
|
152
200
|
@sanitize_input
|
|
153
|
-
def predict_tensor(
|
|
201
|
+
def predict_tensor(
|
|
202
|
+
self,
|
|
203
|
+
tensor: TensorType,
|
|
204
|
+
postproc_fns: Optional[List[Callable]] = None,
|
|
205
|
+
) -> Tuple[List[tf.Tensor], tf.Tensor]:
|
|
154
206
|
"""Get the projection of tensor in the feature space of self.model
|
|
155
207
|
|
|
156
208
|
Args:
|
|
157
209
|
tensor (TensorType): input tensor (or dataset elem)
|
|
210
|
+
postproc_fns (Optional[List[Callable]]): postprocessing function to apply to
|
|
211
|
+
each feature immediately after forward. Default to None.
|
|
158
212
|
|
|
159
213
|
Returns:
|
|
160
214
|
Tuple[List[tf.Tensor], tf.Tensor]: features, logits
|
|
@@ -166,8 +220,12 @@ class KerasFeatureExtractor(FeatureExtractor):
|
|
|
166
220
|
|
|
167
221
|
# split features and logits
|
|
168
222
|
logits = features.pop()
|
|
169
|
-
|
|
170
|
-
|
|
223
|
+
|
|
224
|
+
if postproc_fns is not None:
|
|
225
|
+
features = [
|
|
226
|
+
postproc_fn(feature)
|
|
227
|
+
for feature, postproc_fn in zip(features, postproc_fns)
|
|
228
|
+
]
|
|
171
229
|
|
|
172
230
|
self._last_logits = logits
|
|
173
231
|
return features, logits
|
|
@@ -179,12 +237,17 @@ class KerasFeatureExtractor(FeatureExtractor):
|
|
|
179
237
|
def predict(
|
|
180
238
|
self,
|
|
181
239
|
dataset: Union[ItemType, tf.data.Dataset],
|
|
240
|
+
postproc_fns: Optional[List[Callable]] = None,
|
|
241
|
+
verbose: bool = False,
|
|
182
242
|
**kwargs,
|
|
183
243
|
) -> Tuple[List[tf.Tensor], dict]:
|
|
184
244
|
"""Get the projection of the dataset in the feature space of self.model
|
|
185
245
|
|
|
186
246
|
Args:
|
|
187
247
|
dataset (Union[ItemType, tf.data.Dataset]): input dataset
|
|
248
|
+
postproc_fns (Optional[Callable]): postprocessing function to apply to each
|
|
249
|
+
feature immediately after forward. Default to None.
|
|
250
|
+
verbose (bool): if True, display a progress bar. Defaults to False.
|
|
188
251
|
kwargs (dict): additional arguments not considered for prediction
|
|
189
252
|
|
|
190
253
|
Returns:
|
|
@@ -195,7 +258,7 @@ class KerasFeatureExtractor(FeatureExtractor):
|
|
|
195
258
|
|
|
196
259
|
if isinstance(dataset, get_args(ItemType)):
|
|
197
260
|
tensor = TFDataHandler.get_input_from_dataset_item(dataset)
|
|
198
|
-
features, logits = self.predict_tensor(tensor)
|
|
261
|
+
features, logits = self.predict_tensor(tensor, postproc_fns)
|
|
199
262
|
|
|
200
263
|
# Get labels if dataset is a tuple/list
|
|
201
264
|
if isinstance(dataset, (list, tuple)):
|
|
@@ -205,12 +268,10 @@ class KerasFeatureExtractor(FeatureExtractor):
|
|
|
205
268
|
features = [None for i in range(len(self.feature_layers_id))]
|
|
206
269
|
logits = None
|
|
207
270
|
contains_labels = TFDataHandler.get_item_length(dataset) > 1
|
|
208
|
-
for elem in dataset:
|
|
271
|
+
for elem in tqdm(dataset, desc="Predicting", disable=not verbose):
|
|
209
272
|
tensor = TFDataHandler.get_input_from_dataset_item(elem)
|
|
210
|
-
features_batch, logits_batch = self.predict_tensor(tensor)
|
|
211
|
-
|
|
212
|
-
if len(features) == 1:
|
|
213
|
-
features_batch = [features_batch]
|
|
273
|
+
features_batch, logits_batch = self.predict_tensor(tensor, postproc_fns)
|
|
274
|
+
|
|
214
275
|
for i, f in enumerate(features_batch):
|
|
215
276
|
features[i] = (
|
|
216
277
|
f
|
|
@@ -234,10 +295,6 @@ class KerasFeatureExtractor(FeatureExtractor):
|
|
|
234
295
|
|
|
235
296
|
# store extra information in a dict
|
|
236
297
|
info = dict(labels=labels, logits=logits)
|
|
237
|
-
|
|
238
|
-
if len(features) == 1:
|
|
239
|
-
features = features[0]
|
|
240
|
-
|
|
241
298
|
return features, info
|
|
242
299
|
|
|
243
300
|
def get_weights(self, layer_id: Union[int, str]) -> List[tf.Tensor]:
|
|
@@ -20,13 +20,13 @@
|
|
|
20
20
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
21
21
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
22
22
|
# SOFTWARE.
|
|
23
|
-
from collections import OrderedDict
|
|
24
23
|
from typing import get_args
|
|
25
24
|
from typing import Optional
|
|
26
25
|
|
|
27
26
|
import torch
|
|
28
27
|
from torch import nn
|
|
29
28
|
from torch.utils.data import DataLoader
|
|
29
|
+
from tqdm import tqdm
|
|
30
30
|
|
|
31
31
|
from ..datasets.torch_data_handler import TorchDataHandler
|
|
32
32
|
from ..types import Callable
|
|
@@ -57,6 +57,12 @@ class TorchFeatureExtractor(FeatureExtractor):
|
|
|
57
57
|
Defaults to None.
|
|
58
58
|
react_threshold: if not None, penultimate layer activations are clipped under
|
|
59
59
|
this threshold value (useful for ReAct). Defaults to None.
|
|
60
|
+
scale_percentile: if not None, the features are scaled
|
|
61
|
+
following the method of Xu et al., ICLR 2024.
|
|
62
|
+
Defaults to None.
|
|
63
|
+
ash_percentile: if not None, the features are scaled following
|
|
64
|
+
the method of Djurisic et al., ICLR 2023.
|
|
65
|
+
|
|
60
66
|
"""
|
|
61
67
|
|
|
62
68
|
def __init__(
|
|
@@ -65,6 +71,8 @@ class TorchFeatureExtractor(FeatureExtractor):
|
|
|
65
71
|
feature_layers_id: List[Union[int, str]] = [],
|
|
66
72
|
input_layer_id: Optional[Union[int, str]] = None,
|
|
67
73
|
react_threshold: Optional[float] = None,
|
|
74
|
+
scale_percentile: Optional[float] = None,
|
|
75
|
+
ash_percentile: Optional[float] = None,
|
|
68
76
|
):
|
|
69
77
|
model = model.eval()
|
|
70
78
|
super().__init__(
|
|
@@ -72,6 +80,8 @@ class TorchFeatureExtractor(FeatureExtractor):
|
|
|
72
80
|
feature_layers_id=feature_layers_id,
|
|
73
81
|
input_layer_id=input_layer_id,
|
|
74
82
|
react_threshold=react_threshold,
|
|
83
|
+
scale_percentile=scale_percentile,
|
|
84
|
+
ash_percentile=ash_percentile,
|
|
75
85
|
)
|
|
76
86
|
self._device = next(model.parameters()).device
|
|
77
87
|
self._features = {layer: torch.empty(0) for layer in self._hook_layers_id}
|
|
@@ -140,18 +150,43 @@ class TorchFeatureExtractor(FeatureExtractor):
|
|
|
140
150
|
|
|
141
151
|
def prepare_extractor(self) -> None:
|
|
142
152
|
"""Prepare the feature extractor by adding hooks to self.model"""
|
|
143
|
-
#
|
|
144
|
-
|
|
153
|
+
# prepare self.model for ood hooks (add _ood_handles attribute or
|
|
154
|
+
# remove ood forward hooks attached to the model)
|
|
155
|
+
self._prepare_ood_handles()
|
|
145
156
|
|
|
146
157
|
# === If react method, clip activations from penultimate layer ===
|
|
147
158
|
if self.react_threshold is not None:
|
|
148
|
-
pen_layer = self.find_layer(self.model, -
|
|
149
|
-
|
|
159
|
+
pen_layer = self.find_layer(self.model, -1)
|
|
160
|
+
self.model._ood_handles.append(
|
|
161
|
+
pen_layer.register_forward_pre_hook(
|
|
162
|
+
self._get_clip_hook(self.react_threshold)
|
|
163
|
+
)
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
# === If SCALE method, scale activations from penultimate layer ===
|
|
167
|
+
if self.scale_percentile is not None:
|
|
168
|
+
pen_layer = self.find_layer(self.model, -1)
|
|
169
|
+
self.model._ood_handles.append(
|
|
170
|
+
pen_layer.register_forward_pre_hook(
|
|
171
|
+
self._get_scale_hook(self.scale_percentile)
|
|
172
|
+
)
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
# === If ASH method, scale and prune activations from penultimate layer ===
|
|
176
|
+
if self.ash_percentile is not None:
|
|
177
|
+
pen_layer = self.find_layer(self.model, -1)
|
|
178
|
+
self.model._ood_handles.append(
|
|
179
|
+
pen_layer.register_forward_pre_hook(
|
|
180
|
+
self._get_ash_hook(self.ash_percentile)
|
|
181
|
+
)
|
|
182
|
+
)
|
|
150
183
|
|
|
151
184
|
# Register a hook to store feature values for each considered layer + last layer
|
|
152
185
|
for layer_id in self._hook_layers_id:
|
|
153
186
|
layer = self.find_layer(self.model, layer_id)
|
|
154
|
-
|
|
187
|
+
self.model._ood_handles.append(
|
|
188
|
+
layer.register_forward_hook(self._get_features_hook(layer_id))
|
|
189
|
+
)
|
|
155
190
|
|
|
156
191
|
# Crop model if input layer is provided
|
|
157
192
|
if not (self.input_layer_id) is None:
|
|
@@ -181,12 +216,17 @@ class TorchFeatureExtractor(FeatureExtractor):
|
|
|
181
216
|
|
|
182
217
|
@sanitize_input
|
|
183
218
|
def predict_tensor(
|
|
184
|
-
self,
|
|
219
|
+
self,
|
|
220
|
+
x: TensorType,
|
|
221
|
+
postproc_fns: Optional[List[Callable]] = None,
|
|
222
|
+
detach: bool = True,
|
|
185
223
|
) -> Tuple[List[torch.Tensor], torch.Tensor]:
|
|
186
224
|
"""Get the projection of tensor in the feature space of self.model
|
|
187
225
|
|
|
188
226
|
Args:
|
|
189
227
|
x (TensorType): input tensor (or dataset elem)
|
|
228
|
+
postproc_fns (Optional[List[Callable]]): postprocessing function to apply to
|
|
229
|
+
each feature immediately after forward. Default to None.
|
|
190
230
|
detach (bool): if True, return features detached from the computational
|
|
191
231
|
graph. Defaults to True.
|
|
192
232
|
|
|
@@ -206,8 +246,12 @@ class TorchFeatureExtractor(FeatureExtractor):
|
|
|
206
246
|
|
|
207
247
|
# split features and logits
|
|
208
248
|
logits = features.pop()
|
|
209
|
-
|
|
210
|
-
|
|
249
|
+
|
|
250
|
+
if postproc_fns is not None:
|
|
251
|
+
features = [
|
|
252
|
+
postproc_fn(feature)
|
|
253
|
+
for feature, postproc_fn in zip(features, postproc_fns)
|
|
254
|
+
]
|
|
211
255
|
|
|
212
256
|
self._last_logits = logits
|
|
213
257
|
return features, logits
|
|
@@ -215,15 +259,20 @@ class TorchFeatureExtractor(FeatureExtractor):
|
|
|
215
259
|
def predict(
|
|
216
260
|
self,
|
|
217
261
|
dataset: Union[DataLoader, ItemType],
|
|
262
|
+
postproc_fns: Optional[List[Callable]] = None,
|
|
218
263
|
detach: bool = True,
|
|
264
|
+
verbose: bool = False,
|
|
219
265
|
**kwargs,
|
|
220
266
|
) -> Tuple[List[torch.Tensor], dict]:
|
|
221
267
|
"""Get the projection of the dataset in the feature space of self.model
|
|
222
268
|
|
|
223
269
|
Args:
|
|
224
270
|
dataset (Union[DataLoader, ItemType]): input dataset
|
|
271
|
+
postproc_fns (Optional[List[Callable]]): postprocessing function to apply to
|
|
272
|
+
each feature immediately after forward. Default to None.
|
|
225
273
|
detach (bool): if True, return features detached from the computational
|
|
226
274
|
graph. Defaults to True.
|
|
275
|
+
verbose (bool): if True, display a progress bar. Defaults to False.
|
|
227
276
|
kwargs (dict): additional arguments not considered for prediction
|
|
228
277
|
|
|
229
278
|
Returns:
|
|
@@ -234,7 +283,7 @@ class TorchFeatureExtractor(FeatureExtractor):
|
|
|
234
283
|
|
|
235
284
|
if isinstance(dataset, get_args(ItemType)):
|
|
236
285
|
tensor = TorchDataHandler.get_input_from_dataset_item(dataset)
|
|
237
|
-
features, logits = self.predict_tensor(tensor, detach=detach)
|
|
286
|
+
features, logits = self.predict_tensor(tensor, postproc_fns, detach=detach)
|
|
238
287
|
|
|
239
288
|
# Get labels if dataset is a tuple/list
|
|
240
289
|
if isinstance(dataset, (list, tuple)) and len(dataset) > 1:
|
|
@@ -245,14 +294,11 @@ class TorchFeatureExtractor(FeatureExtractor):
|
|
|
245
294
|
logits = None
|
|
246
295
|
batch = next(iter(dataset))
|
|
247
296
|
contains_labels = isinstance(batch, (list, tuple)) and len(batch) > 1
|
|
248
|
-
for elem in dataset:
|
|
297
|
+
for elem in tqdm(dataset, desc="Predicting", disable=not verbose):
|
|
249
298
|
tensor = TorchDataHandler.get_input_from_dataset_item(elem)
|
|
250
299
|
features_batch, logits_batch = self.predict_tensor(
|
|
251
|
-
tensor, detach=detach
|
|
300
|
+
tensor, postproc_fns, detach=detach
|
|
252
301
|
)
|
|
253
|
-
# concatenate features
|
|
254
|
-
if len(features) == 1:
|
|
255
|
-
features_batch = [features_batch]
|
|
256
302
|
for i, f in enumerate(features_batch):
|
|
257
303
|
features[i] = (
|
|
258
304
|
f if features[i] is None else torch.cat([features[i], f], dim=0)
|
|
@@ -274,10 +320,6 @@ class TorchFeatureExtractor(FeatureExtractor):
|
|
|
274
320
|
|
|
275
321
|
# store extra information in a dict
|
|
276
322
|
info = dict(labels=labels, logits=logits)
|
|
277
|
-
|
|
278
|
-
if len(features) == 1:
|
|
279
|
-
features = features[0]
|
|
280
|
-
|
|
281
323
|
return features, info
|
|
282
324
|
|
|
283
325
|
def get_weights(self, layer_id: Union[str, int]) -> List[torch.Tensor]:
|
|
@@ -303,24 +345,69 @@ class TorchFeatureExtractor(FeatureExtractor):
|
|
|
303
345
|
Callable: hook function
|
|
304
346
|
"""
|
|
305
347
|
|
|
306
|
-
def hook(_,
|
|
307
|
-
|
|
308
|
-
|
|
348
|
+
def hook(_, input):
|
|
349
|
+
input = input[0]
|
|
350
|
+
input = torch.clip(input, max=threshold)
|
|
351
|
+
return input
|
|
309
352
|
|
|
310
353
|
return hook
|
|
311
354
|
|
|
312
|
-
def
|
|
355
|
+
def _get_scale_hook(self, percentile: float) -> Callable:
|
|
313
356
|
"""
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
357
|
+
Hook that scales activation features.
|
|
358
|
+
|
|
359
|
+
Args:
|
|
360
|
+
threshold (float): threshold value
|
|
361
|
+
|
|
362
|
+
Returns:
|
|
363
|
+
Callable: hook function
|
|
317
364
|
"""
|
|
318
365
|
|
|
319
|
-
def
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
366
|
+
def hook(_, input):
|
|
367
|
+
input = input[0]
|
|
368
|
+
output_percentile = torch.quantile(input, percentile, dim=1)
|
|
369
|
+
mask = input > output_percentile[:, None]
|
|
370
|
+
output_masked = input * mask
|
|
371
|
+
s = torch.exp(torch.sum(input, dim=1) / torch.sum(output_masked, dim=1))
|
|
372
|
+
s = torch.unsqueeze(s, 1)
|
|
373
|
+
input = input * s
|
|
374
|
+
return input
|
|
325
375
|
|
|
326
|
-
return
|
|
376
|
+
return hook
|
|
377
|
+
|
|
378
|
+
def _get_ash_hook(self, percentile: float) -> Callable:
|
|
379
|
+
"""
|
|
380
|
+
Hook that scales and prunes activation features under a threshold value
|
|
381
|
+
|
|
382
|
+
Args:
|
|
383
|
+
threshold (float): threshold value
|
|
384
|
+
|
|
385
|
+
Returns:
|
|
386
|
+
Callable: hook function
|
|
387
|
+
"""
|
|
388
|
+
|
|
389
|
+
def hook(_, input):
|
|
390
|
+
input = input[0]
|
|
391
|
+
output_percentile = torch.quantile(input, percentile, dim=1)
|
|
392
|
+
mask = input > output_percentile[:, None]
|
|
393
|
+
output_masked = input * mask
|
|
394
|
+
s = torch.exp(torch.sum(input, dim=1) / torch.sum(output_masked, dim=1))
|
|
395
|
+
s = torch.unsqueeze(s, 1)
|
|
396
|
+
input = output_masked * s
|
|
397
|
+
return input
|
|
398
|
+
|
|
399
|
+
return hook
|
|
400
|
+
|
|
401
|
+
def _prepare_ood_handles(self) -> None:
|
|
402
|
+
"""
|
|
403
|
+
Prepare the model by either setting a new attribute to self.model
|
|
404
|
+
as a list which will contain all the ood specific hooks, or by cleaning
|
|
405
|
+
the existing ood specific hooks if the attribute already exists.
|
|
406
|
+
"""
|
|
407
|
+
|
|
408
|
+
if not hasattr(self.model, "_ood_handles"):
|
|
409
|
+
setattr(self.model, "_ood_handles", [])
|
|
410
|
+
else:
|
|
411
|
+
for handle in self.model._ood_handles:
|
|
412
|
+
handle.remove()
|
|
413
|
+
self.model._ood_handles = []
|
oodeel/methods/__init__.py
CHANGED
|
@@ -23,9 +23,25 @@
|
|
|
23
23
|
from .dknn import DKNN
|
|
24
24
|
from .energy import Energy
|
|
25
25
|
from .entropy import Entropy
|
|
26
|
+
from .gen import GEN
|
|
27
|
+
from .gram import Gram
|
|
26
28
|
from .mahalanobis import Mahalanobis
|
|
27
29
|
from .mls import MLS
|
|
28
30
|
from .odin import ODIN
|
|
31
|
+
from .rmds import RMDS
|
|
32
|
+
from .she import SHE
|
|
29
33
|
from .vim import VIM
|
|
30
34
|
|
|
31
|
-
__all__ = [
|
|
35
|
+
__all__ = [
|
|
36
|
+
"DKNN",
|
|
37
|
+
"Energy",
|
|
38
|
+
"Entropy",
|
|
39
|
+
"GEN",
|
|
40
|
+
"SHE",
|
|
41
|
+
"Gram",
|
|
42
|
+
"Mahalanobis",
|
|
43
|
+
"MLS",
|
|
44
|
+
"ODIN",
|
|
45
|
+
"RMDS",
|
|
46
|
+
"VIM",
|
|
47
|
+
]
|