oodeel 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- oodeel/__init__.py +28 -0
- oodeel/aggregator/__init__.py +26 -0
- oodeel/aggregator/base.py +70 -0
- oodeel/aggregator/fisher.py +259 -0
- oodeel/aggregator/mean.py +72 -0
- oodeel/aggregator/std.py +86 -0
- oodeel/datasets/__init__.py +24 -0
- oodeel/datasets/data_handler.py +334 -0
- oodeel/datasets/deprecated/DEPRECATED_data_handler.py +236 -0
- oodeel/datasets/deprecated/DEPRECATED_ooddataset.py +330 -0
- oodeel/datasets/deprecated/DEPRECATED_tf_data_handler.py +671 -0
- oodeel/datasets/deprecated/DEPRECATED_torch_data_handler.py +769 -0
- oodeel/datasets/deprecated/__init__.py +31 -0
- oodeel/datasets/tf_data_handler.py +600 -0
- oodeel/datasets/torch_data_handler.py +672 -0
- oodeel/eval/__init__.py +22 -0
- oodeel/eval/metrics.py +218 -0
- oodeel/eval/plots/__init__.py +27 -0
- oodeel/eval/plots/features.py +345 -0
- oodeel/eval/plots/metrics.py +118 -0
- oodeel/eval/plots/plotly.py +162 -0
- oodeel/extractor/__init__.py +35 -0
- oodeel/extractor/feature_extractor.py +187 -0
- oodeel/extractor/hf_torch_feature_extractor.py +184 -0
- oodeel/extractor/keras_feature_extractor.py +409 -0
- oodeel/extractor/torch_feature_extractor.py +506 -0
- oodeel/methods/__init__.py +47 -0
- oodeel/methods/base.py +570 -0
- oodeel/methods/dknn.py +185 -0
- oodeel/methods/energy.py +119 -0
- oodeel/methods/entropy.py +113 -0
- oodeel/methods/gen.py +113 -0
- oodeel/methods/gram.py +274 -0
- oodeel/methods/mahalanobis.py +209 -0
- oodeel/methods/mls.py +113 -0
- oodeel/methods/odin.py +109 -0
- oodeel/methods/rmds.py +172 -0
- oodeel/methods/she.py +159 -0
- oodeel/methods/vim.py +273 -0
- oodeel/preprocess/__init__.py +31 -0
- oodeel/preprocess/tf_preprocess.py +95 -0
- oodeel/preprocess/torch_preprocess.py +97 -0
- oodeel/types/__init__.py +75 -0
- oodeel/utils/__init__.py +38 -0
- oodeel/utils/general_utils.py +97 -0
- oodeel/utils/operator.py +253 -0
- oodeel/utils/tf_operator.py +269 -0
- oodeel/utils/tf_training_tools.py +219 -0
- oodeel/utils/torch_operator.py +292 -0
- oodeel/utils/torch_training_tools.py +303 -0
- oodeel-0.4.0.dist-info/METADATA +409 -0
- oodeel-0.4.0.dist-info/RECORD +63 -0
- oodeel-0.4.0.dist-info/WHEEL +5 -0
- oodeel-0.4.0.dist-info/licenses/LICENSE +21 -0
- oodeel-0.4.0.dist-info/top_level.txt +2 -0
- tests/__init__.py +22 -0
- tests/tests_tensorflow/__init__.py +37 -0
- tests/tests_tensorflow/tf_methods_utils.py +140 -0
- tests/tests_tensorflow/tools_tf.py +86 -0
- tests/tests_torch/__init__.py +38 -0
- tests/tests_torch/tools_torch.py +151 -0
- tests/tests_torch/torch_methods_utils.py +148 -0
- tests/tools_operator.py +153 -0
|
@@ -0,0 +1,506 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Copyright IRT Antoine de Saint Exupéry et Université Paul Sabatier Toulouse III - All
|
|
3
|
+
# rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
|
|
4
|
+
# CRIAQ and ANITI - https://www.deel.ai/
|
|
5
|
+
#
|
|
6
|
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
7
|
+
# of this software and associated documentation files (the "Software"), to deal
|
|
8
|
+
# in the Software without restriction, including without limitation the rights
|
|
9
|
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
10
|
+
# copies of the Software, and to permit persons to whom the Software is
|
|
11
|
+
# furnished to do so, subject to the following conditions:
|
|
12
|
+
#
|
|
13
|
+
# The above copyright notice and this permission notice shall be included in all
|
|
14
|
+
# copies or substantial portions of the Software.
|
|
15
|
+
#
|
|
16
|
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
17
|
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
18
|
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
19
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
20
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
21
|
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
22
|
+
# SOFTWARE.
|
|
23
|
+
from typing import get_args
|
|
24
|
+
from typing import Optional
|
|
25
|
+
|
|
26
|
+
import numpy as np
|
|
27
|
+
import torch
|
|
28
|
+
from torch import nn
|
|
29
|
+
from torch.utils.data import DataLoader
|
|
30
|
+
from tqdm import tqdm
|
|
31
|
+
|
|
32
|
+
from ..datasets.torch_data_handler import TorchDataHandler
|
|
33
|
+
from ..types import Callable
|
|
34
|
+
from ..types import ItemType
|
|
35
|
+
from ..types import List
|
|
36
|
+
from ..types import TensorType
|
|
37
|
+
from ..types import Tuple
|
|
38
|
+
from ..types import Union
|
|
39
|
+
from ..utils.torch_operator import sanitize_input
|
|
40
|
+
from .feature_extractor import FeatureExtractor
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class TorchFeatureExtractor(FeatureExtractor):
|
|
44
|
+
"""
|
|
45
|
+
Feature extractor based on "model" to construct a feature space
|
|
46
|
+
on which OOD detection is performed. The features can be the output
|
|
47
|
+
activation values of internal model layers,
|
|
48
|
+
or the output of the model (logits).
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
model: model to extract the features from
|
|
52
|
+
feature_layers_id: list of str or int that identify features to output.
|
|
53
|
+
If int, the rank of the layer in the layer list
|
|
54
|
+
If str, the name of the layer. Defaults to [].
|
|
55
|
+
head_layer_id (int, str): identifier of the head layer.
|
|
56
|
+
If int, the rank of the layer in the layer list
|
|
57
|
+
If str, the name of the layer.
|
|
58
|
+
Defaults to -1
|
|
59
|
+
input_layer_id: input layer of the feature extractor (to avoid useless forwards
|
|
60
|
+
when working on the feature space without finetuning the bottom of
|
|
61
|
+
the model).
|
|
62
|
+
Defaults to None.
|
|
63
|
+
react_threshold: if not None, penultimate layer activations are clipped under
|
|
64
|
+
this threshold value (useful for ReAct). Defaults to None.
|
|
65
|
+
scale_percentile: if not None, the features are scaled
|
|
66
|
+
following the method of Xu et al., ICLR 2024.
|
|
67
|
+
Defaults to None.
|
|
68
|
+
ash_percentile: if not None, the features are scaled following
|
|
69
|
+
the method of Djurisic et al., ICLR 2023.
|
|
70
|
+
return_penultimate (bool): if True, the penultimate values are returned,
|
|
71
|
+
i.e. the input to the head_layer.
|
|
72
|
+
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
def __init__(
|
|
76
|
+
self,
|
|
77
|
+
model: nn.Module,
|
|
78
|
+
feature_layers_id: List[Union[int, str]] = [],
|
|
79
|
+
head_layer_id: Optional[Union[int, str]] = -1,
|
|
80
|
+
input_layer_id: Optional[Union[int, str]] = None,
|
|
81
|
+
react_threshold: Optional[float] = None,
|
|
82
|
+
scale_percentile: Optional[float] = None,
|
|
83
|
+
ash_percentile: Optional[float] = None,
|
|
84
|
+
return_penultimate: Optional[bool] = False,
|
|
85
|
+
):
|
|
86
|
+
model = model.eval()
|
|
87
|
+
|
|
88
|
+
if return_penultimate:
|
|
89
|
+
feature_layers_id.append("penultimate")
|
|
90
|
+
|
|
91
|
+
super().__init__(
|
|
92
|
+
model=model,
|
|
93
|
+
feature_layers_id=feature_layers_id,
|
|
94
|
+
head_layer_id=head_layer_id,
|
|
95
|
+
input_layer_id=input_layer_id,
|
|
96
|
+
react_threshold=react_threshold,
|
|
97
|
+
scale_percentile=scale_percentile,
|
|
98
|
+
ash_percentile=ash_percentile,
|
|
99
|
+
return_penultimate=return_penultimate,
|
|
100
|
+
)
|
|
101
|
+
self._device = next(model.parameters()).device
|
|
102
|
+
self._features = {layer: torch.empty(0) for layer in self._hook_layers_id}
|
|
103
|
+
self._last_logits = None
|
|
104
|
+
self.backend = "torch"
|
|
105
|
+
|
|
106
|
+
@property
|
|
107
|
+
def _hook_layers_id(self):
|
|
108
|
+
return self.feature_layers_id + [self.head_layer_id]
|
|
109
|
+
|
|
110
|
+
def _get_features_hook(self, layer_id: Union[str, int]) -> Callable:
|
|
111
|
+
"""
|
|
112
|
+
Hook that stores features corresponding to a specific layer
|
|
113
|
+
in a class dictionary.
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
layer_id (Union[str, int]): layer identifier
|
|
117
|
+
|
|
118
|
+
Returns:
|
|
119
|
+
Callable: hook function
|
|
120
|
+
"""
|
|
121
|
+
|
|
122
|
+
def hook(_, __, output):
|
|
123
|
+
if isinstance(output, torch.Tensor):
|
|
124
|
+
self._features[layer_id] = output
|
|
125
|
+
else:
|
|
126
|
+
raise NotImplementedError
|
|
127
|
+
|
|
128
|
+
return hook
|
|
129
|
+
|
|
130
|
+
def _get_penultimate_hook(self) -> Callable:
|
|
131
|
+
"""
|
|
132
|
+
Hook that stores features corresponding to a specific layer
|
|
133
|
+
in a class dictionary.
|
|
134
|
+
|
|
135
|
+
Args:
|
|
136
|
+
layer_id (Union[str, int]): layer identifier
|
|
137
|
+
|
|
138
|
+
Returns:
|
|
139
|
+
Callable: hook function
|
|
140
|
+
"""
|
|
141
|
+
|
|
142
|
+
def hook(_, input):
|
|
143
|
+
if isinstance(input[0], torch.Tensor):
|
|
144
|
+
self._features["penultimate"] = input[0]
|
|
145
|
+
else:
|
|
146
|
+
raise NotImplementedError
|
|
147
|
+
|
|
148
|
+
return hook
|
|
149
|
+
|
|
150
|
+
@staticmethod
|
|
151
|
+
def find_layer(
|
|
152
|
+
model: nn.Module,
|
|
153
|
+
layer_id: Union[str, int],
|
|
154
|
+
index_offset: int = 0,
|
|
155
|
+
return_id: bool = False,
|
|
156
|
+
) -> Union[nn.Module, Tuple[nn.Module, str]]:
|
|
157
|
+
"""Find a layer in a model either by his name or by his index.
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
model (nn.Module): model whose identified layer will be returned
|
|
161
|
+
layer_id (Union[str, int]): layer identifier
|
|
162
|
+
index_offset (int): index offset to find layers located before (negative
|
|
163
|
+
offset) or after (positive offset) the identified layer
|
|
164
|
+
return_id (bool): if True, the layer will be returned with its id
|
|
165
|
+
|
|
166
|
+
Returns:
|
|
167
|
+
Union[nn.Module, Tuple[nn.Module, str]]: the corresponding layer and its id
|
|
168
|
+
if return_id is True.
|
|
169
|
+
"""
|
|
170
|
+
if isinstance(layer_id, int):
|
|
171
|
+
layer_id += index_offset
|
|
172
|
+
if isinstance(model, nn.Sequential):
|
|
173
|
+
layer = model[layer_id]
|
|
174
|
+
else:
|
|
175
|
+
layer = list(model.named_modules())[layer_id][1]
|
|
176
|
+
else:
|
|
177
|
+
layer_id = list(dict(model.named_modules()).keys()).index(layer_id)
|
|
178
|
+
layer_id += index_offset
|
|
179
|
+
layer = list(model.named_modules())[layer_id][1]
|
|
180
|
+
|
|
181
|
+
if return_id:
|
|
182
|
+
return layer, layer_id
|
|
183
|
+
else:
|
|
184
|
+
return layer
|
|
185
|
+
|
|
186
|
+
def prepare_extractor(self) -> None:
|
|
187
|
+
"""Prepare the feature extractor by adding hooks to self.model"""
|
|
188
|
+
# prepare self.model for ood hooks (add _ood_handles attribute or
|
|
189
|
+
# remove ood forward hooks attached to the model)
|
|
190
|
+
self._prepare_ood_handles()
|
|
191
|
+
|
|
192
|
+
# === If react method, clip activations from penultimate layer ===
|
|
193
|
+
if self.react_threshold is not None:
|
|
194
|
+
pen_layer = self.find_layer(self.model, self.head_layer_id)
|
|
195
|
+
self.model._ood_handles.append(
|
|
196
|
+
pen_layer.register_forward_pre_hook(
|
|
197
|
+
self._get_clip_hook(self.react_threshold)
|
|
198
|
+
)
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
# === If SCALE method, scale activations from penultimate layer ===
|
|
202
|
+
if self.scale_percentile is not None:
|
|
203
|
+
pen_layer = self.find_layer(self.model, self.head_layer_id)
|
|
204
|
+
self.model._ood_handles.append(
|
|
205
|
+
pen_layer.register_forward_pre_hook(
|
|
206
|
+
self._get_scale_hook(self.scale_percentile)
|
|
207
|
+
)
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
# === If ASH method, scale and prune activations from penultimate layer ===
|
|
211
|
+
if self.ash_percentile is not None:
|
|
212
|
+
pen_layer = self.find_layer(self.model, self.head_layer_id)
|
|
213
|
+
self.model._ood_handles.append(
|
|
214
|
+
pen_layer.register_forward_pre_hook(
|
|
215
|
+
self._get_ash_hook(self.ash_percentile)
|
|
216
|
+
)
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
# Register a hook to store feature values for each considered layer + last layer
|
|
220
|
+
for layer_id in self._hook_layers_id:
|
|
221
|
+
if layer_id == "penultimate":
|
|
222
|
+
# Register penultimate hook
|
|
223
|
+
layer = self.find_layer(self.model, self.head_layer_id)
|
|
224
|
+
self.model._ood_handles.append(
|
|
225
|
+
layer.register_forward_pre_hook(self._get_penultimate_hook())
|
|
226
|
+
)
|
|
227
|
+
continue
|
|
228
|
+
|
|
229
|
+
layer = self.find_layer(self.model, layer_id)
|
|
230
|
+
self.model._ood_handles.append(
|
|
231
|
+
layer.register_forward_hook(self._get_features_hook(layer_id))
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
# Crop model if input layer is provided
|
|
235
|
+
if not (self.input_layer_id) is None:
|
|
236
|
+
if isinstance(self.input_layer_id, int):
|
|
237
|
+
if isinstance(self.model, nn.Sequential):
|
|
238
|
+
self.model = nn.Sequential(
|
|
239
|
+
*list(self.model.modules())[self.input_layer_id :]
|
|
240
|
+
)
|
|
241
|
+
else:
|
|
242
|
+
raise NotImplementedError
|
|
243
|
+
elif isinstance(self.input_layer_id, str):
|
|
244
|
+
if isinstance(self.model, nn.Sequential):
|
|
245
|
+
module_names = list(
|
|
246
|
+
filter(
|
|
247
|
+
lambda x: x != "",
|
|
248
|
+
map(lambda x: x[0], self.model.named_modules()),
|
|
249
|
+
)
|
|
250
|
+
)
|
|
251
|
+
input_module_idx = module_names.index(self.input_layer_id)
|
|
252
|
+
self.model = nn.Sequential(
|
|
253
|
+
*list(self.model.modules())[(input_module_idx + 1) :]
|
|
254
|
+
)
|
|
255
|
+
else:
|
|
256
|
+
raise NotImplementedError
|
|
257
|
+
else:
|
|
258
|
+
raise NotImplementedError
|
|
259
|
+
|
|
260
|
+
@sanitize_input
|
|
261
|
+
def predict_tensor(
|
|
262
|
+
self,
|
|
263
|
+
x: TensorType,
|
|
264
|
+
postproc_fns: Optional[List[Callable]] = None,
|
|
265
|
+
detach: bool = True,
|
|
266
|
+
) -> Tuple[List[torch.Tensor], torch.Tensor]:
|
|
267
|
+
"""Get the projection of tensor in the feature space of self.model
|
|
268
|
+
|
|
269
|
+
Args:
|
|
270
|
+
x (TensorType): input tensor (or dataset elem)
|
|
271
|
+
postproc_fns (Optional[List[Callable]]): postprocessing function to apply to
|
|
272
|
+
each feature immediately after forward. Default to None.
|
|
273
|
+
detach (bool): if True, return features detached from the computational
|
|
274
|
+
graph. Defaults to True.
|
|
275
|
+
|
|
276
|
+
Returns:
|
|
277
|
+
List[torch.Tensor], torch.Tensor: features, logits
|
|
278
|
+
"""
|
|
279
|
+
if x.device != self._device:
|
|
280
|
+
x = x.to(self._device)
|
|
281
|
+
|
|
282
|
+
with torch.set_grad_enabled(not detach):
|
|
283
|
+
_ = self.model(x)
|
|
284
|
+
|
|
285
|
+
if detach:
|
|
286
|
+
features = [
|
|
287
|
+
self._features[layer_id].detach() for layer_id in self._hook_layers_id
|
|
288
|
+
]
|
|
289
|
+
else:
|
|
290
|
+
features = [self._features[layer_id] for layer_id in self._hook_layers_id]
|
|
291
|
+
|
|
292
|
+
# split features and logits
|
|
293
|
+
logits = features.pop()
|
|
294
|
+
|
|
295
|
+
if postproc_fns is not None:
|
|
296
|
+
features = [
|
|
297
|
+
postproc_fn(feature)
|
|
298
|
+
for feature, postproc_fn in zip(features, postproc_fns)
|
|
299
|
+
]
|
|
300
|
+
|
|
301
|
+
self._last_logits = logits
|
|
302
|
+
return features, logits
|
|
303
|
+
|
|
304
|
+
def predict(
|
|
305
|
+
self,
|
|
306
|
+
dataset: Union[DataLoader, ItemType],
|
|
307
|
+
postproc_fns: Optional[List[Callable]] = None,
|
|
308
|
+
detach: bool = True,
|
|
309
|
+
verbose: bool = False,
|
|
310
|
+
numpy_concat: bool = False,
|
|
311
|
+
**kwargs,
|
|
312
|
+
) -> Tuple[List[torch.Tensor], dict]:
|
|
313
|
+
"""Get the projection of the dataset in the feature space of self.model
|
|
314
|
+
|
|
315
|
+
Args:
|
|
316
|
+
dataset (Union[DataLoader, ItemType]): input dataset
|
|
317
|
+
postproc_fns (Optional[List[Callable]]): postprocessing function to apply to
|
|
318
|
+
each feature immediately after forward. Default to None.
|
|
319
|
+
detach (bool): if True, return features detached from the computational
|
|
320
|
+
graph. No gradient will be computed. Defaults to True.
|
|
321
|
+
verbose (bool): if True, display a progress bar. Defaults to False.
|
|
322
|
+
numpy_concat (bool): if True, each mini-batch is immediately moved
|
|
323
|
+
to CPU and converted to a NumPy array before concatenation.
|
|
324
|
+
That keeps GPU memory constant at one batch, at the cost of a small
|
|
325
|
+
host-device transfer overhead. Defaults to False.
|
|
326
|
+
|
|
327
|
+
Returns:
|
|
328
|
+
List[torch.Tensor], dict: features and extra information (logits, labels) as
|
|
329
|
+
a dictionary.
|
|
330
|
+
"""
|
|
331
|
+
labels = None
|
|
332
|
+
|
|
333
|
+
if isinstance(dataset, get_args(ItemType)):
|
|
334
|
+
tensor = TorchDataHandler.get_input_from_dataset_item(dataset)
|
|
335
|
+
features, logits = self.predict_tensor(tensor, postproc_fns, detach=detach)
|
|
336
|
+
|
|
337
|
+
# Get labels if dataset is a tuple/list
|
|
338
|
+
if isinstance(dataset, (list, tuple)) and len(dataset) > 1:
|
|
339
|
+
labels = TorchDataHandler.get_label_from_dataset_item(dataset)
|
|
340
|
+
|
|
341
|
+
else:
|
|
342
|
+
# Check if batches include labels
|
|
343
|
+
first_batch = next(iter(dataset))
|
|
344
|
+
contains_labels = (
|
|
345
|
+
isinstance(first_batch, (list, tuple)) and len(first_batch) > 1
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
# Prepare buffers
|
|
349
|
+
features_per_layer = [[] for _ in self.feature_layers_id]
|
|
350
|
+
logits_list = []
|
|
351
|
+
labels_list = [] if contains_labels else None
|
|
352
|
+
|
|
353
|
+
# Process batches
|
|
354
|
+
for elem in tqdm(dataset, desc="Predicting", disable=not verbose):
|
|
355
|
+
tensor = TorchDataHandler.get_input_from_dataset_item(elem)
|
|
356
|
+
feats_batch, logits_batch = self.predict_tensor(
|
|
357
|
+
tensor, postproc_fns, detach=detach
|
|
358
|
+
)
|
|
359
|
+
|
|
360
|
+
# Move to host and convert if requested
|
|
361
|
+
if numpy_concat:
|
|
362
|
+
feats_batch = [f.detach().cpu().numpy() for f in feats_batch]
|
|
363
|
+
logits_batch = (
|
|
364
|
+
logits_batch.detach().cpu().numpy()
|
|
365
|
+
if logits_batch is not None
|
|
366
|
+
else None
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
# Accumulate
|
|
370
|
+
for i, f in enumerate(feats_batch):
|
|
371
|
+
features_per_layer[i].append(f)
|
|
372
|
+
if logits_batch is not None:
|
|
373
|
+
logits_list.append(logits_batch)
|
|
374
|
+
if contains_labels:
|
|
375
|
+
lbl = TorchDataHandler.get_label_from_dataset_item(elem)
|
|
376
|
+
labels_list.append(lbl)
|
|
377
|
+
|
|
378
|
+
# Concatenate
|
|
379
|
+
labels = torch.cat(labels_list, dim=0) if labels_list is not None else None
|
|
380
|
+
|
|
381
|
+
if numpy_concat:
|
|
382
|
+
features = [np.concatenate(lst, axis=0) for lst in features_per_layer]
|
|
383
|
+
logits = np.concatenate(logits_list, axis=0) if logits_list else None
|
|
384
|
+
labels = labels.cpu().numpy() if labels is not None else None
|
|
385
|
+
|
|
386
|
+
else:
|
|
387
|
+
features = [torch.cat(lst, dim=0) for lst in features_per_layer]
|
|
388
|
+
logits = torch.cat(logits_list, dim=0) if logits_list else None
|
|
389
|
+
|
|
390
|
+
# Package extra info
|
|
391
|
+
info = {"labels": labels, "logits": logits}
|
|
392
|
+
return features, info
|
|
393
|
+
|
|
394
|
+
def get_weights(self, layer_id: Union[str, int]) -> List[torch.Tensor]:
|
|
395
|
+
"""Get the weights of a layer
|
|
396
|
+
|
|
397
|
+
Args:
|
|
398
|
+
layer_id (Union[int, str]): layer identifier
|
|
399
|
+
|
|
400
|
+
Returns:
|
|
401
|
+
List[torch.Tensor]: weights and biases matrixes
|
|
402
|
+
"""
|
|
403
|
+
layer = self.find_layer(self.model, layer_id)
|
|
404
|
+
return [layer.weight.detach().cpu().numpy(), layer.bias.detach().cpu().numpy()]
|
|
405
|
+
|
|
406
|
+
def _get_clip_hook(self, threshold: float) -> Callable:
|
|
407
|
+
"""
|
|
408
|
+
Hook that truncate activation features under a threshold value
|
|
409
|
+
|
|
410
|
+
Args:
|
|
411
|
+
threshold (float): threshold value
|
|
412
|
+
|
|
413
|
+
Returns:
|
|
414
|
+
Callable: hook function
|
|
415
|
+
"""
|
|
416
|
+
|
|
417
|
+
def hook(_, input):
|
|
418
|
+
input = input[0]
|
|
419
|
+
input = torch.clip(input, max=threshold)
|
|
420
|
+
return input
|
|
421
|
+
|
|
422
|
+
return hook
|
|
423
|
+
|
|
424
|
+
def _get_scale_hook(self, percentile: float) -> Callable:
|
|
425
|
+
"""
|
|
426
|
+
Hook that scales activation features.
|
|
427
|
+
|
|
428
|
+
Args:
|
|
429
|
+
threshold (float): threshold value
|
|
430
|
+
|
|
431
|
+
Returns:
|
|
432
|
+
Callable: hook function
|
|
433
|
+
"""
|
|
434
|
+
|
|
435
|
+
def hook(_, input):
|
|
436
|
+
input = input[0]
|
|
437
|
+
output_percentile = torch.quantile(input, percentile, dim=1)
|
|
438
|
+
mask = input > output_percentile[:, None]
|
|
439
|
+
output_masked = input * mask
|
|
440
|
+
s = torch.exp(torch.sum(input, dim=1) / torch.sum(output_masked, dim=1))
|
|
441
|
+
s = torch.unsqueeze(s, 1)
|
|
442
|
+
input = input * s
|
|
443
|
+
return input
|
|
444
|
+
|
|
445
|
+
return hook
|
|
446
|
+
|
|
447
|
+
def _get_ash_hook(self, percentile: float) -> Callable:
|
|
448
|
+
"""
|
|
449
|
+
Hook that scales and prunes activation features under a threshold value
|
|
450
|
+
|
|
451
|
+
Args:
|
|
452
|
+
threshold (float): threshold value
|
|
453
|
+
|
|
454
|
+
Returns:
|
|
455
|
+
Callable: hook function
|
|
456
|
+
"""
|
|
457
|
+
|
|
458
|
+
def hook(_, input):
|
|
459
|
+
input = input[0]
|
|
460
|
+
output_percentile = torch.quantile(input, percentile, dim=1)
|
|
461
|
+
mask = input > output_percentile[:, None]
|
|
462
|
+
output_masked = input * mask
|
|
463
|
+
s = torch.exp(torch.sum(input, dim=1) / torch.sum(output_masked, dim=1))
|
|
464
|
+
s = torch.unsqueeze(s, 1)
|
|
465
|
+
input = output_masked * s
|
|
466
|
+
return input
|
|
467
|
+
|
|
468
|
+
return hook
|
|
469
|
+
|
|
470
|
+
def _prepare_ood_handles(self) -> None:
|
|
471
|
+
"""
|
|
472
|
+
Prepare the model by either setting a new attribute to self.model
|
|
473
|
+
as a list which will contain all the ood specific hooks, or by cleaning
|
|
474
|
+
the existing ood specific hooks if the attribute already exists.
|
|
475
|
+
"""
|
|
476
|
+
|
|
477
|
+
if not hasattr(self.model, "_ood_handles"):
|
|
478
|
+
setattr(self.model, "_ood_handles", [])
|
|
479
|
+
else:
|
|
480
|
+
for handle in self.model._ood_handles:
|
|
481
|
+
handle.remove()
|
|
482
|
+
self.model._ood_handles = []
|
|
483
|
+
|
|
484
|
+
def _default_postproc_fn(self, feat: TensorType) -> TensorType:
|
|
485
|
+
"""Default postprocessing function to apply to each feature immediately
|
|
486
|
+
after forward.
|
|
487
|
+
|
|
488
|
+
Args:
|
|
489
|
+
feat (TensorType): input tensor
|
|
490
|
+
|
|
491
|
+
Returns:
|
|
492
|
+
TensorType: postprocessed tensor
|
|
493
|
+
"""
|
|
494
|
+
if len(feat.shape) == 4:
|
|
495
|
+
feat = nn.AdaptiveAvgPool2d(1)(feat)
|
|
496
|
+
feat = feat.view(feat.size(0), -1)
|
|
497
|
+
elif len(feat.shape) == 3:
|
|
498
|
+
feat = nn.AdaptiveAvgPool1d(1)(feat)
|
|
499
|
+
feat = feat.view(feat.size(0), -1)
|
|
500
|
+
elif len(feat.shape) == 2:
|
|
501
|
+
feat = feat
|
|
502
|
+
else:
|
|
503
|
+
raise NotImplementedError(
|
|
504
|
+
"Postprocessing function not implemented for this feature shape"
|
|
505
|
+
)
|
|
506
|
+
return feat
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Copyright IRT Antoine de Saint Exupéry et Université Paul Sabatier Toulouse III - All
|
|
3
|
+
# rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
|
|
4
|
+
# CRIAQ and ANITI - https://www.deel.ai/
|
|
5
|
+
#
|
|
6
|
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
7
|
+
# of this software and associated documentation files (the "Software"), to deal
|
|
8
|
+
# in the Software without restriction, including without limitation the rights
|
|
9
|
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
10
|
+
# copies of the Software, and to permit persons to whom the Software is
|
|
11
|
+
# furnished to do so, subject to the following conditions:
|
|
12
|
+
#
|
|
13
|
+
# The above copyright notice and this permission notice shall be included in all
|
|
14
|
+
# copies or substantial portions of the Software.
|
|
15
|
+
#
|
|
16
|
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
17
|
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
18
|
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
19
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
20
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
21
|
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
22
|
+
# SOFTWARE.
|
|
23
|
+
from .dknn import DKNN
|
|
24
|
+
from .energy import Energy
|
|
25
|
+
from .entropy import Entropy
|
|
26
|
+
from .gen import GEN
|
|
27
|
+
from .gram import Gram
|
|
28
|
+
from .mahalanobis import Mahalanobis
|
|
29
|
+
from .mls import MLS
|
|
30
|
+
from .odin import ODIN
|
|
31
|
+
from .rmds import RMDS
|
|
32
|
+
from .she import SHE
|
|
33
|
+
from .vim import VIM
|
|
34
|
+
|
|
35
|
+
__all__ = [
|
|
36
|
+
"DKNN",
|
|
37
|
+
"Energy",
|
|
38
|
+
"Entropy",
|
|
39
|
+
"GEN",
|
|
40
|
+
"SHE",
|
|
41
|
+
"Gram",
|
|
42
|
+
"Mahalanobis",
|
|
43
|
+
"MLS",
|
|
44
|
+
"ODIN",
|
|
45
|
+
"RMDS",
|
|
46
|
+
"VIM",
|
|
47
|
+
]
|