oodeel 0.1.1__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of oodeel might be problematic. Click here for more details.
- oodeel/__init__.py +1 -1
- oodeel/datasets/__init__.py +2 -1
- oodeel/datasets/data_handler.py +162 -94
- oodeel/datasets/deprecated/DEPRECATED_data_handler.py +236 -0
- oodeel/datasets/{ooddataset.py → deprecated/DEPRECATED_ooddataset.py} +14 -13
- oodeel/datasets/deprecated/DEPRECATED_tf_data_handler.py +671 -0
- oodeel/datasets/deprecated/DEPRECATED_torch_data_handler.py +769 -0
- oodeel/datasets/deprecated/__init__.py +31 -0
- oodeel/datasets/tf_data_handler.py +105 -167
- oodeel/datasets/torch_data_handler.py +109 -181
- oodeel/eval/metrics.py +7 -2
- oodeel/eval/plots/features.py +2 -2
- oodeel/eval/plots/plotly.py +2 -2
- oodeel/extractor/feature_extractor.py +30 -9
- oodeel/extractor/keras_feature_extractor.py +70 -13
- oodeel/extractor/torch_feature_extractor.py +120 -33
- oodeel/methods/__init__.py +17 -1
- oodeel/methods/base.py +103 -17
- oodeel/methods/dknn.py +22 -9
- oodeel/methods/energy.py +8 -0
- oodeel/methods/entropy.py +8 -0
- oodeel/methods/gen.py +118 -0
- oodeel/methods/gram.py +307 -0
- oodeel/methods/mahalanobis.py +14 -12
- oodeel/methods/mls.py +8 -0
- oodeel/methods/odin.py +8 -0
- oodeel/methods/rmds.py +122 -0
- oodeel/methods/she.py +197 -0
- oodeel/methods/vim.py +5 -5
- oodeel/preprocess/__init__.py +31 -0
- oodeel/preprocess/tf_preprocess.py +95 -0
- oodeel/preprocess/torch_preprocess.py +97 -0
- oodeel/utils/operator.py +72 -2
- oodeel/utils/tf_operator.py +72 -4
- oodeel/utils/tf_training_tools.py +26 -3
- oodeel/utils/torch_operator.py +75 -4
- oodeel/utils/torch_training_tools.py +31 -2
- {oodeel-0.1.1.dist-info → oodeel-0.3.0.dist-info}/METADATA +141 -107
- oodeel-0.3.0.dist-info/RECORD +57 -0
- {oodeel-0.1.1.dist-info → oodeel-0.3.0.dist-info}/WHEEL +1 -1
- tests/tests_tensorflow/tf_methods_utils.py +2 -1
- tests/tests_torch/tools_torch.py +9 -9
- tests/tests_torch/torch_methods_utils.py +34 -27
- tests/tools_operator.py +10 -1
- oodeel-0.1.1.dist-info/RECORD +0 -46
- {oodeel-0.1.1.dist-info → oodeel-0.3.0.dist-info/licenses}/LICENSE +0 -0
- {oodeel-0.1.1.dist-info → oodeel-0.3.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,769 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Copyright IRT Antoine de Saint Exupéry et Université Paul Sabatier Toulouse III - All
|
|
3
|
+
# rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
|
|
4
|
+
# CRIAQ and ANITI - https://www.deel.ai/
|
|
5
|
+
#
|
|
6
|
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
7
|
+
# of this software and associated documentation files (the "Software"), to deal
|
|
8
|
+
# in the Software without restriction, including without limitation the rights
|
|
9
|
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
10
|
+
# copies of the Software, and to permit persons to whom the Software is
|
|
11
|
+
# furnished to do so, subject to the following conditions:
|
|
12
|
+
#
|
|
13
|
+
# The above copyright notice and this permission notice shall be included in all
|
|
14
|
+
# copies or substantial portions of the Software.
|
|
15
|
+
#
|
|
16
|
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
17
|
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
18
|
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
19
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
20
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
21
|
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
22
|
+
# SOFTWARE.
|
|
23
|
+
import copy
|
|
24
|
+
from typing import get_args
|
|
25
|
+
|
|
26
|
+
import numpy as np
|
|
27
|
+
import torch
|
|
28
|
+
import torch.nn.functional as F
|
|
29
|
+
import torchvision
|
|
30
|
+
from torch.utils.data import ConcatDataset
|
|
31
|
+
from torch.utils.data import DataLoader
|
|
32
|
+
from torch.utils.data import Dataset
|
|
33
|
+
from torch.utils.data import Subset
|
|
34
|
+
from torch.utils.data import TensorDataset
|
|
35
|
+
from torch.utils.data.dataloader import default_collate
|
|
36
|
+
|
|
37
|
+
from ...types import Any
|
|
38
|
+
from ...types import Callable
|
|
39
|
+
from ...types import ItemType
|
|
40
|
+
from ...types import List
|
|
41
|
+
from ...types import Optional
|
|
42
|
+
from ...types import TensorType
|
|
43
|
+
from ...types import Tuple
|
|
44
|
+
from ...types import Union
|
|
45
|
+
from .DEPRECATED_data_handler import DataHandler
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def dict_only_ds(ds_handling_method: Callable) -> Callable:
|
|
49
|
+
"""Decorator to ensure that the dataset is a dict dataset and that the input key
|
|
50
|
+
matches one of the feature keys. The signature of decorated functions
|
|
51
|
+
must be function(dataset, *args, **kwargs) with feature_key either in kwargs or
|
|
52
|
+
args[0] when relevant.
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
ds_handling_method: method to decorate
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
decorated method
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
def wrapper(dataset: Dataset, *args, **kwargs):
|
|
63
|
+
assert isinstance(
|
|
64
|
+
dataset, DictDataset
|
|
65
|
+
), "Dataset must be an instance of DictDataset"
|
|
66
|
+
|
|
67
|
+
if "feature_key" in kwargs:
|
|
68
|
+
feature_key = kwargs["feature_key"]
|
|
69
|
+
elif len(args) > 0:
|
|
70
|
+
feature_key = args[0]
|
|
71
|
+
|
|
72
|
+
# If feature_key is provided, check that it is in the dataset feature keys
|
|
73
|
+
if (len(args) > 0) or ("feature_key" in kwargs):
|
|
74
|
+
if isinstance(feature_key, str):
|
|
75
|
+
feature_key = [feature_key]
|
|
76
|
+
for key in feature_key:
|
|
77
|
+
assert (
|
|
78
|
+
key in dataset.output_keys
|
|
79
|
+
), f"The input dataset has no feature names {key}"
|
|
80
|
+
return ds_handling_method(dataset, *args, **kwargs)
|
|
81
|
+
|
|
82
|
+
return wrapper
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def to_torch(array: TensorType) -> torch.Tensor:
|
|
86
|
+
"""Convert an array into a torch Tensor
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
array (TensorType): array to convert
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
torch.Tensor: converted array
|
|
93
|
+
"""
|
|
94
|
+
if isinstance(array, np.ndarray):
|
|
95
|
+
return torch.Tensor(array)
|
|
96
|
+
elif isinstance(array, torch.Tensor):
|
|
97
|
+
return array
|
|
98
|
+
else:
|
|
99
|
+
raise TypeError("Input array must be of numpy or torch type")
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class DictDataset(Dataset):
|
|
103
|
+
r"""Dictionary pytorch dataset
|
|
104
|
+
|
|
105
|
+
Wrapper to output a dictionary of tensors at the __getitem__ call of a dataset.
|
|
106
|
+
Some mapping, filtering and concatenation methods are implemented to imitate
|
|
107
|
+
tensorflow datasets features.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
dataset (Dataset): Dataset to wrap.
|
|
111
|
+
output_keys (output_keys[str]): Keys describing the output tensors.
|
|
112
|
+
"""
|
|
113
|
+
|
|
114
|
+
def __init__(
|
|
115
|
+
self, dataset: Dataset, output_keys: List[str] = ["input", "label"]
|
|
116
|
+
) -> None:
|
|
117
|
+
self._dataset = dataset
|
|
118
|
+
self._raw_output_keys = output_keys
|
|
119
|
+
self.map_fns = []
|
|
120
|
+
self._check_init_args()
|
|
121
|
+
|
|
122
|
+
@property
|
|
123
|
+
def output_keys(self) -> list:
|
|
124
|
+
"""Get the list of keys in a dict-based item from the dataset.
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
list: feature keys of the dataset.
|
|
128
|
+
"""
|
|
129
|
+
dummy_item = self[0]
|
|
130
|
+
return list(dummy_item.keys())
|
|
131
|
+
|
|
132
|
+
@property
|
|
133
|
+
def output_shapes(self) -> list:
|
|
134
|
+
"""Get a list of the tensor shapes in an item from the dataset.
|
|
135
|
+
|
|
136
|
+
Returns:
|
|
137
|
+
list: tensor shapes of an dataset item.
|
|
138
|
+
"""
|
|
139
|
+
dummy_item = self[0]
|
|
140
|
+
return [dummy_item[key].shape for key in self.output_keys]
|
|
141
|
+
|
|
142
|
+
def _check_init_args(self) -> None:
|
|
143
|
+
"""Check validity of dataset and output keys provided at init"""
|
|
144
|
+
dummy_item = self._dataset[0]
|
|
145
|
+
assert isinstance(
|
|
146
|
+
dummy_item, (tuple, dict, list, torch.Tensor)
|
|
147
|
+
), "Dataset to be wrapped needs to return tuple, list or dict of tensors"
|
|
148
|
+
if isinstance(dummy_item, torch.Tensor):
|
|
149
|
+
dummy_item = [dummy_item]
|
|
150
|
+
assert len(dummy_item) == len(
|
|
151
|
+
self._raw_output_keys
|
|
152
|
+
), "Length mismatch between dataset item and provided keys"
|
|
153
|
+
|
|
154
|
+
def __getitem__(self, index: int) -> dict:
|
|
155
|
+
"""Return a dictionary of tensors corresponding to a specfic index.
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
index (int): the index of the item to retrieve.
|
|
159
|
+
|
|
160
|
+
Returns:
|
|
161
|
+
dict: tensors for the item at the specific index.
|
|
162
|
+
"""
|
|
163
|
+
item = self._dataset[index]
|
|
164
|
+
|
|
165
|
+
# convert item to a list / tuple of tensors
|
|
166
|
+
if isinstance(item, torch.Tensor):
|
|
167
|
+
tensors = [item]
|
|
168
|
+
elif isinstance(item, dict):
|
|
169
|
+
tensors = list(item.values())
|
|
170
|
+
else:
|
|
171
|
+
tensors = item
|
|
172
|
+
|
|
173
|
+
# build output dictionary
|
|
174
|
+
output_dict = {
|
|
175
|
+
key: tensor for (key, tensor) in zip(self._raw_output_keys, tensors)
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
# apply map functions
|
|
179
|
+
for map_fn in self.map_fns:
|
|
180
|
+
output_dict = map_fn(output_dict)
|
|
181
|
+
return output_dict
|
|
182
|
+
|
|
183
|
+
def map(self, map_fn: Callable, inplace: bool = False) -> "DictDataset":
|
|
184
|
+
"""Map the dataset
|
|
185
|
+
|
|
186
|
+
Args:
|
|
187
|
+
map_fn (Callable): map function f: dict -> dict
|
|
188
|
+
inplace (bool): if False, applies the mapping on a copied version of\
|
|
189
|
+
the dataset. Defaults to False.
|
|
190
|
+
|
|
191
|
+
Return:
|
|
192
|
+
DictDataset: Mapped dataset
|
|
193
|
+
"""
|
|
194
|
+
dataset = self if inplace else copy.deepcopy(self)
|
|
195
|
+
dataset.map_fns.append(map_fn)
|
|
196
|
+
return dataset
|
|
197
|
+
|
|
198
|
+
def filter(self, filter_fn: Callable, inplace: bool = False) -> "DictDataset":
|
|
199
|
+
"""Filter the dataset
|
|
200
|
+
|
|
201
|
+
Args:
|
|
202
|
+
filter_fn (Callable): filter function f: dict -> bool
|
|
203
|
+
inplace (bool): if False, applies the filtering on a copied version of\
|
|
204
|
+
the dataset. Defaults to False.
|
|
205
|
+
|
|
206
|
+
Returns:
|
|
207
|
+
DictDataset: Filtered dataset
|
|
208
|
+
"""
|
|
209
|
+
indices = [i for i in range(len(self)) if filter_fn(self[i])]
|
|
210
|
+
dataset = self if inplace else copy.deepcopy(self)
|
|
211
|
+
dataset._dataset = Subset(self._dataset, indices)
|
|
212
|
+
return dataset
|
|
213
|
+
|
|
214
|
+
def concatenate(
|
|
215
|
+
self, other_dataset: Dataset, inplace: bool = False
|
|
216
|
+
) -> "DictDataset":
|
|
217
|
+
"""Concatenate with another dataset
|
|
218
|
+
|
|
219
|
+
Args:
|
|
220
|
+
other_dataset (DictDataset): Dataset to concatenate with
|
|
221
|
+
inplace (bool): if False, applies the filtering on a copied version of\
|
|
222
|
+
the dataset. Defaults to False.
|
|
223
|
+
|
|
224
|
+
Returns:
|
|
225
|
+
DictDataset: Concatenated dataset
|
|
226
|
+
"""
|
|
227
|
+
assert isinstance(
|
|
228
|
+
other_dataset, DictDataset
|
|
229
|
+
), "Second dataset should be an instance of DictDataset"
|
|
230
|
+
assert (
|
|
231
|
+
self.output_keys == other_dataset.output_keys
|
|
232
|
+
), "Incompatible dataset elements (different dict keys)"
|
|
233
|
+
if inplace:
|
|
234
|
+
dataset_copy = copy.deepcopy(self)
|
|
235
|
+
self._raw_output_keys = self.output_keys
|
|
236
|
+
self.map_fns = []
|
|
237
|
+
self._dataset = ConcatDataset([dataset_copy, other_dataset])
|
|
238
|
+
dataset = self
|
|
239
|
+
else:
|
|
240
|
+
dataset = DictDataset(
|
|
241
|
+
ConcatDataset([self, other_dataset]), self.output_keys
|
|
242
|
+
)
|
|
243
|
+
return dataset
|
|
244
|
+
|
|
245
|
+
def __len__(self) -> int:
|
|
246
|
+
"""Return the length of the dataset, i.e. the number of items.
|
|
247
|
+
|
|
248
|
+
Returns:
|
|
249
|
+
int: length of the dataset.
|
|
250
|
+
"""
|
|
251
|
+
return len(self._dataset)
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
class TorchDataHandler(DataHandler):
|
|
255
|
+
"""
|
|
256
|
+
Class to manage torch DictDataset. The aim is to provide a simple interface
|
|
257
|
+
for working with torch datasets and manage them without having to use
|
|
258
|
+
torch syntax.
|
|
259
|
+
"""
|
|
260
|
+
|
|
261
|
+
@staticmethod
|
|
262
|
+
def _default_target_transform(y: Any) -> torch.Tensor:
|
|
263
|
+
"""Format int or float item target as a torch tensor
|
|
264
|
+
|
|
265
|
+
Args:
|
|
266
|
+
y (Any): dataset item target
|
|
267
|
+
|
|
268
|
+
Returns:
|
|
269
|
+
torch.Tensor: target as a torch.Tensor
|
|
270
|
+
"""
|
|
271
|
+
return torch.tensor(y) if isinstance(y, (float, int)) else y
|
|
272
|
+
|
|
273
|
+
DEFAULT_TRANSFORM = torchvision.transforms.PILToTensor()
|
|
274
|
+
DEFAULT_TARGET_TRANSFORM = _default_target_transform.__func__
|
|
275
|
+
|
|
276
|
+
@classmethod
|
|
277
|
+
def load_dataset(
|
|
278
|
+
cls,
|
|
279
|
+
dataset_id: Union[Dataset, ItemType, str],
|
|
280
|
+
keys: Optional[list] = None,
|
|
281
|
+
load_kwargs: dict = {},
|
|
282
|
+
) -> DictDataset:
|
|
283
|
+
"""Load dataset from different manners
|
|
284
|
+
|
|
285
|
+
Args:
|
|
286
|
+
dataset_id (Union[Dataset, ItemType, str]): dataset identification
|
|
287
|
+
keys (list, optional): Features keys. If None, assigned as "input_i"
|
|
288
|
+
for i-th feature. Defaults to None.
|
|
289
|
+
load_kwargs (dict, optional): Additional loading kwargs. Defaults to {}.
|
|
290
|
+
|
|
291
|
+
Returns:
|
|
292
|
+
DictDataset: dataset
|
|
293
|
+
"""
|
|
294
|
+
if isinstance(dataset_id, str):
|
|
295
|
+
assert "root" in load_kwargs.keys()
|
|
296
|
+
dataset = cls.load_from_torchvision(dataset_id, **load_kwargs)
|
|
297
|
+
elif isinstance(dataset_id, Dataset):
|
|
298
|
+
dataset = cls.load_custom_dataset(dataset_id, keys)
|
|
299
|
+
elif isinstance(dataset_id, get_args(ItemType)):
|
|
300
|
+
dataset = cls.load_dataset_from_arrays(dataset_id, keys)
|
|
301
|
+
return dataset
|
|
302
|
+
|
|
303
|
+
@staticmethod
|
|
304
|
+
def load_dataset_from_arrays(
|
|
305
|
+
dataset_id: ItemType,
|
|
306
|
+
keys: Optional[list] = None,
|
|
307
|
+
) -> DictDataset:
|
|
308
|
+
"""Load a torch.utils.data.Dataset from an array or a tuple/dict of arrays.
|
|
309
|
+
|
|
310
|
+
Args:
|
|
311
|
+
dataset_id (ItemType):
|
|
312
|
+
numpy / torch array(s) to load.
|
|
313
|
+
keys (list, optional): Features keys. If None, assigned as "input_i"
|
|
314
|
+
for i-th feature. Defaults to None.
|
|
315
|
+
|
|
316
|
+
Returns:
|
|
317
|
+
DictDataset: dataset
|
|
318
|
+
"""
|
|
319
|
+
# If dataset_id is an array
|
|
320
|
+
if isinstance(dataset_id, get_args(TensorType)):
|
|
321
|
+
tensors = tuple(to_torch(dataset_id))
|
|
322
|
+
output_keys = keys or ["input"]
|
|
323
|
+
|
|
324
|
+
# If dataset_id is a tuple of arrays
|
|
325
|
+
elif isinstance(dataset_id, tuple):
|
|
326
|
+
len_elem = len(dataset_id)
|
|
327
|
+
output_keys = keys
|
|
328
|
+
if output_keys is None:
|
|
329
|
+
if len_elem == 2:
|
|
330
|
+
output_keys = ["input", "label"]
|
|
331
|
+
else:
|
|
332
|
+
output_keys = [f"input_{i}" for i in range(len_elem - 1)] + [
|
|
333
|
+
"label"
|
|
334
|
+
]
|
|
335
|
+
print(
|
|
336
|
+
"Loading torch.utils.data.Dataset with elems as dicts, "
|
|
337
|
+
'assigning "input_i" key to the i-th tuple dimension and'
|
|
338
|
+
' "label" key to the last tuple dimension.'
|
|
339
|
+
)
|
|
340
|
+
assert len(output_keys) == len(dataset_id)
|
|
341
|
+
tensors = tuple(to_torch(array) for array in dataset_id)
|
|
342
|
+
|
|
343
|
+
# If dataset_id is a dictionary of arrays
|
|
344
|
+
elif isinstance(dataset_id, dict):
|
|
345
|
+
output_keys = keys or list(dataset_id.keys())
|
|
346
|
+
assert len(output_keys) == len(dataset_id)
|
|
347
|
+
tensors = tuple(to_torch(array) for array in dataset_id.values())
|
|
348
|
+
|
|
349
|
+
# create torch dictionary dataset from tensors tuple and keys
|
|
350
|
+
dataset = DictDataset(TensorDataset(*tensors), output_keys)
|
|
351
|
+
return dataset
|
|
352
|
+
|
|
353
|
+
@staticmethod
|
|
354
|
+
def load_custom_dataset(
|
|
355
|
+
dataset_id: Dataset, keys: Optional[list] = None
|
|
356
|
+
) -> DictDataset:
|
|
357
|
+
"""Load a custom Dataset by ensuring it has the correct format (dict-based)
|
|
358
|
+
|
|
359
|
+
Args:
|
|
360
|
+
dataset_id (Dataset): Dataset
|
|
361
|
+
keys (list, optional): Keys to use for features if dataset_id is
|
|
362
|
+
tuple based. Defaults to None.
|
|
363
|
+
|
|
364
|
+
Returns:
|
|
365
|
+
DictDataset
|
|
366
|
+
"""
|
|
367
|
+
# If dataset_id is a tuple based Dataset, convert it to a DictDataset
|
|
368
|
+
dummy_item = dataset_id[0]
|
|
369
|
+
if not isinstance(dummy_item, dict):
|
|
370
|
+
assert isinstance(
|
|
371
|
+
dummy_item, (Tuple, torch.Tensor)
|
|
372
|
+
), "Custom dataset should be either dictionary based or tuple based"
|
|
373
|
+
output_keys = keys
|
|
374
|
+
if output_keys is None:
|
|
375
|
+
len_elem = len(dummy_item)
|
|
376
|
+
if len_elem == 2:
|
|
377
|
+
output_keys = ["input", "label"]
|
|
378
|
+
else:
|
|
379
|
+
output_keys = [f"input_{i}" for i in range(len_elem - 1)] + [
|
|
380
|
+
"label"
|
|
381
|
+
]
|
|
382
|
+
print(
|
|
383
|
+
"Feature name not found, assigning 'input_i' "
|
|
384
|
+
"key to the i-th tensor and 'label' key to the last"
|
|
385
|
+
)
|
|
386
|
+
dataset_id = DictDataset(dataset_id, output_keys)
|
|
387
|
+
|
|
388
|
+
dataset = dataset_id
|
|
389
|
+
return dataset
|
|
390
|
+
|
|
391
|
+
@classmethod
|
|
392
|
+
def load_from_torchvision(
|
|
393
|
+
cls,
|
|
394
|
+
dataset_id: str,
|
|
395
|
+
root: str,
|
|
396
|
+
transform: Callable = DEFAULT_TRANSFORM,
|
|
397
|
+
target_transform: Callable = DEFAULT_TARGET_TRANSFORM,
|
|
398
|
+
download: bool = False,
|
|
399
|
+
**load_kwargs,
|
|
400
|
+
) -> DictDataset:
|
|
401
|
+
"""Load a Dataset from the torchvision datasets catalog
|
|
402
|
+
|
|
403
|
+
Args:
|
|
404
|
+
dataset_id (str): Identifier of the dataset
|
|
405
|
+
root (str): Root directory of dataset
|
|
406
|
+
transform (Callable, optional): Transform function to apply to the input.
|
|
407
|
+
Defaults to DEFAULT_TRANSFORM.
|
|
408
|
+
target_transform (Callable, optional): Transform function to apply
|
|
409
|
+
to the target. Defaults to DEFAULT_TARGET_TRANSFORM.
|
|
410
|
+
download (bool): If true, downloads the dataset from the internet and puts
|
|
411
|
+
it in root directory. If dataset is already downloaded, it is not
|
|
412
|
+
downloaded again. Defaults to False.
|
|
413
|
+
load_kwargs (dict): Loading kwargs to add to the initialization
|
|
414
|
+
of dataset.
|
|
415
|
+
|
|
416
|
+
Returns:
|
|
417
|
+
DictDataset: dataset
|
|
418
|
+
"""
|
|
419
|
+
assert (
|
|
420
|
+
dataset_id in torchvision.datasets.__all__
|
|
421
|
+
), "Dataset not available on torchvision datasets catalog"
|
|
422
|
+
dataset = getattr(torchvision.datasets, dataset_id)(
|
|
423
|
+
root=root,
|
|
424
|
+
download=download,
|
|
425
|
+
transform=transform,
|
|
426
|
+
target_transform=target_transform,
|
|
427
|
+
**load_kwargs,
|
|
428
|
+
)
|
|
429
|
+
return cls.load_custom_dataset(dataset)
|
|
430
|
+
|
|
431
|
+
@staticmethod
|
|
432
|
+
def assign_feature_value(
|
|
433
|
+
dataset: DictDataset, feature_key: str, value: int
|
|
434
|
+
) -> DictDataset:
|
|
435
|
+
"""Assign a value to a feature for every sample in a DictDataset
|
|
436
|
+
|
|
437
|
+
Args:
|
|
438
|
+
dataset (DictDataset): DictDataset to assign the value to
|
|
439
|
+
feature_key (str): Feature to assign the value to
|
|
440
|
+
value (int): Value to assign
|
|
441
|
+
|
|
442
|
+
Returns:
|
|
443
|
+
DictDataset
|
|
444
|
+
"""
|
|
445
|
+
assert isinstance(
|
|
446
|
+
dataset, DictDataset
|
|
447
|
+
), "Dataset must be an instance of DictDataset"
|
|
448
|
+
|
|
449
|
+
def assign_value_to_feature(x):
|
|
450
|
+
x[feature_key] = torch.tensor(value)
|
|
451
|
+
return x
|
|
452
|
+
|
|
453
|
+
dataset = dataset.map(assign_value_to_feature)
|
|
454
|
+
return dataset
|
|
455
|
+
|
|
456
|
+
@staticmethod
|
|
457
|
+
@dict_only_ds
|
|
458
|
+
def get_feature_from_ds(dataset: DictDataset, feature_key: str) -> np.ndarray:
|
|
459
|
+
"""Get a feature from a DictDataset
|
|
460
|
+
|
|
461
|
+
!!! note
|
|
462
|
+
This function can be a bit time consuming since it needs to iterate
|
|
463
|
+
over the whole dataset.
|
|
464
|
+
|
|
465
|
+
Args:
|
|
466
|
+
dataset (DictDataset): Dataset to get the feature from
|
|
467
|
+
feature_key (str): Feature value to get
|
|
468
|
+
|
|
469
|
+
Returns:
|
|
470
|
+
np.ndarray: Feature values for dataset
|
|
471
|
+
"""
|
|
472
|
+
|
|
473
|
+
features = dataset.map(lambda x: x[feature_key])
|
|
474
|
+
features = np.stack([f.numpy() for f in features])
|
|
475
|
+
return features
|
|
476
|
+
|
|
477
|
+
@staticmethod
|
|
478
|
+
@dict_only_ds
|
|
479
|
+
def get_ds_feature_keys(dataset: DictDataset) -> list:
|
|
480
|
+
"""Get the feature keys of a DictDataset
|
|
481
|
+
|
|
482
|
+
Args:
|
|
483
|
+
dataset (DictDataset): Dataset to get the feature keys from
|
|
484
|
+
|
|
485
|
+
Returns:
|
|
486
|
+
list: List of feature keys
|
|
487
|
+
"""
|
|
488
|
+
return dataset.output_keys
|
|
489
|
+
|
|
490
|
+
@staticmethod
|
|
491
|
+
def has_feature_key(dataset: DictDataset, key: str) -> bool:
|
|
492
|
+
"""Check if a DictDataset has a feature denoted by key
|
|
493
|
+
|
|
494
|
+
Args:
|
|
495
|
+
dataset (DictDataset): Dataset to check
|
|
496
|
+
key (str): Key to check
|
|
497
|
+
|
|
498
|
+
Returns:
|
|
499
|
+
bool: If the dataset has a feature denoted by key
|
|
500
|
+
"""
|
|
501
|
+
assert isinstance(
|
|
502
|
+
dataset, DictDataset
|
|
503
|
+
), "Dataset must be an instance of DictDataset"
|
|
504
|
+
|
|
505
|
+
return key in dataset.output_keys
|
|
506
|
+
|
|
507
|
+
@staticmethod
|
|
508
|
+
def map_ds(
|
|
509
|
+
dataset: DictDataset,
|
|
510
|
+
map_fn: Callable,
|
|
511
|
+
) -> DictDataset:
|
|
512
|
+
"""Map a function to a DictDataset
|
|
513
|
+
|
|
514
|
+
Args:
|
|
515
|
+
dataset (DictDataset): Dataset to map the function to
|
|
516
|
+
map_fn (Callable): Function to map
|
|
517
|
+
|
|
518
|
+
Returns:
|
|
519
|
+
DictDataset: Mapped dataset
|
|
520
|
+
"""
|
|
521
|
+
return dataset.map(map_fn)
|
|
522
|
+
|
|
523
|
+
@staticmethod
|
|
524
|
+
@dict_only_ds
|
|
525
|
+
def filter_by_feature_value(
|
|
526
|
+
dataset: DictDataset,
|
|
527
|
+
feature_key: str,
|
|
528
|
+
values: list,
|
|
529
|
+
excluded: bool = False,
|
|
530
|
+
) -> DictDataset:
|
|
531
|
+
"""Filter the dataset by checking the value of a feature is in `values`
|
|
532
|
+
|
|
533
|
+
!!! note
|
|
534
|
+
This function can be a bit of time consuming since it needs to iterate
|
|
535
|
+
over the whole dataset.
|
|
536
|
+
|
|
537
|
+
Args:
|
|
538
|
+
dataset (DictDataset): Dataset to filter
|
|
539
|
+
feature_key (str): Feature name to check the value
|
|
540
|
+
values (list): Feature_key values to keep
|
|
541
|
+
excluded (bool, optional): To keep (False) or exclude (True) the samples
|
|
542
|
+
with Feature_key value included in Values. Defaults to False.
|
|
543
|
+
|
|
544
|
+
Returns:
|
|
545
|
+
DictDataset: Filtered dataset
|
|
546
|
+
"""
|
|
547
|
+
|
|
548
|
+
if len(dataset[0][feature_key].shape) > 0:
|
|
549
|
+
value_dim = dataset[0][feature_key].shape[-1]
|
|
550
|
+
values = [
|
|
551
|
+
F.one_hot(torch.tensor(value).long(), value_dim) for value in values
|
|
552
|
+
]
|
|
553
|
+
|
|
554
|
+
def filter_fn(x):
|
|
555
|
+
keep = any([torch.all(x[feature_key] == v) for v in values])
|
|
556
|
+
return keep if not excluded else not keep
|
|
557
|
+
|
|
558
|
+
filtered_dataset = dataset.filter(filter_fn)
|
|
559
|
+
return filtered_dataset
|
|
560
|
+
|
|
561
|
+
@classmethod
|
|
562
|
+
def prepare_for_training(
|
|
563
|
+
cls,
|
|
564
|
+
dataset: DictDataset,
|
|
565
|
+
batch_size: int,
|
|
566
|
+
shuffle: bool = False,
|
|
567
|
+
preprocess_fn: Optional[Callable] = None,
|
|
568
|
+
augment_fn: Optional[Callable] = None,
|
|
569
|
+
output_keys: Optional[list] = None,
|
|
570
|
+
dict_based_fns: bool = False,
|
|
571
|
+
shuffle_buffer_size: Optional[int] = None,
|
|
572
|
+
num_workers: int = 8,
|
|
573
|
+
) -> DataLoader:
|
|
574
|
+
"""Prepare a DataLoader for training
|
|
575
|
+
|
|
576
|
+
Args:
|
|
577
|
+
dataset (DictDataset): Dataset to prepare
|
|
578
|
+
batch_size (int): Batch size
|
|
579
|
+
shuffle (bool): Wether to shuffle the dataloader or not
|
|
580
|
+
preprocess_fn (Callable, optional): Preprocessing function to apply to
|
|
581
|
+
the dataset. Defaults to None.
|
|
582
|
+
augment_fn (Callable, optional): Augment function to be used (when the
|
|
583
|
+
returned dataset is to be used for training). Defaults to None.
|
|
584
|
+
output_keys (list): List of keys corresponding to the features that will be
|
|
585
|
+
returned. Keep all features if None. Defaults to None.
|
|
586
|
+
dict_based_fns (bool): Whether to use preprocess and DA functions as dict
|
|
587
|
+
based (if True) or as tuple based (if False). Defaults to False.
|
|
588
|
+
shuffle_buffer_size (int, optional): Size of the shuffle buffer. Not used
|
|
589
|
+
in torch because we only rely on Map-Style datasets. Still as argument
|
|
590
|
+
for API consistency. Defaults to None.
|
|
591
|
+
num_workers (int, optional): Number of workers to use for the dataloader.
|
|
592
|
+
|
|
593
|
+
Returns:
|
|
594
|
+
DataLoader: dataloader
|
|
595
|
+
"""
|
|
596
|
+
output_keys = output_keys or cls.get_ds_feature_keys(dataset)
|
|
597
|
+
|
|
598
|
+
def collate_fn(batch: List[dict]):
|
|
599
|
+
if dict_based_fns:
|
|
600
|
+
# preprocess + DA: List[dict] -> List[dict]
|
|
601
|
+
preprocess_func = preprocess_fn or (lambda x: x)
|
|
602
|
+
augment_func = augment_fn or (lambda x: x)
|
|
603
|
+
batch = [augment_func(preprocess_func(d)) for d in batch]
|
|
604
|
+
# to tuple of batchs
|
|
605
|
+
return tuple(
|
|
606
|
+
default_collate([d[key] for d in batch]) for key in output_keys
|
|
607
|
+
)
|
|
608
|
+
else:
|
|
609
|
+
# preprocess + DA: List[dict] -> List[tuple]
|
|
610
|
+
preprocess_func = preprocess_fn or (lambda *x: x)
|
|
611
|
+
augment_func = augment_fn or (lambda *x: x)
|
|
612
|
+
batch = [
|
|
613
|
+
augment_func(
|
|
614
|
+
*preprocess_func(*tuple(d[key] for key in output_keys))
|
|
615
|
+
)
|
|
616
|
+
for d in batch
|
|
617
|
+
]
|
|
618
|
+
# to tuple of batchs
|
|
619
|
+
return default_collate(batch)
|
|
620
|
+
|
|
621
|
+
loader = DataLoader(
|
|
622
|
+
dataset,
|
|
623
|
+
batch_size=batch_size,
|
|
624
|
+
shuffle=shuffle,
|
|
625
|
+
collate_fn=collate_fn,
|
|
626
|
+
num_workers=num_workers,
|
|
627
|
+
)
|
|
628
|
+
return loader
|
|
629
|
+
|
|
630
|
+
@staticmethod
|
|
631
|
+
def merge(
|
|
632
|
+
id_dataset: DictDataset,
|
|
633
|
+
ood_dataset: DictDataset,
|
|
634
|
+
resize: Optional[bool] = False,
|
|
635
|
+
shape: Optional[Tuple[int]] = None,
|
|
636
|
+
) -> DictDataset:
|
|
637
|
+
"""Merge two instances of DictDataset
|
|
638
|
+
|
|
639
|
+
Args:
|
|
640
|
+
id_dataset (DictDataset): dataset of in-distribution data
|
|
641
|
+
ood_dataset (DictDataset): dataset of out-of-distribution data
|
|
642
|
+
resize (Optional[bool], optional): toggles if input tensors of the
|
|
643
|
+
datasets have to be resized to have the same shape. Defaults to True.
|
|
644
|
+
shape (Optional[Tuple[int]], optional): shape to use for resizing input
|
|
645
|
+
tensors. If None, the tensors are resized with the shape of the
|
|
646
|
+
id_dataset input tensors. Defaults to None.
|
|
647
|
+
|
|
648
|
+
Returns:
|
|
649
|
+
DictDataset: merged dataset
|
|
650
|
+
"""
|
|
651
|
+
# If a desired shape is given, triggers the resize
|
|
652
|
+
if shape is not None:
|
|
653
|
+
resize = True
|
|
654
|
+
|
|
655
|
+
# If the shape of the two datasets are different, triggers the resize
|
|
656
|
+
if id_dataset.output_shapes[0] != ood_dataset.output_shapes[0]:
|
|
657
|
+
resize = True
|
|
658
|
+
if shape is None:
|
|
659
|
+
print(
|
|
660
|
+
"Resizing the first item of elem (usually the image)",
|
|
661
|
+
" with the shape of id_dataset",
|
|
662
|
+
)
|
|
663
|
+
shape = id_dataset.output_shapes[0][1:]
|
|
664
|
+
|
|
665
|
+
if resize:
|
|
666
|
+
resize_fn = torchvision.transforms.Resize(shape)
|
|
667
|
+
|
|
668
|
+
def reshape_fn(item_dict):
|
|
669
|
+
item_dict["input"] = resize_fn(item_dict["input"])
|
|
670
|
+
return item_dict
|
|
671
|
+
|
|
672
|
+
id_dataset = id_dataset.map(reshape_fn)
|
|
673
|
+
ood_dataset = ood_dataset.map(reshape_fn)
|
|
674
|
+
|
|
675
|
+
merged_dataset = id_dataset.concatenate(ood_dataset)
|
|
676
|
+
return merged_dataset
|
|
677
|
+
|
|
678
|
+
@staticmethod
|
|
679
|
+
def get_item_length(dataset: Dataset) -> int:
|
|
680
|
+
"""Number of elements in a dataset item
|
|
681
|
+
|
|
682
|
+
Args:
|
|
683
|
+
dataset (DictDataset): Dataset
|
|
684
|
+
|
|
685
|
+
Returns:
|
|
686
|
+
int: Item length
|
|
687
|
+
"""
|
|
688
|
+
return len(dataset[0])
|
|
689
|
+
|
|
690
|
+
@staticmethod
|
|
691
|
+
def get_dataset_length(dataset: Dataset) -> int:
|
|
692
|
+
"""Number of items in a dataset
|
|
693
|
+
|
|
694
|
+
Args:
|
|
695
|
+
dataset (DictDataset): Dataset
|
|
696
|
+
|
|
697
|
+
Returns:
|
|
698
|
+
int: Dataset length
|
|
699
|
+
"""
|
|
700
|
+
return len(dataset)
|
|
701
|
+
|
|
702
|
+
@staticmethod
|
|
703
|
+
def get_feature_shape(dataset: Dataset, feature_key: Union[str, int]) -> tuple:
|
|
704
|
+
"""Get the shape of a feature of dataset identified by feature_key
|
|
705
|
+
|
|
706
|
+
Args:
|
|
707
|
+
dataset (Dataset): a Dataset
|
|
708
|
+
feature_key (Union[str, int]): The identifier of the feature
|
|
709
|
+
|
|
710
|
+
Returns:
|
|
711
|
+
tuple: the shape of feature_id
|
|
712
|
+
"""
|
|
713
|
+
return tuple(dataset[0][feature_key].shape)
|
|
714
|
+
|
|
715
|
+
@staticmethod
|
|
716
|
+
def get_input_from_dataset_item(elem: ItemType) -> Any:
|
|
717
|
+
"""Get the tensor that is to be feed as input to a model from a dataset element.
|
|
718
|
+
|
|
719
|
+
Args:
|
|
720
|
+
elem (ItemType): dataset element to extract input from
|
|
721
|
+
|
|
722
|
+
Returns:
|
|
723
|
+
Any: Input tensor
|
|
724
|
+
"""
|
|
725
|
+
if isinstance(elem, (tuple, list)):
|
|
726
|
+
tensor = elem[0]
|
|
727
|
+
elif isinstance(elem, dict):
|
|
728
|
+
tensor = elem[list(elem.keys())[0]]
|
|
729
|
+
else:
|
|
730
|
+
tensor = elem
|
|
731
|
+
return tensor
|
|
732
|
+
|
|
733
|
+
@staticmethod
|
|
734
|
+
def get_label_from_dataset_item(item: ItemType):
|
|
735
|
+
"""Retrieve label tensor from item as a tuple/list. Label must be at index 1
|
|
736
|
+
in the item tuple. If one-hot encoded, labels are converted to single value.
|
|
737
|
+
|
|
738
|
+
Args:
|
|
739
|
+
elem (ItemType): dataset element to extract label from
|
|
740
|
+
|
|
741
|
+
Returns:
|
|
742
|
+
Any: Label tensor
|
|
743
|
+
"""
|
|
744
|
+
label = item[1] # labels must be at index 1 in the batch tuple
|
|
745
|
+
# If labels are one-hot encoded, take the argmax
|
|
746
|
+
if len(label.shape) > 1 and label.shape[1] > 1:
|
|
747
|
+
label = label.view(label.size(0), -1)
|
|
748
|
+
label = torch.argmax(label, dim=1)
|
|
749
|
+
# If labels are in two dimensions, squeeze them
|
|
750
|
+
if len(label.shape) > 1:
|
|
751
|
+
label = label.view([label.shape[0]])
|
|
752
|
+
return label
|
|
753
|
+
|
|
754
|
+
@staticmethod
|
|
755
|
+
def get_feature(dataset: DictDataset, feature_key: Union[str, int]) -> DictDataset:
|
|
756
|
+
"""Extract a feature from a dataset
|
|
757
|
+
|
|
758
|
+
Args:
|
|
759
|
+
dataset (tf.data.Dataset): Dataset to extract the feature from
|
|
760
|
+
feature_key (Union[str, int]): feature to extract
|
|
761
|
+
|
|
762
|
+
Returns:
|
|
763
|
+
tf.data.Dataset: dataset built with the extracted feature only
|
|
764
|
+
"""
|
|
765
|
+
|
|
766
|
+
def _get_feature_item(item):
|
|
767
|
+
return item[feature_key]
|
|
768
|
+
|
|
769
|
+
return dataset.map(_get_feature_item)
|