oodeel 0.1.0__tar.gz → 0.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of oodeel might be problematic. Click here for more details.

Files changed (52) hide show
  1. {oodeel-0.1.0 → oodeel-0.2.0}/PKG-INFO +3 -3
  2. {oodeel-0.1.0 → oodeel-0.2.0}/README.md +2 -2
  3. {oodeel-0.1.0 → oodeel-0.2.0}/oodeel/__init__.py +1 -1
  4. {oodeel-0.1.0 → oodeel-0.2.0}/oodeel/eval/plots/features.py +2 -2
  5. {oodeel-0.1.0 → oodeel-0.2.0}/oodeel/eval/plots/plotly.py +2 -2
  6. {oodeel-0.1.0 → oodeel-0.2.0}/oodeel/extractor/feature_extractor.py +19 -9
  7. {oodeel-0.1.0 → oodeel-0.2.0}/oodeel/extractor/keras_feature_extractor.py +19 -12
  8. {oodeel-0.1.0 → oodeel-0.2.0}/oodeel/extractor/torch_feature_extractor.py +17 -12
  9. {oodeel-0.1.0 → oodeel-0.2.0}/oodeel/methods/__init__.py +2 -1
  10. {oodeel-0.1.0 → oodeel-0.2.0}/oodeel/methods/base.py +38 -5
  11. {oodeel-0.1.0 → oodeel-0.2.0}/oodeel/methods/dknn.py +2 -2
  12. oodeel-0.2.0/oodeel/methods/gram.py +296 -0
  13. {oodeel-0.1.0 → oodeel-0.2.0}/oodeel/methods/mahalanobis.py +5 -5
  14. {oodeel-0.1.0 → oodeel-0.2.0}/oodeel/methods/vim.py +4 -4
  15. {oodeel-0.1.0 → oodeel-0.2.0}/oodeel/utils/operator.py +55 -2
  16. {oodeel-0.1.0 → oodeel-0.2.0}/oodeel/utils/tf_operator.py +57 -4
  17. {oodeel-0.1.0 → oodeel-0.2.0}/oodeel/utils/tf_training_tools.py +24 -1
  18. {oodeel-0.1.0 → oodeel-0.2.0}/oodeel/utils/torch_operator.py +56 -4
  19. {oodeel-0.1.0 → oodeel-0.2.0}/oodeel/utils/torch_training_tools.py +31 -2
  20. {oodeel-0.1.0 → oodeel-0.2.0}/oodeel.egg-info/PKG-INFO +3 -3
  21. {oodeel-0.1.0 → oodeel-0.2.0}/oodeel.egg-info/SOURCES.txt +1 -0
  22. {oodeel-0.1.0 → oodeel-0.2.0}/setup.py +1 -1
  23. {oodeel-0.1.0 → oodeel-0.2.0}/tests/tests_torch/tools_torch.py +9 -9
  24. {oodeel-0.1.0 → oodeel-0.2.0}/tests/tools_operator.py +10 -1
  25. {oodeel-0.1.0 → oodeel-0.2.0}/LICENSE +0 -0
  26. {oodeel-0.1.0 → oodeel-0.2.0}/oodeel/datasets/__init__.py +0 -0
  27. {oodeel-0.1.0 → oodeel-0.2.0}/oodeel/datasets/data_handler.py +0 -0
  28. {oodeel-0.1.0 → oodeel-0.2.0}/oodeel/datasets/ooddataset.py +0 -0
  29. {oodeel-0.1.0 → oodeel-0.2.0}/oodeel/datasets/tf_data_handler.py +0 -0
  30. {oodeel-0.1.0 → oodeel-0.2.0}/oodeel/datasets/torch_data_handler.py +0 -0
  31. {oodeel-0.1.0 → oodeel-0.2.0}/oodeel/eval/__init__.py +0 -0
  32. {oodeel-0.1.0 → oodeel-0.2.0}/oodeel/eval/metrics.py +0 -0
  33. {oodeel-0.1.0 → oodeel-0.2.0}/oodeel/eval/plots/__init__.py +0 -0
  34. {oodeel-0.1.0 → oodeel-0.2.0}/oodeel/eval/plots/metrics.py +0 -0
  35. {oodeel-0.1.0 → oodeel-0.2.0}/oodeel/extractor/__init__.py +0 -0
  36. {oodeel-0.1.0 → oodeel-0.2.0}/oodeel/methods/energy.py +0 -0
  37. {oodeel-0.1.0 → oodeel-0.2.0}/oodeel/methods/entropy.py +0 -0
  38. {oodeel-0.1.0 → oodeel-0.2.0}/oodeel/methods/mls.py +0 -0
  39. {oodeel-0.1.0 → oodeel-0.2.0}/oodeel/methods/odin.py +0 -0
  40. {oodeel-0.1.0 → oodeel-0.2.0}/oodeel/types/__init__.py +0 -0
  41. {oodeel-0.1.0 → oodeel-0.2.0}/oodeel/utils/__init__.py +0 -0
  42. {oodeel-0.1.0 → oodeel-0.2.0}/oodeel/utils/general_utils.py +0 -0
  43. {oodeel-0.1.0 → oodeel-0.2.0}/oodeel.egg-info/dependency_links.txt +0 -0
  44. {oodeel-0.1.0 → oodeel-0.2.0}/oodeel.egg-info/requires.txt +0 -0
  45. {oodeel-0.1.0 → oodeel-0.2.0}/oodeel.egg-info/top_level.txt +0 -0
  46. {oodeel-0.1.0 → oodeel-0.2.0}/setup.cfg +0 -0
  47. {oodeel-0.1.0 → oodeel-0.2.0}/tests/__init__.py +0 -0
  48. {oodeel-0.1.0 → oodeel-0.2.0}/tests/tests_tensorflow/__init__.py +0 -0
  49. {oodeel-0.1.0 → oodeel-0.2.0}/tests/tests_tensorflow/tf_methods_utils.py +0 -0
  50. {oodeel-0.1.0 → oodeel-0.2.0}/tests/tests_tensorflow/tools_tf.py +0 -0
  51. {oodeel-0.1.0 → oodeel-0.2.0}/tests/tests_torch/__init__.py +0 -0
  52. {oodeel-0.1.0 → oodeel-0.2.0}/tests/tests_torch/torch_methods_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: oodeel
3
- Version: 0.1.0
3
+ Version: 0.2.0
4
4
  Summary: Simple, compact, and hackable post-hoc deep OOD detection for alreadytrained tensorflow or pytorch image classifiers.
5
5
  Author: DEEL Core Team
6
6
  Author-email: paul.novello@irt-saintexupery.com
@@ -194,14 +194,14 @@ Currently, **oodeel** includes the following baselines:
194
194
  | MSP | [A Baseline for Detecting Misclassified and Out-of-Distribution Examples in Neural Networks](http://arxiv.org/abs/1610.02136) | ICLR 2017 | avail [tensorflow & torch](docs/pages/getting_started.ipynb)|
195
195
  | Mahalanobis | [A Simple Unified Framework for Detecting Out-of-Distribution Samples and Adversarial Attacks](http://arxiv.org/abs/1807.03888) | NeurIPS 2018 | avail [tensorflow](docs/notebooks/tensorflow/demo_mahalanobis_tf.ipynb) or [torch](docs/notebooks/torch/demo_mahalanobis_torch.ipynb)|
196
196
  | Energy | [Energy-based Out-of-distribution Detection](http://arxiv.org/abs/2010.03759) | NeurIPS 2020 |avail [tensorflow](docs/notebooks/tensorflow/demo_energy_tf.ipynb) or [torch](docs/notebooks/torch/demo_energy_torch.ipynb) |
197
- | Odin | [Enhancing The Reliability of Out-of-distribution Image Detection in Neural Networks](http://arxiv.org/abs/1706.02690) | ICLR 2018 |avail [tensorflow](docs/notebooks/tensorflow/demo_odin_tf.ipynb) or [torch](docs/notebooks/torch/demo_odin_torch.ipynb) |
197
+ | Odin | [Enhancing The Reliability of Out-of-distribution Image Detection in Neural Networks](http://arxiv.org/abs/1706.02690) | ICLR 2018 | avail [tensorflow](docs/notebooks/tensorflow/demo_odin_tf.ipynb) or [torch](docs/notebooks/torch/demo_odin_torch.ipynb) |
198
198
  | DKNN | [Out-of-Distribution Detection with Deep Nearest Neighbors](http://arxiv.org/abs/2204.06507) | ICML 2022 | avail [tensorflow](docs/notebooks/tensorflow/demo_dknn_tf.ipynb) or [torch](docs/notebooks/torch/demo_dknn_torch.ipynb) |
199
199
  | VIM | [ViM: Out-Of-Distribution with Virtual-logit Matching](http://arxiv.org/abs/2203.10807) | CVPR 2022 |avail [tensorflow](docs/notebooks/tensorflow/demo_vim_tf.ipynb) or [torch](docs/notebooks/torch/demo_vim_torch.ipynb) |
200
200
  | Entropy | [Likelihood Ratios for Out-of-Distribution Detection](https://proceedings.neurips.cc/paper/2019/hash/1e79596878b2320cac26dd792a6c51c9-Abstract.html) | NeurIPS 2019 |avail [tensorflow](docs/notebooks/tensorflow/demo_entropy_tf.ipynb) or [torch](docs/notebooks/torch/demo_entropy_torch.ipynb) |
201
201
  | GODIN | [Generalized ODIN: Detecting Out-of-Distribution Image Without Learning From Out-of-Distribution Data](https://ieeexplore.ieee.org/document/9156473/) | CVPR 2020 | planned |
202
202
  | ReAct | [ReAct: Out-of-distribution Detection With Rectified Activations](http://arxiv.org/abs/2111.12797) | NeurIPS 2021 | avail [tensorflow](docs/notebooks/tensorflow/demo_react_tf.ipynb) or [torch](docs/notebooks/torch/demo_react_torch.ipynb) |
203
203
  | NMD | [Neural Mean Discrepancy for Efficient Out-of-Distribution Detection](https://openaccess.thecvf.com/content/CVPR2022/html/Dong_Neural_Mean_Discrepancy_for_Efficient_Out-of-Distribution_Detection_CVPR_2022_paper.html) | CVPR 2022 | planned |
204
- | Gram | [Detecting Out-of-Distribution Examples with Gram Matrices](https://proceedings.mlr.press/v119/sastry20a.html) | ICML 2020 | planned |
204
+ | Gram | [Detecting Out-of-Distribution Examples with Gram Matrices](https://proceedings.mlr.press/v119/sastry20a.html) | ICML 2020 | avail [tensorflow](docs/notebooks/tensorflow/demo_gram_tf.ipynb) or [torch](docs/notebooks/torch/demo_gram_torch.ipynb) |
205
205
 
206
206
 
207
207
 
@@ -173,14 +173,14 @@ Currently, **oodeel** includes the following baselines:
173
173
  | MSP | [A Baseline for Detecting Misclassified and Out-of-Distribution Examples in Neural Networks](http://arxiv.org/abs/1610.02136) | ICLR 2017 | avail [tensorflow & torch](docs/pages/getting_started.ipynb)|
174
174
  | Mahalanobis | [A Simple Unified Framework for Detecting Out-of-Distribution Samples and Adversarial Attacks](http://arxiv.org/abs/1807.03888) | NeurIPS 2018 | avail [tensorflow](docs/notebooks/tensorflow/demo_mahalanobis_tf.ipynb) or [torch](docs/notebooks/torch/demo_mahalanobis_torch.ipynb)|
175
175
  | Energy | [Energy-based Out-of-distribution Detection](http://arxiv.org/abs/2010.03759) | NeurIPS 2020 |avail [tensorflow](docs/notebooks/tensorflow/demo_energy_tf.ipynb) or [torch](docs/notebooks/torch/demo_energy_torch.ipynb) |
176
- | Odin | [Enhancing The Reliability of Out-of-distribution Image Detection in Neural Networks](http://arxiv.org/abs/1706.02690) | ICLR 2018 |avail [tensorflow](docs/notebooks/tensorflow/demo_odin_tf.ipynb) or [torch](docs/notebooks/torch/demo_odin_torch.ipynb) |
176
+ | Odin | [Enhancing The Reliability of Out-of-distribution Image Detection in Neural Networks](http://arxiv.org/abs/1706.02690) | ICLR 2018 | avail [tensorflow](docs/notebooks/tensorflow/demo_odin_tf.ipynb) or [torch](docs/notebooks/torch/demo_odin_torch.ipynb) |
177
177
  | DKNN | [Out-of-Distribution Detection with Deep Nearest Neighbors](http://arxiv.org/abs/2204.06507) | ICML 2022 | avail [tensorflow](docs/notebooks/tensorflow/demo_dknn_tf.ipynb) or [torch](docs/notebooks/torch/demo_dknn_torch.ipynb) |
178
178
  | VIM | [ViM: Out-Of-Distribution with Virtual-logit Matching](http://arxiv.org/abs/2203.10807) | CVPR 2022 |avail [tensorflow](docs/notebooks/tensorflow/demo_vim_tf.ipynb) or [torch](docs/notebooks/torch/demo_vim_torch.ipynb) |
179
179
  | Entropy | [Likelihood Ratios for Out-of-Distribution Detection](https://proceedings.neurips.cc/paper/2019/hash/1e79596878b2320cac26dd792a6c51c9-Abstract.html) | NeurIPS 2019 |avail [tensorflow](docs/notebooks/tensorflow/demo_entropy_tf.ipynb) or [torch](docs/notebooks/torch/demo_entropy_torch.ipynb) |
180
180
  | GODIN | [Generalized ODIN: Detecting Out-of-Distribution Image Without Learning From Out-of-Distribution Data](https://ieeexplore.ieee.org/document/9156473/) | CVPR 2020 | planned |
181
181
  | ReAct | [ReAct: Out-of-distribution Detection With Rectified Activations](http://arxiv.org/abs/2111.12797) | NeurIPS 2021 | avail [tensorflow](docs/notebooks/tensorflow/demo_react_tf.ipynb) or [torch](docs/notebooks/torch/demo_react_torch.ipynb) |
182
182
  | NMD | [Neural Mean Discrepancy for Efficient Out-of-Distribution Detection](https://openaccess.thecvf.com/content/CVPR2022/html/Dong_Neural_Mean_Discrepancy_for_Efficient_Out-of-Distribution_Detection_CVPR_2022_paper.html) | CVPR 2022 | planned |
183
- | Gram | [Detecting Out-of-Distribution Examples with Gram Matrices](https://proceedings.mlr.press/v119/sastry20a.html) | ICML 2020 | planned |
183
+ | Gram | [Detecting Out-of-Distribution Examples with Gram Matrices](https://proceedings.mlr.press/v119/sastry20a.html) | ICML 2020 | avail [tensorflow](docs/notebooks/tensorflow/demo_gram_tf.ipynb) or [torch](docs/notebooks/torch/demo_gram_torch.ipynb) |
184
184
 
185
185
 
186
186
 
@@ -25,4 +25,4 @@ oodeel
25
25
  -------
26
26
  """
27
27
 
28
- __version__ = "0.1.0"
28
+ __version__ = "0.2.0"
@@ -168,7 +168,7 @@ def _plot_features(
168
168
  # === extract id features ===
169
169
  # features
170
170
  in_features, _ = feature_extractor.predict(in_dataset)
171
- in_features = op.convert_to_numpy(op.flatten(in_features))[:max_samples]
171
+ in_features = op.convert_to_numpy(op.flatten(in_features[0]))[:max_samples]
172
172
 
173
173
  # labels
174
174
  in_labels = []
@@ -181,7 +181,7 @@ def _plot_features(
181
181
  if out_dataset is not None:
182
182
  # features
183
183
  out_features, _ = feature_extractor.predict(out_dataset)
184
- out_features = op.convert_to_numpy(op.flatten(out_features))[:max_samples]
184
+ out_features = op.convert_to_numpy(op.flatten(out_features[0]))[:max_samples]
185
185
 
186
186
  # labels
187
187
  out_labels_str = np.array(["unknown"] * len(out_features))
@@ -82,7 +82,7 @@ def plotly_3D_features(
82
82
  # === extract id features ===
83
83
  # features
84
84
  in_features, _ = feature_extractor.predict(in_dataset)
85
- in_features = op.convert_to_numpy(op.flatten(in_features))[:max_samples]
85
+ in_features = op.convert_to_numpy(op.flatten(in_features[0]))[:max_samples]
86
86
 
87
87
  # labels
88
88
  in_labels = []
@@ -95,7 +95,7 @@ def plotly_3D_features(
95
95
  if out_dataset is not None:
96
96
  # features
97
97
  out_features, _ = feature_extractor.predict(out_dataset)
98
- out_features = op.convert_to_numpy(op.flatten(out_features))[:max_samples]
98
+ out_features = op.convert_to_numpy(op.flatten(out_features[0]))[:max_samples]
99
99
 
100
100
  # labels
101
101
  out_labels = np.array(["unknown"] * len(out_features))
@@ -25,6 +25,7 @@ from abc import abstractmethod
25
25
 
26
26
  from ..types import Callable
27
27
  from ..types import DatasetType
28
+ from ..types import ItemType
28
29
  from ..types import List
29
30
  from ..types import Optional
30
31
  from ..types import TensorType
@@ -91,28 +92,37 @@ class FeatureExtractor(ABC):
91
92
  raise NotImplementedError()
92
93
 
93
94
  @abstractmethod
94
- def predict_tensor(self, tensor: TensorType) -> Tuple[List[TensorType], TensorType]:
95
- """
96
- Projects input samples "inputs" into the feature space
95
+ def predict_tensor(
96
+ self,
97
+ tensor: TensorType,
98
+ postproc_fns: Optional[List[Callable]] = None,
99
+ ) -> Tuple[List[TensorType], TensorType]:
100
+ """Get the projection of tensor in the feature space of self.model
97
101
 
98
102
  Args:
99
- tensor (TensorType): input tensor
103
+ tensor (TensorType): input tensor (or dataset elem)
104
+ postproc_fns (Optional[Callable]): postprocessing function to apply to each
105
+ feature immediately after forward. Default to None.
100
106
 
101
107
  Returns:
102
- List[TensorType], TensorType: features, logits
108
+ Tuple[List[TensorType], TensorType]: features, logits
103
109
  """
104
110
  raise NotImplementedError()
105
111
 
106
112
  @abstractmethod
107
113
  def predict(
108
114
  self,
109
- dataset: Union[DatasetType, TensorType],
115
+ dataset: Union[ItemType, DatasetType],
116
+ postproc_fns: Optional[List[Callable]] = None,
117
+ **kwargs,
110
118
  ) -> Tuple[List[TensorType], dict]:
111
- """
112
- Projects input samples "inputs" into the feature space for a batched dataset
119
+ """Get the projection of the dataset in the feature space of self.model
113
120
 
114
121
  Args:
115
- dataset (Union[DatasetType, TensorType]): iterable of tensor batches
122
+ dataset (Union[ItemType, DatasetType]): input dataset
123
+ postproc_fns (Optional[Callable]): postprocessing function to apply to each
124
+ feature immediately after forward. Default to None.
125
+ kwargs (dict): additional arguments not considered for prediction
116
126
 
117
127
  Returns:
118
128
  List[TensorType], dict: features and extra information (logits, labels) as a
@@ -150,11 +150,17 @@ class KerasFeatureExtractor(FeatureExtractor):
150
150
  return extractor
151
151
 
152
152
  @sanitize_input
153
- def predict_tensor(self, tensor: TensorType) -> Tuple[List[tf.Tensor], tf.Tensor]:
153
+ def predict_tensor(
154
+ self,
155
+ tensor: TensorType,
156
+ postproc_fns: Optional[List[Callable]] = None,
157
+ ) -> Tuple[List[tf.Tensor], tf.Tensor]:
154
158
  """Get the projection of tensor in the feature space of self.model
155
159
 
156
160
  Args:
157
161
  tensor (TensorType): input tensor (or dataset elem)
162
+ postproc_fns (Optional[List[Callable]]): postprocessing function to apply to
163
+ each feature immediately after forward. Default to None.
158
164
 
159
165
  Returns:
160
166
  Tuple[List[tf.Tensor], tf.Tensor]: features, logits
@@ -166,8 +172,12 @@ class KerasFeatureExtractor(FeatureExtractor):
166
172
 
167
173
  # split features and logits
168
174
  logits = features.pop()
169
- if len(features) == 1:
170
- features = features[0]
175
+
176
+ if postproc_fns is not None:
177
+ features = [
178
+ postproc_fn(feature)
179
+ for feature, postproc_fn in zip(features, postproc_fns)
180
+ ]
171
181
 
172
182
  self._last_logits = logits
173
183
  return features, logits
@@ -179,12 +189,15 @@ class KerasFeatureExtractor(FeatureExtractor):
179
189
  def predict(
180
190
  self,
181
191
  dataset: Union[ItemType, tf.data.Dataset],
192
+ postproc_fns: Optional[List[Callable]] = None,
182
193
  **kwargs,
183
194
  ) -> Tuple[List[tf.Tensor], dict]:
184
195
  """Get the projection of the dataset in the feature space of self.model
185
196
 
186
197
  Args:
187
198
  dataset (Union[ItemType, tf.data.Dataset]): input dataset
199
+ postproc_fns (Optional[Callable]): postprocessing function to apply to each
200
+ feature immediately after forward. Default to None.
188
201
  kwargs (dict): additional arguments not considered for prediction
189
202
 
190
203
  Returns:
@@ -195,7 +208,7 @@ class KerasFeatureExtractor(FeatureExtractor):
195
208
 
196
209
  if isinstance(dataset, get_args(ItemType)):
197
210
  tensor = TFDataHandler.get_input_from_dataset_item(dataset)
198
- features, logits = self.predict_tensor(tensor)
211
+ features, logits = self.predict_tensor(tensor, postproc_fns)
199
212
 
200
213
  # Get labels if dataset is a tuple/list
201
214
  if isinstance(dataset, (list, tuple)):
@@ -207,10 +220,8 @@ class KerasFeatureExtractor(FeatureExtractor):
207
220
  contains_labels = TFDataHandler.get_item_length(dataset) > 1
208
221
  for elem in dataset:
209
222
  tensor = TFDataHandler.get_input_from_dataset_item(elem)
210
- features_batch, logits_batch = self.predict_tensor(tensor)
211
- # concatenate features
212
- if len(features) == 1:
213
- features_batch = [features_batch]
223
+ features_batch, logits_batch = self.predict_tensor(tensor, postproc_fns)
224
+
214
225
  for i, f in enumerate(features_batch):
215
226
  features[i] = (
216
227
  f
@@ -234,10 +245,6 @@ class KerasFeatureExtractor(FeatureExtractor):
234
245
 
235
246
  # store extra information in a dict
236
247
  info = dict(labels=labels, logits=logits)
237
-
238
- if len(features) == 1:
239
- features = features[0]
240
-
241
248
  return features, info
242
249
 
243
250
  def get_weights(self, layer_id: Union[int, str]) -> List[tf.Tensor]:
@@ -181,12 +181,17 @@ class TorchFeatureExtractor(FeatureExtractor):
181
181
 
182
182
  @sanitize_input
183
183
  def predict_tensor(
184
- self, x: TensorType, detach: bool = True
184
+ self,
185
+ x: TensorType,
186
+ postproc_fns: Optional[List[Callable]] = None,
187
+ detach: bool = True,
185
188
  ) -> Tuple[List[torch.Tensor], torch.Tensor]:
186
189
  """Get the projection of tensor in the feature space of self.model
187
190
 
188
191
  Args:
189
192
  x (TensorType): input tensor (or dataset elem)
193
+ postproc_fns (Optional[List[Callable]]): postprocessing function to apply to
194
+ each feature immediately after forward. Default to None.
190
195
  detach (bool): if True, return features detached from the computational
191
196
  graph. Defaults to True.
192
197
 
@@ -206,8 +211,12 @@ class TorchFeatureExtractor(FeatureExtractor):
206
211
 
207
212
  # split features and logits
208
213
  logits = features.pop()
209
- if len(features) == 1:
210
- features = features[0]
214
+
215
+ if postproc_fns is not None:
216
+ features = [
217
+ postproc_fn(feature)
218
+ for feature, postproc_fn in zip(features, postproc_fns)
219
+ ]
211
220
 
212
221
  self._last_logits = logits
213
222
  return features, logits
@@ -215,6 +224,7 @@ class TorchFeatureExtractor(FeatureExtractor):
215
224
  def predict(
216
225
  self,
217
226
  dataset: Union[DataLoader, ItemType],
227
+ postproc_fns: Optional[List[Callable]] = None,
218
228
  detach: bool = True,
219
229
  **kwargs,
220
230
  ) -> Tuple[List[torch.Tensor], dict]:
@@ -222,6 +232,8 @@ class TorchFeatureExtractor(FeatureExtractor):
222
232
 
223
233
  Args:
224
234
  dataset (Union[DataLoader, ItemType]): input dataset
235
+ postproc_fns (Optional[List[Callable]]): postprocessing function to apply to
236
+ each feature immediately after forward. Default to None.
225
237
  detach (bool): if True, return features detached from the computational
226
238
  graph. Defaults to True.
227
239
  kwargs (dict): additional arguments not considered for prediction
@@ -234,7 +246,7 @@ class TorchFeatureExtractor(FeatureExtractor):
234
246
 
235
247
  if isinstance(dataset, get_args(ItemType)):
236
248
  tensor = TorchDataHandler.get_input_from_dataset_item(dataset)
237
- features, logits = self.predict_tensor(tensor, detach=detach)
249
+ features, logits = self.predict_tensor(tensor, postproc_fns, detach=detach)
238
250
 
239
251
  # Get labels if dataset is a tuple/list
240
252
  if isinstance(dataset, (list, tuple)) and len(dataset) > 1:
@@ -248,11 +260,8 @@ class TorchFeatureExtractor(FeatureExtractor):
248
260
  for elem in dataset:
249
261
  tensor = TorchDataHandler.get_input_from_dataset_item(elem)
250
262
  features_batch, logits_batch = self.predict_tensor(
251
- tensor, detach=detach
263
+ tensor, postproc_fns, detach=detach
252
264
  )
253
- # concatenate features
254
- if len(features) == 1:
255
- features_batch = [features_batch]
256
265
  for i, f in enumerate(features_batch):
257
266
  features[i] = (
258
267
  f if features[i] is None else torch.cat([features[i], f], dim=0)
@@ -274,10 +283,6 @@ class TorchFeatureExtractor(FeatureExtractor):
274
283
 
275
284
  # store extra information in a dict
276
285
  info = dict(labels=labels, logits=logits)
277
-
278
- if len(features) == 1:
279
- features = features[0]
280
-
281
286
  return features, info
282
287
 
283
288
  def get_weights(self, layer_id: Union[str, int]) -> List[torch.Tensor]:
@@ -23,9 +23,10 @@
23
23
  from .dknn import DKNN
24
24
  from .energy import Energy
25
25
  from .entropy import Entropy
26
+ from .gram import Gram
26
27
  from .mahalanobis import Mahalanobis
27
28
  from .mls import MLS
28
29
  from .odin import ODIN
29
30
  from .vim import VIM
30
31
 
31
- __all__ = ["MLS", "DKNN", "ODIN", "Energy", "VIM", "Mahalanobis", "Entropy"]
32
+ __all__ = ["MLS", "DKNN", "ODIN", "Energy", "VIM", "Mahalanobis", "Entropy", "Gram"]
@@ -52,11 +52,13 @@ class OODBaseDetector(ABC):
52
52
  self,
53
53
  use_react: bool = False,
54
54
  react_quantile: float = 0.8,
55
+ postproc_fns: List[Callable] = None,
55
56
  ):
56
57
  self.feature_extractor: FeatureExtractor = None
57
58
  self.use_react = use_react
58
59
  self.react_quantile = react_quantile
59
60
  self.react_threshold = None
61
+ self.postproc_fns = self._sanitize_posproc_fns(postproc_fns)
60
62
 
61
63
  @abstractmethod
62
64
  def _score_tensor(self, inputs: TensorType) -> np.ndarray:
@@ -66,18 +68,43 @@ class OODBaseDetector(ABC):
66
68
 
67
69
  Args:
68
70
  inputs (TensorType): tensor to score
69
-
70
71
  Returns:
71
72
  Tuple[TensorType]: OOD scores, predicted logits
72
73
  """
73
74
  raise NotImplementedError()
74
75
 
76
+ def _sanitize_posproc_fns(
77
+ self,
78
+ postproc_fns: Union[List[Callable], None],
79
+ ) -> List[Callable]:
80
+ """Sanitize postproc fns used at each layer output of the feature extractor.
81
+
82
+ Args:
83
+ postproc_fns (Optional[List[Callable]], optional): List of postproc
84
+ functions, one per output layer. Defaults to None.
85
+
86
+ Returns:
87
+ List[Callable]: Sanitized postproc_fns list
88
+ """
89
+ if postproc_fns is not None:
90
+ assert len(postproc_fns) == len(
91
+ self.output_layers_id
92
+ ), "len of postproc_fns and output_layers_id must match"
93
+
94
+ def identity(x):
95
+ return x
96
+
97
+ postproc_fns = [identity if fn is None else fn for fn in postproc_fns]
98
+
99
+ return postproc_fns
100
+
75
101
  def fit(
76
102
  self,
77
103
  model: Callable,
78
104
  fit_dataset: Optional[Union[ItemType, DatasetType]] = None,
79
105
  feature_layers_id: List[Union[int, str]] = [],
80
106
  input_layer_id: Optional[Union[int, str]] = None,
107
+ **kwargs,
81
108
  ) -> None:
82
109
  """Prepare the detector for scoring:
83
110
  * Constructs the feature extractor based on the model
@@ -133,7 +160,7 @@ class OODBaseDetector(ABC):
133
160
  )
134
161
 
135
162
  if fit_dataset is not None:
136
- self._fit_to_dataset(fit_dataset)
163
+ self._fit_to_dataset(fit_dataset, **kwargs)
137
164
 
138
165
  def _load_feature_extractor(
139
166
  self,
@@ -207,7 +234,7 @@ class OODBaseDetector(ABC):
207
234
  # Case 2: dataset is a tf.data.Dataset or a torch.DataLoader
208
235
  elif isinstance(dataset, get_args(DatasetType)):
209
236
  scores = np.array([])
210
- logits = np.empty((0, 2))
237
+ logits = None
211
238
 
212
239
  for item in dataset:
213
240
  tensor = self.data_handler.get_input_from_dataset_item(item)
@@ -226,7 +253,11 @@ class OODBaseDetector(ABC):
226
253
  )
227
254
 
228
255
  scores = np.append(scores, score_batch)
229
- logits = np.concatenate([logits, logits_batch])
256
+ logits = (
257
+ logits_batch
258
+ if logits is None
259
+ else np.concatenate([logits, logits_batch], axis=0)
260
+ )
230
261
 
231
262
  else:
232
263
  raise NotImplementedError(
@@ -239,7 +270,9 @@ class OODBaseDetector(ABC):
239
270
  def compute_react_threshold(self, model: Callable, fit_dataset: DatasetType):
240
271
  penult_feat_extractor = self._load_feature_extractor(model, [-2])
241
272
  unclipped_features, _ = penult_feat_extractor.predict(fit_dataset)
242
- self.react_threshold = self.op.quantile(unclipped_features, self.react_quantile)
273
+ self.react_threshold = self.op.quantile(
274
+ unclipped_features[0], self.react_quantile
275
+ )
243
276
 
244
277
  def __call__(self, inputs: Union[ItemType, DatasetType]) -> np.ndarray:
245
278
  """
@@ -58,7 +58,7 @@ class DKNN(OODBaseDetector):
58
58
  fit_dataset: input dataset (ID) to construct the index with.
59
59
  """
60
60
  fit_projected, _ = self.feature_extractor.predict(fit_dataset)
61
- fit_projected = self.op.convert_to_numpy(fit_projected)
61
+ fit_projected = self.op.convert_to_numpy(fit_projected[0])
62
62
  fit_projected = fit_projected.reshape(fit_projected.shape[0], -1)
63
63
  norm_fit_projected = self._l2_normalization(fit_projected)
64
64
  self.index = faiss.IndexFlatL2(norm_fit_projected.shape[1])
@@ -77,7 +77,7 @@ class DKNN(OODBaseDetector):
77
77
  """
78
78
 
79
79
  input_projected, _ = self.feature_extractor.predict_tensor(inputs)
80
- input_projected = self.op.convert_to_numpy(input_projected)
80
+ input_projected = self.op.convert_to_numpy(input_projected[0])
81
81
  input_projected = input_projected.reshape(input_projected.shape[0], -1)
82
82
  norm_input_projected = self._l2_normalization(input_projected)
83
83
  scores, _ = self.index.search(norm_input_projected, self.nearest)