oodeel 0.1.1__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of oodeel might be problematic. Click here for more details.

Files changed (47) hide show
  1. oodeel/__init__.py +1 -1
  2. oodeel/datasets/__init__.py +2 -1
  3. oodeel/datasets/data_handler.py +162 -94
  4. oodeel/datasets/deprecated/DEPRECATED_data_handler.py +236 -0
  5. oodeel/datasets/{ooddataset.py → deprecated/DEPRECATED_ooddataset.py} +14 -13
  6. oodeel/datasets/deprecated/DEPRECATED_tf_data_handler.py +671 -0
  7. oodeel/datasets/deprecated/DEPRECATED_torch_data_handler.py +769 -0
  8. oodeel/datasets/deprecated/__init__.py +31 -0
  9. oodeel/datasets/tf_data_handler.py +105 -167
  10. oodeel/datasets/torch_data_handler.py +109 -181
  11. oodeel/eval/metrics.py +7 -2
  12. oodeel/eval/plots/features.py +2 -2
  13. oodeel/eval/plots/plotly.py +2 -2
  14. oodeel/extractor/feature_extractor.py +30 -9
  15. oodeel/extractor/keras_feature_extractor.py +70 -13
  16. oodeel/extractor/torch_feature_extractor.py +120 -33
  17. oodeel/methods/__init__.py +17 -1
  18. oodeel/methods/base.py +103 -17
  19. oodeel/methods/dknn.py +22 -9
  20. oodeel/methods/energy.py +8 -0
  21. oodeel/methods/entropy.py +8 -0
  22. oodeel/methods/gen.py +118 -0
  23. oodeel/methods/gram.py +307 -0
  24. oodeel/methods/mahalanobis.py +14 -12
  25. oodeel/methods/mls.py +8 -0
  26. oodeel/methods/odin.py +8 -0
  27. oodeel/methods/rmds.py +122 -0
  28. oodeel/methods/she.py +197 -0
  29. oodeel/methods/vim.py +5 -5
  30. oodeel/preprocess/__init__.py +31 -0
  31. oodeel/preprocess/tf_preprocess.py +95 -0
  32. oodeel/preprocess/torch_preprocess.py +97 -0
  33. oodeel/utils/operator.py +72 -2
  34. oodeel/utils/tf_operator.py +72 -4
  35. oodeel/utils/tf_training_tools.py +26 -3
  36. oodeel/utils/torch_operator.py +75 -4
  37. oodeel/utils/torch_training_tools.py +31 -2
  38. {oodeel-0.1.1.dist-info → oodeel-0.3.0.dist-info}/METADATA +141 -107
  39. oodeel-0.3.0.dist-info/RECORD +57 -0
  40. {oodeel-0.1.1.dist-info → oodeel-0.3.0.dist-info}/WHEEL +1 -1
  41. tests/tests_tensorflow/tf_methods_utils.py +2 -1
  42. tests/tests_torch/tools_torch.py +9 -9
  43. tests/tests_torch/torch_methods_utils.py +34 -27
  44. tests/tools_operator.py +10 -1
  45. oodeel-0.1.1.dist-info/RECORD +0 -46
  46. {oodeel-0.1.1.dist-info → oodeel-0.3.0.dist-info/licenses}/LICENSE +0 -0
  47. {oodeel-0.1.1.dist-info → oodeel-0.3.0.dist-info}/top_level.txt +0 -0
oodeel/methods/gram.py ADDED
@@ -0,0 +1,307 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright IRT Antoine de Saint Exupéry et Université Paul Sabatier Toulouse III - All
3
+ # rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
4
+ # CRIAQ and ANITI - https://www.deel.ai/
5
+ #
6
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
7
+ # of this software and associated documentation files (the "Software"), to deal
8
+ # in the Software without restriction, including without limitation the rights
9
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
+ # copies of the Software, and to permit persons to whom the Software is
11
+ # furnished to do so, subject to the following conditions:
12
+ #
13
+ # The above copyright notice and this permission notice shall be included in all
14
+ # copies or substantial portions of the Software.
15
+ #
16
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
+ # SOFTWARE.
23
+ import numpy as np
24
+ from sklearn.model_selection import train_test_split
25
+
26
+ from ..types import DatasetType
27
+ from ..types import List
28
+ from ..types import TensorType
29
+ from ..types import Union
30
+ from .base import OODBaseDetector
31
+
32
+
33
+ class Gram(OODBaseDetector):
34
+ r"""
35
+ "Detecting Out-of-Distribution Examples with Gram Matrices"
36
+ [link](https://proceedings.mlr.press/v119/sastry20a.html)
37
+
38
+ **Important Disclaimer**: Taking the statistics of min/max deviation,
39
+ as in the paper raises some problems.
40
+
41
+ The method often yields a score of zero for some tasks.
42
+ This is expected since the min/max among the samples of a random
43
+ variable becomes more and more extreme with the sample
44
+ size. As a result, computing the min/max over the training set is likely to produce
45
+ min/max values that are so extreme that none of the in distribution correlations of
46
+ the validation set goes beyond these threshold. The worst is that a significant
47
+ part of ood data does not exceed the thresholds either. This can be aleviated by
48
+ computing the min/max over a limited number of sample. However, it is
49
+ counter-intuitive and, in our opinion, not desirable: adding
50
+ some more information should only improve a method.
51
+
52
+ Hence, we decided to replace the min/max by the q / 1-q quantile, with q a new
53
+ parameter of the method. Specifically, instead of the deviation as defined in
54
+ eq. 3 of the paper, we use the definition
55
+ $$
56
+ \delta(t_q, t_{1-q}, value) =
57
+ \begin{cases}
58
+ 0 & \text{if} \; t_q \leq value \leq t_{1-q}, \;\;
59
+ \frac{t_q - value}{|t_q|} & \text{if } value < t_q, \;\;
60
+ \frac{value - t_{1-q}}{|t_q|} & \text{if } value > t_{1-q}
61
+ \end{cases}
62
+ $$
63
+ With this new deviation, the more point we add, the more accurate the quantile
64
+ becomes. In addition, the method can be made more or less discriminative by
65
+ toggling the value of q.
66
+
67
+ Finally, we found that this approach improved the performance of the baseline in
68
+ our experiments.
69
+
70
+ Args:
71
+ orders (List[int]): power orders to consider for the correlation matrix
72
+ quantile (float): quantile to consider for the correlations to build the
73
+ deviation threshold.
74
+
75
+ """
76
+
77
+ def __init__(
78
+ self,
79
+ orders: List[int] = [i for i in range(1, 6)],
80
+ quantile: float = 0.01,
81
+ ):
82
+ super().__init__()
83
+ if isinstance(orders, int):
84
+ orders = [orders]
85
+ self.orders = orders
86
+ self.postproc_fns = None
87
+ self.quantile = quantile
88
+
89
+ def _fit_to_dataset(
90
+ self,
91
+ fit_dataset: Union[TensorType, DatasetType],
92
+ val_split: float = 0.2,
93
+ verbose: bool = False,
94
+ ) -> None:
95
+ """
96
+ Compute the quantiles of channelwise correlations for each layer, power of
97
+ gram matrices, and class. Then, compute the normalization constants for the
98
+ deviation. To stay faithful to the spirit of the original method, we still name
99
+ the quantiles min/max
100
+
101
+ Args:
102
+ fit_dataset (Union[TensorType, DatasetType]): input dataset (ID) to
103
+ construct the index with.
104
+ val_split (float): The percentage of fit data to use as validation data for
105
+ normalization. Default to 0.2.
106
+ verbose (bool): Whether to print information during the fitting process.
107
+ Default to False.
108
+ """
109
+ self.postproc_fns = [
110
+ self._stat for i in range(len(self.feature_extractor.feature_layers_id))
111
+ ]
112
+
113
+ # fit_stats shape: [n_features, n_samples, n_orders, n_channels]
114
+ fit_stats, info = self.feature_extractor.predict(
115
+ fit_dataset,
116
+ postproc_fns=self.postproc_fns,
117
+ return_labels=True,
118
+ verbose=verbose,
119
+ )
120
+ labels = info["labels"]
121
+ self._classes = np.sort(np.unique(self.op.convert_to_numpy(labels)))
122
+
123
+ full_indices = np.arange(labels.shape[0])
124
+ train_indices, val_indices = train_test_split(full_indices, test_size=val_split)
125
+ train_indices = self.op.from_numpy(
126
+ [bool(ind in train_indices) for ind in full_indices]
127
+ )
128
+ val_indices = self.op.from_numpy(
129
+ [bool(ind in val_indices) for ind in full_indices]
130
+ )
131
+
132
+ val_stats = [fit_stat[val_indices] for fit_stat in fit_stats]
133
+ fit_stats = [fit_stat[train_indices] for fit_stat in fit_stats]
134
+ labels = labels[train_indices]
135
+
136
+ self.min_maxs = dict()
137
+ for cls in self._classes:
138
+ indexes = self.op.equal(labels, cls)
139
+ min_maxs = []
140
+ for fit_stat in fit_stats:
141
+ fit_stat = fit_stat[indexes]
142
+ mins = self.op.unsqueeze(
143
+ self.op.quantile(fit_stat, self.quantile, dim=0), -1
144
+ )
145
+ maxs = self.op.unsqueeze(
146
+ self.op.quantile(fit_stat, 1 - self.quantile, dim=0), -1
147
+ )
148
+ min_max = self.op.cat([mins, maxs], dim=-1)
149
+ min_maxs.append(min_max)
150
+
151
+ self.min_maxs[cls] = min_maxs
152
+
153
+ devnorm = []
154
+ for cls in self._classes:
155
+ min_maxs = []
156
+ for min_max in self.min_maxs[cls]:
157
+ min_maxs.append(
158
+ self.op.stack([min_max for i in range(val_stats[0].shape[0])])
159
+ )
160
+ devnorm.append(
161
+ [
162
+ float(self.op.mean(dev))
163
+ for dev in self._deviation(val_stats, min_maxs)
164
+ ]
165
+ )
166
+ self.devnorm = np.mean(np.array(devnorm), axis=0)
167
+
168
+ def _score_tensor(self, inputs: TensorType) -> np.ndarray:
169
+ """
170
+ Computes an OOD score for input samples "inputs" based on
171
+ the aggregation of deviations from quantiles of in-distribution channel-wise
172
+ correlations evaluate for each layer, power of gram matrices, and class.
173
+
174
+ Args:
175
+ inputs: input samples to score
176
+
177
+ Returns:
178
+ scores
179
+ """
180
+
181
+ tensor_stats, _ = self.feature_extractor.predict_tensor(
182
+ inputs, postproc_fns=self.postproc_fns
183
+ )
184
+
185
+ _, logits = self.feature_extractor.predict_tensor(inputs)
186
+ preds = self.op.convert_to_numpy(self.op.argmax(logits, dim=1))
187
+
188
+ # We stack the min_maxs for each class depending on the prediction for each
189
+ # samples
190
+ min_maxs = []
191
+ for i in range(len(tensor_stats)):
192
+ min_maxs.append(self.op.stack([self.min_maxs[label][i] for label in preds]))
193
+
194
+ tensor_dev = self._deviation(tensor_stats, min_maxs)
195
+ score = self.op.mean(
196
+ self.op.cat(
197
+ [
198
+ self.op.unsqueeze(tensor_dev_l, dim=0) / devnorm_l
199
+ for tensor_dev_l, devnorm_l in zip(tensor_dev, self.devnorm)
200
+ ]
201
+ ),
202
+ dim=0,
203
+ )
204
+ return self.op.convert_to_numpy(score)
205
+
206
+ def _deviation(
207
+ self, stats: List[TensorType], min_maxs: List[TensorType]
208
+ ) -> List[TensorType]:
209
+ """Compute the deviation wrt quantiles (min/max) for feature_maps
210
+
211
+ Args:
212
+ stats (TensorType): The list of gram matrices (stacked power-wise)
213
+ for which we want to compute the deviation.
214
+ min_maxs (TensorType): The quantiles (tensorised) to compute the deviation
215
+ against.
216
+
217
+ Returns:
218
+ List(TensorType): A list with one element per layer containing a tensor of
219
+ per-sample deviation.
220
+ """
221
+ deviation = []
222
+ for stat, min_max in zip(stats, min_maxs):
223
+ where_min = self.op.where(stat < min_max[..., 0], 1.0, 0.0)
224
+ where_max = self.op.where(stat > min_max[..., 1], 1.0, 0.0)
225
+ deviation_min = (
226
+ (min_max[..., 0] - stat)
227
+ / (self.op.abs(min_max[..., 0]) + 1e-6)
228
+ * where_min
229
+ )
230
+ deviation_max = (
231
+ (stat - min_max[..., 1])
232
+ / (self.op.abs(min_max[..., 1]) + 1e-6)
233
+ * where_max
234
+ )
235
+ deviation.append(self.op.sum(deviation_min + deviation_max, dim=(1, 2)))
236
+ return deviation
237
+
238
+ def _stat(self, feature_map: TensorType) -> TensorType:
239
+ """Compute the correlation map (stat) for a given feature map. The values
240
+ for each power of gram matrix are contained in the same tensor
241
+
242
+ Args:
243
+ feature_map (TensorType): The input feature_map
244
+
245
+ Returns:
246
+ TensorType: The stacked gram matrices power-wise.
247
+ """
248
+ fm_s = feature_map.shape
249
+ stat = []
250
+ for p in self.orders:
251
+ feature_map_p = feature_map**p
252
+ # construct the Gram matrix
253
+ if len(fm_s) == 2:
254
+ # build gram matrix for feature map of shape [dim_dense_layer, 1]
255
+ feature_map_p = self.op.einsum(
256
+ "bi,bj->bij", feature_map_p, feature_map_p
257
+ )
258
+ elif len(fm_s) >= 3:
259
+ # flatten the feature map
260
+ if self.backend == "tensorflow":
261
+ feature_map_p = self.op.reshape(
262
+ self.op.einsum("i...j->ij...", feature_map_p),
263
+ (fm_s[0], fm_s[-1], -1),
264
+ )
265
+ else:
266
+ # batch, channel, spatial
267
+ feature_map_p = self.op.reshape(
268
+ feature_map_p, (fm_s[0], fm_s[1], -1)
269
+ )
270
+ # batch, channel, channel
271
+ feature_map_p = self.op.matmul(
272
+ feature_map_p, self.op.permute(feature_map_p, (0, 2, 1))
273
+ )
274
+ # normalize the Gram matrix
275
+ feature_map_p = self.op.sign(feature_map_p) * (
276
+ self.op.abs(feature_map_p) ** (1 / p)
277
+ )
278
+ # get the lower triangular part of the matrix
279
+ feature_map_p = self.op.tril(feature_map_p)
280
+ # directly sum row-wise (to limit computational burden) -> batch, channel
281
+ feature_map_p = self.op.sum(feature_map_p, dim=2)
282
+ # stat.append(self.op.t(feature_map_p))
283
+ stat.append(feature_map_p)
284
+ # batch, n_orders, channel
285
+ stat = self.op.stack(stat, 1)
286
+ return stat
287
+
288
+ @property
289
+ def requires_to_fit_dataset(self) -> bool:
290
+ """
291
+ Whether an OOD detector needs a `fit_dataset` argument in the fit function.
292
+
293
+ Returns:
294
+ bool: True if `fit_dataset` is required else False.
295
+ """
296
+ return True
297
+
298
+ @property
299
+ def requires_internal_features(self) -> bool:
300
+ """
301
+ Whether an OOD detector acts on internal model features.
302
+
303
+ Returns:
304
+ bool: True if the detector perform computations on an intermediate layer
305
+ else False.
306
+ """
307
+ return True
@@ -55,7 +55,7 @@ class Mahalanobis(OODBaseDetector):
55
55
  fit_dataset (Union[TensorType, DatasetType]): input dataset (ID)
56
56
  """
57
57
  # extract features and labels
58
- features, infos = self.feature_extractor.predict(fit_dataset)
58
+ features, infos = self.feature_extractor.predict(fit_dataset, detach=True)
59
59
  labels = infos["labels"]
60
60
 
61
61
  # unique sorted classes
@@ -63,22 +63,24 @@ class Mahalanobis(OODBaseDetector):
63
63
 
64
64
  # compute mus and covs
65
65
  mus = dict()
66
- covs = dict()
66
+ mean_cov = None
67
67
  for cls in self._classes:
68
68
  indexes = self.op.equal(labels, cls)
69
- _features_cls = self.op.flatten(features[indexes])
69
+ _features_cls = self.op.flatten(features[0][indexes])
70
70
  mus[cls] = self.op.mean(_features_cls, dim=0)
71
71
  _zero_f_cls = _features_cls - mus[cls]
72
- covs[cls] = (
73
- self.op.matmul(self.op.transpose(_zero_f_cls), _zero_f_cls)
72
+ cov_cls = (
73
+ self.op.matmul(self.op.t(_zero_f_cls), _zero_f_cls)
74
74
  / _zero_f_cls.shape[0]
75
75
  )
76
+ if mean_cov is None:
77
+ mean_cov = (len(_features_cls) / len(features[0])) * cov_cls
78
+ else:
79
+ mean_cov += (len(_features_cls) / len(features[0])) * cov_cls
76
80
 
77
- # mean cov and its inverse
78
- mean_cov = self.op.mean(self.op.stack(list(covs.values())), dim=0)
79
-
80
- self._mus = mus
81
+ # pseudo-inverse of the mean covariance matrix
81
82
  self._pinv_cov = self.op.pinv(mean_cov)
83
+ self._mus = mus
82
84
 
83
85
  def _score_tensor(self, inputs: TensorType) -> Tuple[np.ndarray]:
84
86
  """
@@ -99,7 +101,7 @@ class Mahalanobis(OODBaseDetector):
99
101
 
100
102
  # mahalanobis score on perturbed inputs
101
103
  features_p, _ = self.feature_extractor.predict_tensor(inputs_p)
102
- features_p = self.op.flatten(features_p)
104
+ features_p = self.op.flatten(features_p[0])
103
105
  gaussian_score_p = self._mahalanobis_score(features_p)
104
106
 
105
107
  # take the highest score for each sample
@@ -132,7 +134,7 @@ class Mahalanobis(OODBaseDetector):
132
134
  """
133
135
  # extract features
134
136
  out_features, _ = self.feature_extractor.predict(inputs, detach=False)
135
- out_features = self.op.flatten(out_features)
137
+ out_features = self.op.flatten(out_features[0])
136
138
  # get mahalanobis score for the class maximizing it
137
139
  gaussian_score = self._mahalanobis_score(out_features)
138
140
  log_probs_f = self.op.max(gaussian_score, dim=1)
@@ -167,7 +169,7 @@ class Mahalanobis(OODBaseDetector):
167
169
  # gaussian log prob density (mahalanobis)
168
170
  log_probs_f = -0.5 * self.op.diag(
169
171
  self.op.matmul(
170
- self.op.matmul(zero_f, self._pinv_cov), self.op.transpose(zero_f)
172
+ self.op.matmul(zero_f, self._pinv_cov), self.op.t(zero_f)
171
173
  )
172
174
  )
173
175
  gaussian_scores.append(self.op.reshape(log_probs_f, (-1, 1)))
oodeel/methods/mls.py CHANGED
@@ -53,11 +53,19 @@ class MLS(OODBaseDetector):
53
53
  self,
54
54
  output_activation: str = "linear",
55
55
  use_react: bool = False,
56
+ use_scale: bool = False,
57
+ use_ash: bool = False,
56
58
  react_quantile: float = 0.8,
59
+ scale_percentile: float = 0.85,
60
+ ash_percentile: float = 0.90,
57
61
  ):
58
62
  super().__init__(
59
63
  use_react=use_react,
64
+ use_scale=use_scale,
65
+ use_ash=use_ash,
60
66
  react_quantile=react_quantile,
67
+ scale_percentile=scale_percentile,
68
+ ash_percentile=ash_percentile,
61
69
  )
62
70
  self.output_activation = output_activation
63
71
 
oodeel/methods/odin.py CHANGED
@@ -48,12 +48,20 @@ class ODIN(OODBaseDetector):
48
48
  temperature: float = 1000,
49
49
  noise: float = 0.014,
50
50
  use_react: bool = False,
51
+ use_scale: bool = False,
52
+ use_ash: bool = False,
51
53
  react_quantile: float = 0.8,
54
+ scale_percentile: float = 0.85,
55
+ ash_percentile: float = 0.90,
52
56
  ):
53
57
  self.temperature = temperature
54
58
  super().__init__(
55
59
  use_react=use_react,
60
+ use_scale=use_scale,
61
+ use_ash=use_ash,
56
62
  react_quantile=react_quantile,
63
+ scale_percentile=scale_percentile,
64
+ ash_percentile=ash_percentile,
57
65
  )
58
66
  self.noise = noise
59
67
 
oodeel/methods/rmds.py ADDED
@@ -0,0 +1,122 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright IRT Antoine de Saint Exupéry et Université Paul Sabatier Toulouse III - All
3
+ # rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
4
+ # CRIAQ and ANITI - https://www.deel.ai/
5
+ #
6
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
7
+ # of this software and associated documentation files (the "Software"), to deal
8
+ # in the Software without restriction, including without limitation the rights
9
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
+ # copies of the Software, and to permit persons to whom the Software is
11
+ # furnished to do so, subject to the following conditions:
12
+ #
13
+ # The above copyright notice and this permission notice shall be included in all
14
+ # copies or substantial portions of the Software.
15
+ #
16
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
+ # SOFTWARE.
23
+ import numpy as np
24
+
25
+ from ..types import DatasetType
26
+ from ..types import TensorType
27
+ from ..types import Tuple
28
+ from oodeel.methods.mahalanobis import Mahalanobis
29
+
30
+
31
+ class RMDS(Mahalanobis):
32
+ """
33
+ "A Simple Fix to Mahalanobis Distance for Improving Near-OOD Detection"
34
+ https://arxiv.org/abs/2106.09022
35
+
36
+ Args:
37
+ eps (float): magnitude for gradient based input perturbation.
38
+ Defaults to 0.02.
39
+ """
40
+
41
+ def __init__(self, eps: float = 0.002):
42
+ super().__init__(eps=eps)
43
+
44
+ def _fit_to_dataset(self, fit_dataset: DatasetType) -> None:
45
+ """
46
+ Constructs the per class means and the covariance matrix,
47
+ as well as the background mean and covariance matrix,
48
+ from ID data "fit_dataset".
49
+ The means and pseudo-inverses of the covariance matrices
50
+ will be used for RMDS score computation.
51
+
52
+ Args:
53
+ fit_dataset (Union[TensorType, DatasetType]): input dataset (ID)
54
+ """
55
+ # means and pseudo-inverse of the mean convariance matrix from Mahalanobis
56
+ super()._fit_to_dataset(fit_dataset)
57
+
58
+ # extract features
59
+ features, _ = self.feature_extractor.predict(fit_dataset)
60
+
61
+ # compute background mu and cov
62
+ _features_bg = self.op.flatten(features[0])
63
+ mu_bg = self.op.mean(_features_bg, dim=0)
64
+ _zero_f_bg = _features_bg - mu_bg
65
+ cov_bg = self.op.matmul(self.op.t(_zero_f_bg), _zero_f_bg) / _zero_f_bg.shape[0]
66
+
67
+ # background mu and pseudo-inverse of the mean covariance matrices
68
+ self._mu_bg = mu_bg
69
+ self._pinv_cov_bg = self.op.pinv(cov_bg)
70
+
71
+ def _score_tensor(self, inputs: TensorType) -> Tuple[np.ndarray]:
72
+ """
73
+ Computes an OOD score for input samples "inputs" based on the RMDS
74
+ distance with respect to the closest class-conditional Gaussian distribution,
75
+ and the background distribution.
76
+
77
+ Args:
78
+ inputs (TensorType): input samples
79
+
80
+ Returns:
81
+ Tuple[np.ndarray]: scores, logits
82
+ """
83
+ # input preprocessing (perturbation)
84
+ if self.eps > 0:
85
+ inputs_p = self._input_perturbation(inputs)
86
+ else:
87
+ inputs_p = inputs
88
+
89
+ # mahalanobis score on perturbed inputs
90
+ features_p, _ = self.feature_extractor.predict_tensor(inputs_p)
91
+ features_p = self.op.flatten(features_p[0])
92
+ gaussian_score_p = self._mahalanobis_score(features_p)
93
+
94
+ # background score on perturbed inputs
95
+ gaussian_score_bg = self._background_score(features_p)
96
+
97
+ # take the highest score for each sample
98
+ gaussian_score_corrected = self.op.max(
99
+ gaussian_score_p - gaussian_score_bg, dim=1
100
+ )
101
+ return -self.op.convert_to_numpy(gaussian_score_corrected)
102
+
103
+ def _background_score(self, out_features: TensorType) -> TensorType:
104
+ """
105
+ Mahalanobis distance-based background score. For each test sample, it computes
106
+ the log of the probability densities of some observations (assuming a
107
+ normal distribution) using the mahalanobis distance with respect to the
108
+ background distribution.
109
+
110
+ Args:
111
+ out_features (TensorType): test samples features
112
+
113
+ Returns:
114
+ TensorType: confidence scores (with respect to the background distribution)
115
+ """
116
+ zero_f = out_features - self._mu_bg
117
+ # gaussian log prob density (mahalanobis)
118
+ log_probs_f = -0.5 * self.op.diag(
119
+ self.op.matmul(self.op.matmul(zero_f, self._pinv_cov_bg), self.op.t(zero_f))
120
+ )
121
+ gaussian_score = self.op.reshape(log_probs_f, (-1, 1))
122
+ return gaussian_score