oodeel 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- oodeel/__init__.py +28 -0
- oodeel/aggregator/__init__.py +26 -0
- oodeel/aggregator/base.py +70 -0
- oodeel/aggregator/fisher.py +259 -0
- oodeel/aggregator/mean.py +72 -0
- oodeel/aggregator/std.py +86 -0
- oodeel/datasets/__init__.py +24 -0
- oodeel/datasets/data_handler.py +334 -0
- oodeel/datasets/deprecated/DEPRECATED_data_handler.py +236 -0
- oodeel/datasets/deprecated/DEPRECATED_ooddataset.py +330 -0
- oodeel/datasets/deprecated/DEPRECATED_tf_data_handler.py +671 -0
- oodeel/datasets/deprecated/DEPRECATED_torch_data_handler.py +769 -0
- oodeel/datasets/deprecated/__init__.py +31 -0
- oodeel/datasets/tf_data_handler.py +600 -0
- oodeel/datasets/torch_data_handler.py +672 -0
- oodeel/eval/__init__.py +22 -0
- oodeel/eval/metrics.py +218 -0
- oodeel/eval/plots/__init__.py +27 -0
- oodeel/eval/plots/features.py +345 -0
- oodeel/eval/plots/metrics.py +118 -0
- oodeel/eval/plots/plotly.py +162 -0
- oodeel/extractor/__init__.py +35 -0
- oodeel/extractor/feature_extractor.py +187 -0
- oodeel/extractor/hf_torch_feature_extractor.py +184 -0
- oodeel/extractor/keras_feature_extractor.py +409 -0
- oodeel/extractor/torch_feature_extractor.py +506 -0
- oodeel/methods/__init__.py +47 -0
- oodeel/methods/base.py +570 -0
- oodeel/methods/dknn.py +185 -0
- oodeel/methods/energy.py +119 -0
- oodeel/methods/entropy.py +113 -0
- oodeel/methods/gen.py +113 -0
- oodeel/methods/gram.py +274 -0
- oodeel/methods/mahalanobis.py +209 -0
- oodeel/methods/mls.py +113 -0
- oodeel/methods/odin.py +109 -0
- oodeel/methods/rmds.py +172 -0
- oodeel/methods/she.py +159 -0
- oodeel/methods/vim.py +273 -0
- oodeel/preprocess/__init__.py +31 -0
- oodeel/preprocess/tf_preprocess.py +95 -0
- oodeel/preprocess/torch_preprocess.py +97 -0
- oodeel/types/__init__.py +75 -0
- oodeel/utils/__init__.py +38 -0
- oodeel/utils/general_utils.py +97 -0
- oodeel/utils/operator.py +253 -0
- oodeel/utils/tf_operator.py +269 -0
- oodeel/utils/tf_training_tools.py +219 -0
- oodeel/utils/torch_operator.py +292 -0
- oodeel/utils/torch_training_tools.py +303 -0
- oodeel-0.4.0.dist-info/METADATA +409 -0
- oodeel-0.4.0.dist-info/RECORD +63 -0
- oodeel-0.4.0.dist-info/WHEEL +5 -0
- oodeel-0.4.0.dist-info/licenses/LICENSE +21 -0
- oodeel-0.4.0.dist-info/top_level.txt +2 -0
- tests/__init__.py +22 -0
- tests/tests_tensorflow/__init__.py +37 -0
- tests/tests_tensorflow/tf_methods_utils.py +140 -0
- tests/tests_tensorflow/tools_tf.py +86 -0
- tests/tests_torch/__init__.py +38 -0
- tests/tests_torch/tools_torch.py +151 -0
- tests/tests_torch/torch_methods_utils.py +148 -0
- tests/tools_operator.py +153 -0
oodeel/methods/vim.py
ADDED
|
@@ -0,0 +1,273 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Copyright IRT Antoine de Saint Exupéry et Université Paul Sabatier Toulouse III - All
|
|
3
|
+
# rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
|
|
4
|
+
# CRIAQ and ANITI - https://www.deel.ai/
|
|
5
|
+
#
|
|
6
|
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
7
|
+
# of this software and associated documentation files (the "Software"), to deal
|
|
8
|
+
# in the Software without restriction, including without limitation the rights
|
|
9
|
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
10
|
+
# copies of the Software, and to permit persons to whom the Software is
|
|
11
|
+
# furnished to do so, subject to the following conditions:
|
|
12
|
+
#
|
|
13
|
+
# The above copyright notice and this permission notice shall be included in all
|
|
14
|
+
# copies or substantial portions of the Software.
|
|
15
|
+
#
|
|
16
|
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
17
|
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
18
|
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
19
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
20
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
21
|
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
22
|
+
# SOFTWARE.
|
|
23
|
+
from typing import List
|
|
24
|
+
from typing import Optional
|
|
25
|
+
from typing import Union
|
|
26
|
+
|
|
27
|
+
import matplotlib.pyplot as plt
|
|
28
|
+
import numpy as np
|
|
29
|
+
from scipy.special import logsumexp
|
|
30
|
+
|
|
31
|
+
from ..aggregator import BaseAggregator
|
|
32
|
+
from ..types import TensorType
|
|
33
|
+
from .base import FeatureBasedDetector
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class VIM(FeatureBasedDetector):
|
|
37
|
+
"""
|
|
38
|
+
Virtual Matching Logit (VIM) out-of-distribution detector.
|
|
39
|
+
|
|
40
|
+
Implements the VIM method from https://arxiv.org/abs/2203.10807:
|
|
41
|
+
1. Energy-based score: log-sum-exp over classifier logits.
|
|
42
|
+
2. PCA residual score: distance of features from a low-dimensional subspace.
|
|
43
|
+
|
|
44
|
+
Supports multiple feature layers by computing PCA on each layer's features
|
|
45
|
+
and combining per-layer VIM scores via an optional aggregator.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
princ_dims (Union[int, float]):
|
|
49
|
+
- If int: exact number of principal components to consider per layer.
|
|
50
|
+
- If a float, it must be in [0,1), it represents the ratio of explained
|
|
51
|
+
variance to consider to determine the number of principal components per
|
|
52
|
+
layer. Defaults to 0.99.
|
|
53
|
+
pca_origin (str): Method to compute the subspace origin (center).
|
|
54
|
+
- "pseudo": (Only for the final layer (ID -1), other layers will use
|
|
55
|
+
empirical mean!)
|
|
56
|
+
Weights are used to compute the pseudo-center W⁻¹ b, where W is the
|
|
57
|
+
weight matrix of the final linear layer (ID -1) and b is the bias
|
|
58
|
+
vector.
|
|
59
|
+
- "center": use the empirical mean of features. Defaults to "center".
|
|
60
|
+
aggregator (Optional[BaseAggregator]): Combines multi-layer VIM scores.
|
|
61
|
+
If None and more than one layer is used, defaults to
|
|
62
|
+
StdNormalizedAggregator.
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
def __init__(
|
|
66
|
+
self,
|
|
67
|
+
princ_dims: Union[int, float] = 0.99,
|
|
68
|
+
pca_origin: str = "center",
|
|
69
|
+
aggregator: Optional[BaseAggregator] = None,
|
|
70
|
+
**kwargs,
|
|
71
|
+
):
|
|
72
|
+
super().__init__(aggregator=aggregator, **kwargs)
|
|
73
|
+
# Store PCA settings and optional aggregator
|
|
74
|
+
self._princ_dims = princ_dims
|
|
75
|
+
self.pca_origin = pca_origin
|
|
76
|
+
# Containers for per-layer PCA parameters
|
|
77
|
+
self.centers: List[TensorType] = []
|
|
78
|
+
self.residual_projections: List[TensorType] = []
|
|
79
|
+
self.eig_vals_list: List[np.ndarray] = []
|
|
80
|
+
self.princ_dims_list: List[int] = []
|
|
81
|
+
self.alphas: List[float] = []
|
|
82
|
+
|
|
83
|
+
# === Per-layer logic ===
|
|
84
|
+
def _fit_layer(
|
|
85
|
+
self,
|
|
86
|
+
layer_id: int,
|
|
87
|
+
layer_features: np.ndarray,
|
|
88
|
+
info: dict,
|
|
89
|
+
**kwargs,
|
|
90
|
+
) -> None:
|
|
91
|
+
"""Compute PCA statistics for a single feature layer.
|
|
92
|
+
|
|
93
|
+
The PCA subspace is estimated from the layer activations of the
|
|
94
|
+
in-distribution training data. The ratio between the norm of the
|
|
95
|
+
residual component and the maximum logit defines the :math:`\alpha`
|
|
96
|
+
scaling used at inference.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
layer_id: Index of the feature layer.
|
|
100
|
+
layer_features: Features for this layer with shape `[N, D]`.
|
|
101
|
+
info: Dictionary containing at least the training logits.
|
|
102
|
+
"""
|
|
103
|
+
|
|
104
|
+
logits_train = info["logits"]
|
|
105
|
+
train_maxlogit = np.max(logits_train, axis=-1)
|
|
106
|
+
|
|
107
|
+
feat = self.op.flatten(self.op.from_numpy(layer_features))
|
|
108
|
+
N, D = feat.shape
|
|
109
|
+
|
|
110
|
+
if layer_id == -1 and self.pca_origin == "pseudo":
|
|
111
|
+
W, b = self.feature_extractor.get_weights(-1)
|
|
112
|
+
W_mat = (
|
|
113
|
+
self.op.t(self.op.from_numpy(W))
|
|
114
|
+
if self.backend == "tensorflow"
|
|
115
|
+
else self.op.from_numpy(W)
|
|
116
|
+
)
|
|
117
|
+
b_vec = self.op.from_numpy(b.reshape(-1, 1))
|
|
118
|
+
center = -self.op.reshape(self.op.matmul(self.op.pinv(W_mat), b_vec), (-1,))
|
|
119
|
+
else:
|
|
120
|
+
center = self.op.mean(feat, dim=0)
|
|
121
|
+
|
|
122
|
+
centered = feat - center
|
|
123
|
+
cov = self.op.matmul(self.op.t(centered), centered) / N
|
|
124
|
+
eig_vals, eig_vecs = self.op.eigh(cov)
|
|
125
|
+
eig_vals_np = self.op.convert_to_numpy(eig_vals)
|
|
126
|
+
|
|
127
|
+
if isinstance(self._princ_dims, int):
|
|
128
|
+
assert (
|
|
129
|
+
0 < self._princ_dims < D
|
|
130
|
+
), f"princ_dims ({self._princ_dims}) must be in 1..{D-1}"
|
|
131
|
+
princ_dim = self._princ_dims
|
|
132
|
+
else:
|
|
133
|
+
assert (
|
|
134
|
+
0 < self._princ_dims <= 1
|
|
135
|
+
), f"princ_dims ratio ({self._princ_dims}) must be in (0,1]"
|
|
136
|
+
explained_variance = np.cumsum(np.flip(eig_vals_np)) / np.sum(eig_vals_np)
|
|
137
|
+
princ_dim = np.where(explained_variance > self._princ_dims)[0][0]
|
|
138
|
+
|
|
139
|
+
proj = eig_vecs[:, : D - princ_dim]
|
|
140
|
+
|
|
141
|
+
residual_norms = self._compute_residual_norms(feat, center, proj)
|
|
142
|
+
alpha = float(np.mean(train_maxlogit) / np.mean(residual_norms))
|
|
143
|
+
|
|
144
|
+
self.centers.append(center)
|
|
145
|
+
self.residual_projections.append(proj)
|
|
146
|
+
self.eig_vals_list.append(eig_vals_np)
|
|
147
|
+
self.princ_dims_list.append(princ_dim)
|
|
148
|
+
self.alphas.append(alpha)
|
|
149
|
+
|
|
150
|
+
def _score_layer(
|
|
151
|
+
self,
|
|
152
|
+
layer_id: int,
|
|
153
|
+
layer_features: TensorType,
|
|
154
|
+
info: dict,
|
|
155
|
+
fit: bool = False,
|
|
156
|
+
**kwargs,
|
|
157
|
+
) -> np.ndarray:
|
|
158
|
+
"""Compute the VIM score associated with one feature layer.
|
|
159
|
+
|
|
160
|
+
Args:
|
|
161
|
+
layer_id: Index of the processed layer.
|
|
162
|
+
layer_features: Features from the current layer.
|
|
163
|
+
info: Dictionary containing the logits of the batch.
|
|
164
|
+
fit: Whether scoring is performed during fitting. Unused here.
|
|
165
|
+
|
|
166
|
+
Returns:
|
|
167
|
+
np.ndarray: VIM scores for the layer.
|
|
168
|
+
"""
|
|
169
|
+
energy = logsumexp(self.op.convert_to_numpy(info["logits"]), axis=-1)
|
|
170
|
+
flat = self.op.flatten(layer_features)
|
|
171
|
+
resid = self._compute_residual_score_tensor(flat, layer_id)
|
|
172
|
+
return self.alphas[layer_id] * resid - energy
|
|
173
|
+
|
|
174
|
+
# === Internal utilities ===
|
|
175
|
+
def _compute_residual_score_tensor(
|
|
176
|
+
self, features: TensorType, layer_idx: int
|
|
177
|
+
) -> np.ndarray:
|
|
178
|
+
"""
|
|
179
|
+
Compute the residual norm of features orthogonal to the principal subspace.
|
|
180
|
+
|
|
181
|
+
Args:
|
|
182
|
+
features: Flattened feature matrix [N, D].
|
|
183
|
+
layer_idx: Index of the feature layer.
|
|
184
|
+
Returns:
|
|
185
|
+
Numpy array of residual norms (shape [N]).
|
|
186
|
+
"""
|
|
187
|
+
center = self.centers[layer_idx]
|
|
188
|
+
proj = self.residual_projections[layer_idx]
|
|
189
|
+
return self._compute_residual_norms(features, center, proj)
|
|
190
|
+
|
|
191
|
+
def _compute_residual_norms(
|
|
192
|
+
self, features: TensorType, center: TensorType, proj: TensorType
|
|
193
|
+
) -> np.ndarray:
|
|
194
|
+
"""Compute residual norms for the provided features.
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
features: Flattened feature matrix `[N, D]`.
|
|
198
|
+
center: Center of the PCA subspace for the layer.
|
|
199
|
+
proj: Projection matrix onto the residual subspace.
|
|
200
|
+
|
|
201
|
+
Returns:
|
|
202
|
+
Numpy array of residual norms of shape `[N]`.
|
|
203
|
+
"""
|
|
204
|
+
coords = self.op.matmul(features - center, proj)
|
|
205
|
+
norms = self.op.norm(coords, dim=-1)
|
|
206
|
+
return self.op.convert_to_numpy(norms)
|
|
207
|
+
|
|
208
|
+
# === Visualization ===
|
|
209
|
+
def plot_spectrum(self) -> None:
|
|
210
|
+
"""
|
|
211
|
+
Visualize residual explained variance per layer vs. principal dimensions being
|
|
212
|
+
excluded.
|
|
213
|
+
|
|
214
|
+
If princ_dims is int: x-axis = number of components [0..D-1].
|
|
215
|
+
If princ_dims is float: x-axis = ratio [0..1].
|
|
216
|
+
|
|
217
|
+
Draws:
|
|
218
|
+
- Curve: residual explained variance vs. number of principal components.
|
|
219
|
+
- Dashed line: selected princ_dims marker.
|
|
220
|
+
"""
|
|
221
|
+
is_ratio = isinstance(self._princ_dims, float)
|
|
222
|
+
|
|
223
|
+
for idx, eig_vals in enumerate(self.eig_vals_list):
|
|
224
|
+
D = eig_vals.size
|
|
225
|
+
# Compute residual explained variance curve
|
|
226
|
+
residual_cumsum = np.cumsum(eig_vals)[::-1]
|
|
227
|
+
residual_explained = residual_cumsum / residual_cumsum.max()
|
|
228
|
+
|
|
229
|
+
# Choose x-axis scale and marker
|
|
230
|
+
if is_ratio:
|
|
231
|
+
x = np.linspace(0, 1, D)
|
|
232
|
+
marker = self.princ_dims_list[idx] / D
|
|
233
|
+
xlabel = "Ratio of principal components"
|
|
234
|
+
else:
|
|
235
|
+
x = np.arange(D)
|
|
236
|
+
marker = self.princ_dims_list[idx]
|
|
237
|
+
xlabel = "Number of principal components"
|
|
238
|
+
|
|
239
|
+
(line,) = plt.plot(x, residual_explained, label=f"layer {idx}")
|
|
240
|
+
plt.axvline(
|
|
241
|
+
x=marker,
|
|
242
|
+
linestyle="--",
|
|
243
|
+
color=line.get_color(),
|
|
244
|
+
label=f"layer {idx} marker",
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
plt.xlabel(xlabel)
|
|
248
|
+
plt.ylabel("Residual explained variance")
|
|
249
|
+
plt.legend()
|
|
250
|
+
plt.tight_layout()
|
|
251
|
+
plt.show()
|
|
252
|
+
|
|
253
|
+
# === Properties ===
|
|
254
|
+
@property
|
|
255
|
+
def requires_to_fit_dataset(self) -> bool:
|
|
256
|
+
"""
|
|
257
|
+
Whether an OOD detector needs a `fit_dataset` argument in the fit function.
|
|
258
|
+
|
|
259
|
+
Returns:
|
|
260
|
+
bool: True if `fit_dataset` is required else False.
|
|
261
|
+
"""
|
|
262
|
+
return True
|
|
263
|
+
|
|
264
|
+
@property
|
|
265
|
+
def requires_internal_features(self) -> bool:
|
|
266
|
+
"""
|
|
267
|
+
Whether an OOD detector acts on internal model features.
|
|
268
|
+
|
|
269
|
+
Returns:
|
|
270
|
+
bool: True if the detector perform computations on an intermediate layer
|
|
271
|
+
else False.
|
|
272
|
+
"""
|
|
273
|
+
return True
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Copyright IRT Antoine de Saint Exupéry et Université Paul Sabatier Toulouse III - All
|
|
3
|
+
# rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
|
|
4
|
+
# CRIAQ and ANITI - https://www.deel.ai/
|
|
5
|
+
#
|
|
6
|
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
7
|
+
# of this software and associated documentation files (the "Software"), to deal
|
|
8
|
+
# in the Software without restriction, including without limitation the rights
|
|
9
|
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
10
|
+
# copies of the Software, and to permit persons to whom the Software is
|
|
11
|
+
# furnished to do so, subject to the following conditions:
|
|
12
|
+
#
|
|
13
|
+
# The above copyright notice and this permission notice shall be included in all
|
|
14
|
+
# copies or substantial portions of the Software.
|
|
15
|
+
#
|
|
16
|
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
17
|
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
18
|
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
19
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
20
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
21
|
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
22
|
+
# SOFTWARE.
|
|
23
|
+
try:
|
|
24
|
+
from .tf_preprocess import TFRandomPatchPermutation
|
|
25
|
+
except ImportError:
|
|
26
|
+
pass
|
|
27
|
+
|
|
28
|
+
try:
|
|
29
|
+
from .torch_preprocess import TorchRandomPatchPermutation
|
|
30
|
+
except ImportError:
|
|
31
|
+
pass
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Copyright IRT Antoine de Saint Exupéry et Université Paul Sabatier Toulouse III - All
|
|
3
|
+
# rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
|
|
4
|
+
# CRIAQ and ANITI - https://www.deel.ai/
|
|
5
|
+
#
|
|
6
|
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
7
|
+
# of this software and associated documentation files (the "Software"), to deal
|
|
8
|
+
# in the Software without restriction, including without limitation the rights
|
|
9
|
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
10
|
+
# copies of the Software, and to permit persons to whom the Software is
|
|
11
|
+
# furnished to do so, subject to the following conditions:
|
|
12
|
+
#
|
|
13
|
+
# The above copyright notice and this permission notice shall be included in all
|
|
14
|
+
# copies or substantial portions of the Software.
|
|
15
|
+
#
|
|
16
|
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
17
|
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
18
|
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
19
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
20
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
21
|
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
22
|
+
# SOFTWARE.
|
|
23
|
+
import numpy as np
|
|
24
|
+
import tensorflow as tf
|
|
25
|
+
|
|
26
|
+
from ..types import Optional
|
|
27
|
+
from ..types import Tuple
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class TFRandomPatchPermutation:
|
|
31
|
+
def __init__(self, patch_size: Tuple[int] = (8, 8)):
|
|
32
|
+
"""Randomly permute the patches of an image. This transformation is used in NMD
|
|
33
|
+
paper to artificially craft OOD data from ID images.
|
|
34
|
+
|
|
35
|
+
Source (NMD paper):
|
|
36
|
+
"Neural Mean Discrepancy for Efficient Out-of-Distribution Detection"
|
|
37
|
+
[link](https://arxiv.org/pdf/2104.11408.pdf)
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
patch_size (Tuple[int], optional): Patch dimensions (h, w), should be
|
|
41
|
+
divisors of the image dimensions (H, W). Defaults to (8, 8).
|
|
42
|
+
"""
|
|
43
|
+
self.patch_size = patch_size
|
|
44
|
+
|
|
45
|
+
def __call__(self, tensor: tf.Tensor, seed: Optional[int] = None):
|
|
46
|
+
"""Apply random patch permutation.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
tensor (tf.Tensor): Tensor of shape [H, W, C]
|
|
50
|
+
seed (Optinal[int]): Seed number to set for the permutation if not None.
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
tf.Tensor: Transformed tensor.
|
|
54
|
+
"""
|
|
55
|
+
h, w = self.patch_size
|
|
56
|
+
H, W, C = tensor.shape
|
|
57
|
+
tensor_ = tensor
|
|
58
|
+
|
|
59
|
+
# raise warning if patch dimensions are not divisors of image dimensions
|
|
60
|
+
if H % h != 0:
|
|
61
|
+
print(
|
|
62
|
+
f"Warning! Patch height ({h}) should be a divisor of the image height"
|
|
63
|
+
+ f" ({H}). Zero padding will be added to get the correct output shape."
|
|
64
|
+
)
|
|
65
|
+
tensor_ = tensor[: -(H % h)]
|
|
66
|
+
if W % w != 0:
|
|
67
|
+
print(
|
|
68
|
+
f"Warning! Patch width ({w}) should be a divisor of the image width"
|
|
69
|
+
+ f" ({W}). Zero padding will be added to get the correct output shape."
|
|
70
|
+
)
|
|
71
|
+
tensor_ = tensor_[:, : -(W % w)]
|
|
72
|
+
|
|
73
|
+
# === patch permutation ===
|
|
74
|
+
# divide the batch of images into non-overlapping patches
|
|
75
|
+
# => [num_patches, h * w, C]
|
|
76
|
+
u = tf.transpose(
|
|
77
|
+
tf.reshape(tensor_, (H // h, h, W // w, w, C)), (0, 2, 1, 3, 4)
|
|
78
|
+
)
|
|
79
|
+
u = tf.reshape(u, (-1, h * w, C))
|
|
80
|
+
|
|
81
|
+
# permute the patches of each image in the batch
|
|
82
|
+
# => [num_patches, h * w, C]
|
|
83
|
+
# Note: we use numpy rng for deterministic index shuffling because
|
|
84
|
+
# `tf.stateless_shuffle` is still experimental
|
|
85
|
+
g = np.random.default_rng(seed=seed)
|
|
86
|
+
indices = np.arange(u.shape[0])
|
|
87
|
+
g.shuffle(indices)
|
|
88
|
+
pu = tf.gather(u, indices)
|
|
89
|
+
|
|
90
|
+
# fold the permuted patches back together
|
|
91
|
+
# => [H, W, C]
|
|
92
|
+
f = tf.transpose(tf.reshape(pu, (H // h, W // w, h, w, C)), (0, 2, 1, 3, 4))
|
|
93
|
+
f = tf.reshape(f, tensor_.shape)
|
|
94
|
+
f = tf.pad(f, tf.constant([[0, H % h], [0, W % w], [0, 0]]))
|
|
95
|
+
return f
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Copyright IRT Antoine de Saint Exupéry et Université Paul Sabatier Toulouse III - All
|
|
3
|
+
# rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
|
|
4
|
+
# CRIAQ and ANITI - https://www.deel.ai/
|
|
5
|
+
#
|
|
6
|
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
7
|
+
# of this software and associated documentation files (the "Software"), to deal
|
|
8
|
+
# in the Software without restriction, including without limitation the rights
|
|
9
|
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
10
|
+
# copies of the Software, and to permit persons to whom the Software is
|
|
11
|
+
# furnished to do so, subject to the following conditions:
|
|
12
|
+
#
|
|
13
|
+
# The above copyright notice and this permission notice shall be included in all
|
|
14
|
+
# copies or substantial portions of the Software.
|
|
15
|
+
#
|
|
16
|
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
17
|
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
18
|
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
19
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
20
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
21
|
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
22
|
+
# SOFTWARE.
|
|
23
|
+
import torch
|
|
24
|
+
import torch.nn.functional as F
|
|
25
|
+
|
|
26
|
+
from ..types import Optional
|
|
27
|
+
from ..types import Tuple
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class TorchRandomPatchPermutation:
|
|
31
|
+
def __init__(self, patch_size: Tuple[int] = (8, 8)):
|
|
32
|
+
"""Randomly permute the patches of an image. This transformation is used in NMD
|
|
33
|
+
paper to artificially craft OOD data from ID images.
|
|
34
|
+
|
|
35
|
+
Source (NMD paper):
|
|
36
|
+
"Neural Mean Discrepancy for Efficient Out-of-Distribution Detection"
|
|
37
|
+
[link](https://arxiv.org/pdf/2104.11408.pdf)
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
patch_size (Tuple[int], optional): Patch dimensions (h, w), should be
|
|
41
|
+
divisors of the image dimensions (H, W). Defaults to (8, 8).
|
|
42
|
+
"""
|
|
43
|
+
self.patch_size = patch_size
|
|
44
|
+
|
|
45
|
+
def __call__(self, tensor: torch.Tensor, seed: Optional[int] = None):
|
|
46
|
+
"""Apply random patch permutation.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
tensor (torch.Tensor): Tensor of shape [C, H, W]
|
|
50
|
+
seed (Optinal[int]): Seed number to set for the permutation if not None.
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
torch.Tensor: Transformed tensor.
|
|
54
|
+
"""
|
|
55
|
+
h, w = self.patch_size
|
|
56
|
+
H, W, _ = tensor.shape
|
|
57
|
+
|
|
58
|
+
# set generator if seed is not None
|
|
59
|
+
g = None
|
|
60
|
+
if seed is not None:
|
|
61
|
+
g = torch.Generator(device=tensor.device)
|
|
62
|
+
g.manual_seed(seed)
|
|
63
|
+
|
|
64
|
+
# raise warning if patch dimensions are not divisors of image dimensions
|
|
65
|
+
if H % h != 0:
|
|
66
|
+
print(
|
|
67
|
+
f"Warning! Patch height ({h}) should be a divisor of the image height"
|
|
68
|
+
+ f" ({H}). Zero padding will be added to get the correct output shape."
|
|
69
|
+
)
|
|
70
|
+
if W % w != 0:
|
|
71
|
+
print(
|
|
72
|
+
f"Warning! Patch width ({w}) should be a divisor of the image width"
|
|
73
|
+
+ f" ({W}). Zero padding will be added to get the correct output shape."
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
# === patch permutation ===
|
|
77
|
+
# [C, H, W] => [1, C, H, W]
|
|
78
|
+
x = tensor.unsqueeze(0)
|
|
79
|
+
# divide the batch of images into non-overlapping patches
|
|
80
|
+
# => [1, h * w, num_patches]
|
|
81
|
+
u = F.unfold(x, kernel_size=self.patch_size, stride=self.patch_size, padding=0)
|
|
82
|
+
# permute the patches of each image in the batch
|
|
83
|
+
# => [1, h * w, num_patches]
|
|
84
|
+
pu = torch.cat(
|
|
85
|
+
[b_[:, torch.randperm(b_.shape[-1], generator=g)][None, ...] for b_ in u],
|
|
86
|
+
dim=0,
|
|
87
|
+
)
|
|
88
|
+
# fold the permuted patches back together
|
|
89
|
+
# => [1, C, H, W]
|
|
90
|
+
f = F.fold(
|
|
91
|
+
pu,
|
|
92
|
+
x.shape[-2:],
|
|
93
|
+
kernel_size=self.patch_size,
|
|
94
|
+
stride=self.patch_size,
|
|
95
|
+
padding=0,
|
|
96
|
+
)
|
|
97
|
+
return f.squeeze(0)
|
oodeel/types/__init__.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Copyright IRT Antoine de Saint Exupéry et Université Paul Sabatier Toulouse III - All
|
|
3
|
+
# rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
|
|
4
|
+
# CRIAQ and ANITI - https://www.deel.ai/
|
|
5
|
+
#
|
|
6
|
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
7
|
+
# of this software and associated documentation files (the "Software"), to deal
|
|
8
|
+
# in the Software without restriction, including without limitation the rights
|
|
9
|
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
10
|
+
# copies of the Software, and to permit persons to whom the Software is
|
|
11
|
+
# furnished to do so, subject to the following conditions:
|
|
12
|
+
#
|
|
13
|
+
# The above copyright notice and this permission notice shall be included in all
|
|
14
|
+
# copies or substantial portions of the Software.
|
|
15
|
+
#
|
|
16
|
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
17
|
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
18
|
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
19
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
20
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
21
|
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
22
|
+
# SOFTWARE.
|
|
23
|
+
"""
|
|
24
|
+
Typing module
|
|
25
|
+
"""
|
|
26
|
+
from typing import Any
|
|
27
|
+
from typing import Callable
|
|
28
|
+
from typing import Dict
|
|
29
|
+
from typing import Iterable
|
|
30
|
+
from typing import List
|
|
31
|
+
from typing import Optional
|
|
32
|
+
from typing import Sequence
|
|
33
|
+
from typing import Tuple
|
|
34
|
+
from typing import Type
|
|
35
|
+
from typing import TypeVar
|
|
36
|
+
from typing import Union
|
|
37
|
+
|
|
38
|
+
import numpy as np
|
|
39
|
+
|
|
40
|
+
avail_lib = []
|
|
41
|
+
try:
|
|
42
|
+
import tensorflow as tf
|
|
43
|
+
|
|
44
|
+
avail_lib.append("tensorflow")
|
|
45
|
+
except ImportError:
|
|
46
|
+
pass
|
|
47
|
+
|
|
48
|
+
try:
|
|
49
|
+
import torch
|
|
50
|
+
|
|
51
|
+
avail_lib.append("torch")
|
|
52
|
+
except ImportError:
|
|
53
|
+
pass
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
if len(avail_lib) == 2:
|
|
57
|
+
DatasetType = Union[tf.data.Dataset, torch.utils.data.DataLoader]
|
|
58
|
+
TensorType = Union[tf.Tensor, torch.Tensor, np.ndarray]
|
|
59
|
+
ItemType = Union[
|
|
60
|
+
tf.Tensor,
|
|
61
|
+
torch.Tensor,
|
|
62
|
+
np.ndarray,
|
|
63
|
+
tuple,
|
|
64
|
+
list,
|
|
65
|
+
dict,
|
|
66
|
+
]
|
|
67
|
+
|
|
68
|
+
elif "tensorflow" in avail_lib:
|
|
69
|
+
DatasetType = Type[tf.data.Dataset]
|
|
70
|
+
TensorType = Union[tf.Tensor, np.ndarray]
|
|
71
|
+
ItemType = Union[tf.Tensor, np.ndarray, tuple, list, dict]
|
|
72
|
+
elif "torch" in avail_lib:
|
|
73
|
+
DatasetType = Type[torch.utils.data.DataLoader]
|
|
74
|
+
TensorType = Union[torch.Tensor, np.ndarray]
|
|
75
|
+
ItemType = Union[torch.Tensor, np.ndarray, tuple, list, dict]
|
oodeel/utils/__init__.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Copyright IRT Antoine de Saint Exupéry et Université Paul Sabatier Toulouse III - All
|
|
3
|
+
# rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
|
|
4
|
+
# CRIAQ and ANITI - https://www.deel.ai/
|
|
5
|
+
#
|
|
6
|
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
7
|
+
# of this software and associated documentation files (the "Software"), to deal
|
|
8
|
+
# in the Software without restriction, including without limitation the rights
|
|
9
|
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
10
|
+
# copies of the Software, and to permit persons to whom the Software is
|
|
11
|
+
# furnished to do so, subject to the following conditions:
|
|
12
|
+
#
|
|
13
|
+
# The above copyright notice and this permission notice shall be included in all
|
|
14
|
+
# copies or substantial portions of the Software.
|
|
15
|
+
#
|
|
16
|
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
17
|
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
18
|
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
19
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
20
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
21
|
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
22
|
+
# SOFTWARE.
|
|
23
|
+
from .general_utils import import_backend_specific_stuff
|
|
24
|
+
from .general_utils import is_from
|
|
25
|
+
|
|
26
|
+
try:
|
|
27
|
+
import tensorflow as tf
|
|
28
|
+
from .tf_operator import TFOperator
|
|
29
|
+
from .tf_training_tools import train_tf_model
|
|
30
|
+
except ImportError:
|
|
31
|
+
pass
|
|
32
|
+
|
|
33
|
+
try:
|
|
34
|
+
import torch
|
|
35
|
+
from .torch_operator import TorchOperator
|
|
36
|
+
from .torch_training_tools import train_torch_model
|
|
37
|
+
except ImportError:
|
|
38
|
+
pass
|