oodeel 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. oodeel/__init__.py +28 -0
  2. oodeel/aggregator/__init__.py +26 -0
  3. oodeel/aggregator/base.py +70 -0
  4. oodeel/aggregator/fisher.py +259 -0
  5. oodeel/aggregator/mean.py +72 -0
  6. oodeel/aggregator/std.py +86 -0
  7. oodeel/datasets/__init__.py +24 -0
  8. oodeel/datasets/data_handler.py +334 -0
  9. oodeel/datasets/deprecated/DEPRECATED_data_handler.py +236 -0
  10. oodeel/datasets/deprecated/DEPRECATED_ooddataset.py +330 -0
  11. oodeel/datasets/deprecated/DEPRECATED_tf_data_handler.py +671 -0
  12. oodeel/datasets/deprecated/DEPRECATED_torch_data_handler.py +769 -0
  13. oodeel/datasets/deprecated/__init__.py +31 -0
  14. oodeel/datasets/tf_data_handler.py +600 -0
  15. oodeel/datasets/torch_data_handler.py +672 -0
  16. oodeel/eval/__init__.py +22 -0
  17. oodeel/eval/metrics.py +218 -0
  18. oodeel/eval/plots/__init__.py +27 -0
  19. oodeel/eval/plots/features.py +345 -0
  20. oodeel/eval/plots/metrics.py +118 -0
  21. oodeel/eval/plots/plotly.py +162 -0
  22. oodeel/extractor/__init__.py +35 -0
  23. oodeel/extractor/feature_extractor.py +187 -0
  24. oodeel/extractor/hf_torch_feature_extractor.py +184 -0
  25. oodeel/extractor/keras_feature_extractor.py +409 -0
  26. oodeel/extractor/torch_feature_extractor.py +506 -0
  27. oodeel/methods/__init__.py +47 -0
  28. oodeel/methods/base.py +570 -0
  29. oodeel/methods/dknn.py +185 -0
  30. oodeel/methods/energy.py +119 -0
  31. oodeel/methods/entropy.py +113 -0
  32. oodeel/methods/gen.py +113 -0
  33. oodeel/methods/gram.py +274 -0
  34. oodeel/methods/mahalanobis.py +209 -0
  35. oodeel/methods/mls.py +113 -0
  36. oodeel/methods/odin.py +109 -0
  37. oodeel/methods/rmds.py +172 -0
  38. oodeel/methods/she.py +159 -0
  39. oodeel/methods/vim.py +273 -0
  40. oodeel/preprocess/__init__.py +31 -0
  41. oodeel/preprocess/tf_preprocess.py +95 -0
  42. oodeel/preprocess/torch_preprocess.py +97 -0
  43. oodeel/types/__init__.py +75 -0
  44. oodeel/utils/__init__.py +38 -0
  45. oodeel/utils/general_utils.py +97 -0
  46. oodeel/utils/operator.py +253 -0
  47. oodeel/utils/tf_operator.py +269 -0
  48. oodeel/utils/tf_training_tools.py +219 -0
  49. oodeel/utils/torch_operator.py +292 -0
  50. oodeel/utils/torch_training_tools.py +303 -0
  51. oodeel-0.4.0.dist-info/METADATA +409 -0
  52. oodeel-0.4.0.dist-info/RECORD +63 -0
  53. oodeel-0.4.0.dist-info/WHEEL +5 -0
  54. oodeel-0.4.0.dist-info/licenses/LICENSE +21 -0
  55. oodeel-0.4.0.dist-info/top_level.txt +2 -0
  56. tests/__init__.py +22 -0
  57. tests/tests_tensorflow/__init__.py +37 -0
  58. tests/tests_tensorflow/tf_methods_utils.py +140 -0
  59. tests/tests_tensorflow/tools_tf.py +86 -0
  60. tests/tests_torch/__init__.py +38 -0
  61. tests/tests_torch/tools_torch.py +151 -0
  62. tests/tests_torch/torch_methods_utils.py +148 -0
  63. tests/tools_operator.py +153 -0
@@ -0,0 +1,330 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright IRT Antoine de Saint Exupéry et Université Paul Sabatier Toulouse III - All
3
+ # rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
4
+ # CRIAQ and ANITI - https://www.deel.ai/
5
+ #
6
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
7
+ # of this software and associated documentation files (the "Software"), to deal
8
+ # in the Software without restriction, including without limitation the rights
9
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
+ # copies of the Software, and to permit persons to whom the Software is
11
+ # furnished to do so, subject to the following conditions:
12
+ #
13
+ # The above copyright notice and this permission notice shall be included in all
14
+ # copies or substantial portions of the Software.
15
+ #
16
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
+ # SOFTWARE.
23
+ import numpy as np
24
+
25
+ from ...types import Callable
26
+ from ...types import DatasetType
27
+ from ...types import Optional
28
+ from ...types import Tuple
29
+ from ...types import Union
30
+
31
+
32
+ class OODDataset(object):
33
+ """Class for managing loading and processing of datasets that are to be used for
34
+ OOD detection. The class encapsulates a dataset like object augmented with OOD
35
+ related information, and then returns a dataset like object that is suited for
36
+ scoring or training with the .prepare method.
37
+
38
+ Args:
39
+ dataset_id (Union[DatasetType, tuple, dict, str]): The dataset to load.
40
+ Can be loaded from tensorflow or torch datasets catalog when the str matches
41
+ one of the datasets. Defaults to Union[DatasetType, tuple, dict, str].
42
+ backend (str, optional): Whether the dataset is to be used for tensorflow
43
+ or torch models. Defaults to "tensorflow". Alternative: "torch".
44
+ keys (list, optional): keys to use for dataset elems. Default to None
45
+ load_kwargs (dict, optional): Additional loading kwargs when loading from
46
+ tensorflow_datasets catalog. Defaults to {}.
47
+ load_from_tensorflow_datasets (bool, optional): In the case where if the backend
48
+ is torch but the user still wants to import from tensorflow_datasets
49
+ catalog. In that case, tf.Tensor will not be loaded in VRAM and converted as
50
+ torch.Tensors on the fly. Defaults to False.
51
+ input_key (str, optional): The key of the element/item to consider as the
52
+ model input tensor. If None, taken as the first key. Defaults to None.
53
+ """
54
+
55
+ def __init__(
56
+ self,
57
+ dataset_id: Union[DatasetType, tuple, dict, str],
58
+ backend: str = "tensorflow",
59
+ keys: Optional[list] = None,
60
+ load_kwargs: dict = {},
61
+ load_from_tensorflow_datasets: bool = False,
62
+ input_key: Optional[str] = None,
63
+ ):
64
+
65
+ self.backend = backend
66
+ self.load_from_tensorflow_datasets = load_from_tensorflow_datasets
67
+
68
+ # The length of the dataset is kept as attribute to avoid redundant
69
+ # iterations over self.data
70
+ self.length = None
71
+
72
+ # Set the load parameters for tfds / torchvision
73
+ if backend == "tensorflow":
74
+ load_kwargs["as_supervised"] = False
75
+ # Set the channel order depending on the backend
76
+ if self.backend == "torch":
77
+ if load_from_tensorflow_datasets:
78
+ from .DEPRECATED_tf_data_handler import TFDataHandler
79
+ import tensorflow as tf
80
+
81
+ tf.config.set_visible_devices([], "GPU")
82
+ self._data_handler = TFDataHandler()
83
+ load_kwargs["as_supervised"] = False
84
+ else:
85
+ from .DEPRECATED_torch_data_handler import TorchDataHandler
86
+
87
+ self._data_handler = TorchDataHandler()
88
+ self.channel_order = "channels_first"
89
+ else:
90
+ from .DEPRECATED_tf_data_handler import TFDataHandler
91
+
92
+ self._data_handler = TFDataHandler()
93
+ self.channel_order = "channels_last"
94
+
95
+ self.load_params = load_kwargs
96
+ # Load the dataset depending on the type of dataset_id
97
+ self.data = self._data_handler.load_dataset(dataset_id, keys, load_kwargs)
98
+
99
+ # Get the length of the elements/items in the dataset
100
+ self.len_item = self._data_handler.get_item_length(self.data)
101
+ if self.has_ood_label:
102
+ self.len_item -= 1
103
+
104
+ # Get the key of the tensor to feed the model with
105
+ if input_key is None:
106
+ self.input_key = self._data_handler.get_ds_feature_keys(self.data)[0]
107
+ else:
108
+ self.input_key = input_key
109
+
110
+ def __len__(self) -> int:
111
+ """get the length of the dataset.
112
+
113
+ Returns:
114
+ int: length of the dataset
115
+ """
116
+ if self.length is None:
117
+ self.length = self._data_handler.get_dataset_length(self.data)
118
+ return self.length
119
+
120
+ @property
121
+ def has_ood_label(self) -> bool:
122
+ """Check if the dataset has an out-of-distribution label.
123
+
124
+ Returns:
125
+ bool: True if data handler has a "ood_label" feature key.
126
+ """
127
+ return self._data_handler.has_feature_key(self.data, "ood_label")
128
+
129
+ def get_ood_labels(
130
+ self,
131
+ ) -> np.ndarray:
132
+ """Get ood_labels from self.data if any
133
+
134
+ Returns:
135
+ np.ndarray: array of labels
136
+ """
137
+ assert self._data_handler.has_feature_key(
138
+ self.data, "ood_label"
139
+ ), "The data has no ood_labels"
140
+ labels = self._data_handler.get_feature_from_ds(self.data, "ood_label")
141
+ return labels
142
+
143
+ def add_out_data(
144
+ self,
145
+ out_dataset: Union["OODDataset", DatasetType],
146
+ in_value: int = 0,
147
+ out_value: int = 1,
148
+ resize: Optional[bool] = False,
149
+ shape: Optional[Tuple[int]] = None,
150
+ ) -> "OODDataset":
151
+ """Concatenate two OODDatasets. Useful for scoring on multiple datasets, or
152
+ training with added out-of-distribution data.
153
+
154
+ Args:
155
+ out_dataset (Union[OODDataset, DatasetType]): dataset of
156
+ out-of-distribution data
157
+ in_value (int): ood label value for in-distribution data. Defaults to 0
158
+ out_value (int): ood label value for out-of-distribution data. Defaults to 1
159
+ resize (Optional[bool], optional):toggles if input tensors of the
160
+ datasets have to be resized to have the same shape. Defaults to False.
161
+ shape (Optional[Tuple[int]], optional):shape to use for resizing input
162
+ tensors. If None, the tensors are resized with the shape of the
163
+ in_dataset input tensors. Defaults to None.
164
+
165
+ Returns:
166
+ OODDataset: a Dataset object with the concatenated data
167
+ """
168
+
169
+ # Creating an OODDataset object from out_dataset if necessary and make sure
170
+ # the two OODDatasets have compatible parameters
171
+ if isinstance(out_dataset, type(self)):
172
+ out_dataset = out_dataset.data
173
+ else:
174
+ out_dataset = OODDataset(out_dataset, backend=self.backend).data
175
+
176
+ # Assign the correct ood_label to self.data, depending on out_as_in
177
+ self.data = self._data_handler.assign_feature_value(
178
+ self.data, "ood_label", in_value
179
+ )
180
+ out_dataset = self._data_handler.assign_feature_value(
181
+ out_dataset, "ood_label", out_value
182
+ )
183
+
184
+ # Merge the two underlying Datasets
185
+ merge_kwargs = (
186
+ {"channel_order": self.channel_order}
187
+ if self.backend == "tensorflow"
188
+ else {}
189
+ )
190
+ data = self._data_handler.merge(
191
+ self.data,
192
+ out_dataset,
193
+ resize=resize,
194
+ shape=shape,
195
+ **merge_kwargs,
196
+ )
197
+
198
+ # Create a new OODDataset from the merged Dataset
199
+ output_ds = OODDataset(
200
+ dataset_id=data,
201
+ backend=self.backend,
202
+ )
203
+
204
+ return output_ds
205
+
206
+ def split_by_class(
207
+ self,
208
+ in_labels: Optional[Union[np.ndarray, list]] = None,
209
+ out_labels: Optional[Union[np.ndarray, list]] = None,
210
+ ) -> Optional[Tuple["OODDataset"]]:
211
+ """Filter the dataset by assigning ood labels depending on labels
212
+ value (typically, class id).
213
+
214
+ Args:
215
+ in_labels (Optional[Union[np.ndarray, list]], optional): set of labels
216
+ to be considered as in-distribution. Defaults to None.
217
+ out_labels (Optional[Union[np.ndarray, list]], optional): set of labels
218
+ to be considered as out-of-distribution. Defaults to None.
219
+
220
+ Returns:
221
+ Optional[Tuple[OODDataset]]: Tuple of in-distribution and
222
+ out-of-distribution OODDatasets
223
+ """
224
+ # Make sure the dataset has labels
225
+ assert (in_labels is not None) or (
226
+ out_labels is not None
227
+ ), "specify labels to filter with"
228
+ assert self.len_item >= 2, "the dataset has no labels"
229
+
230
+ # Filter the dataset depending on in_labels and out_labels given
231
+ if (out_labels is not None) and (in_labels is not None):
232
+ in_data = self._data_handler.filter_by_feature_value(
233
+ self.data, "label", in_labels
234
+ )
235
+ out_data = self._data_handler.filter_by_feature_value(
236
+ self.data, "label", out_labels
237
+ )
238
+
239
+ if out_labels is None:
240
+ in_data = self._data_handler.filter_by_feature_value(
241
+ self.data, "label", in_labels
242
+ )
243
+ out_data = self._data_handler.filter_by_feature_value(
244
+ self.data, "label", in_labels, excluded=True
245
+ )
246
+
247
+ elif in_labels is None:
248
+ in_data = self._data_handler.filter_by_feature_value(
249
+ self.data, "label", out_labels, excluded=True
250
+ )
251
+ out_data = self._data_handler.filter_by_feature_value(
252
+ self.data, "label", out_labels
253
+ )
254
+
255
+ # Return the filtered OODDatasets
256
+ return (
257
+ OODDataset(in_data, backend=self.backend),
258
+ OODDataset(out_data, backend=self.backend),
259
+ )
260
+
261
+ def prepare(
262
+ self,
263
+ batch_size: int = 128,
264
+ preprocess_fn: Optional[Callable] = None,
265
+ augment_fn: Optional[Callable] = None,
266
+ with_ood_labels: bool = False,
267
+ with_labels: bool = True,
268
+ shuffle: bool = False,
269
+ **kwargs_prepare,
270
+ ) -> DatasetType:
271
+ """Prepare self.data for scoring or training
272
+
273
+ Args:
274
+ batch_size (int, optional): Batch_size of the returned dataset like object.
275
+ Defaults to 128.
276
+ preprocess_fn (Callable, optional): Preprocessing function to apply to
277
+ the dataset. Defaults to None.
278
+ augment_fn (Callable, optional): Augment function to be used (when the
279
+ returned dataset is to be used for training). Defaults to None.
280
+ with_ood_labels (bool, optional): To return the dataset with ood_labels
281
+ or not. Defaults to True.
282
+ with_labels (bool, optional): To return the dataset with labels or not.
283
+ Defaults to True.
284
+ shuffle (bool, optional): To shuffle the returned dataset or not.
285
+ Defaults to False.
286
+ kwargs_prepare (dict): Additional parameters to be passed to the
287
+ data_handler.prepare_for_training method.
288
+
289
+
290
+ Returns:
291
+ DatasetType: prepared dataset
292
+ """
293
+ # Check if the dataset has at least one of label and ood_label
294
+ assert (
295
+ with_ood_labels or with_labels
296
+ ), "The dataset must have at least one of label and ood_label"
297
+
298
+ # Check if the dataset has ood_labels when asked to return with_ood_labels
299
+ if with_ood_labels:
300
+ assert (
301
+ self.has_ood_label
302
+ ), "Please assign ood labels before preparing with ood_labels"
303
+
304
+ dataset_to_prepare = self.data
305
+
306
+ # Making the dataset channel first if the backend is torch
307
+ if self.backend == "torch" and self.load_from_tensorflow_datasets:
308
+ dataset_to_prepare = self._data_handler.make_channel_first(
309
+ self.input_key, dataset_to_prepare
310
+ )
311
+
312
+ # # Select the keys to be returned
313
+ keys = [self.input_key, "label", "ood_label"]
314
+ if not with_labels:
315
+ keys.remove("label")
316
+ if not with_ood_labels:
317
+ keys.remove("ood_label")
318
+
319
+ # Prepare the dataset for training or scoring
320
+ dataset = self._data_handler.prepare_for_training(
321
+ dataset=dataset_to_prepare,
322
+ batch_size=batch_size,
323
+ shuffle=shuffle,
324
+ preprocess_fn=preprocess_fn,
325
+ augment_fn=augment_fn,
326
+ output_keys=keys,
327
+ **kwargs_prepare,
328
+ )
329
+
330
+ return dataset