oodeel 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. oodeel/__init__.py +28 -0
  2. oodeel/aggregator/__init__.py +26 -0
  3. oodeel/aggregator/base.py +70 -0
  4. oodeel/aggregator/fisher.py +259 -0
  5. oodeel/aggregator/mean.py +72 -0
  6. oodeel/aggregator/std.py +86 -0
  7. oodeel/datasets/__init__.py +24 -0
  8. oodeel/datasets/data_handler.py +334 -0
  9. oodeel/datasets/deprecated/DEPRECATED_data_handler.py +236 -0
  10. oodeel/datasets/deprecated/DEPRECATED_ooddataset.py +330 -0
  11. oodeel/datasets/deprecated/DEPRECATED_tf_data_handler.py +671 -0
  12. oodeel/datasets/deprecated/DEPRECATED_torch_data_handler.py +769 -0
  13. oodeel/datasets/deprecated/__init__.py +31 -0
  14. oodeel/datasets/tf_data_handler.py +600 -0
  15. oodeel/datasets/torch_data_handler.py +672 -0
  16. oodeel/eval/__init__.py +22 -0
  17. oodeel/eval/metrics.py +218 -0
  18. oodeel/eval/plots/__init__.py +27 -0
  19. oodeel/eval/plots/features.py +345 -0
  20. oodeel/eval/plots/metrics.py +118 -0
  21. oodeel/eval/plots/plotly.py +162 -0
  22. oodeel/extractor/__init__.py +35 -0
  23. oodeel/extractor/feature_extractor.py +187 -0
  24. oodeel/extractor/hf_torch_feature_extractor.py +184 -0
  25. oodeel/extractor/keras_feature_extractor.py +409 -0
  26. oodeel/extractor/torch_feature_extractor.py +506 -0
  27. oodeel/methods/__init__.py +47 -0
  28. oodeel/methods/base.py +570 -0
  29. oodeel/methods/dknn.py +185 -0
  30. oodeel/methods/energy.py +119 -0
  31. oodeel/methods/entropy.py +113 -0
  32. oodeel/methods/gen.py +113 -0
  33. oodeel/methods/gram.py +274 -0
  34. oodeel/methods/mahalanobis.py +209 -0
  35. oodeel/methods/mls.py +113 -0
  36. oodeel/methods/odin.py +109 -0
  37. oodeel/methods/rmds.py +172 -0
  38. oodeel/methods/she.py +159 -0
  39. oodeel/methods/vim.py +273 -0
  40. oodeel/preprocess/__init__.py +31 -0
  41. oodeel/preprocess/tf_preprocess.py +95 -0
  42. oodeel/preprocess/torch_preprocess.py +97 -0
  43. oodeel/types/__init__.py +75 -0
  44. oodeel/utils/__init__.py +38 -0
  45. oodeel/utils/general_utils.py +97 -0
  46. oodeel/utils/operator.py +253 -0
  47. oodeel/utils/tf_operator.py +269 -0
  48. oodeel/utils/tf_training_tools.py +219 -0
  49. oodeel/utils/torch_operator.py +292 -0
  50. oodeel/utils/torch_training_tools.py +303 -0
  51. oodeel-0.4.0.dist-info/METADATA +409 -0
  52. oodeel-0.4.0.dist-info/RECORD +63 -0
  53. oodeel-0.4.0.dist-info/WHEEL +5 -0
  54. oodeel-0.4.0.dist-info/licenses/LICENSE +21 -0
  55. oodeel-0.4.0.dist-info/top_level.txt +2 -0
  56. tests/__init__.py +22 -0
  57. tests/tests_tensorflow/__init__.py +37 -0
  58. tests/tests_tensorflow/tf_methods_utils.py +140 -0
  59. tests/tests_tensorflow/tools_tf.py +86 -0
  60. tests/tests_torch/__init__.py +38 -0
  61. tests/tests_torch/tools_torch.py +151 -0
  62. tests/tests_torch/torch_methods_utils.py +148 -0
  63. tests/tools_operator.py +153 -0
oodeel/methods/vim.py ADDED
@@ -0,0 +1,273 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright IRT Antoine de Saint Exupéry et Université Paul Sabatier Toulouse III - All
3
+ # rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
4
+ # CRIAQ and ANITI - https://www.deel.ai/
5
+ #
6
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
7
+ # of this software and associated documentation files (the "Software"), to deal
8
+ # in the Software without restriction, including without limitation the rights
9
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
+ # copies of the Software, and to permit persons to whom the Software is
11
+ # furnished to do so, subject to the following conditions:
12
+ #
13
+ # The above copyright notice and this permission notice shall be included in all
14
+ # copies or substantial portions of the Software.
15
+ #
16
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
+ # SOFTWARE.
23
+ from typing import List
24
+ from typing import Optional
25
+ from typing import Union
26
+
27
+ import matplotlib.pyplot as plt
28
+ import numpy as np
29
+ from scipy.special import logsumexp
30
+
31
+ from ..aggregator import BaseAggregator
32
+ from ..types import TensorType
33
+ from .base import FeatureBasedDetector
34
+
35
+
36
+ class VIM(FeatureBasedDetector):
37
+ """
38
+ Virtual Matching Logit (VIM) out-of-distribution detector.
39
+
40
+ Implements the VIM method from https://arxiv.org/abs/2203.10807:
41
+ 1. Energy-based score: log-sum-exp over classifier logits.
42
+ 2. PCA residual score: distance of features from a low-dimensional subspace.
43
+
44
+ Supports multiple feature layers by computing PCA on each layer's features
45
+ and combining per-layer VIM scores via an optional aggregator.
46
+
47
+ Args:
48
+ princ_dims (Union[int, float]):
49
+ - If int: exact number of principal components to consider per layer.
50
+ - If a float, it must be in [0,1), it represents the ratio of explained
51
+ variance to consider to determine the number of principal components per
52
+ layer. Defaults to 0.99.
53
+ pca_origin (str): Method to compute the subspace origin (center).
54
+ - "pseudo": (Only for the final layer (ID -1), other layers will use
55
+ empirical mean!)
56
+ Weights are used to compute the pseudo-center W⁻¹ b, where W is the
57
+ weight matrix of the final linear layer (ID -1) and b is the bias
58
+ vector.
59
+ - "center": use the empirical mean of features. Defaults to "center".
60
+ aggregator (Optional[BaseAggregator]): Combines multi-layer VIM scores.
61
+ If None and more than one layer is used, defaults to
62
+ StdNormalizedAggregator.
63
+ """
64
+
65
+ def __init__(
66
+ self,
67
+ princ_dims: Union[int, float] = 0.99,
68
+ pca_origin: str = "center",
69
+ aggregator: Optional[BaseAggregator] = None,
70
+ **kwargs,
71
+ ):
72
+ super().__init__(aggregator=aggregator, **kwargs)
73
+ # Store PCA settings and optional aggregator
74
+ self._princ_dims = princ_dims
75
+ self.pca_origin = pca_origin
76
+ # Containers for per-layer PCA parameters
77
+ self.centers: List[TensorType] = []
78
+ self.residual_projections: List[TensorType] = []
79
+ self.eig_vals_list: List[np.ndarray] = []
80
+ self.princ_dims_list: List[int] = []
81
+ self.alphas: List[float] = []
82
+
83
+ # === Per-layer logic ===
84
+ def _fit_layer(
85
+ self,
86
+ layer_id: int,
87
+ layer_features: np.ndarray,
88
+ info: dict,
89
+ **kwargs,
90
+ ) -> None:
91
+ """Compute PCA statistics for a single feature layer.
92
+
93
+ The PCA subspace is estimated from the layer activations of the
94
+ in-distribution training data. The ratio between the norm of the
95
+ residual component and the maximum logit defines the :math:`\alpha`
96
+ scaling used at inference.
97
+
98
+ Args:
99
+ layer_id: Index of the feature layer.
100
+ layer_features: Features for this layer with shape `[N, D]`.
101
+ info: Dictionary containing at least the training logits.
102
+ """
103
+
104
+ logits_train = info["logits"]
105
+ train_maxlogit = np.max(logits_train, axis=-1)
106
+
107
+ feat = self.op.flatten(self.op.from_numpy(layer_features))
108
+ N, D = feat.shape
109
+
110
+ if layer_id == -1 and self.pca_origin == "pseudo":
111
+ W, b = self.feature_extractor.get_weights(-1)
112
+ W_mat = (
113
+ self.op.t(self.op.from_numpy(W))
114
+ if self.backend == "tensorflow"
115
+ else self.op.from_numpy(W)
116
+ )
117
+ b_vec = self.op.from_numpy(b.reshape(-1, 1))
118
+ center = -self.op.reshape(self.op.matmul(self.op.pinv(W_mat), b_vec), (-1,))
119
+ else:
120
+ center = self.op.mean(feat, dim=0)
121
+
122
+ centered = feat - center
123
+ cov = self.op.matmul(self.op.t(centered), centered) / N
124
+ eig_vals, eig_vecs = self.op.eigh(cov)
125
+ eig_vals_np = self.op.convert_to_numpy(eig_vals)
126
+
127
+ if isinstance(self._princ_dims, int):
128
+ assert (
129
+ 0 < self._princ_dims < D
130
+ ), f"princ_dims ({self._princ_dims}) must be in 1..{D-1}"
131
+ princ_dim = self._princ_dims
132
+ else:
133
+ assert (
134
+ 0 < self._princ_dims <= 1
135
+ ), f"princ_dims ratio ({self._princ_dims}) must be in (0,1]"
136
+ explained_variance = np.cumsum(np.flip(eig_vals_np)) / np.sum(eig_vals_np)
137
+ princ_dim = np.where(explained_variance > self._princ_dims)[0][0]
138
+
139
+ proj = eig_vecs[:, : D - princ_dim]
140
+
141
+ residual_norms = self._compute_residual_norms(feat, center, proj)
142
+ alpha = float(np.mean(train_maxlogit) / np.mean(residual_norms))
143
+
144
+ self.centers.append(center)
145
+ self.residual_projections.append(proj)
146
+ self.eig_vals_list.append(eig_vals_np)
147
+ self.princ_dims_list.append(princ_dim)
148
+ self.alphas.append(alpha)
149
+
150
+ def _score_layer(
151
+ self,
152
+ layer_id: int,
153
+ layer_features: TensorType,
154
+ info: dict,
155
+ fit: bool = False,
156
+ **kwargs,
157
+ ) -> np.ndarray:
158
+ """Compute the VIM score associated with one feature layer.
159
+
160
+ Args:
161
+ layer_id: Index of the processed layer.
162
+ layer_features: Features from the current layer.
163
+ info: Dictionary containing the logits of the batch.
164
+ fit: Whether scoring is performed during fitting. Unused here.
165
+
166
+ Returns:
167
+ np.ndarray: VIM scores for the layer.
168
+ """
169
+ energy = logsumexp(self.op.convert_to_numpy(info["logits"]), axis=-1)
170
+ flat = self.op.flatten(layer_features)
171
+ resid = self._compute_residual_score_tensor(flat, layer_id)
172
+ return self.alphas[layer_id] * resid - energy
173
+
174
+ # === Internal utilities ===
175
+ def _compute_residual_score_tensor(
176
+ self, features: TensorType, layer_idx: int
177
+ ) -> np.ndarray:
178
+ """
179
+ Compute the residual norm of features orthogonal to the principal subspace.
180
+
181
+ Args:
182
+ features: Flattened feature matrix [N, D].
183
+ layer_idx: Index of the feature layer.
184
+ Returns:
185
+ Numpy array of residual norms (shape [N]).
186
+ """
187
+ center = self.centers[layer_idx]
188
+ proj = self.residual_projections[layer_idx]
189
+ return self._compute_residual_norms(features, center, proj)
190
+
191
+ def _compute_residual_norms(
192
+ self, features: TensorType, center: TensorType, proj: TensorType
193
+ ) -> np.ndarray:
194
+ """Compute residual norms for the provided features.
195
+
196
+ Args:
197
+ features: Flattened feature matrix `[N, D]`.
198
+ center: Center of the PCA subspace for the layer.
199
+ proj: Projection matrix onto the residual subspace.
200
+
201
+ Returns:
202
+ Numpy array of residual norms of shape `[N]`.
203
+ """
204
+ coords = self.op.matmul(features - center, proj)
205
+ norms = self.op.norm(coords, dim=-1)
206
+ return self.op.convert_to_numpy(norms)
207
+
208
+ # === Visualization ===
209
+ def plot_spectrum(self) -> None:
210
+ """
211
+ Visualize residual explained variance per layer vs. principal dimensions being
212
+ excluded.
213
+
214
+ If princ_dims is int: x-axis = number of components [0..D-1].
215
+ If princ_dims is float: x-axis = ratio [0..1].
216
+
217
+ Draws:
218
+ - Curve: residual explained variance vs. number of principal components.
219
+ - Dashed line: selected princ_dims marker.
220
+ """
221
+ is_ratio = isinstance(self._princ_dims, float)
222
+
223
+ for idx, eig_vals in enumerate(self.eig_vals_list):
224
+ D = eig_vals.size
225
+ # Compute residual explained variance curve
226
+ residual_cumsum = np.cumsum(eig_vals)[::-1]
227
+ residual_explained = residual_cumsum / residual_cumsum.max()
228
+
229
+ # Choose x-axis scale and marker
230
+ if is_ratio:
231
+ x = np.linspace(0, 1, D)
232
+ marker = self.princ_dims_list[idx] / D
233
+ xlabel = "Ratio of principal components"
234
+ else:
235
+ x = np.arange(D)
236
+ marker = self.princ_dims_list[idx]
237
+ xlabel = "Number of principal components"
238
+
239
+ (line,) = plt.plot(x, residual_explained, label=f"layer {idx}")
240
+ plt.axvline(
241
+ x=marker,
242
+ linestyle="--",
243
+ color=line.get_color(),
244
+ label=f"layer {idx} marker",
245
+ )
246
+
247
+ plt.xlabel(xlabel)
248
+ plt.ylabel("Residual explained variance")
249
+ plt.legend()
250
+ plt.tight_layout()
251
+ plt.show()
252
+
253
+ # === Properties ===
254
+ @property
255
+ def requires_to_fit_dataset(self) -> bool:
256
+ """
257
+ Whether an OOD detector needs a `fit_dataset` argument in the fit function.
258
+
259
+ Returns:
260
+ bool: True if `fit_dataset` is required else False.
261
+ """
262
+ return True
263
+
264
+ @property
265
+ def requires_internal_features(self) -> bool:
266
+ """
267
+ Whether an OOD detector acts on internal model features.
268
+
269
+ Returns:
270
+ bool: True if the detector perform computations on an intermediate layer
271
+ else False.
272
+ """
273
+ return True
@@ -0,0 +1,31 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright IRT Antoine de Saint Exupéry et Université Paul Sabatier Toulouse III - All
3
+ # rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
4
+ # CRIAQ and ANITI - https://www.deel.ai/
5
+ #
6
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
7
+ # of this software and associated documentation files (the "Software"), to deal
8
+ # in the Software without restriction, including without limitation the rights
9
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
+ # copies of the Software, and to permit persons to whom the Software is
11
+ # furnished to do so, subject to the following conditions:
12
+ #
13
+ # The above copyright notice and this permission notice shall be included in all
14
+ # copies or substantial portions of the Software.
15
+ #
16
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
+ # SOFTWARE.
23
+ try:
24
+ from .tf_preprocess import TFRandomPatchPermutation
25
+ except ImportError:
26
+ pass
27
+
28
+ try:
29
+ from .torch_preprocess import TorchRandomPatchPermutation
30
+ except ImportError:
31
+ pass
@@ -0,0 +1,95 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright IRT Antoine de Saint Exupéry et Université Paul Sabatier Toulouse III - All
3
+ # rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
4
+ # CRIAQ and ANITI - https://www.deel.ai/
5
+ #
6
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
7
+ # of this software and associated documentation files (the "Software"), to deal
8
+ # in the Software without restriction, including without limitation the rights
9
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
+ # copies of the Software, and to permit persons to whom the Software is
11
+ # furnished to do so, subject to the following conditions:
12
+ #
13
+ # The above copyright notice and this permission notice shall be included in all
14
+ # copies or substantial portions of the Software.
15
+ #
16
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
+ # SOFTWARE.
23
+ import numpy as np
24
+ import tensorflow as tf
25
+
26
+ from ..types import Optional
27
+ from ..types import Tuple
28
+
29
+
30
+ class TFRandomPatchPermutation:
31
+ def __init__(self, patch_size: Tuple[int] = (8, 8)):
32
+ """Randomly permute the patches of an image. This transformation is used in NMD
33
+ paper to artificially craft OOD data from ID images.
34
+
35
+ Source (NMD paper):
36
+ "Neural Mean Discrepancy for Efficient Out-of-Distribution Detection"
37
+ [link](https://arxiv.org/pdf/2104.11408.pdf)
38
+
39
+ Args:
40
+ patch_size (Tuple[int], optional): Patch dimensions (h, w), should be
41
+ divisors of the image dimensions (H, W). Defaults to (8, 8).
42
+ """
43
+ self.patch_size = patch_size
44
+
45
+ def __call__(self, tensor: tf.Tensor, seed: Optional[int] = None):
46
+ """Apply random patch permutation.
47
+
48
+ Args:
49
+ tensor (tf.Tensor): Tensor of shape [H, W, C]
50
+ seed (Optinal[int]): Seed number to set for the permutation if not None.
51
+
52
+ Returns:
53
+ tf.Tensor: Transformed tensor.
54
+ """
55
+ h, w = self.patch_size
56
+ H, W, C = tensor.shape
57
+ tensor_ = tensor
58
+
59
+ # raise warning if patch dimensions are not divisors of image dimensions
60
+ if H % h != 0:
61
+ print(
62
+ f"Warning! Patch height ({h}) should be a divisor of the image height"
63
+ + f" ({H}). Zero padding will be added to get the correct output shape."
64
+ )
65
+ tensor_ = tensor[: -(H % h)]
66
+ if W % w != 0:
67
+ print(
68
+ f"Warning! Patch width ({w}) should be a divisor of the image width"
69
+ + f" ({W}). Zero padding will be added to get the correct output shape."
70
+ )
71
+ tensor_ = tensor_[:, : -(W % w)]
72
+
73
+ # === patch permutation ===
74
+ # divide the batch of images into non-overlapping patches
75
+ # => [num_patches, h * w, C]
76
+ u = tf.transpose(
77
+ tf.reshape(tensor_, (H // h, h, W // w, w, C)), (0, 2, 1, 3, 4)
78
+ )
79
+ u = tf.reshape(u, (-1, h * w, C))
80
+
81
+ # permute the patches of each image in the batch
82
+ # => [num_patches, h * w, C]
83
+ # Note: we use numpy rng for deterministic index shuffling because
84
+ # `tf.stateless_shuffle` is still experimental
85
+ g = np.random.default_rng(seed=seed)
86
+ indices = np.arange(u.shape[0])
87
+ g.shuffle(indices)
88
+ pu = tf.gather(u, indices)
89
+
90
+ # fold the permuted patches back together
91
+ # => [H, W, C]
92
+ f = tf.transpose(tf.reshape(pu, (H // h, W // w, h, w, C)), (0, 2, 1, 3, 4))
93
+ f = tf.reshape(f, tensor_.shape)
94
+ f = tf.pad(f, tf.constant([[0, H % h], [0, W % w], [0, 0]]))
95
+ return f
@@ -0,0 +1,97 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright IRT Antoine de Saint Exupéry et Université Paul Sabatier Toulouse III - All
3
+ # rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
4
+ # CRIAQ and ANITI - https://www.deel.ai/
5
+ #
6
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
7
+ # of this software and associated documentation files (the "Software"), to deal
8
+ # in the Software without restriction, including without limitation the rights
9
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
+ # copies of the Software, and to permit persons to whom the Software is
11
+ # furnished to do so, subject to the following conditions:
12
+ #
13
+ # The above copyright notice and this permission notice shall be included in all
14
+ # copies or substantial portions of the Software.
15
+ #
16
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
+ # SOFTWARE.
23
+ import torch
24
+ import torch.nn.functional as F
25
+
26
+ from ..types import Optional
27
+ from ..types import Tuple
28
+
29
+
30
+ class TorchRandomPatchPermutation:
31
+ def __init__(self, patch_size: Tuple[int] = (8, 8)):
32
+ """Randomly permute the patches of an image. This transformation is used in NMD
33
+ paper to artificially craft OOD data from ID images.
34
+
35
+ Source (NMD paper):
36
+ "Neural Mean Discrepancy for Efficient Out-of-Distribution Detection"
37
+ [link](https://arxiv.org/pdf/2104.11408.pdf)
38
+
39
+ Args:
40
+ patch_size (Tuple[int], optional): Patch dimensions (h, w), should be
41
+ divisors of the image dimensions (H, W). Defaults to (8, 8).
42
+ """
43
+ self.patch_size = patch_size
44
+
45
+ def __call__(self, tensor: torch.Tensor, seed: Optional[int] = None):
46
+ """Apply random patch permutation.
47
+
48
+ Args:
49
+ tensor (torch.Tensor): Tensor of shape [C, H, W]
50
+ seed (Optinal[int]): Seed number to set for the permutation if not None.
51
+
52
+ Returns:
53
+ torch.Tensor: Transformed tensor.
54
+ """
55
+ h, w = self.patch_size
56
+ H, W, _ = tensor.shape
57
+
58
+ # set generator if seed is not None
59
+ g = None
60
+ if seed is not None:
61
+ g = torch.Generator(device=tensor.device)
62
+ g.manual_seed(seed)
63
+
64
+ # raise warning if patch dimensions are not divisors of image dimensions
65
+ if H % h != 0:
66
+ print(
67
+ f"Warning! Patch height ({h}) should be a divisor of the image height"
68
+ + f" ({H}). Zero padding will be added to get the correct output shape."
69
+ )
70
+ if W % w != 0:
71
+ print(
72
+ f"Warning! Patch width ({w}) should be a divisor of the image width"
73
+ + f" ({W}). Zero padding will be added to get the correct output shape."
74
+ )
75
+
76
+ # === patch permutation ===
77
+ # [C, H, W] => [1, C, H, W]
78
+ x = tensor.unsqueeze(0)
79
+ # divide the batch of images into non-overlapping patches
80
+ # => [1, h * w, num_patches]
81
+ u = F.unfold(x, kernel_size=self.patch_size, stride=self.patch_size, padding=0)
82
+ # permute the patches of each image in the batch
83
+ # => [1, h * w, num_patches]
84
+ pu = torch.cat(
85
+ [b_[:, torch.randperm(b_.shape[-1], generator=g)][None, ...] for b_ in u],
86
+ dim=0,
87
+ )
88
+ # fold the permuted patches back together
89
+ # => [1, C, H, W]
90
+ f = F.fold(
91
+ pu,
92
+ x.shape[-2:],
93
+ kernel_size=self.patch_size,
94
+ stride=self.patch_size,
95
+ padding=0,
96
+ )
97
+ return f.squeeze(0)
@@ -0,0 +1,75 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright IRT Antoine de Saint Exupéry et Université Paul Sabatier Toulouse III - All
3
+ # rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
4
+ # CRIAQ and ANITI - https://www.deel.ai/
5
+ #
6
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
7
+ # of this software and associated documentation files (the "Software"), to deal
8
+ # in the Software without restriction, including without limitation the rights
9
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
+ # copies of the Software, and to permit persons to whom the Software is
11
+ # furnished to do so, subject to the following conditions:
12
+ #
13
+ # The above copyright notice and this permission notice shall be included in all
14
+ # copies or substantial portions of the Software.
15
+ #
16
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
+ # SOFTWARE.
23
+ """
24
+ Typing module
25
+ """
26
+ from typing import Any
27
+ from typing import Callable
28
+ from typing import Dict
29
+ from typing import Iterable
30
+ from typing import List
31
+ from typing import Optional
32
+ from typing import Sequence
33
+ from typing import Tuple
34
+ from typing import Type
35
+ from typing import TypeVar
36
+ from typing import Union
37
+
38
+ import numpy as np
39
+
40
+ avail_lib = []
41
+ try:
42
+ import tensorflow as tf
43
+
44
+ avail_lib.append("tensorflow")
45
+ except ImportError:
46
+ pass
47
+
48
+ try:
49
+ import torch
50
+
51
+ avail_lib.append("torch")
52
+ except ImportError:
53
+ pass
54
+
55
+
56
+ if len(avail_lib) == 2:
57
+ DatasetType = Union[tf.data.Dataset, torch.utils.data.DataLoader]
58
+ TensorType = Union[tf.Tensor, torch.Tensor, np.ndarray]
59
+ ItemType = Union[
60
+ tf.Tensor,
61
+ torch.Tensor,
62
+ np.ndarray,
63
+ tuple,
64
+ list,
65
+ dict,
66
+ ]
67
+
68
+ elif "tensorflow" in avail_lib:
69
+ DatasetType = Type[tf.data.Dataset]
70
+ TensorType = Union[tf.Tensor, np.ndarray]
71
+ ItemType = Union[tf.Tensor, np.ndarray, tuple, list, dict]
72
+ elif "torch" in avail_lib:
73
+ DatasetType = Type[torch.utils.data.DataLoader]
74
+ TensorType = Union[torch.Tensor, np.ndarray]
75
+ ItemType = Union[torch.Tensor, np.ndarray, tuple, list, dict]
@@ -0,0 +1,38 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright IRT Antoine de Saint Exupéry et Université Paul Sabatier Toulouse III - All
3
+ # rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
4
+ # CRIAQ and ANITI - https://www.deel.ai/
5
+ #
6
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
7
+ # of this software and associated documentation files (the "Software"), to deal
8
+ # in the Software without restriction, including without limitation the rights
9
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
+ # copies of the Software, and to permit persons to whom the Software is
11
+ # furnished to do so, subject to the following conditions:
12
+ #
13
+ # The above copyright notice and this permission notice shall be included in all
14
+ # copies or substantial portions of the Software.
15
+ #
16
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
+ # SOFTWARE.
23
+ from .general_utils import import_backend_specific_stuff
24
+ from .general_utils import is_from
25
+
26
+ try:
27
+ import tensorflow as tf
28
+ from .tf_operator import TFOperator
29
+ from .tf_training_tools import train_tf_model
30
+ except ImportError:
31
+ pass
32
+
33
+ try:
34
+ import torch
35
+ from .torch_operator import TorchOperator
36
+ from .torch_training_tools import train_torch_model
37
+ except ImportError:
38
+ pass