oodeel 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of oodeel might be problematic. Click here for more details.

Files changed (42) hide show
  1. oodeel/__init__.py +1 -1
  2. oodeel/datasets/__init__.py +2 -1
  3. oodeel/datasets/data_handler.py +162 -94
  4. oodeel/datasets/deprecated/DEPRECATED_data_handler.py +236 -0
  5. oodeel/datasets/{ooddataset.py → deprecated/DEPRECATED_ooddataset.py} +14 -13
  6. oodeel/datasets/deprecated/DEPRECATED_tf_data_handler.py +671 -0
  7. oodeel/datasets/deprecated/DEPRECATED_torch_data_handler.py +769 -0
  8. oodeel/datasets/deprecated/__init__.py +31 -0
  9. oodeel/datasets/tf_data_handler.py +105 -167
  10. oodeel/datasets/torch_data_handler.py +109 -181
  11. oodeel/eval/metrics.py +7 -2
  12. oodeel/extractor/feature_extractor.py +11 -0
  13. oodeel/extractor/keras_feature_extractor.py +51 -1
  14. oodeel/extractor/torch_feature_extractor.py +103 -21
  15. oodeel/methods/__init__.py +16 -1
  16. oodeel/methods/base.py +72 -15
  17. oodeel/methods/dknn.py +20 -7
  18. oodeel/methods/energy.py +8 -0
  19. oodeel/methods/entropy.py +8 -0
  20. oodeel/methods/gen.py +118 -0
  21. oodeel/methods/gram.py +15 -4
  22. oodeel/methods/mahalanobis.py +9 -7
  23. oodeel/methods/mls.py +8 -0
  24. oodeel/methods/odin.py +8 -0
  25. oodeel/methods/rmds.py +122 -0
  26. oodeel/methods/she.py +197 -0
  27. oodeel/methods/vim.py +1 -1
  28. oodeel/preprocess/__init__.py +31 -0
  29. oodeel/preprocess/tf_preprocess.py +95 -0
  30. oodeel/preprocess/torch_preprocess.py +97 -0
  31. oodeel/utils/operator.py +17 -0
  32. oodeel/utils/tf_operator.py +15 -0
  33. oodeel/utils/tf_training_tools.py +2 -2
  34. oodeel/utils/torch_operator.py +19 -0
  35. {oodeel-0.2.0.dist-info → oodeel-0.3.0.dist-info}/METADATA +139 -105
  36. oodeel-0.3.0.dist-info/RECORD +57 -0
  37. {oodeel-0.2.0.dist-info → oodeel-0.3.0.dist-info}/WHEEL +1 -1
  38. tests/tests_tensorflow/tf_methods_utils.py +2 -1
  39. tests/tests_torch/torch_methods_utils.py +34 -27
  40. oodeel-0.2.0.dist-info/RECORD +0 -47
  41. {oodeel-0.2.0.dist-info → oodeel-0.3.0.dist-info/licenses}/LICENSE +0 -0
  42. {oodeel-0.2.0.dist-info → oodeel-0.3.0.dist-info}/top_level.txt +0 -0
@@ -24,6 +24,8 @@ from typing import get_args
24
24
  from typing import Optional
25
25
 
26
26
  import tensorflow as tf
27
+ import tensorflow_probability as tfp
28
+ from tqdm import tqdm
27
29
 
28
30
  from ..datasets.tf_data_handler import TFDataHandler
29
31
  from ..types import Callable
@@ -54,6 +56,11 @@ class KerasFeatureExtractor(FeatureExtractor):
54
56
  Defaults to None.
55
57
  react_threshold: if not None, penultimate layer activations are clipped under
56
58
  this threshold value (useful for ReAct). Defaults to None.
59
+ scale_percentile: if not None, the features are scaled
60
+ following the method of Xu et al., ICLR 2024.
61
+ Defaults to None.
62
+ ash_percentile: if not None, the features are scaled following
63
+ the method of Djurisic et al., ICLR 2023.
57
64
  """
58
65
 
59
66
  def __init__(
@@ -62,6 +69,8 @@ class KerasFeatureExtractor(FeatureExtractor):
62
69
  feature_layers_id: List[Union[int, str]] = [-1],
63
70
  input_layer_id: Optional[Union[int, str]] = None,
64
71
  react_threshold: Optional[float] = None,
72
+ scale_percentile: Optional[float] = None,
73
+ ash_percentile: Optional[float] = None,
65
74
  ):
66
75
  if input_layer_id is None:
67
76
  input_layer_id = 0
@@ -70,6 +79,8 @@ class KerasFeatureExtractor(FeatureExtractor):
70
79
  feature_layers_id=feature_layers_id,
71
80
  input_layer_id=input_layer_id,
72
81
  react_threshold=react_threshold,
82
+ scale_percentile=scale_percentile,
83
+ ash_percentile=ash_percentile,
73
84
  )
74
85
 
75
86
  self.backend = "tensorflow"
@@ -143,6 +154,43 @@ class KerasFeatureExtractor(FeatureExtractor):
143
154
  )
144
155
  # apply ultimate layer on clipped activations
145
156
  output_tensors.append(last_layer(x))
157
+
158
+ # === If SCALE method, scale activations from penultimate layer ===
159
+ # === If ASH method, scale and prune activations from penultimate layer ===
160
+ elif (self.scale_percentile is not None) or (self.ash_percentile is not None):
161
+ penultimate_layer = self.find_layer(self.model, -2)
162
+ penult_extractor = tf.keras.models.Model(
163
+ new_input, penultimate_layer.output
164
+ )
165
+ last_layer = self.find_layer(self.model, -1)
166
+
167
+ # apply scaling on penultimate activations
168
+ penultimate = penult_extractor(new_input)
169
+ if self.scale_percentile is not None:
170
+ output_percentile = tfp.stats.percentile(
171
+ penultimate, 100 * self.scale_percentile, axis=1
172
+ )
173
+ else:
174
+ output_percentile = tfp.stats.percentile(
175
+ penultimate, 100 * self.ash_percentile, axis=1
176
+ )
177
+
178
+ mask = penultimate > tf.reshape(output_percentile, (-1, 1))
179
+ filtered_penultimate = tf.where(
180
+ mask, penultimate, tf.zeros_like(penultimate)
181
+ )
182
+ s = tf.math.exp(
183
+ tf.reduce_sum(penultimate, axis=1)
184
+ / tf.reduce_sum(filtered_penultimate, axis=1)
185
+ )
186
+
187
+ if self.scale_percentile is not None:
188
+ x = penultimate * tf.expand_dims(s, 1)
189
+ else:
190
+ x = filtered_penultimate * tf.expand_dims(s, 1)
191
+ # apply ultimate layer on scaled activations
192
+ output_tensors.append(last_layer(x))
193
+
146
194
  else:
147
195
  output_tensors.append(self.find_layer(self.model, -1).output)
148
196
 
@@ -190,6 +238,7 @@ class KerasFeatureExtractor(FeatureExtractor):
190
238
  self,
191
239
  dataset: Union[ItemType, tf.data.Dataset],
192
240
  postproc_fns: Optional[List[Callable]] = None,
241
+ verbose: bool = False,
193
242
  **kwargs,
194
243
  ) -> Tuple[List[tf.Tensor], dict]:
195
244
  """Get the projection of the dataset in the feature space of self.model
@@ -198,6 +247,7 @@ class KerasFeatureExtractor(FeatureExtractor):
198
247
  dataset (Union[ItemType, tf.data.Dataset]): input dataset
199
248
  postproc_fns (Optional[Callable]): postprocessing function to apply to each
200
249
  feature immediately after forward. Default to None.
250
+ verbose (bool): if True, display a progress bar. Defaults to False.
201
251
  kwargs (dict): additional arguments not considered for prediction
202
252
 
203
253
  Returns:
@@ -218,7 +268,7 @@ class KerasFeatureExtractor(FeatureExtractor):
218
268
  features = [None for i in range(len(self.feature_layers_id))]
219
269
  logits = None
220
270
  contains_labels = TFDataHandler.get_item_length(dataset) > 1
221
- for elem in dataset:
271
+ for elem in tqdm(dataset, desc="Predicting", disable=not verbose):
222
272
  tensor = TFDataHandler.get_input_from_dataset_item(elem)
223
273
  features_batch, logits_batch = self.predict_tensor(tensor, postproc_fns)
224
274
 
@@ -20,13 +20,13 @@
20
20
  # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
21
  # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
22
  # SOFTWARE.
23
- from collections import OrderedDict
24
23
  from typing import get_args
25
24
  from typing import Optional
26
25
 
27
26
  import torch
28
27
  from torch import nn
29
28
  from torch.utils.data import DataLoader
29
+ from tqdm import tqdm
30
30
 
31
31
  from ..datasets.torch_data_handler import TorchDataHandler
32
32
  from ..types import Callable
@@ -57,6 +57,12 @@ class TorchFeatureExtractor(FeatureExtractor):
57
57
  Defaults to None.
58
58
  react_threshold: if not None, penultimate layer activations are clipped under
59
59
  this threshold value (useful for ReAct). Defaults to None.
60
+ scale_percentile: if not None, the features are scaled
61
+ following the method of Xu et al., ICLR 2024.
62
+ Defaults to None.
63
+ ash_percentile: if not None, the features are scaled following
64
+ the method of Djurisic et al., ICLR 2023.
65
+
60
66
  """
61
67
 
62
68
  def __init__(
@@ -65,6 +71,8 @@ class TorchFeatureExtractor(FeatureExtractor):
65
71
  feature_layers_id: List[Union[int, str]] = [],
66
72
  input_layer_id: Optional[Union[int, str]] = None,
67
73
  react_threshold: Optional[float] = None,
74
+ scale_percentile: Optional[float] = None,
75
+ ash_percentile: Optional[float] = None,
68
76
  ):
69
77
  model = model.eval()
70
78
  super().__init__(
@@ -72,6 +80,8 @@ class TorchFeatureExtractor(FeatureExtractor):
72
80
  feature_layers_id=feature_layers_id,
73
81
  input_layer_id=input_layer_id,
74
82
  react_threshold=react_threshold,
83
+ scale_percentile=scale_percentile,
84
+ ash_percentile=ash_percentile,
75
85
  )
76
86
  self._device = next(model.parameters()).device
77
87
  self._features = {layer: torch.empty(0) for layer in self._hook_layers_id}
@@ -140,18 +150,43 @@ class TorchFeatureExtractor(FeatureExtractor):
140
150
 
141
151
  def prepare_extractor(self) -> None:
142
152
  """Prepare the feature extractor by adding hooks to self.model"""
143
- # remove forward hooks attached to the model
144
- self._clean_forward_hooks()
153
+ # prepare self.model for ood hooks (add _ood_handles attribute or
154
+ # remove ood forward hooks attached to the model)
155
+ self._prepare_ood_handles()
145
156
 
146
157
  # === If react method, clip activations from penultimate layer ===
147
158
  if self.react_threshold is not None:
148
- pen_layer = self.find_layer(self.model, -2)
149
- pen_layer.register_forward_hook(self._get_clip_hook(self.react_threshold))
159
+ pen_layer = self.find_layer(self.model, -1)
160
+ self.model._ood_handles.append(
161
+ pen_layer.register_forward_pre_hook(
162
+ self._get_clip_hook(self.react_threshold)
163
+ )
164
+ )
165
+
166
+ # === If SCALE method, scale activations from penultimate layer ===
167
+ if self.scale_percentile is not None:
168
+ pen_layer = self.find_layer(self.model, -1)
169
+ self.model._ood_handles.append(
170
+ pen_layer.register_forward_pre_hook(
171
+ self._get_scale_hook(self.scale_percentile)
172
+ )
173
+ )
174
+
175
+ # === If ASH method, scale and prune activations from penultimate layer ===
176
+ if self.ash_percentile is not None:
177
+ pen_layer = self.find_layer(self.model, -1)
178
+ self.model._ood_handles.append(
179
+ pen_layer.register_forward_pre_hook(
180
+ self._get_ash_hook(self.ash_percentile)
181
+ )
182
+ )
150
183
 
151
184
  # Register a hook to store feature values for each considered layer + last layer
152
185
  for layer_id in self._hook_layers_id:
153
186
  layer = self.find_layer(self.model, layer_id)
154
- layer.register_forward_hook(self._get_features_hook(layer_id))
187
+ self.model._ood_handles.append(
188
+ layer.register_forward_hook(self._get_features_hook(layer_id))
189
+ )
155
190
 
156
191
  # Crop model if input layer is provided
157
192
  if not (self.input_layer_id) is None:
@@ -226,6 +261,7 @@ class TorchFeatureExtractor(FeatureExtractor):
226
261
  dataset: Union[DataLoader, ItemType],
227
262
  postproc_fns: Optional[List[Callable]] = None,
228
263
  detach: bool = True,
264
+ verbose: bool = False,
229
265
  **kwargs,
230
266
  ) -> Tuple[List[torch.Tensor], dict]:
231
267
  """Get the projection of the dataset in the feature space of self.model
@@ -236,6 +272,7 @@ class TorchFeatureExtractor(FeatureExtractor):
236
272
  each feature immediately after forward. Default to None.
237
273
  detach (bool): if True, return features detached from the computational
238
274
  graph. Defaults to True.
275
+ verbose (bool): if True, display a progress bar. Defaults to False.
239
276
  kwargs (dict): additional arguments not considered for prediction
240
277
 
241
278
  Returns:
@@ -257,7 +294,7 @@ class TorchFeatureExtractor(FeatureExtractor):
257
294
  logits = None
258
295
  batch = next(iter(dataset))
259
296
  contains_labels = isinstance(batch, (list, tuple)) and len(batch) > 1
260
- for elem in dataset:
297
+ for elem in tqdm(dataset, desc="Predicting", disable=not verbose):
261
298
  tensor = TorchDataHandler.get_input_from_dataset_item(elem)
262
299
  features_batch, logits_batch = self.predict_tensor(
263
300
  tensor, postproc_fns, detach=detach
@@ -308,24 +345,69 @@ class TorchFeatureExtractor(FeatureExtractor):
308
345
  Callable: hook function
309
346
  """
310
347
 
311
- def hook(_, __, output):
312
- output = torch.clip(output, max=threshold)
313
- return output
348
+ def hook(_, input):
349
+ input = input[0]
350
+ input = torch.clip(input, max=threshold)
351
+ return input
352
+
353
+ return hook
354
+
355
+ def _get_scale_hook(self, percentile: float) -> Callable:
356
+ """
357
+ Hook that scales activation features.
358
+
359
+ Args:
360
+ threshold (float): threshold value
361
+
362
+ Returns:
363
+ Callable: hook function
364
+ """
365
+
366
+ def hook(_, input):
367
+ input = input[0]
368
+ output_percentile = torch.quantile(input, percentile, dim=1)
369
+ mask = input > output_percentile[:, None]
370
+ output_masked = input * mask
371
+ s = torch.exp(torch.sum(input, dim=1) / torch.sum(output_masked, dim=1))
372
+ s = torch.unsqueeze(s, 1)
373
+ input = input * s
374
+ return input
314
375
 
315
376
  return hook
316
377
 
317
- def _clean_forward_hooks(self) -> None:
378
+ def _get_ash_hook(self, percentile: float) -> Callable:
318
379
  """
319
- Remove all the forward hook attached to the model's layers. This function should
320
- be called at the __init__, and prevent from accumulating the hooks when
321
- defining a new TorchFeatureExtractor for the same model.
380
+ Hook that scales and prunes activation features under a threshold value
381
+
382
+ Args:
383
+ threshold (float): threshold value
384
+
385
+ Returns:
386
+ Callable: hook function
322
387
  """
323
388
 
324
- def __clean_hooks(m: nn.Module):
325
- for _, child in m._modules.items():
326
- if child is not None:
327
- if hasattr(child, "_forward_hooks"):
328
- child._forward_hooks = OrderedDict()
329
- __clean_hooks(child)
389
+ def hook(_, input):
390
+ input = input[0]
391
+ output_percentile = torch.quantile(input, percentile, dim=1)
392
+ mask = input > output_percentile[:, None]
393
+ output_masked = input * mask
394
+ s = torch.exp(torch.sum(input, dim=1) / torch.sum(output_masked, dim=1))
395
+ s = torch.unsqueeze(s, 1)
396
+ input = output_masked * s
397
+ return input
398
+
399
+ return hook
330
400
 
331
- return __clean_hooks(self.model)
401
+ def _prepare_ood_handles(self) -> None:
402
+ """
403
+ Prepare the model by either setting a new attribute to self.model
404
+ as a list which will contain all the ood specific hooks, or by cleaning
405
+ the existing ood specific hooks if the attribute already exists.
406
+ """
407
+
408
+ if not hasattr(self.model, "_ood_handles"):
409
+ setattr(self.model, "_ood_handles", [])
410
+ else:
411
+ for handle in self.model._ood_handles:
412
+ handle.remove()
413
+ self.model._ood_handles = []
@@ -23,10 +23,25 @@
23
23
  from .dknn import DKNN
24
24
  from .energy import Energy
25
25
  from .entropy import Entropy
26
+ from .gen import GEN
26
27
  from .gram import Gram
27
28
  from .mahalanobis import Mahalanobis
28
29
  from .mls import MLS
29
30
  from .odin import ODIN
31
+ from .rmds import RMDS
32
+ from .she import SHE
30
33
  from .vim import VIM
31
34
 
32
- __all__ = ["MLS", "DKNN", "ODIN", "Energy", "VIM", "Mahalanobis", "Entropy", "Gram"]
35
+ __all__ = [
36
+ "DKNN",
37
+ "Energy",
38
+ "Entropy",
39
+ "GEN",
40
+ "SHE",
41
+ "Gram",
42
+ "Mahalanobis",
43
+ "MLS",
44
+ "ODIN",
45
+ "RMDS",
46
+ "VIM",
47
+ ]
oodeel/methods/base.py CHANGED
@@ -20,11 +20,13 @@
20
20
  # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
21
  # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
22
  # SOFTWARE.
23
+ import inspect
23
24
  from abc import ABC
24
25
  from abc import abstractmethod
25
26
  from typing import get_args
26
27
 
27
28
  import numpy as np
29
+ from tqdm import tqdm
28
30
 
29
31
  from ..extractor.feature_extractor import FeatureExtractor
30
32
  from ..types import Callable
@@ -38,28 +40,66 @@ from ..utils import import_backend_specific_stuff
38
40
 
39
41
 
40
42
  class OODBaseDetector(ABC):
41
- """Base Class for methods that assign a score to unseen samples.
42
-
43
- Args:
44
- use_react (bool): if true, apply ReAct method by clipping penultimate
45
- activations under a threshold value.
46
- react_quantile (Optional[float]): q value in the range [0, 1] used to compute
47
- the react clipping threshold defined as the q-th quantile penultimate layer
48
- activations. Defaults to 0.8.
43
+ """OODBaseDetector is an abstract base class for Out-of-Distribution (OOD)
44
+ detection.
45
+
46
+ Attributes:
47
+ feature_extractor (FeatureExtractor): The feature extractor instance.
48
+ use_react (bool): Flag to indicate if ReAct method is used.
49
+ use_scale (bool): Flag to indicate if scaling method is used.
50
+ use_ash (bool): Flag to indicate if ASH method is used.
51
+ react_quantile (float): Quantile value for ReAct threshold.
52
+ scale_percentile (float): Percentile value for scaling.
53
+ ash_percentile (float): Percentile value for ASH.
54
+ react_threshold (float): Threshold value for ReAct.
55
+ postproc_fns (List[Callable]): List of post-processing functions.
56
+
57
+ Methods:
58
+ __init__: Initializes the OODBaseDetector with specified parameters.
59
+ _score_tensor: Abstract method to compute OOD score for input samples.
60
+ _sanitize_posproc_fns: Sanitizes post-processing functions used at each layer
61
+ output.
62
+ fit: Prepares the detector for scoring by constructing the feature extractor
63
+ and calibrating on ID data.
64
+ _load_feature_extractor: Loads the feature extractor based on the model and
65
+ specified layers.
66
+ _fit_to_dataset: Abstract method to fit the OOD detector to a dataset.
67
+ score: Computes an OOD score for input samples.
68
+ compute_react_threshold: Computes the ReAct threshold using the fit dataset.
69
+ __call__: Convenience wrapper for the score method.
70
+ requires_to_fit_dataset: Property indicating if the detector needs a fit
71
+ dataset.
72
+ requires_internal_features: Property indicating if the detector acts on
73
+ internal model features.
49
74
  """
50
75
 
51
76
  def __init__(
52
77
  self,
53
- use_react: bool = False,
54
- react_quantile: float = 0.8,
55
- postproc_fns: List[Callable] = None,
78
+ use_react: Optional[bool] = False,
79
+ use_scale: Optional[bool] = False,
80
+ use_ash: Optional[bool] = False,
81
+ react_quantile: Optional[float] = None,
82
+ scale_percentile: Optional[float] = None,
83
+ ash_percentile: Optional[float] = None,
84
+ postproc_fns: Optional[List[Callable]] = None,
56
85
  ):
57
86
  self.feature_extractor: FeatureExtractor = None
58
87
  self.use_react = use_react
88
+ self.use_scale = use_scale
89
+ self.use_ash = use_ash
59
90
  self.react_quantile = react_quantile
91
+ self.scale_percentile = scale_percentile
92
+ self.ash_percentile = ash_percentile
60
93
  self.react_threshold = None
61
94
  self.postproc_fns = self._sanitize_posproc_fns(postproc_fns)
62
95
 
96
+ if use_scale and use_react:
97
+ raise ValueError("Cannot use both ReAct and scale at the same time")
98
+ if use_scale and use_ash:
99
+ raise ValueError("Cannot use both ASH and scale at the same time")
100
+ if use_ash and use_react:
101
+ raise ValueError("Cannot use both ReAct and ASH at the same time")
102
+
63
103
  @abstractmethod
64
104
  def _score_tensor(self, inputs: TensorType) -> np.ndarray:
65
105
  """Computes an OOD score for input samples "inputs".
@@ -104,6 +144,7 @@ class OODBaseDetector(ABC):
104
144
  fit_dataset: Optional[Union[ItemType, DatasetType]] = None,
105
145
  feature_layers_id: List[Union[int, str]] = [],
106
146
  input_layer_id: Optional[Union[int, str]] = None,
147
+ verbose: bool = False,
107
148
  **kwargs,
108
149
  ) -> None:
109
150
  """Prepare the detector for scoring:
@@ -122,6 +163,7 @@ class OODBaseDetector(ABC):
122
163
  layer of the feature extractor.
123
164
  If int, the rank of the layer in the layer list
124
165
  If str, the name of the layer. Defaults to None.
166
+ verbose (bool): if True, display a progress bar. Defaults to False.
125
167
  """
126
168
  (
127
169
  self.backend,
@@ -144,7 +186,7 @@ class OODBaseDetector(ABC):
144
186
  " provided to compute react activation threshold"
145
187
  )
146
188
  else:
147
- self.compute_react_threshold(model, fit_dataset)
189
+ self.compute_react_threshold(model, fit_dataset, verbose=verbose)
148
190
 
149
191
  if (feature_layers_id == []) and (self.requires_internal_features):
150
192
  raise ValueError(
@@ -160,6 +202,8 @@ class OODBaseDetector(ABC):
160
202
  )
161
203
 
162
204
  if fit_dataset is not None:
205
+ if "verbose" in inspect.signature(self._fit_to_dataset).parameters.keys():
206
+ kwargs.update({"verbose": verbose})
163
207
  self._fit_to_dataset(fit_dataset, **kwargs)
164
208
 
165
209
  def _load_feature_extractor(
@@ -185,11 +229,18 @@ class OODBaseDetector(ABC):
185
229
  Returns:
186
230
  FeatureExtractor: a feature extractor instance
187
231
  """
232
+ if not self.use_ash:
233
+ self.ash_percentile = None
234
+ if not self.use_scale:
235
+ self.scale_percentile = None
236
+
188
237
  feature_extractor = self.FeatureExtractorClass(
189
238
  model,
190
239
  feature_layers_id=feature_layers_id,
191
240
  input_layer_id=input_layer_id,
192
241
  react_threshold=self.react_threshold,
242
+ scale_percentile=self.scale_percentile,
243
+ ash_percentile=self.ash_percentile,
193
244
  )
194
245
  return feature_extractor
195
246
 
@@ -207,12 +258,14 @@ class OODBaseDetector(ABC):
207
258
  def score(
208
259
  self,
209
260
  dataset: Union[ItemType, DatasetType],
261
+ verbose: bool = False,
210
262
  ) -> np.ndarray:
211
263
  """
212
264
  Computes an OOD score for input samples "inputs".
213
265
 
214
266
  Args:
215
267
  dataset (Union[ItemType, DatasetType]): dataset or tensors to score
268
+ verbose (bool): if True, display a progress bar. Defaults to False.
216
269
 
217
270
  Returns:
218
271
  tuple: scores or list of scores (depending on the input) and a dictionary
@@ -236,7 +289,7 @@ class OODBaseDetector(ABC):
236
289
  scores = np.array([])
237
290
  logits = None
238
291
 
239
- for item in dataset:
292
+ for item in tqdm(dataset, desc="Scoring", disable=not verbose):
240
293
  tensor = self.data_handler.get_input_from_dataset_item(item)
241
294
  score_batch = self._score_tensor(tensor)
242
295
  logits_batch = self.op.convert_to_numpy(
@@ -267,9 +320,13 @@ class OODBaseDetector(ABC):
267
320
  info = dict(labels=labels, logits=logits)
268
321
  return scores, info
269
322
 
270
- def compute_react_threshold(self, model: Callable, fit_dataset: DatasetType):
323
+ def compute_react_threshold(
324
+ self, model: Callable, fit_dataset: DatasetType, verbose: bool = False
325
+ ):
271
326
  penult_feat_extractor = self._load_feature_extractor(model, [-2])
272
- unclipped_features, _ = penult_feat_extractor.predict(fit_dataset)
327
+ unclipped_features, _ = penult_feat_extractor.predict(
328
+ fit_dataset, verbose=verbose
329
+ )
273
330
  self.react_threshold = self.op.quantile(
274
331
  unclipped_features[0], self.react_quantile
275
332
  )
oodeel/methods/dknn.py CHANGED
@@ -38,21 +38,28 @@ class DKNN(OODBaseDetector):
38
38
  Args:
39
39
  nearest: number of nearest neighbors to consider.
40
40
  Defaults to 1.
41
+ use_gpu (bool): Flag to enable GPU acceleration for FAISS. Defaults to False.
41
42
  """
42
43
 
43
- def __init__(
44
- self,
45
- nearest: int = 1,
46
- ):
44
+ def __init__(self, nearest: int = 50, use_gpu: bool = False):
47
45
  super().__init__()
48
-
49
46
  self.index = None
50
47
  self.nearest = nearest
48
+ self.use_gpu = use_gpu
49
+
50
+ if self.use_gpu:
51
+ try:
52
+ self.res = faiss.StandardGpuResources()
53
+ except AttributeError as e:
54
+ raise ImportError(
55
+ "faiss-gpu is not installed, but use_gpu was set to True."
56
+ + "Please install faiss-gpu or set use_gpu to False."
57
+ ) from e
51
58
 
52
59
  def _fit_to_dataset(self, fit_dataset: Union[TensorType, DatasetType]) -> None:
53
60
  """
54
61
  Constructs the index from ID data "fit_dataset", which will be used for
55
- nearest neighbor search.
62
+ nearest neighbor search. Can operate on CPU or GPU based on the `use_gpu` flag.
56
63
 
57
64
  Args:
58
65
  fit_dataset: input dataset (ID) to construct the index with.
@@ -61,7 +68,13 @@ class DKNN(OODBaseDetector):
61
68
  fit_projected = self.op.convert_to_numpy(fit_projected[0])
62
69
  fit_projected = fit_projected.reshape(fit_projected.shape[0], -1)
63
70
  norm_fit_projected = self._l2_normalization(fit_projected)
64
- self.index = faiss.IndexFlatL2(norm_fit_projected.shape[1])
71
+
72
+ if self.use_gpu:
73
+ cpu_index = faiss.IndexFlatL2(norm_fit_projected.shape[1])
74
+ self.index = faiss.index_cpu_to_gpu(self.res, 0, cpu_index)
75
+ else:
76
+ self.index = faiss.IndexFlatL2(norm_fit_projected.shape[1])
77
+
65
78
  self.index.add(norm_fit_projected)
66
79
 
67
80
  def _score_tensor(self, inputs: TensorType) -> Tuple[np.ndarray]:
oodeel/methods/energy.py CHANGED
@@ -59,11 +59,19 @@ class Energy(OODBaseDetector):
59
59
  def __init__(
60
60
  self,
61
61
  use_react: bool = False,
62
+ use_scale: bool = False,
63
+ use_ash: bool = False,
62
64
  react_quantile: float = 0.8,
65
+ scale_percentile: float = 0.85,
66
+ ash_percentile: float = 0.90,
63
67
  ):
64
68
  super().__init__(
65
69
  use_react=use_react,
70
+ use_scale=use_scale,
71
+ use_ash=use_ash,
66
72
  react_quantile=react_quantile,
73
+ scale_percentile=scale_percentile,
74
+ ash_percentile=ash_percentile,
67
75
  )
68
76
 
69
77
  def _score_tensor(self, inputs: TensorType) -> Tuple[np.ndarray]:
oodeel/methods/entropy.py CHANGED
@@ -52,11 +52,19 @@ class Entropy(OODBaseDetector):
52
52
  def __init__(
53
53
  self,
54
54
  use_react: bool = False,
55
+ use_scale: bool = False,
56
+ use_ash: bool = False,
55
57
  react_quantile: float = 0.8,
58
+ scale_percentile: float = 0.85,
59
+ ash_percentile: float = 0.90,
56
60
  ):
57
61
  super().__init__(
58
62
  use_react=use_react,
63
+ use_scale=use_scale,
64
+ use_ash=use_ash,
59
65
  react_quantile=react_quantile,
66
+ scale_percentile=scale_percentile,
67
+ ash_percentile=ash_percentile,
60
68
  )
61
69
 
62
70
  def _score_tensor(self, inputs: TensorType) -> Tuple[np.ndarray]: