oodeel 0.1.1__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of oodeel might be problematic. Click here for more details.

Files changed (47) hide show
  1. oodeel/__init__.py +1 -1
  2. oodeel/datasets/__init__.py +2 -1
  3. oodeel/datasets/data_handler.py +162 -94
  4. oodeel/datasets/deprecated/DEPRECATED_data_handler.py +236 -0
  5. oodeel/datasets/{ooddataset.py → deprecated/DEPRECATED_ooddataset.py} +14 -13
  6. oodeel/datasets/deprecated/DEPRECATED_tf_data_handler.py +671 -0
  7. oodeel/datasets/deprecated/DEPRECATED_torch_data_handler.py +769 -0
  8. oodeel/datasets/deprecated/__init__.py +31 -0
  9. oodeel/datasets/tf_data_handler.py +105 -167
  10. oodeel/datasets/torch_data_handler.py +109 -181
  11. oodeel/eval/metrics.py +7 -2
  12. oodeel/eval/plots/features.py +2 -2
  13. oodeel/eval/plots/plotly.py +2 -2
  14. oodeel/extractor/feature_extractor.py +30 -9
  15. oodeel/extractor/keras_feature_extractor.py +70 -13
  16. oodeel/extractor/torch_feature_extractor.py +120 -33
  17. oodeel/methods/__init__.py +17 -1
  18. oodeel/methods/base.py +103 -17
  19. oodeel/methods/dknn.py +22 -9
  20. oodeel/methods/energy.py +8 -0
  21. oodeel/methods/entropy.py +8 -0
  22. oodeel/methods/gen.py +118 -0
  23. oodeel/methods/gram.py +307 -0
  24. oodeel/methods/mahalanobis.py +14 -12
  25. oodeel/methods/mls.py +8 -0
  26. oodeel/methods/odin.py +8 -0
  27. oodeel/methods/rmds.py +122 -0
  28. oodeel/methods/she.py +197 -0
  29. oodeel/methods/vim.py +5 -5
  30. oodeel/preprocess/__init__.py +31 -0
  31. oodeel/preprocess/tf_preprocess.py +95 -0
  32. oodeel/preprocess/torch_preprocess.py +97 -0
  33. oodeel/utils/operator.py +72 -2
  34. oodeel/utils/tf_operator.py +72 -4
  35. oodeel/utils/tf_training_tools.py +26 -3
  36. oodeel/utils/torch_operator.py +75 -4
  37. oodeel/utils/torch_training_tools.py +31 -2
  38. {oodeel-0.1.1.dist-info → oodeel-0.3.0.dist-info}/METADATA +141 -107
  39. oodeel-0.3.0.dist-info/RECORD +57 -0
  40. {oodeel-0.1.1.dist-info → oodeel-0.3.0.dist-info}/WHEEL +1 -1
  41. tests/tests_tensorflow/tf_methods_utils.py +2 -1
  42. tests/tests_torch/tools_torch.py +9 -9
  43. tests/tests_torch/torch_methods_utils.py +34 -27
  44. tests/tools_operator.py +10 -1
  45. oodeel-0.1.1.dist-info/RECORD +0 -46
  46. {oodeel-0.1.1.dist-info → oodeel-0.3.0.dist-info/licenses}/LICENSE +0 -0
  47. {oodeel-0.1.1.dist-info → oodeel-0.3.0.dist-info}/top_level.txt +0 -0
@@ -25,6 +25,7 @@ from abc import abstractmethod
25
25
 
26
26
  from ..types import Callable
27
27
  from ..types import DatasetType
28
+ from ..types import ItemType
28
29
  from ..types import List
29
30
  from ..types import Optional
30
31
  from ..types import TensorType
@@ -51,6 +52,11 @@ class FeatureExtractor(ABC):
51
52
  Defaults to None.
52
53
  react_threshold: if not None, penultimate layer activations are clipped under
53
54
  this threshold value (useful for ReAct). Defaults to None.
55
+ scale_percentile: if not None, the features are scaled
56
+ following the method of Xu et al., ICLR 2024.
57
+ Defaults to None.
58
+ ash_percentile: if not None, the features are scaled following
59
+ the method of Djurisic et al., ICLR 2023.
54
60
  """
55
61
 
56
62
  def __init__(
@@ -59,6 +65,8 @@ class FeatureExtractor(ABC):
59
65
  feature_layers_id: List[Union[int, str]] = [-1],
60
66
  input_layer_id: Union[int, str] = [0],
61
67
  react_threshold: Optional[float] = None,
68
+ scale_percentile: Optional[float] = None,
69
+ ash_percentile: Optional[float] = None,
62
70
  ):
63
71
  if not isinstance(feature_layers_id, list):
64
72
  feature_layers_id = [feature_layers_id]
@@ -66,6 +74,8 @@ class FeatureExtractor(ABC):
66
74
  self.feature_layers_id = feature_layers_id
67
75
  self.input_layer_id = input_layer_id
68
76
  self.react_threshold = react_threshold
77
+ self.scale_percentile = scale_percentile
78
+ self.ash_percentile = ash_percentile
69
79
  self.model = model
70
80
  self.extractor = self.prepare_extractor()
71
81
 
@@ -91,28 +101,39 @@ class FeatureExtractor(ABC):
91
101
  raise NotImplementedError()
92
102
 
93
103
  @abstractmethod
94
- def predict_tensor(self, tensor: TensorType) -> Tuple[List[TensorType], TensorType]:
95
- """
96
- Projects input samples "inputs" into the feature space
104
+ def predict_tensor(
105
+ self,
106
+ tensor: TensorType,
107
+ postproc_fns: Optional[List[Callable]] = None,
108
+ ) -> Tuple[List[TensorType], TensorType]:
109
+ """Get the projection of tensor in the feature space of self.model
97
110
 
98
111
  Args:
99
- tensor (TensorType): input tensor
112
+ tensor (TensorType): input tensor (or dataset elem)
113
+ postproc_fns (Optional[Callable]): postprocessing function to apply to each
114
+ feature immediately after forward. Default to None.
100
115
 
101
116
  Returns:
102
- List[TensorType], TensorType: features, logits
117
+ Tuple[List[TensorType], TensorType]: features, logits
103
118
  """
104
119
  raise NotImplementedError()
105
120
 
106
121
  @abstractmethod
107
122
  def predict(
108
123
  self,
109
- dataset: Union[DatasetType, TensorType],
124
+ dataset: Union[ItemType, DatasetType],
125
+ postproc_fns: Optional[List[Callable]] = None,
126
+ verbose: bool = False,
127
+ **kwargs,
110
128
  ) -> Tuple[List[TensorType], dict]:
111
- """
112
- Projects input samples "inputs" into the feature space for a batched dataset
129
+ """Get the projection of the dataset in the feature space of self.model
113
130
 
114
131
  Args:
115
- dataset (Union[DatasetType, TensorType]): iterable of tensor batches
132
+ dataset (Union[ItemType, DatasetType]): input dataset
133
+ postproc_fns (Optional[Callable]): postprocessing function to apply to each
134
+ feature immediately after forward. Default to None.
135
+ verbose (bool): if True, display a progress bar. Defaults to False.
136
+ kwargs (dict): additional arguments not considered for prediction
116
137
 
117
138
  Returns:
118
139
  List[TensorType], dict: features and extra information (logits, labels) as a
@@ -24,6 +24,8 @@ from typing import get_args
24
24
  from typing import Optional
25
25
 
26
26
  import tensorflow as tf
27
+ import tensorflow_probability as tfp
28
+ from tqdm import tqdm
27
29
 
28
30
  from ..datasets.tf_data_handler import TFDataHandler
29
31
  from ..types import Callable
@@ -54,6 +56,11 @@ class KerasFeatureExtractor(FeatureExtractor):
54
56
  Defaults to None.
55
57
  react_threshold: if not None, penultimate layer activations are clipped under
56
58
  this threshold value (useful for ReAct). Defaults to None.
59
+ scale_percentile: if not None, the features are scaled
60
+ following the method of Xu et al., ICLR 2024.
61
+ Defaults to None.
62
+ ash_percentile: if not None, the features are scaled following
63
+ the method of Djurisic et al., ICLR 2023.
57
64
  """
58
65
 
59
66
  def __init__(
@@ -62,6 +69,8 @@ class KerasFeatureExtractor(FeatureExtractor):
62
69
  feature_layers_id: List[Union[int, str]] = [-1],
63
70
  input_layer_id: Optional[Union[int, str]] = None,
64
71
  react_threshold: Optional[float] = None,
72
+ scale_percentile: Optional[float] = None,
73
+ ash_percentile: Optional[float] = None,
65
74
  ):
66
75
  if input_layer_id is None:
67
76
  input_layer_id = 0
@@ -70,6 +79,8 @@ class KerasFeatureExtractor(FeatureExtractor):
70
79
  feature_layers_id=feature_layers_id,
71
80
  input_layer_id=input_layer_id,
72
81
  react_threshold=react_threshold,
82
+ scale_percentile=scale_percentile,
83
+ ash_percentile=ash_percentile,
73
84
  )
74
85
 
75
86
  self.backend = "tensorflow"
@@ -143,6 +154,43 @@ class KerasFeatureExtractor(FeatureExtractor):
143
154
  )
144
155
  # apply ultimate layer on clipped activations
145
156
  output_tensors.append(last_layer(x))
157
+
158
+ # === If SCALE method, scale activations from penultimate layer ===
159
+ # === If ASH method, scale and prune activations from penultimate layer ===
160
+ elif (self.scale_percentile is not None) or (self.ash_percentile is not None):
161
+ penultimate_layer = self.find_layer(self.model, -2)
162
+ penult_extractor = tf.keras.models.Model(
163
+ new_input, penultimate_layer.output
164
+ )
165
+ last_layer = self.find_layer(self.model, -1)
166
+
167
+ # apply scaling on penultimate activations
168
+ penultimate = penult_extractor(new_input)
169
+ if self.scale_percentile is not None:
170
+ output_percentile = tfp.stats.percentile(
171
+ penultimate, 100 * self.scale_percentile, axis=1
172
+ )
173
+ else:
174
+ output_percentile = tfp.stats.percentile(
175
+ penultimate, 100 * self.ash_percentile, axis=1
176
+ )
177
+
178
+ mask = penultimate > tf.reshape(output_percentile, (-1, 1))
179
+ filtered_penultimate = tf.where(
180
+ mask, penultimate, tf.zeros_like(penultimate)
181
+ )
182
+ s = tf.math.exp(
183
+ tf.reduce_sum(penultimate, axis=1)
184
+ / tf.reduce_sum(filtered_penultimate, axis=1)
185
+ )
186
+
187
+ if self.scale_percentile is not None:
188
+ x = penultimate * tf.expand_dims(s, 1)
189
+ else:
190
+ x = filtered_penultimate * tf.expand_dims(s, 1)
191
+ # apply ultimate layer on scaled activations
192
+ output_tensors.append(last_layer(x))
193
+
146
194
  else:
147
195
  output_tensors.append(self.find_layer(self.model, -1).output)
148
196
 
@@ -150,11 +198,17 @@ class KerasFeatureExtractor(FeatureExtractor):
150
198
  return extractor
151
199
 
152
200
  @sanitize_input
153
- def predict_tensor(self, tensor: TensorType) -> Tuple[List[tf.Tensor], tf.Tensor]:
201
+ def predict_tensor(
202
+ self,
203
+ tensor: TensorType,
204
+ postproc_fns: Optional[List[Callable]] = None,
205
+ ) -> Tuple[List[tf.Tensor], tf.Tensor]:
154
206
  """Get the projection of tensor in the feature space of self.model
155
207
 
156
208
  Args:
157
209
  tensor (TensorType): input tensor (or dataset elem)
210
+ postproc_fns (Optional[List[Callable]]): postprocessing function to apply to
211
+ each feature immediately after forward. Default to None.
158
212
 
159
213
  Returns:
160
214
  Tuple[List[tf.Tensor], tf.Tensor]: features, logits
@@ -166,8 +220,12 @@ class KerasFeatureExtractor(FeatureExtractor):
166
220
 
167
221
  # split features and logits
168
222
  logits = features.pop()
169
- if len(features) == 1:
170
- features = features[0]
223
+
224
+ if postproc_fns is not None:
225
+ features = [
226
+ postproc_fn(feature)
227
+ for feature, postproc_fn in zip(features, postproc_fns)
228
+ ]
171
229
 
172
230
  self._last_logits = logits
173
231
  return features, logits
@@ -179,12 +237,17 @@ class KerasFeatureExtractor(FeatureExtractor):
179
237
  def predict(
180
238
  self,
181
239
  dataset: Union[ItemType, tf.data.Dataset],
240
+ postproc_fns: Optional[List[Callable]] = None,
241
+ verbose: bool = False,
182
242
  **kwargs,
183
243
  ) -> Tuple[List[tf.Tensor], dict]:
184
244
  """Get the projection of the dataset in the feature space of self.model
185
245
 
186
246
  Args:
187
247
  dataset (Union[ItemType, tf.data.Dataset]): input dataset
248
+ postproc_fns (Optional[Callable]): postprocessing function to apply to each
249
+ feature immediately after forward. Default to None.
250
+ verbose (bool): if True, display a progress bar. Defaults to False.
188
251
  kwargs (dict): additional arguments not considered for prediction
189
252
 
190
253
  Returns:
@@ -195,7 +258,7 @@ class KerasFeatureExtractor(FeatureExtractor):
195
258
 
196
259
  if isinstance(dataset, get_args(ItemType)):
197
260
  tensor = TFDataHandler.get_input_from_dataset_item(dataset)
198
- features, logits = self.predict_tensor(tensor)
261
+ features, logits = self.predict_tensor(tensor, postproc_fns)
199
262
 
200
263
  # Get labels if dataset is a tuple/list
201
264
  if isinstance(dataset, (list, tuple)):
@@ -205,12 +268,10 @@ class KerasFeatureExtractor(FeatureExtractor):
205
268
  features = [None for i in range(len(self.feature_layers_id))]
206
269
  logits = None
207
270
  contains_labels = TFDataHandler.get_item_length(dataset) > 1
208
- for elem in dataset:
271
+ for elem in tqdm(dataset, desc="Predicting", disable=not verbose):
209
272
  tensor = TFDataHandler.get_input_from_dataset_item(elem)
210
- features_batch, logits_batch = self.predict_tensor(tensor)
211
- # concatenate features
212
- if len(features) == 1:
213
- features_batch = [features_batch]
273
+ features_batch, logits_batch = self.predict_tensor(tensor, postproc_fns)
274
+
214
275
  for i, f in enumerate(features_batch):
215
276
  features[i] = (
216
277
  f
@@ -234,10 +295,6 @@ class KerasFeatureExtractor(FeatureExtractor):
234
295
 
235
296
  # store extra information in a dict
236
297
  info = dict(labels=labels, logits=logits)
237
-
238
- if len(features) == 1:
239
- features = features[0]
240
-
241
298
  return features, info
242
299
 
243
300
  def get_weights(self, layer_id: Union[int, str]) -> List[tf.Tensor]:
@@ -20,13 +20,13 @@
20
20
  # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
21
  # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
22
  # SOFTWARE.
23
- from collections import OrderedDict
24
23
  from typing import get_args
25
24
  from typing import Optional
26
25
 
27
26
  import torch
28
27
  from torch import nn
29
28
  from torch.utils.data import DataLoader
29
+ from tqdm import tqdm
30
30
 
31
31
  from ..datasets.torch_data_handler import TorchDataHandler
32
32
  from ..types import Callable
@@ -57,6 +57,12 @@ class TorchFeatureExtractor(FeatureExtractor):
57
57
  Defaults to None.
58
58
  react_threshold: if not None, penultimate layer activations are clipped under
59
59
  this threshold value (useful for ReAct). Defaults to None.
60
+ scale_percentile: if not None, the features are scaled
61
+ following the method of Xu et al., ICLR 2024.
62
+ Defaults to None.
63
+ ash_percentile: if not None, the features are scaled following
64
+ the method of Djurisic et al., ICLR 2023.
65
+
60
66
  """
61
67
 
62
68
  def __init__(
@@ -65,6 +71,8 @@ class TorchFeatureExtractor(FeatureExtractor):
65
71
  feature_layers_id: List[Union[int, str]] = [],
66
72
  input_layer_id: Optional[Union[int, str]] = None,
67
73
  react_threshold: Optional[float] = None,
74
+ scale_percentile: Optional[float] = None,
75
+ ash_percentile: Optional[float] = None,
68
76
  ):
69
77
  model = model.eval()
70
78
  super().__init__(
@@ -72,6 +80,8 @@ class TorchFeatureExtractor(FeatureExtractor):
72
80
  feature_layers_id=feature_layers_id,
73
81
  input_layer_id=input_layer_id,
74
82
  react_threshold=react_threshold,
83
+ scale_percentile=scale_percentile,
84
+ ash_percentile=ash_percentile,
75
85
  )
76
86
  self._device = next(model.parameters()).device
77
87
  self._features = {layer: torch.empty(0) for layer in self._hook_layers_id}
@@ -140,18 +150,43 @@ class TorchFeatureExtractor(FeatureExtractor):
140
150
 
141
151
  def prepare_extractor(self) -> None:
142
152
  """Prepare the feature extractor by adding hooks to self.model"""
143
- # remove forward hooks attached to the model
144
- self._clean_forward_hooks()
153
+ # prepare self.model for ood hooks (add _ood_handles attribute or
154
+ # remove ood forward hooks attached to the model)
155
+ self._prepare_ood_handles()
145
156
 
146
157
  # === If react method, clip activations from penultimate layer ===
147
158
  if self.react_threshold is not None:
148
- pen_layer = self.find_layer(self.model, -2)
149
- pen_layer.register_forward_hook(self._get_clip_hook(self.react_threshold))
159
+ pen_layer = self.find_layer(self.model, -1)
160
+ self.model._ood_handles.append(
161
+ pen_layer.register_forward_pre_hook(
162
+ self._get_clip_hook(self.react_threshold)
163
+ )
164
+ )
165
+
166
+ # === If SCALE method, scale activations from penultimate layer ===
167
+ if self.scale_percentile is not None:
168
+ pen_layer = self.find_layer(self.model, -1)
169
+ self.model._ood_handles.append(
170
+ pen_layer.register_forward_pre_hook(
171
+ self._get_scale_hook(self.scale_percentile)
172
+ )
173
+ )
174
+
175
+ # === If ASH method, scale and prune activations from penultimate layer ===
176
+ if self.ash_percentile is not None:
177
+ pen_layer = self.find_layer(self.model, -1)
178
+ self.model._ood_handles.append(
179
+ pen_layer.register_forward_pre_hook(
180
+ self._get_ash_hook(self.ash_percentile)
181
+ )
182
+ )
150
183
 
151
184
  # Register a hook to store feature values for each considered layer + last layer
152
185
  for layer_id in self._hook_layers_id:
153
186
  layer = self.find_layer(self.model, layer_id)
154
- layer.register_forward_hook(self._get_features_hook(layer_id))
187
+ self.model._ood_handles.append(
188
+ layer.register_forward_hook(self._get_features_hook(layer_id))
189
+ )
155
190
 
156
191
  # Crop model if input layer is provided
157
192
  if not (self.input_layer_id) is None:
@@ -181,12 +216,17 @@ class TorchFeatureExtractor(FeatureExtractor):
181
216
 
182
217
  @sanitize_input
183
218
  def predict_tensor(
184
- self, x: TensorType, detach: bool = True
219
+ self,
220
+ x: TensorType,
221
+ postproc_fns: Optional[List[Callable]] = None,
222
+ detach: bool = True,
185
223
  ) -> Tuple[List[torch.Tensor], torch.Tensor]:
186
224
  """Get the projection of tensor in the feature space of self.model
187
225
 
188
226
  Args:
189
227
  x (TensorType): input tensor (or dataset elem)
228
+ postproc_fns (Optional[List[Callable]]): postprocessing function to apply to
229
+ each feature immediately after forward. Default to None.
190
230
  detach (bool): if True, return features detached from the computational
191
231
  graph. Defaults to True.
192
232
 
@@ -206,8 +246,12 @@ class TorchFeatureExtractor(FeatureExtractor):
206
246
 
207
247
  # split features and logits
208
248
  logits = features.pop()
209
- if len(features) == 1:
210
- features = features[0]
249
+
250
+ if postproc_fns is not None:
251
+ features = [
252
+ postproc_fn(feature)
253
+ for feature, postproc_fn in zip(features, postproc_fns)
254
+ ]
211
255
 
212
256
  self._last_logits = logits
213
257
  return features, logits
@@ -215,15 +259,20 @@ class TorchFeatureExtractor(FeatureExtractor):
215
259
  def predict(
216
260
  self,
217
261
  dataset: Union[DataLoader, ItemType],
262
+ postproc_fns: Optional[List[Callable]] = None,
218
263
  detach: bool = True,
264
+ verbose: bool = False,
219
265
  **kwargs,
220
266
  ) -> Tuple[List[torch.Tensor], dict]:
221
267
  """Get the projection of the dataset in the feature space of self.model
222
268
 
223
269
  Args:
224
270
  dataset (Union[DataLoader, ItemType]): input dataset
271
+ postproc_fns (Optional[List[Callable]]): postprocessing function to apply to
272
+ each feature immediately after forward. Default to None.
225
273
  detach (bool): if True, return features detached from the computational
226
274
  graph. Defaults to True.
275
+ verbose (bool): if True, display a progress bar. Defaults to False.
227
276
  kwargs (dict): additional arguments not considered for prediction
228
277
 
229
278
  Returns:
@@ -234,7 +283,7 @@ class TorchFeatureExtractor(FeatureExtractor):
234
283
 
235
284
  if isinstance(dataset, get_args(ItemType)):
236
285
  tensor = TorchDataHandler.get_input_from_dataset_item(dataset)
237
- features, logits = self.predict_tensor(tensor, detach=detach)
286
+ features, logits = self.predict_tensor(tensor, postproc_fns, detach=detach)
238
287
 
239
288
  # Get labels if dataset is a tuple/list
240
289
  if isinstance(dataset, (list, tuple)) and len(dataset) > 1:
@@ -245,14 +294,11 @@ class TorchFeatureExtractor(FeatureExtractor):
245
294
  logits = None
246
295
  batch = next(iter(dataset))
247
296
  contains_labels = isinstance(batch, (list, tuple)) and len(batch) > 1
248
- for elem in dataset:
297
+ for elem in tqdm(dataset, desc="Predicting", disable=not verbose):
249
298
  tensor = TorchDataHandler.get_input_from_dataset_item(elem)
250
299
  features_batch, logits_batch = self.predict_tensor(
251
- tensor, detach=detach
300
+ tensor, postproc_fns, detach=detach
252
301
  )
253
- # concatenate features
254
- if len(features) == 1:
255
- features_batch = [features_batch]
256
302
  for i, f in enumerate(features_batch):
257
303
  features[i] = (
258
304
  f if features[i] is None else torch.cat([features[i], f], dim=0)
@@ -274,10 +320,6 @@ class TorchFeatureExtractor(FeatureExtractor):
274
320
 
275
321
  # store extra information in a dict
276
322
  info = dict(labels=labels, logits=logits)
277
-
278
- if len(features) == 1:
279
- features = features[0]
280
-
281
323
  return features, info
282
324
 
283
325
  def get_weights(self, layer_id: Union[str, int]) -> List[torch.Tensor]:
@@ -303,24 +345,69 @@ class TorchFeatureExtractor(FeatureExtractor):
303
345
  Callable: hook function
304
346
  """
305
347
 
306
- def hook(_, __, output):
307
- output = torch.clip(output, max=threshold)
308
- return output
348
+ def hook(_, input):
349
+ input = input[0]
350
+ input = torch.clip(input, max=threshold)
351
+ return input
309
352
 
310
353
  return hook
311
354
 
312
- def _clean_forward_hooks(self) -> None:
355
+ def _get_scale_hook(self, percentile: float) -> Callable:
313
356
  """
314
- Remove all the forward hook attached to the model's layers. This function should
315
- be called at the __init__, and prevent from accumulating the hooks when
316
- defining a new TorchFeatureExtractor for the same model.
357
+ Hook that scales activation features.
358
+
359
+ Args:
360
+ threshold (float): threshold value
361
+
362
+ Returns:
363
+ Callable: hook function
317
364
  """
318
365
 
319
- def __clean_hooks(m: nn.Module):
320
- for _, child in m._modules.items():
321
- if child is not None:
322
- if hasattr(child, "_forward_hooks"):
323
- child._forward_hooks = OrderedDict()
324
- __clean_hooks(child)
366
+ def hook(_, input):
367
+ input = input[0]
368
+ output_percentile = torch.quantile(input, percentile, dim=1)
369
+ mask = input > output_percentile[:, None]
370
+ output_masked = input * mask
371
+ s = torch.exp(torch.sum(input, dim=1) / torch.sum(output_masked, dim=1))
372
+ s = torch.unsqueeze(s, 1)
373
+ input = input * s
374
+ return input
325
375
 
326
- return __clean_hooks(self.model)
376
+ return hook
377
+
378
+ def _get_ash_hook(self, percentile: float) -> Callable:
379
+ """
380
+ Hook that scales and prunes activation features under a threshold value
381
+
382
+ Args:
383
+ threshold (float): threshold value
384
+
385
+ Returns:
386
+ Callable: hook function
387
+ """
388
+
389
+ def hook(_, input):
390
+ input = input[0]
391
+ output_percentile = torch.quantile(input, percentile, dim=1)
392
+ mask = input > output_percentile[:, None]
393
+ output_masked = input * mask
394
+ s = torch.exp(torch.sum(input, dim=1) / torch.sum(output_masked, dim=1))
395
+ s = torch.unsqueeze(s, 1)
396
+ input = output_masked * s
397
+ return input
398
+
399
+ return hook
400
+
401
+ def _prepare_ood_handles(self) -> None:
402
+ """
403
+ Prepare the model by either setting a new attribute to self.model
404
+ as a list which will contain all the ood specific hooks, or by cleaning
405
+ the existing ood specific hooks if the attribute already exists.
406
+ """
407
+
408
+ if not hasattr(self.model, "_ood_handles"):
409
+ setattr(self.model, "_ood_handles", [])
410
+ else:
411
+ for handle in self.model._ood_handles:
412
+ handle.remove()
413
+ self.model._ood_handles = []
@@ -23,9 +23,25 @@
23
23
  from .dknn import DKNN
24
24
  from .energy import Energy
25
25
  from .entropy import Entropy
26
+ from .gen import GEN
27
+ from .gram import Gram
26
28
  from .mahalanobis import Mahalanobis
27
29
  from .mls import MLS
28
30
  from .odin import ODIN
31
+ from .rmds import RMDS
32
+ from .she import SHE
29
33
  from .vim import VIM
30
34
 
31
- __all__ = ["MLS", "DKNN", "ODIN", "Energy", "VIM", "Mahalanobis", "Entropy"]
35
+ __all__ = [
36
+ "DKNN",
37
+ "Energy",
38
+ "Entropy",
39
+ "GEN",
40
+ "SHE",
41
+ "Gram",
42
+ "Mahalanobis",
43
+ "MLS",
44
+ "ODIN",
45
+ "RMDS",
46
+ "VIM",
47
+ ]