oodeel 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- oodeel/__init__.py +28 -0
- oodeel/aggregator/__init__.py +26 -0
- oodeel/aggregator/base.py +70 -0
- oodeel/aggregator/fisher.py +259 -0
- oodeel/aggregator/mean.py +72 -0
- oodeel/aggregator/std.py +86 -0
- oodeel/datasets/__init__.py +24 -0
- oodeel/datasets/data_handler.py +334 -0
- oodeel/datasets/deprecated/DEPRECATED_data_handler.py +236 -0
- oodeel/datasets/deprecated/DEPRECATED_ooddataset.py +330 -0
- oodeel/datasets/deprecated/DEPRECATED_tf_data_handler.py +671 -0
- oodeel/datasets/deprecated/DEPRECATED_torch_data_handler.py +769 -0
- oodeel/datasets/deprecated/__init__.py +31 -0
- oodeel/datasets/tf_data_handler.py +600 -0
- oodeel/datasets/torch_data_handler.py +672 -0
- oodeel/eval/__init__.py +22 -0
- oodeel/eval/metrics.py +218 -0
- oodeel/eval/plots/__init__.py +27 -0
- oodeel/eval/plots/features.py +345 -0
- oodeel/eval/plots/metrics.py +118 -0
- oodeel/eval/plots/plotly.py +162 -0
- oodeel/extractor/__init__.py +35 -0
- oodeel/extractor/feature_extractor.py +187 -0
- oodeel/extractor/hf_torch_feature_extractor.py +184 -0
- oodeel/extractor/keras_feature_extractor.py +409 -0
- oodeel/extractor/torch_feature_extractor.py +506 -0
- oodeel/methods/__init__.py +47 -0
- oodeel/methods/base.py +570 -0
- oodeel/methods/dknn.py +185 -0
- oodeel/methods/energy.py +119 -0
- oodeel/methods/entropy.py +113 -0
- oodeel/methods/gen.py +113 -0
- oodeel/methods/gram.py +274 -0
- oodeel/methods/mahalanobis.py +209 -0
- oodeel/methods/mls.py +113 -0
- oodeel/methods/odin.py +109 -0
- oodeel/methods/rmds.py +172 -0
- oodeel/methods/she.py +159 -0
- oodeel/methods/vim.py +273 -0
- oodeel/preprocess/__init__.py +31 -0
- oodeel/preprocess/tf_preprocess.py +95 -0
- oodeel/preprocess/torch_preprocess.py +97 -0
- oodeel/types/__init__.py +75 -0
- oodeel/utils/__init__.py +38 -0
- oodeel/utils/general_utils.py +97 -0
- oodeel/utils/operator.py +253 -0
- oodeel/utils/tf_operator.py +269 -0
- oodeel/utils/tf_training_tools.py +219 -0
- oodeel/utils/torch_operator.py +292 -0
- oodeel/utils/torch_training_tools.py +303 -0
- oodeel-0.4.0.dist-info/METADATA +409 -0
- oodeel-0.4.0.dist-info/RECORD +63 -0
- oodeel-0.4.0.dist-info/WHEEL +5 -0
- oodeel-0.4.0.dist-info/licenses/LICENSE +21 -0
- oodeel-0.4.0.dist-info/top_level.txt +2 -0
- tests/__init__.py +22 -0
- tests/tests_tensorflow/__init__.py +37 -0
- tests/tests_tensorflow/tf_methods_utils.py +140 -0
- tests/tests_tensorflow/tools_tf.py +86 -0
- tests/tests_torch/__init__.py +38 -0
- tests/tests_torch/tools_torch.py +151 -0
- tests/tests_torch/torch_methods_utils.py +148 -0
- tests/tools_operator.py +153 -0
|
@@ -0,0 +1,330 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Copyright IRT Antoine de Saint Exupéry et Université Paul Sabatier Toulouse III - All
|
|
3
|
+
# rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
|
|
4
|
+
# CRIAQ and ANITI - https://www.deel.ai/
|
|
5
|
+
#
|
|
6
|
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
7
|
+
# of this software and associated documentation files (the "Software"), to deal
|
|
8
|
+
# in the Software without restriction, including without limitation the rights
|
|
9
|
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
10
|
+
# copies of the Software, and to permit persons to whom the Software is
|
|
11
|
+
# furnished to do so, subject to the following conditions:
|
|
12
|
+
#
|
|
13
|
+
# The above copyright notice and this permission notice shall be included in all
|
|
14
|
+
# copies or substantial portions of the Software.
|
|
15
|
+
#
|
|
16
|
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
17
|
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
18
|
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
19
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
20
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
21
|
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
22
|
+
# SOFTWARE.
|
|
23
|
+
import numpy as np
|
|
24
|
+
|
|
25
|
+
from ...types import Callable
|
|
26
|
+
from ...types import DatasetType
|
|
27
|
+
from ...types import Optional
|
|
28
|
+
from ...types import Tuple
|
|
29
|
+
from ...types import Union
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class OODDataset(object):
|
|
33
|
+
"""Class for managing loading and processing of datasets that are to be used for
|
|
34
|
+
OOD detection. The class encapsulates a dataset like object augmented with OOD
|
|
35
|
+
related information, and then returns a dataset like object that is suited for
|
|
36
|
+
scoring or training with the .prepare method.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
dataset_id (Union[DatasetType, tuple, dict, str]): The dataset to load.
|
|
40
|
+
Can be loaded from tensorflow or torch datasets catalog when the str matches
|
|
41
|
+
one of the datasets. Defaults to Union[DatasetType, tuple, dict, str].
|
|
42
|
+
backend (str, optional): Whether the dataset is to be used for tensorflow
|
|
43
|
+
or torch models. Defaults to "tensorflow". Alternative: "torch".
|
|
44
|
+
keys (list, optional): keys to use for dataset elems. Default to None
|
|
45
|
+
load_kwargs (dict, optional): Additional loading kwargs when loading from
|
|
46
|
+
tensorflow_datasets catalog. Defaults to {}.
|
|
47
|
+
load_from_tensorflow_datasets (bool, optional): In the case where if the backend
|
|
48
|
+
is torch but the user still wants to import from tensorflow_datasets
|
|
49
|
+
catalog. In that case, tf.Tensor will not be loaded in VRAM and converted as
|
|
50
|
+
torch.Tensors on the fly. Defaults to False.
|
|
51
|
+
input_key (str, optional): The key of the element/item to consider as the
|
|
52
|
+
model input tensor. If None, taken as the first key. Defaults to None.
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
def __init__(
|
|
56
|
+
self,
|
|
57
|
+
dataset_id: Union[DatasetType, tuple, dict, str],
|
|
58
|
+
backend: str = "tensorflow",
|
|
59
|
+
keys: Optional[list] = None,
|
|
60
|
+
load_kwargs: dict = {},
|
|
61
|
+
load_from_tensorflow_datasets: bool = False,
|
|
62
|
+
input_key: Optional[str] = None,
|
|
63
|
+
):
|
|
64
|
+
|
|
65
|
+
self.backend = backend
|
|
66
|
+
self.load_from_tensorflow_datasets = load_from_tensorflow_datasets
|
|
67
|
+
|
|
68
|
+
# The length of the dataset is kept as attribute to avoid redundant
|
|
69
|
+
# iterations over self.data
|
|
70
|
+
self.length = None
|
|
71
|
+
|
|
72
|
+
# Set the load parameters for tfds / torchvision
|
|
73
|
+
if backend == "tensorflow":
|
|
74
|
+
load_kwargs["as_supervised"] = False
|
|
75
|
+
# Set the channel order depending on the backend
|
|
76
|
+
if self.backend == "torch":
|
|
77
|
+
if load_from_tensorflow_datasets:
|
|
78
|
+
from .DEPRECATED_tf_data_handler import TFDataHandler
|
|
79
|
+
import tensorflow as tf
|
|
80
|
+
|
|
81
|
+
tf.config.set_visible_devices([], "GPU")
|
|
82
|
+
self._data_handler = TFDataHandler()
|
|
83
|
+
load_kwargs["as_supervised"] = False
|
|
84
|
+
else:
|
|
85
|
+
from .DEPRECATED_torch_data_handler import TorchDataHandler
|
|
86
|
+
|
|
87
|
+
self._data_handler = TorchDataHandler()
|
|
88
|
+
self.channel_order = "channels_first"
|
|
89
|
+
else:
|
|
90
|
+
from .DEPRECATED_tf_data_handler import TFDataHandler
|
|
91
|
+
|
|
92
|
+
self._data_handler = TFDataHandler()
|
|
93
|
+
self.channel_order = "channels_last"
|
|
94
|
+
|
|
95
|
+
self.load_params = load_kwargs
|
|
96
|
+
# Load the dataset depending on the type of dataset_id
|
|
97
|
+
self.data = self._data_handler.load_dataset(dataset_id, keys, load_kwargs)
|
|
98
|
+
|
|
99
|
+
# Get the length of the elements/items in the dataset
|
|
100
|
+
self.len_item = self._data_handler.get_item_length(self.data)
|
|
101
|
+
if self.has_ood_label:
|
|
102
|
+
self.len_item -= 1
|
|
103
|
+
|
|
104
|
+
# Get the key of the tensor to feed the model with
|
|
105
|
+
if input_key is None:
|
|
106
|
+
self.input_key = self._data_handler.get_ds_feature_keys(self.data)[0]
|
|
107
|
+
else:
|
|
108
|
+
self.input_key = input_key
|
|
109
|
+
|
|
110
|
+
def __len__(self) -> int:
|
|
111
|
+
"""get the length of the dataset.
|
|
112
|
+
|
|
113
|
+
Returns:
|
|
114
|
+
int: length of the dataset
|
|
115
|
+
"""
|
|
116
|
+
if self.length is None:
|
|
117
|
+
self.length = self._data_handler.get_dataset_length(self.data)
|
|
118
|
+
return self.length
|
|
119
|
+
|
|
120
|
+
@property
|
|
121
|
+
def has_ood_label(self) -> bool:
|
|
122
|
+
"""Check if the dataset has an out-of-distribution label.
|
|
123
|
+
|
|
124
|
+
Returns:
|
|
125
|
+
bool: True if data handler has a "ood_label" feature key.
|
|
126
|
+
"""
|
|
127
|
+
return self._data_handler.has_feature_key(self.data, "ood_label")
|
|
128
|
+
|
|
129
|
+
def get_ood_labels(
|
|
130
|
+
self,
|
|
131
|
+
) -> np.ndarray:
|
|
132
|
+
"""Get ood_labels from self.data if any
|
|
133
|
+
|
|
134
|
+
Returns:
|
|
135
|
+
np.ndarray: array of labels
|
|
136
|
+
"""
|
|
137
|
+
assert self._data_handler.has_feature_key(
|
|
138
|
+
self.data, "ood_label"
|
|
139
|
+
), "The data has no ood_labels"
|
|
140
|
+
labels = self._data_handler.get_feature_from_ds(self.data, "ood_label")
|
|
141
|
+
return labels
|
|
142
|
+
|
|
143
|
+
def add_out_data(
|
|
144
|
+
self,
|
|
145
|
+
out_dataset: Union["OODDataset", DatasetType],
|
|
146
|
+
in_value: int = 0,
|
|
147
|
+
out_value: int = 1,
|
|
148
|
+
resize: Optional[bool] = False,
|
|
149
|
+
shape: Optional[Tuple[int]] = None,
|
|
150
|
+
) -> "OODDataset":
|
|
151
|
+
"""Concatenate two OODDatasets. Useful for scoring on multiple datasets, or
|
|
152
|
+
training with added out-of-distribution data.
|
|
153
|
+
|
|
154
|
+
Args:
|
|
155
|
+
out_dataset (Union[OODDataset, DatasetType]): dataset of
|
|
156
|
+
out-of-distribution data
|
|
157
|
+
in_value (int): ood label value for in-distribution data. Defaults to 0
|
|
158
|
+
out_value (int): ood label value for out-of-distribution data. Defaults to 1
|
|
159
|
+
resize (Optional[bool], optional):toggles if input tensors of the
|
|
160
|
+
datasets have to be resized to have the same shape. Defaults to False.
|
|
161
|
+
shape (Optional[Tuple[int]], optional):shape to use for resizing input
|
|
162
|
+
tensors. If None, the tensors are resized with the shape of the
|
|
163
|
+
in_dataset input tensors. Defaults to None.
|
|
164
|
+
|
|
165
|
+
Returns:
|
|
166
|
+
OODDataset: a Dataset object with the concatenated data
|
|
167
|
+
"""
|
|
168
|
+
|
|
169
|
+
# Creating an OODDataset object from out_dataset if necessary and make sure
|
|
170
|
+
# the two OODDatasets have compatible parameters
|
|
171
|
+
if isinstance(out_dataset, type(self)):
|
|
172
|
+
out_dataset = out_dataset.data
|
|
173
|
+
else:
|
|
174
|
+
out_dataset = OODDataset(out_dataset, backend=self.backend).data
|
|
175
|
+
|
|
176
|
+
# Assign the correct ood_label to self.data, depending on out_as_in
|
|
177
|
+
self.data = self._data_handler.assign_feature_value(
|
|
178
|
+
self.data, "ood_label", in_value
|
|
179
|
+
)
|
|
180
|
+
out_dataset = self._data_handler.assign_feature_value(
|
|
181
|
+
out_dataset, "ood_label", out_value
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
# Merge the two underlying Datasets
|
|
185
|
+
merge_kwargs = (
|
|
186
|
+
{"channel_order": self.channel_order}
|
|
187
|
+
if self.backend == "tensorflow"
|
|
188
|
+
else {}
|
|
189
|
+
)
|
|
190
|
+
data = self._data_handler.merge(
|
|
191
|
+
self.data,
|
|
192
|
+
out_dataset,
|
|
193
|
+
resize=resize,
|
|
194
|
+
shape=shape,
|
|
195
|
+
**merge_kwargs,
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
# Create a new OODDataset from the merged Dataset
|
|
199
|
+
output_ds = OODDataset(
|
|
200
|
+
dataset_id=data,
|
|
201
|
+
backend=self.backend,
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
return output_ds
|
|
205
|
+
|
|
206
|
+
def split_by_class(
|
|
207
|
+
self,
|
|
208
|
+
in_labels: Optional[Union[np.ndarray, list]] = None,
|
|
209
|
+
out_labels: Optional[Union[np.ndarray, list]] = None,
|
|
210
|
+
) -> Optional[Tuple["OODDataset"]]:
|
|
211
|
+
"""Filter the dataset by assigning ood labels depending on labels
|
|
212
|
+
value (typically, class id).
|
|
213
|
+
|
|
214
|
+
Args:
|
|
215
|
+
in_labels (Optional[Union[np.ndarray, list]], optional): set of labels
|
|
216
|
+
to be considered as in-distribution. Defaults to None.
|
|
217
|
+
out_labels (Optional[Union[np.ndarray, list]], optional): set of labels
|
|
218
|
+
to be considered as out-of-distribution. Defaults to None.
|
|
219
|
+
|
|
220
|
+
Returns:
|
|
221
|
+
Optional[Tuple[OODDataset]]: Tuple of in-distribution and
|
|
222
|
+
out-of-distribution OODDatasets
|
|
223
|
+
"""
|
|
224
|
+
# Make sure the dataset has labels
|
|
225
|
+
assert (in_labels is not None) or (
|
|
226
|
+
out_labels is not None
|
|
227
|
+
), "specify labels to filter with"
|
|
228
|
+
assert self.len_item >= 2, "the dataset has no labels"
|
|
229
|
+
|
|
230
|
+
# Filter the dataset depending on in_labels and out_labels given
|
|
231
|
+
if (out_labels is not None) and (in_labels is not None):
|
|
232
|
+
in_data = self._data_handler.filter_by_feature_value(
|
|
233
|
+
self.data, "label", in_labels
|
|
234
|
+
)
|
|
235
|
+
out_data = self._data_handler.filter_by_feature_value(
|
|
236
|
+
self.data, "label", out_labels
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
if out_labels is None:
|
|
240
|
+
in_data = self._data_handler.filter_by_feature_value(
|
|
241
|
+
self.data, "label", in_labels
|
|
242
|
+
)
|
|
243
|
+
out_data = self._data_handler.filter_by_feature_value(
|
|
244
|
+
self.data, "label", in_labels, excluded=True
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
elif in_labels is None:
|
|
248
|
+
in_data = self._data_handler.filter_by_feature_value(
|
|
249
|
+
self.data, "label", out_labels, excluded=True
|
|
250
|
+
)
|
|
251
|
+
out_data = self._data_handler.filter_by_feature_value(
|
|
252
|
+
self.data, "label", out_labels
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
# Return the filtered OODDatasets
|
|
256
|
+
return (
|
|
257
|
+
OODDataset(in_data, backend=self.backend),
|
|
258
|
+
OODDataset(out_data, backend=self.backend),
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
def prepare(
|
|
262
|
+
self,
|
|
263
|
+
batch_size: int = 128,
|
|
264
|
+
preprocess_fn: Optional[Callable] = None,
|
|
265
|
+
augment_fn: Optional[Callable] = None,
|
|
266
|
+
with_ood_labels: bool = False,
|
|
267
|
+
with_labels: bool = True,
|
|
268
|
+
shuffle: bool = False,
|
|
269
|
+
**kwargs_prepare,
|
|
270
|
+
) -> DatasetType:
|
|
271
|
+
"""Prepare self.data for scoring or training
|
|
272
|
+
|
|
273
|
+
Args:
|
|
274
|
+
batch_size (int, optional): Batch_size of the returned dataset like object.
|
|
275
|
+
Defaults to 128.
|
|
276
|
+
preprocess_fn (Callable, optional): Preprocessing function to apply to
|
|
277
|
+
the dataset. Defaults to None.
|
|
278
|
+
augment_fn (Callable, optional): Augment function to be used (when the
|
|
279
|
+
returned dataset is to be used for training). Defaults to None.
|
|
280
|
+
with_ood_labels (bool, optional): To return the dataset with ood_labels
|
|
281
|
+
or not. Defaults to True.
|
|
282
|
+
with_labels (bool, optional): To return the dataset with labels or not.
|
|
283
|
+
Defaults to True.
|
|
284
|
+
shuffle (bool, optional): To shuffle the returned dataset or not.
|
|
285
|
+
Defaults to False.
|
|
286
|
+
kwargs_prepare (dict): Additional parameters to be passed to the
|
|
287
|
+
data_handler.prepare_for_training method.
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
Returns:
|
|
291
|
+
DatasetType: prepared dataset
|
|
292
|
+
"""
|
|
293
|
+
# Check if the dataset has at least one of label and ood_label
|
|
294
|
+
assert (
|
|
295
|
+
with_ood_labels or with_labels
|
|
296
|
+
), "The dataset must have at least one of label and ood_label"
|
|
297
|
+
|
|
298
|
+
# Check if the dataset has ood_labels when asked to return with_ood_labels
|
|
299
|
+
if with_ood_labels:
|
|
300
|
+
assert (
|
|
301
|
+
self.has_ood_label
|
|
302
|
+
), "Please assign ood labels before preparing with ood_labels"
|
|
303
|
+
|
|
304
|
+
dataset_to_prepare = self.data
|
|
305
|
+
|
|
306
|
+
# Making the dataset channel first if the backend is torch
|
|
307
|
+
if self.backend == "torch" and self.load_from_tensorflow_datasets:
|
|
308
|
+
dataset_to_prepare = self._data_handler.make_channel_first(
|
|
309
|
+
self.input_key, dataset_to_prepare
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
# # Select the keys to be returned
|
|
313
|
+
keys = [self.input_key, "label", "ood_label"]
|
|
314
|
+
if not with_labels:
|
|
315
|
+
keys.remove("label")
|
|
316
|
+
if not with_ood_labels:
|
|
317
|
+
keys.remove("ood_label")
|
|
318
|
+
|
|
319
|
+
# Prepare the dataset for training or scoring
|
|
320
|
+
dataset = self._data_handler.prepare_for_training(
|
|
321
|
+
dataset=dataset_to_prepare,
|
|
322
|
+
batch_size=batch_size,
|
|
323
|
+
shuffle=shuffle,
|
|
324
|
+
preprocess_fn=preprocess_fn,
|
|
325
|
+
augment_fn=augment_fn,
|
|
326
|
+
output_keys=keys,
|
|
327
|
+
**kwargs_prepare,
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
return dataset
|