oodeel 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- oodeel/__init__.py +28 -0
- oodeel/aggregator/__init__.py +26 -0
- oodeel/aggregator/base.py +70 -0
- oodeel/aggregator/fisher.py +259 -0
- oodeel/aggregator/mean.py +72 -0
- oodeel/aggregator/std.py +86 -0
- oodeel/datasets/__init__.py +24 -0
- oodeel/datasets/data_handler.py +334 -0
- oodeel/datasets/deprecated/DEPRECATED_data_handler.py +236 -0
- oodeel/datasets/deprecated/DEPRECATED_ooddataset.py +330 -0
- oodeel/datasets/deprecated/DEPRECATED_tf_data_handler.py +671 -0
- oodeel/datasets/deprecated/DEPRECATED_torch_data_handler.py +769 -0
- oodeel/datasets/deprecated/__init__.py +31 -0
- oodeel/datasets/tf_data_handler.py +600 -0
- oodeel/datasets/torch_data_handler.py +672 -0
- oodeel/eval/__init__.py +22 -0
- oodeel/eval/metrics.py +218 -0
- oodeel/eval/plots/__init__.py +27 -0
- oodeel/eval/plots/features.py +345 -0
- oodeel/eval/plots/metrics.py +118 -0
- oodeel/eval/plots/plotly.py +162 -0
- oodeel/extractor/__init__.py +35 -0
- oodeel/extractor/feature_extractor.py +187 -0
- oodeel/extractor/hf_torch_feature_extractor.py +184 -0
- oodeel/extractor/keras_feature_extractor.py +409 -0
- oodeel/extractor/torch_feature_extractor.py +506 -0
- oodeel/methods/__init__.py +47 -0
- oodeel/methods/base.py +570 -0
- oodeel/methods/dknn.py +185 -0
- oodeel/methods/energy.py +119 -0
- oodeel/methods/entropy.py +113 -0
- oodeel/methods/gen.py +113 -0
- oodeel/methods/gram.py +274 -0
- oodeel/methods/mahalanobis.py +209 -0
- oodeel/methods/mls.py +113 -0
- oodeel/methods/odin.py +109 -0
- oodeel/methods/rmds.py +172 -0
- oodeel/methods/she.py +159 -0
- oodeel/methods/vim.py +273 -0
- oodeel/preprocess/__init__.py +31 -0
- oodeel/preprocess/tf_preprocess.py +95 -0
- oodeel/preprocess/torch_preprocess.py +97 -0
- oodeel/types/__init__.py +75 -0
- oodeel/utils/__init__.py +38 -0
- oodeel/utils/general_utils.py +97 -0
- oodeel/utils/operator.py +253 -0
- oodeel/utils/tf_operator.py +269 -0
- oodeel/utils/tf_training_tools.py +219 -0
- oodeel/utils/torch_operator.py +292 -0
- oodeel/utils/torch_training_tools.py +303 -0
- oodeel-0.4.0.dist-info/METADATA +409 -0
- oodeel-0.4.0.dist-info/RECORD +63 -0
- oodeel-0.4.0.dist-info/WHEEL +5 -0
- oodeel-0.4.0.dist-info/licenses/LICENSE +21 -0
- oodeel-0.4.0.dist-info/top_level.txt +2 -0
- tests/__init__.py +22 -0
- tests/tests_tensorflow/__init__.py +37 -0
- tests/tests_tensorflow/tf_methods_utils.py +140 -0
- tests/tests_tensorflow/tools_tf.py +86 -0
- tests/tests_torch/__init__.py +38 -0
- tests/tests_torch/tools_torch.py +151 -0
- tests/tests_torch/torch_methods_utils.py +148 -0
- tests/tools_operator.py +153 -0
oodeel/methods/base.py
ADDED
|
@@ -0,0 +1,570 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Copyright IRT Antoine de Saint Exupéry et Université Paul Sabatier Toulouse III - All
|
|
3
|
+
# rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
|
|
4
|
+
# CRIAQ and ANITI - https://www.deel.ai/
|
|
5
|
+
#
|
|
6
|
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
7
|
+
# of this software and associated documentation files (the "Software"), to deal
|
|
8
|
+
# in the Software without restriction, including without limitation the rights
|
|
9
|
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
10
|
+
# copies of the Software, and to permit persons to whom the Software is
|
|
11
|
+
# furnished to do so, subject to the following conditions:
|
|
12
|
+
#
|
|
13
|
+
# The above copyright notice and this permission notice shall be included in all
|
|
14
|
+
# copies or substantial portions of the Software.
|
|
15
|
+
#
|
|
16
|
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
17
|
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
18
|
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
19
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
20
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
21
|
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
22
|
+
# SOFTWARE.
|
|
23
|
+
import inspect
|
|
24
|
+
from abc import ABC
|
|
25
|
+
from typing import Dict
|
|
26
|
+
from typing import get_args
|
|
27
|
+
|
|
28
|
+
import numpy as np
|
|
29
|
+
from tqdm import tqdm
|
|
30
|
+
|
|
31
|
+
from ..aggregator import BaseAggregator
|
|
32
|
+
from ..aggregator import StdNormalizedAggregator
|
|
33
|
+
from ..extractor.feature_extractor import FeatureExtractor
|
|
34
|
+
from ..types import Callable
|
|
35
|
+
from ..types import DatasetType
|
|
36
|
+
from ..types import ItemType
|
|
37
|
+
from ..types import List
|
|
38
|
+
from ..types import Optional
|
|
39
|
+
from ..types import TensorType
|
|
40
|
+
from ..types import Union
|
|
41
|
+
from ..utils import import_backend_specific_stuff
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class OODBaseDetector(ABC):
|
|
45
|
+
"""OODBaseDetector is an abstract base class for Out-of-Distribution (OOD)
|
|
46
|
+
detection.
|
|
47
|
+
|
|
48
|
+
Attributes:
|
|
49
|
+
feature_extractor (FeatureExtractor): The feature extractor instance.
|
|
50
|
+
use_react (bool): Flag to indicate if ReAct method is used.
|
|
51
|
+
use_scale (bool): Flag to indicate if scaling method is used.
|
|
52
|
+
use_ash (bool): Flag to indicate if ASH method is used.
|
|
53
|
+
react_quantile (float): Quantile value for ReAct threshold.
|
|
54
|
+
scale_percentile (float): Percentile value for scaling.
|
|
55
|
+
ash_percentile (float): Percentile value for ASH.
|
|
56
|
+
eps (float): Perturbation noise for input perturbation.
|
|
57
|
+
temperature (float): Temperature parameter for input pertubation.
|
|
58
|
+
react_threshold (Optional[float]): Threshold for ReAct clipping.
|
|
59
|
+
|
|
60
|
+
Public Methods:
|
|
61
|
+
- fit(): Prepare the detector by setting up feature extraction and calibrating.
|
|
62
|
+
- score(): Compute OOD scores on input data (batched or single item).
|
|
63
|
+
- __call__(): Shorthand for score().
|
|
64
|
+
|
|
65
|
+
Internal Methods (for subclassing or advanced usage):
|
|
66
|
+
- _fit_to_dataset(): Optional calibration routine on a dataset.
|
|
67
|
+
- _score_tensor(): Compute scores for a single batch of input data.
|
|
68
|
+
- _load_feature_extractor(): Initialize feature extraction pipeline.
|
|
69
|
+
- _sanitize_posproc_fns(): Normalize post-processing function list.
|
|
70
|
+
- _compute_react_threshold(): Calibrate ReAct clipping threshold.
|
|
71
|
+
- _input_perturbation(): Apply perturbation to input data.
|
|
72
|
+
|
|
73
|
+
Abstract Properties:
|
|
74
|
+
- requires_to_fit_dataset: Whether fit_dataset is mandatory for calibration.
|
|
75
|
+
- requires_internal_features: Whether the detector uses internal features.
|
|
76
|
+
"""
|
|
77
|
+
|
|
78
|
+
def __init__(
|
|
79
|
+
self,
|
|
80
|
+
use_react: Optional[bool] = False,
|
|
81
|
+
use_scale: Optional[bool] = False,
|
|
82
|
+
use_ash: Optional[bool] = False,
|
|
83
|
+
react_quantile: Optional[float] = None,
|
|
84
|
+
scale_percentile: Optional[float] = None,
|
|
85
|
+
ash_percentile: Optional[float] = None,
|
|
86
|
+
eps: float = 0.0,
|
|
87
|
+
temperature: float = 1000.0,
|
|
88
|
+
):
|
|
89
|
+
self.feature_extractor: FeatureExtractor = None
|
|
90
|
+
self.use_react = use_react
|
|
91
|
+
self.use_scale = use_scale
|
|
92
|
+
self.use_ash = use_ash
|
|
93
|
+
self.react_quantile = react_quantile
|
|
94
|
+
self.scale_percentile = scale_percentile
|
|
95
|
+
self.ash_percentile = ash_percentile
|
|
96
|
+
self.eps = eps
|
|
97
|
+
self.temperature = temperature
|
|
98
|
+
self.react_threshold = None
|
|
99
|
+
|
|
100
|
+
if use_scale and use_react:
|
|
101
|
+
raise ValueError("Cannot use both ReAct and scale at the same time")
|
|
102
|
+
if use_scale and use_ash:
|
|
103
|
+
raise ValueError("Cannot use both ASH and scale at the same time")
|
|
104
|
+
if use_ash and use_react:
|
|
105
|
+
raise ValueError("Cannot use both ReAct and ASH at the same time")
|
|
106
|
+
|
|
107
|
+
# === Public API ===
|
|
108
|
+
def fit(
|
|
109
|
+
self,
|
|
110
|
+
model: Callable,
|
|
111
|
+
fit_dataset: Optional[Union[ItemType, DatasetType]] = None,
|
|
112
|
+
feature_layers_id: List[Union[int, str]] = [],
|
|
113
|
+
postproc_fns: Optional[List[Callable]] = None,
|
|
114
|
+
head_layer_id: Optional[Union[int, str]] = -1,
|
|
115
|
+
input_layer_id: Optional[Union[int, str]] = None,
|
|
116
|
+
verbose: bool = False,
|
|
117
|
+
**kwargs,
|
|
118
|
+
) -> None:
|
|
119
|
+
"""Prepare the detector for scoring:
|
|
120
|
+
* Constructs the feature extractor based on the model
|
|
121
|
+
* Calibrates the detector on ID data "fit_dataset" if needed,
|
|
122
|
+
using self._fit_to_dataset
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
model: model to extract the features from
|
|
126
|
+
fit_dataset: dataset to fit the detector on
|
|
127
|
+
feature_layers_id (List[int]): list of str or int that identify
|
|
128
|
+
features to output.
|
|
129
|
+
If int, the rank of the layer in the layer list
|
|
130
|
+
If str, the name of the layer. Defaults to [-1]
|
|
131
|
+
postproc_fns (Optional[List[Callable]]): list of postproc functions,
|
|
132
|
+
one per output layer. Defaults to None.
|
|
133
|
+
If None, identity function is used.
|
|
134
|
+
head_layer_id (int, str): identifier of the head layer.
|
|
135
|
+
If int, the rank of the layer in the layer list
|
|
136
|
+
If str, the name of the layer.
|
|
137
|
+
Defaults to -1
|
|
138
|
+
input_layer_id (List[int]): = list of str or int that identify the input
|
|
139
|
+
layer of the feature extractor.
|
|
140
|
+
If int, the rank of the layer in the layer list
|
|
141
|
+
If str, the name of the layer. Defaults to None.
|
|
142
|
+
verbose (bool): if True, display a progress bar. Defaults to False.
|
|
143
|
+
"""
|
|
144
|
+
(
|
|
145
|
+
self.backend,
|
|
146
|
+
self.data_handler,
|
|
147
|
+
self.op,
|
|
148
|
+
self.FeatureExtractorClass,
|
|
149
|
+
) = import_backend_specific_stuff(model)
|
|
150
|
+
|
|
151
|
+
# if required by the method, check that fit_dataset is not None
|
|
152
|
+
if self.requires_to_fit_dataset and fit_dataset is None:
|
|
153
|
+
raise ValueError(
|
|
154
|
+
"`fit_dataset` argument must be provided for this OOD detector"
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
self.postproc_fns = self._sanitize_posproc_fns(postproc_fns)
|
|
158
|
+
|
|
159
|
+
# react: compute threshold (activation percentiles)
|
|
160
|
+
if self.use_react:
|
|
161
|
+
if fit_dataset is None:
|
|
162
|
+
raise ValueError(
|
|
163
|
+
"if react quantile is not None, fit_dataset must be"
|
|
164
|
+
" provided to compute react activation threshold"
|
|
165
|
+
)
|
|
166
|
+
else:
|
|
167
|
+
self._compute_react_threshold(
|
|
168
|
+
model, fit_dataset, verbose=verbose, head_layer_id=head_layer_id
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
if (feature_layers_id == []) and isinstance(self, FeatureBasedDetector):
|
|
172
|
+
raise ValueError(
|
|
173
|
+
"Explicitly specify feature_layers_id=[layer0, layer1,...], "
|
|
174
|
+
+ "where layer0, layer1,... are the names of the desired output "
|
|
175
|
+
+ "layers of your model. These can be int or str (even though str"
|
|
176
|
+
+ " is safer). To know what to put, have a look at model.summary() "
|
|
177
|
+
+ "with keras or model.named_modules() with pytorch"
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
self.feature_extractor = self._load_feature_extractor(
|
|
181
|
+
model, feature_layers_id, head_layer_id, input_layer_id
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
if fit_dataset is not None:
|
|
185
|
+
if "verbose" in inspect.signature(self._fit_to_dataset).parameters.keys():
|
|
186
|
+
kwargs.update({"verbose": verbose})
|
|
187
|
+
self._fit_to_dataset(fit_dataset, **kwargs)
|
|
188
|
+
|
|
189
|
+
def score(
|
|
190
|
+
self,
|
|
191
|
+
dataset: Union[ItemType, DatasetType],
|
|
192
|
+
verbose: bool = False,
|
|
193
|
+
) -> np.ndarray:
|
|
194
|
+
"""
|
|
195
|
+
Computes an OOD score for input samples "inputs".
|
|
196
|
+
|
|
197
|
+
Args:
|
|
198
|
+
dataset (Union[ItemType, DatasetType]): dataset or tensors to score
|
|
199
|
+
verbose (bool): if True, display a progress bar. Defaults to False.
|
|
200
|
+
|
|
201
|
+
Returns:
|
|
202
|
+
tuple: scores or list of scores (depending on the input) and a dictionary
|
|
203
|
+
containing logits and labels.
|
|
204
|
+
"""
|
|
205
|
+
assert self.feature_extractor is not None, "Call .fit() before .score()"
|
|
206
|
+
labels = None
|
|
207
|
+
# Case 1: dataset is neither a tf.data.Dataset nor a torch.DataLoader
|
|
208
|
+
if isinstance(dataset, get_args(ItemType)):
|
|
209
|
+
tensor = self.data_handler.get_input_from_dataset_item(dataset)
|
|
210
|
+
scores = self._score_tensor(tensor)
|
|
211
|
+
logits = self.op.convert_to_numpy(self.feature_extractor._last_logits)
|
|
212
|
+
|
|
213
|
+
# Get labels if dataset is a tuple/list
|
|
214
|
+
if isinstance(dataset, (list, tuple)):
|
|
215
|
+
labels = self.data_handler.get_label_from_dataset_item(dataset)
|
|
216
|
+
labels = self.op.convert_to_numpy(labels)
|
|
217
|
+
|
|
218
|
+
# Case 2: dataset is a tf.data.Dataset or a torch.DataLoader
|
|
219
|
+
elif isinstance(dataset, get_args(DatasetType)):
|
|
220
|
+
scores = np.array([])
|
|
221
|
+
logits = None
|
|
222
|
+
|
|
223
|
+
for item in tqdm(dataset, desc="Scoring", disable=not verbose):
|
|
224
|
+
tensor = self.data_handler.get_input_from_dataset_item(item)
|
|
225
|
+
score_batch = self._score_tensor(tensor)
|
|
226
|
+
logits_batch = self.op.convert_to_numpy(
|
|
227
|
+
self.feature_extractor._last_logits
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
# get the label if available
|
|
231
|
+
if len(item) > 1:
|
|
232
|
+
labels_batch = self.data_handler.get_label_from_dataset_item(item)
|
|
233
|
+
labels = (
|
|
234
|
+
labels_batch
|
|
235
|
+
if labels is None
|
|
236
|
+
else np.append(labels, self.op.convert_to_numpy(labels_batch))
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
scores = np.append(scores, score_batch)
|
|
240
|
+
logits = (
|
|
241
|
+
logits_batch
|
|
242
|
+
if logits is None
|
|
243
|
+
else np.concatenate([logits, logits_batch], axis=0)
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
else:
|
|
247
|
+
raise NotImplementedError(
|
|
248
|
+
f"OODBaseDetector.score() not implemented for {type(dataset)}"
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
info = dict(labels=labels, logits=logits)
|
|
252
|
+
return scores, info
|
|
253
|
+
|
|
254
|
+
def __call__(self, inputs: Union[ItemType, DatasetType]) -> np.ndarray:
|
|
255
|
+
"""
|
|
256
|
+
Convenience wrapper for score
|
|
257
|
+
|
|
258
|
+
Args:
|
|
259
|
+
inputs (Union[ItemType, DatasetType]): dataset or tensors to score.
|
|
260
|
+
threshold (float): threshold to use for distinguishing between OOD and ID
|
|
261
|
+
|
|
262
|
+
Returns:
|
|
263
|
+
np.ndarray: array of 0 for ID samples and 1 for OOD samples
|
|
264
|
+
"""
|
|
265
|
+
return self.score(inputs)
|
|
266
|
+
|
|
267
|
+
# === Internal: Feature Extractor ===
|
|
268
|
+
def _load_feature_extractor(
|
|
269
|
+
self,
|
|
270
|
+
model: Callable,
|
|
271
|
+
feature_layers_id: List[Union[int, str]] = [],
|
|
272
|
+
head_layer_id: Optional[Union[int, str]] = -1,
|
|
273
|
+
input_layer_id: Optional[Union[int, str]] = None,
|
|
274
|
+
return_penultimate: bool = False,
|
|
275
|
+
) -> Callable:
|
|
276
|
+
"""
|
|
277
|
+
Loads feature extractor
|
|
278
|
+
|
|
279
|
+
Args:
|
|
280
|
+
model: a model (Keras or PyTorch) to load.
|
|
281
|
+
feature_layers_id (List[int]): list of str or int that identify
|
|
282
|
+
features to output.
|
|
283
|
+
If int, the rank of the layer in the layer list
|
|
284
|
+
If str, the name of the layer. Defaults to [-1]
|
|
285
|
+
head_layer_id (int): identifier of the head layer.
|
|
286
|
+
-1 when the last layer is the head
|
|
287
|
+
-2 when the last layer is a softmax activation layer
|
|
288
|
+
...
|
|
289
|
+
If int, the rank of the layer in the layer list
|
|
290
|
+
If str, the name of the layer. Defaults to -1
|
|
291
|
+
input_layer_id (List[int]): = list of str or int that identify the input
|
|
292
|
+
layer of the feature extractor.
|
|
293
|
+
If int, the rank of the layer in the layer list
|
|
294
|
+
If str, the name of the layer. Defaults to None.
|
|
295
|
+
return_penultimate (bool): if True, the penultimate values are returned,
|
|
296
|
+
i.e. the input to the head_layer.
|
|
297
|
+
|
|
298
|
+
Returns:
|
|
299
|
+
FeatureExtractor: a feature extractor instance
|
|
300
|
+
"""
|
|
301
|
+
if not self.use_ash:
|
|
302
|
+
self.ash_percentile = None
|
|
303
|
+
if not self.use_scale:
|
|
304
|
+
self.scale_percentile = None
|
|
305
|
+
|
|
306
|
+
feature_extractor = self.FeatureExtractorClass(
|
|
307
|
+
model,
|
|
308
|
+
feature_layers_id=feature_layers_id,
|
|
309
|
+
input_layer_id=input_layer_id,
|
|
310
|
+
head_layer_id=head_layer_id,
|
|
311
|
+
react_threshold=self.react_threshold,
|
|
312
|
+
scale_percentile=self.scale_percentile,
|
|
313
|
+
ash_percentile=self.ash_percentile,
|
|
314
|
+
return_penultimate=return_penultimate,
|
|
315
|
+
)
|
|
316
|
+
return feature_extractor
|
|
317
|
+
|
|
318
|
+
def _sanitize_posproc_fns(
|
|
319
|
+
self,
|
|
320
|
+
postproc_fns: Union[List[Callable], None],
|
|
321
|
+
) -> List[Callable]:
|
|
322
|
+
"""Sanitize postproc fns used at each layer output of the feature extractor.
|
|
323
|
+
|
|
324
|
+
Args:
|
|
325
|
+
postproc_fns (Optional[List[Callable]], optional): List of postproc
|
|
326
|
+
functions, one per output layer. Defaults to None.
|
|
327
|
+
|
|
328
|
+
Returns:
|
|
329
|
+
List[Callable]: Sanitized postproc_fns list
|
|
330
|
+
"""
|
|
331
|
+
if postproc_fns is not None:
|
|
332
|
+
assert len(postproc_fns) == len(
|
|
333
|
+
self.feature_extractor.feature_layers_id
|
|
334
|
+
), "len of postproc_fns and output_layers_id must match"
|
|
335
|
+
|
|
336
|
+
def identity(x):
|
|
337
|
+
return x
|
|
338
|
+
|
|
339
|
+
postproc_fns = [identity if fn is None else fn for fn in postproc_fns]
|
|
340
|
+
|
|
341
|
+
return postproc_fns
|
|
342
|
+
|
|
343
|
+
# === Internal: ODIN input perturbation ===
|
|
344
|
+
def _input_perturbation(
|
|
345
|
+
self, inputs: TensorType, eps: float, temperature: float = 1000
|
|
346
|
+
) -> TensorType:
|
|
347
|
+
"""Apply the ODIN gradient-based perturbation to the inputs.
|
|
348
|
+
|
|
349
|
+
Args:
|
|
350
|
+
inputs: Batch of samples to perturb.
|
|
351
|
+
eps: Magnitude of the perturbation. If zero, `inputs` are
|
|
352
|
+
returned unchanged.
|
|
353
|
+
temperature: Temperature used for the softmax in the loss.
|
|
354
|
+
|
|
355
|
+
Returns:
|
|
356
|
+
Perturbed input tensor of the same shape as `inputs`.
|
|
357
|
+
"""
|
|
358
|
+
|
|
359
|
+
if eps == 0:
|
|
360
|
+
return inputs
|
|
361
|
+
|
|
362
|
+
if self.feature_extractor.backend == "torch":
|
|
363
|
+
inputs = inputs.to(self.feature_extractor._device)
|
|
364
|
+
|
|
365
|
+
preds = self.feature_extractor.model(inputs)
|
|
366
|
+
labels = self.op.argmax(preds, dim=1)
|
|
367
|
+
gradients = self.op.gradient(
|
|
368
|
+
self._temperature_loss, inputs, labels, temperature
|
|
369
|
+
)
|
|
370
|
+
return inputs - eps * self.op.sign(gradients)
|
|
371
|
+
|
|
372
|
+
def _temperature_loss(
|
|
373
|
+
self, inputs: TensorType, labels: TensorType, temperature: float
|
|
374
|
+
) -> TensorType:
|
|
375
|
+
"""Cross-entropy loss used for ODIN input perturbation."""
|
|
376
|
+
|
|
377
|
+
preds = self.feature_extractor.model(inputs) / temperature
|
|
378
|
+
loss = self.op.CrossEntropyLoss(reduction="sum")(inputs=preds, targets=labels)
|
|
379
|
+
return loss
|
|
380
|
+
|
|
381
|
+
# === Internal: Fitting logic ===
|
|
382
|
+
def _fit_to_dataset(
|
|
383
|
+
self,
|
|
384
|
+
fit_dataset: DatasetType,
|
|
385
|
+
verbose: bool = False,
|
|
386
|
+
**kwargs,
|
|
387
|
+
) -> None:
|
|
388
|
+
"""Optional fitting routine for detectors using a calibration dataset."""
|
|
389
|
+
|
|
390
|
+
return None
|
|
391
|
+
|
|
392
|
+
# === Internal: Scoring logic ===
|
|
393
|
+
def _score_tensor(self, inputs: TensorType) -> np.ndarray:
|
|
394
|
+
"""Compute an OOD score for a batch of inputs.
|
|
395
|
+
|
|
396
|
+
Child classes must implement this method. It should return one score per
|
|
397
|
+
sample of `inputs`.
|
|
398
|
+
"""
|
|
399
|
+
|
|
400
|
+
raise NotImplementedError("_score_tensor must be implemented in subclasses")
|
|
401
|
+
|
|
402
|
+
# === Internal calibration methods ===
|
|
403
|
+
def _compute_react_threshold(
|
|
404
|
+
self,
|
|
405
|
+
model: Callable,
|
|
406
|
+
fit_dataset: DatasetType,
|
|
407
|
+
verbose: bool = False,
|
|
408
|
+
head_layer_id: int = -1,
|
|
409
|
+
):
|
|
410
|
+
penult_feat_extractor = self._load_feature_extractor(
|
|
411
|
+
model, head_layer_id=head_layer_id, return_penultimate=True
|
|
412
|
+
)
|
|
413
|
+
unclipped_features, _ = penult_feat_extractor.predict(
|
|
414
|
+
fit_dataset,
|
|
415
|
+
verbose=verbose,
|
|
416
|
+
postproc_fns=self.postproc_fns,
|
|
417
|
+
numpy_concat=True,
|
|
418
|
+
)
|
|
419
|
+
self.react_threshold = np.quantile(unclipped_features[0], self.react_quantile)
|
|
420
|
+
|
|
421
|
+
# === Properties ===
|
|
422
|
+
@property
|
|
423
|
+
def requires_to_fit_dataset(self) -> bool:
|
|
424
|
+
"""
|
|
425
|
+
Whether an OOD detector needs a `fit_dataset` argument in the fit function.
|
|
426
|
+
|
|
427
|
+
Returns:
|
|
428
|
+
bool: True if `fit_dataset` is required else False.
|
|
429
|
+
"""
|
|
430
|
+
raise NotImplementedError(
|
|
431
|
+
"Property `requires_to_fit_dataset` is not implemented. It should return"
|
|
432
|
+
+ " a True or False boolean."
|
|
433
|
+
)
|
|
434
|
+
|
|
435
|
+
@property
|
|
436
|
+
def requires_internal_features(self) -> bool:
|
|
437
|
+
"""
|
|
438
|
+
Whether an OOD detector acts on internal model features.
|
|
439
|
+
|
|
440
|
+
Returns:
|
|
441
|
+
bool: True if the detector perform computations on an intermediate layer
|
|
442
|
+
else False.
|
|
443
|
+
"""
|
|
444
|
+
raise NotImplementedError(
|
|
445
|
+
"Property `requires_internal_dataset` is not implemented. It should return"
|
|
446
|
+
+ " a True or False boolean."
|
|
447
|
+
)
|
|
448
|
+
|
|
449
|
+
|
|
450
|
+
class FeatureBasedDetector(OODBaseDetector):
|
|
451
|
+
"""Base class for detectors operating on internal feature representations."""
|
|
452
|
+
|
|
453
|
+
def __init__(self, aggregator: BaseAggregator = None, *args, **kwargs):
|
|
454
|
+
"""Initialize the feature-based OOD detector.
|
|
455
|
+
|
|
456
|
+
Args:
|
|
457
|
+
aggregator (BaseAggregator, optional): Aggregator to normalize scores
|
|
458
|
+
across multiple feature layers. Defaults to None.
|
|
459
|
+
*args: Additional positional arguments.
|
|
460
|
+
**kwargs: Additional keyword arguments.
|
|
461
|
+
"""
|
|
462
|
+
super().__init__(*args, **kwargs)
|
|
463
|
+
self.aggregator = aggregator
|
|
464
|
+
self.postproc_fns = None
|
|
465
|
+
|
|
466
|
+
# === Internal: Fitting logic ===
|
|
467
|
+
def _fit_to_dataset(
|
|
468
|
+
self,
|
|
469
|
+
fit_dataset: DatasetType,
|
|
470
|
+
verbose: bool = False,
|
|
471
|
+
**kwargs,
|
|
472
|
+
) -> None:
|
|
473
|
+
"""Extract features from `fit_dataset` and compute layer statistitics which will
|
|
474
|
+
be used for scoring in _score_layer.
|
|
475
|
+
|
|
476
|
+
Child classes must implement :func:`_fit_layer` to compute the
|
|
477
|
+
statistics required by the detector on each feature layer. If an
|
|
478
|
+
`aggregator` attribute is present and more than one layer is used, the
|
|
479
|
+
returned scores are fed to it for normalization.
|
|
480
|
+
"""
|
|
481
|
+
|
|
482
|
+
n_layers = len(self.feature_extractor.feature_layers_id)
|
|
483
|
+
|
|
484
|
+
if self.postproc_fns is None:
|
|
485
|
+
self.postproc_fns = [self.feature_extractor._default_postproc_fn] * n_layers
|
|
486
|
+
|
|
487
|
+
feats, info = self.feature_extractor.predict(
|
|
488
|
+
fit_dataset,
|
|
489
|
+
postproc_fns=self.postproc_fns,
|
|
490
|
+
verbose=verbose,
|
|
491
|
+
return_labels=True,
|
|
492
|
+
numpy_concat=True,
|
|
493
|
+
)
|
|
494
|
+
|
|
495
|
+
aggregator = getattr(self, "aggregator", None)
|
|
496
|
+
if aggregator is None and n_layers > 1:
|
|
497
|
+
aggregator = StdNormalizedAggregator()
|
|
498
|
+
setattr(self, "aggregator", aggregator)
|
|
499
|
+
|
|
500
|
+
per_layer_scores = []
|
|
501
|
+
for idx in range(n_layers):
|
|
502
|
+
self._fit_layer(idx, feats[idx], info, **kwargs)
|
|
503
|
+
|
|
504
|
+
# If the aggregator is not None, compute per-layer scores
|
|
505
|
+
if aggregator is not None:
|
|
506
|
+
# batch the scoring to avoid memory issues
|
|
507
|
+
batch_size = kwargs.get("batch_size", 128)
|
|
508
|
+
n_samples = feats[idx].shape[0]
|
|
509
|
+
scores = []
|
|
510
|
+
for start in range(0, n_samples, batch_size):
|
|
511
|
+
end = min(start + batch_size, n_samples)
|
|
512
|
+
batch_feats = feats[idx][start:end]
|
|
513
|
+
batch_info = {k: v[start:end] for k, v in info.items()}
|
|
514
|
+
batch_feats = self.op.from_numpy(batch_feats)
|
|
515
|
+
scores.append(
|
|
516
|
+
self._score_layer(
|
|
517
|
+
idx, batch_feats, batch_info, fit=True, **kwargs
|
|
518
|
+
)
|
|
519
|
+
)
|
|
520
|
+
scores = np.concatenate(scores, axis=0)
|
|
521
|
+
per_layer_scores.append(scores)
|
|
522
|
+
|
|
523
|
+
if aggregator is not None and per_layer_scores:
|
|
524
|
+
aggregator.fit(per_layer_scores)
|
|
525
|
+
|
|
526
|
+
def _fit_layer(
|
|
527
|
+
self,
|
|
528
|
+
layer_id: int,
|
|
529
|
+
layer_features: np.ndarray,
|
|
530
|
+
info: Dict[str, TensorType],
|
|
531
|
+
**kwargs,
|
|
532
|
+
) -> None:
|
|
533
|
+
"""Compute statistics for a single feature layer."""
|
|
534
|
+
|
|
535
|
+
raise NotImplementedError
|
|
536
|
+
|
|
537
|
+
# === Internal: Scoring logic ===
|
|
538
|
+
def _score_tensor(self, inputs: TensorType) -> np.ndarray:
|
|
539
|
+
"""Compute the OOD score for a batch using internal features."""
|
|
540
|
+
|
|
541
|
+
if getattr(self, "eps", 0) > 0:
|
|
542
|
+
inputs = self._input_perturbation(inputs, self.eps)
|
|
543
|
+
|
|
544
|
+
feats, logits = self.feature_extractor.predict_tensor(
|
|
545
|
+
inputs, postproc_fns=self.postproc_fns
|
|
546
|
+
)
|
|
547
|
+
|
|
548
|
+
info: Dict[str, TensorType] = {"logits": logits}
|
|
549
|
+
per_layer_scores = [
|
|
550
|
+
self._score_layer(idx, feats[idx], info) for idx in range(len(feats))
|
|
551
|
+
]
|
|
552
|
+
|
|
553
|
+
aggregator = getattr(self, "aggregator", None)
|
|
554
|
+
if aggregator is not None and len(per_layer_scores) > 1:
|
|
555
|
+
return aggregator.aggregate(per_layer_scores)
|
|
556
|
+
if len(per_layer_scores) > 1:
|
|
557
|
+
return np.mean(np.stack(per_layer_scores, axis=1), axis=1)
|
|
558
|
+
return per_layer_scores[0]
|
|
559
|
+
|
|
560
|
+
def _score_layer(
|
|
561
|
+
self,
|
|
562
|
+
layer_id: int,
|
|
563
|
+
layer_features: TensorType,
|
|
564
|
+
info: Dict[str, TensorType],
|
|
565
|
+
fit: bool = False,
|
|
566
|
+
**kwargs,
|
|
567
|
+
) -> np.ndarray:
|
|
568
|
+
"""Score samples for a single feature layer."""
|
|
569
|
+
|
|
570
|
+
raise NotImplementedError
|