oodeel 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- oodeel/__init__.py +28 -0
- oodeel/aggregator/__init__.py +26 -0
- oodeel/aggregator/base.py +70 -0
- oodeel/aggregator/fisher.py +259 -0
- oodeel/aggregator/mean.py +72 -0
- oodeel/aggregator/std.py +86 -0
- oodeel/datasets/__init__.py +24 -0
- oodeel/datasets/data_handler.py +334 -0
- oodeel/datasets/deprecated/DEPRECATED_data_handler.py +236 -0
- oodeel/datasets/deprecated/DEPRECATED_ooddataset.py +330 -0
- oodeel/datasets/deprecated/DEPRECATED_tf_data_handler.py +671 -0
- oodeel/datasets/deprecated/DEPRECATED_torch_data_handler.py +769 -0
- oodeel/datasets/deprecated/__init__.py +31 -0
- oodeel/datasets/tf_data_handler.py +600 -0
- oodeel/datasets/torch_data_handler.py +672 -0
- oodeel/eval/__init__.py +22 -0
- oodeel/eval/metrics.py +218 -0
- oodeel/eval/plots/__init__.py +27 -0
- oodeel/eval/plots/features.py +345 -0
- oodeel/eval/plots/metrics.py +118 -0
- oodeel/eval/plots/plotly.py +162 -0
- oodeel/extractor/__init__.py +35 -0
- oodeel/extractor/feature_extractor.py +187 -0
- oodeel/extractor/hf_torch_feature_extractor.py +184 -0
- oodeel/extractor/keras_feature_extractor.py +409 -0
- oodeel/extractor/torch_feature_extractor.py +506 -0
- oodeel/methods/__init__.py +47 -0
- oodeel/methods/base.py +570 -0
- oodeel/methods/dknn.py +185 -0
- oodeel/methods/energy.py +119 -0
- oodeel/methods/entropy.py +113 -0
- oodeel/methods/gen.py +113 -0
- oodeel/methods/gram.py +274 -0
- oodeel/methods/mahalanobis.py +209 -0
- oodeel/methods/mls.py +113 -0
- oodeel/methods/odin.py +109 -0
- oodeel/methods/rmds.py +172 -0
- oodeel/methods/she.py +159 -0
- oodeel/methods/vim.py +273 -0
- oodeel/preprocess/__init__.py +31 -0
- oodeel/preprocess/tf_preprocess.py +95 -0
- oodeel/preprocess/torch_preprocess.py +97 -0
- oodeel/types/__init__.py +75 -0
- oodeel/utils/__init__.py +38 -0
- oodeel/utils/general_utils.py +97 -0
- oodeel/utils/operator.py +253 -0
- oodeel/utils/tf_operator.py +269 -0
- oodeel/utils/tf_training_tools.py +219 -0
- oodeel/utils/torch_operator.py +292 -0
- oodeel/utils/torch_training_tools.py +303 -0
- oodeel-0.4.0.dist-info/METADATA +409 -0
- oodeel-0.4.0.dist-info/RECORD +63 -0
- oodeel-0.4.0.dist-info/WHEEL +5 -0
- oodeel-0.4.0.dist-info/licenses/LICENSE +21 -0
- oodeel-0.4.0.dist-info/top_level.txt +2 -0
- tests/__init__.py +22 -0
- tests/tests_tensorflow/__init__.py +37 -0
- tests/tests_tensorflow/tf_methods_utils.py +140 -0
- tests/tests_tensorflow/tools_tf.py +86 -0
- tests/tests_torch/__init__.py +38 -0
- tests/tests_torch/tools_torch.py +151 -0
- tests/tests_torch/torch_methods_utils.py +148 -0
- tests/tools_operator.py +153 -0
|
@@ -0,0 +1,672 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Copyright IRT Antoine de Saint Exupéry et Université Paul Sabatier Toulouse III - All
|
|
3
|
+
# rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
|
|
4
|
+
# CRIAQ and ANITI - https://www.deel.ai/
|
|
5
|
+
#
|
|
6
|
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
7
|
+
# of this software and associated documentation files (the "Software"), to deal
|
|
8
|
+
# in the Software without restriction, including without limitation the rights
|
|
9
|
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
10
|
+
# copies of the Software, and to permit persons to whom the Software is
|
|
11
|
+
# furnished to do so, subject to the following conditions:
|
|
12
|
+
#
|
|
13
|
+
# The above copyright notice and this permission notice shall be included in all
|
|
14
|
+
# copies or substantial portions of the Software.
|
|
15
|
+
#
|
|
16
|
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
17
|
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
18
|
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
19
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
20
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
21
|
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
22
|
+
# SOFTWARE.
|
|
23
|
+
import copy
|
|
24
|
+
from typing import get_args
|
|
25
|
+
|
|
26
|
+
import numpy as np
|
|
27
|
+
import torch
|
|
28
|
+
import torch.nn.functional as F
|
|
29
|
+
import torchvision
|
|
30
|
+
from datasets import load_dataset as hf_load_dataset
|
|
31
|
+
from torch.utils.data import DataLoader
|
|
32
|
+
from torch.utils.data import Dataset
|
|
33
|
+
from torch.utils.data import Subset
|
|
34
|
+
from torch.utils.data import TensorDataset
|
|
35
|
+
from torch.utils.data.dataloader import default_collate
|
|
36
|
+
|
|
37
|
+
from ..types import Any
|
|
38
|
+
from ..types import Callable
|
|
39
|
+
from ..types import ItemType
|
|
40
|
+
from ..types import List
|
|
41
|
+
from ..types import Optional
|
|
42
|
+
from ..types import TensorType
|
|
43
|
+
from ..types import Tuple
|
|
44
|
+
from ..types import Union
|
|
45
|
+
from .data_handler import DataHandler
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def dict_only_ds(ds_handling_method: Callable) -> Callable:
|
|
49
|
+
"""Decorator to ensure that the dataset is a dict dataset and that the column_name
|
|
50
|
+
given as argument matches one of the column names.
|
|
51
|
+
matches one of the column names. The signature of decorated functions
|
|
52
|
+
must be function(dataset, *args, **kwargs) with column_name either in kwargs or
|
|
53
|
+
args[0] when relevant.
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
ds_handling_method: method to decorate
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
decorated method
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
def wrapper(dataset: Dataset, *args, **kwargs):
|
|
64
|
+
assert isinstance(
|
|
65
|
+
dataset[0], dict
|
|
66
|
+
), "Dataset must be an instance of DictDataset"
|
|
67
|
+
|
|
68
|
+
if "column_name" in kwargs:
|
|
69
|
+
column_name = kwargs["column_name"]
|
|
70
|
+
elif len(args) > 0:
|
|
71
|
+
column_name = args[0]
|
|
72
|
+
|
|
73
|
+
# If column_name is provided, check that it is in the dataset column names
|
|
74
|
+
if (len(args) > 0) or ("column_name" in kwargs):
|
|
75
|
+
if isinstance(column_name, str):
|
|
76
|
+
column_name = [column_name]
|
|
77
|
+
for name in column_name:
|
|
78
|
+
assert (
|
|
79
|
+
name in dataset.column_names
|
|
80
|
+
), f"The input dataset has no column named {name}"
|
|
81
|
+
return ds_handling_method(dataset, *args, **kwargs)
|
|
82
|
+
|
|
83
|
+
return wrapper
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def to_torch(array: TensorType) -> torch.Tensor:
|
|
87
|
+
"""Convert an array into a torch Tensor
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
array (TensorType): array to convert
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
torch.Tensor: converted array
|
|
94
|
+
"""
|
|
95
|
+
if isinstance(array, np.ndarray):
|
|
96
|
+
return torch.Tensor(array)
|
|
97
|
+
elif isinstance(array, torch.Tensor):
|
|
98
|
+
return array
|
|
99
|
+
else:
|
|
100
|
+
raise TypeError("Input array must be of numpy or torch type")
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
class DictDataset(Dataset):
|
|
104
|
+
r"""Dictionary pytorch dataset
|
|
105
|
+
|
|
106
|
+
Wrapper to output a dictionary of tensors at the __getitem__ call of a dataset.
|
|
107
|
+
Some mapping, filtering and concatenation methods are implemented to imitate
|
|
108
|
+
tensorflow datasets features.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
dataset (Dataset): Dataset to wrap.
|
|
112
|
+
columns (columns[str]): Column names describing the output tensors.
|
|
113
|
+
"""
|
|
114
|
+
|
|
115
|
+
def __init__(
|
|
116
|
+
self, dataset: Dataset, column_names: List[str] = ["input", "label"]
|
|
117
|
+
) -> None:
|
|
118
|
+
self._dataset = dataset
|
|
119
|
+
self._raw_columns = column_names
|
|
120
|
+
self.map_fns = []
|
|
121
|
+
self._check_init_args()
|
|
122
|
+
|
|
123
|
+
@property
|
|
124
|
+
def column_names(self) -> list:
|
|
125
|
+
"""Get the list of columns in a dict-based item from the dataset.
|
|
126
|
+
|
|
127
|
+
Returns:
|
|
128
|
+
list: column names of the dataset.
|
|
129
|
+
"""
|
|
130
|
+
dummy_item = self[0]
|
|
131
|
+
return list(dummy_item.keys())
|
|
132
|
+
|
|
133
|
+
def _check_init_args(self) -> None:
|
|
134
|
+
"""Check validity of dataset and column names provided at init"""
|
|
135
|
+
dummy_item = self._dataset[0]
|
|
136
|
+
assert isinstance(
|
|
137
|
+
dummy_item, (tuple, dict, list, torch.Tensor)
|
|
138
|
+
), "Dataset to be wrapped needs to return tuple, list or dict of tensors"
|
|
139
|
+
if isinstance(dummy_item, torch.Tensor):
|
|
140
|
+
dummy_item = [dummy_item]
|
|
141
|
+
assert len(dummy_item) == len(
|
|
142
|
+
self._raw_columns
|
|
143
|
+
), "Length mismatch between dataset item and provided column names"
|
|
144
|
+
|
|
145
|
+
def __getitem__(self, index: int) -> dict:
|
|
146
|
+
"""Return a dictionary of tensors corresponding to a specfic index.
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
index (int): the index of the item to retrieve.
|
|
150
|
+
|
|
151
|
+
Returns:
|
|
152
|
+
dict: tensors for the item at the specific index.
|
|
153
|
+
"""
|
|
154
|
+
item = self._dataset[index]
|
|
155
|
+
|
|
156
|
+
# convert item to a list / tuple of tensors
|
|
157
|
+
if isinstance(item, torch.Tensor):
|
|
158
|
+
tensors = [item]
|
|
159
|
+
elif isinstance(item, dict):
|
|
160
|
+
tensors = list(item.values())
|
|
161
|
+
else:
|
|
162
|
+
tensors = item
|
|
163
|
+
|
|
164
|
+
# build output dictionary
|
|
165
|
+
output_dict = {key: tensor for (key, tensor) in zip(self._raw_columns, tensors)}
|
|
166
|
+
|
|
167
|
+
# apply map functions
|
|
168
|
+
for map_fn in self.map_fns:
|
|
169
|
+
output_dict = map_fn(output_dict)
|
|
170
|
+
return output_dict
|
|
171
|
+
|
|
172
|
+
def map(self, map_fn: Callable, inplace: bool = False) -> "DictDataset":
|
|
173
|
+
"""Map the dataset
|
|
174
|
+
|
|
175
|
+
Args:
|
|
176
|
+
map_fn (Callable): map function f: dict -> dict
|
|
177
|
+
inplace (bool): if False, applies the mapping on a copied version of\
|
|
178
|
+
the dataset. Defaults to False.
|
|
179
|
+
|
|
180
|
+
Return:
|
|
181
|
+
DictDataset: Mapped dataset
|
|
182
|
+
"""
|
|
183
|
+
dataset = self if inplace else copy.deepcopy(self)
|
|
184
|
+
dataset.map_fns.append(map_fn)
|
|
185
|
+
return dataset
|
|
186
|
+
|
|
187
|
+
def filter(self, filter_fn: Callable, inplace: bool = False) -> "DictDataset":
|
|
188
|
+
"""Filter the dataset
|
|
189
|
+
|
|
190
|
+
Args:
|
|
191
|
+
filter_fn (Callable): filter function f: dict -> bool
|
|
192
|
+
inplace (bool): if False, applies the filtering on a copied version of\
|
|
193
|
+
the dataset. Defaults to False.
|
|
194
|
+
|
|
195
|
+
Returns:
|
|
196
|
+
DictDataset: Filtered dataset
|
|
197
|
+
"""
|
|
198
|
+
indices = [i for i in range(len(self)) if filter_fn(self[i])]
|
|
199
|
+
dataset = self if inplace else copy.deepcopy(self)
|
|
200
|
+
dataset._dataset = Subset(self._dataset, indices)
|
|
201
|
+
return dataset
|
|
202
|
+
|
|
203
|
+
def __len__(self) -> int:
|
|
204
|
+
"""Return the length of the dataset, i.e. the number of items.
|
|
205
|
+
|
|
206
|
+
Returns:
|
|
207
|
+
int: length of the dataset.
|
|
208
|
+
"""
|
|
209
|
+
return len(self._dataset)
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
class TorchDataHandler(DataHandler):
|
|
213
|
+
"""
|
|
214
|
+
Class to manage torch DictDataset. The aim is to provide a simple interface
|
|
215
|
+
for working with torch datasets and manage them without having to use
|
|
216
|
+
torch syntax.
|
|
217
|
+
"""
|
|
218
|
+
|
|
219
|
+
def __init__(self) -> None:
|
|
220
|
+
"""
|
|
221
|
+
Initializes the TorchDataHandler.
|
|
222
|
+
Attributes:
|
|
223
|
+
backend (str): The backend framework used, set to "torch".
|
|
224
|
+
channel_order (str): The channel order format, set to "channels_first".
|
|
225
|
+
"""
|
|
226
|
+
|
|
227
|
+
super().__init__()
|
|
228
|
+
self.backend = "torch"
|
|
229
|
+
self.channel_order = "channels_first"
|
|
230
|
+
|
|
231
|
+
@staticmethod
|
|
232
|
+
def _default_target_transform(y: Any) -> torch.Tensor:
|
|
233
|
+
"""Format int or float item target as a torch tensor
|
|
234
|
+
|
|
235
|
+
Args:
|
|
236
|
+
y (Any): dataset item target
|
|
237
|
+
|
|
238
|
+
Returns:
|
|
239
|
+
torch.Tensor: target as a torch.Tensor
|
|
240
|
+
"""
|
|
241
|
+
return torch.tensor(y) if isinstance(y, (float, int)) else y
|
|
242
|
+
|
|
243
|
+
def load_dataset(
|
|
244
|
+
cls,
|
|
245
|
+
dataset_id: Union[Dataset, ItemType, str],
|
|
246
|
+
columns: Optional[list] = None,
|
|
247
|
+
hub: Optional[str] = "torchvision",
|
|
248
|
+
load_kwargs: dict = {},
|
|
249
|
+
) -> DictDataset:
|
|
250
|
+
"""Load dataset from different manners
|
|
251
|
+
|
|
252
|
+
Args:
|
|
253
|
+
dataset_id (Union[Dataset, ItemType, str]): dataset identification.
|
|
254
|
+
Can be the name of a dataset from torchvision, a torch Dataset,
|
|
255
|
+
or a tuple/dict of np.ndarrays/torch tensors.
|
|
256
|
+
columns (list, optional): Column names. If None, assigned as "input_i"
|
|
257
|
+
for i-th feature. Defaults to None.
|
|
258
|
+
load_kwargs (dict, optional): Additional loading kwargs. Defaults to {}.
|
|
259
|
+
|
|
260
|
+
Returns:
|
|
261
|
+
DictDataset: dataset
|
|
262
|
+
"""
|
|
263
|
+
|
|
264
|
+
assert hub in {
|
|
265
|
+
"torchvision",
|
|
266
|
+
"huggingface",
|
|
267
|
+
}, "hub must be either 'torchvision' or 'huggingface'"
|
|
268
|
+
|
|
269
|
+
if isinstance(dataset_id, str):
|
|
270
|
+
if hub == "torchvision":
|
|
271
|
+
assert "root" in load_kwargs.keys()
|
|
272
|
+
dataset = cls.load_from_torchvision(dataset_id, load_kwargs)
|
|
273
|
+
elif hub == "huggingface":
|
|
274
|
+
dataset = cls.load_from_huggingface(dataset_id, load_kwargs)
|
|
275
|
+
elif isinstance(dataset_id, Dataset):
|
|
276
|
+
dataset = cls.load_custom_dataset(dataset_id, columns)
|
|
277
|
+
elif isinstance(dataset_id, get_args(ItemType)):
|
|
278
|
+
dataset = cls.load_dataset_from_arrays(dataset_id, columns)
|
|
279
|
+
return dataset
|
|
280
|
+
|
|
281
|
+
@staticmethod
|
|
282
|
+
def load_dataset_from_arrays(
|
|
283
|
+
dataset_id: ItemType,
|
|
284
|
+
columns: Optional[list] = None,
|
|
285
|
+
) -> DictDataset:
|
|
286
|
+
"""Load a torch.utils.data.Dataset from an array or a tuple/dict of arrays.
|
|
287
|
+
|
|
288
|
+
Args:
|
|
289
|
+
dataset_id (ItemType):
|
|
290
|
+
numpy / torch array(s) to load.
|
|
291
|
+
columns (list, optional): Column names to assign. If None,
|
|
292
|
+
assigned as "input_i" for i-th feature. Defaults to None.
|
|
293
|
+
|
|
294
|
+
Returns:
|
|
295
|
+
DictDataset: dataset
|
|
296
|
+
"""
|
|
297
|
+
# If dataset_id is an array
|
|
298
|
+
if isinstance(dataset_id, get_args(TensorType)):
|
|
299
|
+
tensors = tuple(to_torch(dataset_id))
|
|
300
|
+
columns = columns or ["input"]
|
|
301
|
+
|
|
302
|
+
# If dataset_id is a tuple of arrays
|
|
303
|
+
elif isinstance(dataset_id, tuple):
|
|
304
|
+
len_elem = len(dataset_id)
|
|
305
|
+
if columns is None:
|
|
306
|
+
if len_elem == 2:
|
|
307
|
+
columns = ["input", "label"]
|
|
308
|
+
else:
|
|
309
|
+
columns = [f"input_{i}" for i in range(len_elem - 1)] + ["label"]
|
|
310
|
+
print(
|
|
311
|
+
"Loading torch.utils.data.Dataset with elems as dicts, "
|
|
312
|
+
'assigning "input_i" key to the i-th tuple dimension and'
|
|
313
|
+
' "label" key to the last tuple dimension.'
|
|
314
|
+
)
|
|
315
|
+
assert len(columns) == len(dataset_id)
|
|
316
|
+
tensors = tuple(to_torch(array) for array in dataset_id)
|
|
317
|
+
|
|
318
|
+
# If dataset_id is a dictionary of arrays
|
|
319
|
+
elif isinstance(dataset_id, dict):
|
|
320
|
+
columns = columns or list(dataset_id.keys())
|
|
321
|
+
assert len(columns) == len(dataset_id)
|
|
322
|
+
tensors = tuple(to_torch(array) for array in dataset_id.values())
|
|
323
|
+
|
|
324
|
+
# create torch dictionary dataset from tensors tuple and columns
|
|
325
|
+
dataset = DictDataset(TensorDataset(*tensors), columns)
|
|
326
|
+
return dataset
|
|
327
|
+
|
|
328
|
+
@staticmethod
|
|
329
|
+
def load_custom_dataset(
|
|
330
|
+
dataset_id: Dataset, columns: Optional[list] = None
|
|
331
|
+
) -> DictDataset:
|
|
332
|
+
"""Load a custom Dataset by ensuring it has the correct format (dict-based)
|
|
333
|
+
|
|
334
|
+
Args:
|
|
335
|
+
dataset_id (Dataset): Dataset
|
|
336
|
+
columns (list, optional): Column names to use for elements if dataset_id is
|
|
337
|
+
tuple based. If None, assigned as "input_i"
|
|
338
|
+
for i-th column. Defaults to None.
|
|
339
|
+
|
|
340
|
+
Returns:
|
|
341
|
+
DictDataset
|
|
342
|
+
"""
|
|
343
|
+
# If dataset_id is a tuple based Dataset, convert it to a DictDataset
|
|
344
|
+
dummy_item = dataset_id[0]
|
|
345
|
+
if not isinstance(dummy_item, dict):
|
|
346
|
+
assert isinstance(
|
|
347
|
+
dummy_item, (Tuple, torch.Tensor)
|
|
348
|
+
), "Custom dataset should be either dictionary based or tuple based"
|
|
349
|
+
if columns is None:
|
|
350
|
+
len_elem = len(dummy_item)
|
|
351
|
+
if len_elem == 2:
|
|
352
|
+
columns = ["input", "label"]
|
|
353
|
+
else:
|
|
354
|
+
columns = [f"input_{i}" for i in range(len_elem - 1)] + ["label"]
|
|
355
|
+
print(
|
|
356
|
+
"Feature name not found, assigning 'input_i' "
|
|
357
|
+
"key to the i-th tensor and 'label' key to the last"
|
|
358
|
+
)
|
|
359
|
+
dataset_id = DictDataset(dataset_id, columns)
|
|
360
|
+
|
|
361
|
+
dataset = dataset_id
|
|
362
|
+
return dataset
|
|
363
|
+
|
|
364
|
+
@classmethod
|
|
365
|
+
def load_from_huggingface(
|
|
366
|
+
cls,
|
|
367
|
+
dataset_id: str,
|
|
368
|
+
load_kwargs: dict = {},
|
|
369
|
+
) -> DictDataset:
|
|
370
|
+
"""Load a Dataset from the Hugging Face datasets catalog
|
|
371
|
+
|
|
372
|
+
Args:
|
|
373
|
+
dataset_id (str): Identifier of the dataset
|
|
374
|
+
load_kwargs (dict): Loading kwargs to add to the initialization
|
|
375
|
+
of the dataset.
|
|
376
|
+
|
|
377
|
+
Returns:
|
|
378
|
+
DictDataset: dataset
|
|
379
|
+
"""
|
|
380
|
+
if "transform" in load_kwargs.keys():
|
|
381
|
+
transform = load_kwargs["transform"]
|
|
382
|
+
load_kwargs.pop("transform")
|
|
383
|
+
else:
|
|
384
|
+
|
|
385
|
+
def transform(x):
|
|
386
|
+
return x
|
|
387
|
+
|
|
388
|
+
dataset = hf_load_dataset(dataset_id, **load_kwargs)
|
|
389
|
+
|
|
390
|
+
def transform_full(examples):
|
|
391
|
+
examples = transform(examples)
|
|
392
|
+
examples["label"] = [
|
|
393
|
+
cls._default_target_transform(example) for example in examples["label"]
|
|
394
|
+
]
|
|
395
|
+
return examples
|
|
396
|
+
|
|
397
|
+
dataset = dataset.with_transform(transform_full)
|
|
398
|
+
return dataset # HF datasets are already dict-based
|
|
399
|
+
|
|
400
|
+
@classmethod
|
|
401
|
+
def load_from_torchvision(
|
|
402
|
+
cls,
|
|
403
|
+
dataset_id: str,
|
|
404
|
+
load_kwargs: dict = {},
|
|
405
|
+
) -> DictDataset:
|
|
406
|
+
"""Load a Dataset from the torchvision datasets catalog
|
|
407
|
+
|
|
408
|
+
Args:
|
|
409
|
+
dataset_id (str): Identifier of the dataset
|
|
410
|
+
root (str): Root directory of dataset
|
|
411
|
+
transform (Callable, optional): Transform function to apply to the input.
|
|
412
|
+
Defaults to DEFAULT_TRANSFORM.
|
|
413
|
+
target_transform (Callable, optional): Transform function to apply
|
|
414
|
+
to the target. Defaults to DEFAULT_TARGET_TRANSFORM.
|
|
415
|
+
download (bool): If true, downloads the dataset from the internet and puts
|
|
416
|
+
it in root directory. If dataset is already downloaded, it is not
|
|
417
|
+
downloaded again. Defaults to False.
|
|
418
|
+
load_kwargs (dict): Loading kwargs to add to the initialization
|
|
419
|
+
of dataset.
|
|
420
|
+
|
|
421
|
+
Returns:
|
|
422
|
+
DictDataset: dataset
|
|
423
|
+
"""
|
|
424
|
+
assert (
|
|
425
|
+
dataset_id in torchvision.datasets.__all__
|
|
426
|
+
), "Dataset not available on torchvision datasets catalog"
|
|
427
|
+
|
|
428
|
+
if "transform" not in load_kwargs.keys():
|
|
429
|
+
load_kwargs["transform"] = torchvision.transforms.PILToTensor()
|
|
430
|
+
if "target_transform" not in load_kwargs.keys():
|
|
431
|
+
load_kwargs["target_transform"] = cls._default_target_transform
|
|
432
|
+
|
|
433
|
+
dataset = getattr(torchvision.datasets, dataset_id)(
|
|
434
|
+
**load_kwargs,
|
|
435
|
+
)
|
|
436
|
+
return cls.load_custom_dataset(dataset)
|
|
437
|
+
|
|
438
|
+
@staticmethod
|
|
439
|
+
@dict_only_ds
|
|
440
|
+
def get_ds_column_names(dataset: DictDataset) -> list:
|
|
441
|
+
"""Get the column names of a DictDataset
|
|
442
|
+
|
|
443
|
+
Args:
|
|
444
|
+
dataset (DictDataset): Dataset to get the column names from
|
|
445
|
+
|
|
446
|
+
Returns:
|
|
447
|
+
list: List of column names
|
|
448
|
+
"""
|
|
449
|
+
return dataset.column_names
|
|
450
|
+
|
|
451
|
+
@staticmethod
|
|
452
|
+
def map_ds(
|
|
453
|
+
dataset: DictDataset,
|
|
454
|
+
map_fn: Callable,
|
|
455
|
+
) -> DictDataset:
|
|
456
|
+
"""Map a function to a DictDataset
|
|
457
|
+
|
|
458
|
+
Args:
|
|
459
|
+
dataset (DictDataset): Dataset to map the function to
|
|
460
|
+
map_fn (Callable): Function to map
|
|
461
|
+
|
|
462
|
+
Returns:
|
|
463
|
+
DictDataset: Mapped dataset
|
|
464
|
+
"""
|
|
465
|
+
return dataset.map(map_fn)
|
|
466
|
+
|
|
467
|
+
@staticmethod
|
|
468
|
+
@dict_only_ds
|
|
469
|
+
def filter_by_value(
|
|
470
|
+
dataset: DictDataset,
|
|
471
|
+
column_name: str,
|
|
472
|
+
values: list,
|
|
473
|
+
excluded: bool = False,
|
|
474
|
+
) -> DictDataset:
|
|
475
|
+
"""Filter the dataset by checking if the value of a column is in `values`
|
|
476
|
+
|
|
477
|
+
!!! note
|
|
478
|
+
This function can be a bit of time consuming since it needs to iterate
|
|
479
|
+
over the whole dataset.
|
|
480
|
+
|
|
481
|
+
Args:
|
|
482
|
+
dataset (DictDataset): Dataset to filter
|
|
483
|
+
column_name (str): Column to filter the dataset with
|
|
484
|
+
values (list): Column values to keep
|
|
485
|
+
excluded (bool, optional): To keep (False) or exclude (True) the samples
|
|
486
|
+
with column values included in Values. Defaults to False.
|
|
487
|
+
|
|
488
|
+
Returns:
|
|
489
|
+
DictDataset: Filtered dataset
|
|
490
|
+
"""
|
|
491
|
+
|
|
492
|
+
if len(dataset[0][column_name].shape) > 0:
|
|
493
|
+
value_dim = dataset[0][column_name].shape[-1]
|
|
494
|
+
values = [
|
|
495
|
+
F.one_hot(torch.tensor(value).long(), value_dim) for value in values
|
|
496
|
+
]
|
|
497
|
+
|
|
498
|
+
def filter_fn(x):
|
|
499
|
+
keep = any([torch.all(x[column_name] == v) for v in values])
|
|
500
|
+
return keep if not excluded else not keep
|
|
501
|
+
|
|
502
|
+
filtered_dataset = dataset.filter(filter_fn)
|
|
503
|
+
return filtered_dataset
|
|
504
|
+
|
|
505
|
+
@classmethod
|
|
506
|
+
def prepare(
|
|
507
|
+
cls,
|
|
508
|
+
dataset: DictDataset,
|
|
509
|
+
batch_size: int,
|
|
510
|
+
preprocess_fn: Optional[Callable] = None,
|
|
511
|
+
augment_fn: Optional[Callable] = None,
|
|
512
|
+
columns: Optional[list] = None,
|
|
513
|
+
shuffle: bool = False,
|
|
514
|
+
dict_based_fns: bool = True,
|
|
515
|
+
return_tuple: bool = True,
|
|
516
|
+
num_workers: int = 0,
|
|
517
|
+
) -> DataLoader:
|
|
518
|
+
"""Prepare a DataLoader for training
|
|
519
|
+
|
|
520
|
+
Args:
|
|
521
|
+
dataset (DictDataset): Dataset to prepare
|
|
522
|
+
batch_size (int): Batch size
|
|
523
|
+
preprocess_fn (Callable, optional): Preprocessing function to apply to
|
|
524
|
+
the dataset. Defaults to None.
|
|
525
|
+
augment_fn (Callable, optional): Augment function to be used (when the
|
|
526
|
+
returned dataset is to be used for training). Defaults to None.
|
|
527
|
+
columns (list, optional): List of column names corresponding to the columns
|
|
528
|
+
that will be returned. Keep all features if None. Defaults to None.
|
|
529
|
+
shuffle (bool, optional): To shuffle the returned dataset or not.
|
|
530
|
+
Defaults to False.
|
|
531
|
+
dict_based_fns (bool): Whether to use preprocess and DA functions as dict
|
|
532
|
+
based (if True) or as tuple based (if False). Defaults to True.
|
|
533
|
+
return_tuple (bool, optional): Whether to return each dataset item
|
|
534
|
+
as a tuple. Defaults to True.
|
|
535
|
+
num_workers (int, optional): Number of workers to use for the dataloader.
|
|
536
|
+
|
|
537
|
+
Returns:
|
|
538
|
+
DataLoader: dataloader
|
|
539
|
+
"""
|
|
540
|
+
columns = columns or cls.get_ds_column_names(dataset)
|
|
541
|
+
|
|
542
|
+
def collate_fn(batch: List[dict]):
|
|
543
|
+
if dict_based_fns:
|
|
544
|
+
# preprocess + DA: List[dict] -> List[dict]
|
|
545
|
+
preprocess_func = preprocess_fn or (lambda x: x)
|
|
546
|
+
augment_func = augment_fn or (lambda x: x)
|
|
547
|
+
batch = [augment_func(preprocess_func(d)) for d in batch]
|
|
548
|
+
# to dict of batchs
|
|
549
|
+
if return_tuple:
|
|
550
|
+
return tuple(
|
|
551
|
+
default_collate([d[key] for d in batch]) for key in columns
|
|
552
|
+
)
|
|
553
|
+
return {
|
|
554
|
+
key: default_collate([d[key] for d in batch]) for key in columns
|
|
555
|
+
}
|
|
556
|
+
else:
|
|
557
|
+
# preprocess + DA: List[dict] -> List[tuple]
|
|
558
|
+
preprocess_func = preprocess_fn or (lambda *x: x)
|
|
559
|
+
augment_func = augment_fn or (lambda *x: x)
|
|
560
|
+
batch = [
|
|
561
|
+
augment_func(*preprocess_func(*tuple(d[key] for key in columns)))
|
|
562
|
+
for d in batch
|
|
563
|
+
]
|
|
564
|
+
# to tuple of batchs
|
|
565
|
+
return default_collate(batch)
|
|
566
|
+
|
|
567
|
+
loader = DataLoader(
|
|
568
|
+
dataset,
|
|
569
|
+
batch_size=batch_size,
|
|
570
|
+
shuffle=shuffle,
|
|
571
|
+
collate_fn=collate_fn,
|
|
572
|
+
num_workers=num_workers,
|
|
573
|
+
)
|
|
574
|
+
return loader
|
|
575
|
+
|
|
576
|
+
@staticmethod
|
|
577
|
+
def get_item_length(dataset: Dataset) -> int:
|
|
578
|
+
"""Number of elements in a dataset item
|
|
579
|
+
|
|
580
|
+
Args:
|
|
581
|
+
dataset (DictDataset): Dataset
|
|
582
|
+
|
|
583
|
+
Returns:
|
|
584
|
+
int: Item length
|
|
585
|
+
"""
|
|
586
|
+
return len(dataset[0])
|
|
587
|
+
|
|
588
|
+
@staticmethod
|
|
589
|
+
def get_dataset_length(dataset: Dataset) -> int:
|
|
590
|
+
"""Number of items in a dataset
|
|
591
|
+
|
|
592
|
+
Args:
|
|
593
|
+
dataset (DictDataset): Dataset
|
|
594
|
+
|
|
595
|
+
Returns:
|
|
596
|
+
int: Dataset length
|
|
597
|
+
"""
|
|
598
|
+
return len(dataset)
|
|
599
|
+
|
|
600
|
+
@staticmethod
|
|
601
|
+
def get_column_elements_shape(
|
|
602
|
+
dataset: Dataset, column_name: Union[str, int]
|
|
603
|
+
) -> tuple:
|
|
604
|
+
"""Get the shape of the elements of a column of dataset identified by
|
|
605
|
+
column_name
|
|
606
|
+
|
|
607
|
+
Args:
|
|
608
|
+
dataset (Dataset): a Dataset
|
|
609
|
+
column_name (Union[str, int]): The column name to get
|
|
610
|
+
the element shape from.
|
|
611
|
+
|
|
612
|
+
Returns:
|
|
613
|
+
tuple: the shape of an element from column_name
|
|
614
|
+
"""
|
|
615
|
+
return tuple(dataset[0][column_name].shape)
|
|
616
|
+
|
|
617
|
+
@staticmethod
|
|
618
|
+
def get_columns_shapes(dataset: Dataset) -> dict:
|
|
619
|
+
"""Get the shapes of the elements of all columns of a dataset
|
|
620
|
+
|
|
621
|
+
Args:
|
|
622
|
+
dataset (Dataset): a Dataset
|
|
623
|
+
|
|
624
|
+
Returns:
|
|
625
|
+
dict: dictionary of column names and their corresponding shape
|
|
626
|
+
"""
|
|
627
|
+
shapes = {}
|
|
628
|
+
for key in dataset.column_names:
|
|
629
|
+
try:
|
|
630
|
+
shapes[key] = tuple(dataset[0][key].shape)
|
|
631
|
+
except AttributeError:
|
|
632
|
+
pass
|
|
633
|
+
return shapes
|
|
634
|
+
|
|
635
|
+
@staticmethod
|
|
636
|
+
def get_input_from_dataset_item(elem: ItemType) -> Any:
|
|
637
|
+
"""Get the tensor that is to be feed as input to a model from a dataset element.
|
|
638
|
+
|
|
639
|
+
Args:
|
|
640
|
+
elem (ItemType): dataset element to extract input from
|
|
641
|
+
|
|
642
|
+
Returns:
|
|
643
|
+
Any: Input tensor
|
|
644
|
+
"""
|
|
645
|
+
if isinstance(elem, (tuple, list)):
|
|
646
|
+
tensor = elem[0]
|
|
647
|
+
elif isinstance(elem, dict):
|
|
648
|
+
tensor = elem[list(elem.keys())[0]]
|
|
649
|
+
else:
|
|
650
|
+
tensor = elem
|
|
651
|
+
return tensor
|
|
652
|
+
|
|
653
|
+
@staticmethod
|
|
654
|
+
def get_label_from_dataset_item(item: ItemType):
|
|
655
|
+
"""Retrieve label tensor from item as a tuple/list. Label must be at index 1
|
|
656
|
+
in the item tuple. If one-hot encoded, labels are converted to single value.
|
|
657
|
+
|
|
658
|
+
Args:
|
|
659
|
+
elem (ItemType): dataset element to extract label from
|
|
660
|
+
|
|
661
|
+
Returns:
|
|
662
|
+
Any: Label tensor
|
|
663
|
+
"""
|
|
664
|
+
label = item[1] # labels must be at index 1 in the batch tuple
|
|
665
|
+
# If labels are one-hot encoded, take the argmax
|
|
666
|
+
if len(label.shape) > 1 and label.shape[1] > 1:
|
|
667
|
+
label = label.view(label.size(0), -1)
|
|
668
|
+
label = torch.argmax(label, dim=1)
|
|
669
|
+
# If labels are in two dimensions, squeeze them
|
|
670
|
+
if len(label.shape) > 1:
|
|
671
|
+
label = label.view([label.shape[0]])
|
|
672
|
+
return label
|
oodeel/eval/__init__.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Copyright IRT Antoine de Saint Exupéry et Université Paul Sabatier Toulouse III - All
|
|
3
|
+
# rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
|
|
4
|
+
# CRIAQ and ANITI - https://www.deel.ai/
|
|
5
|
+
#
|
|
6
|
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
7
|
+
# of this software and associated documentation files (the "Software"), to deal
|
|
8
|
+
# in the Software without restriction, including without limitation the rights
|
|
9
|
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
10
|
+
# copies of the Software, and to permit persons to whom the Software is
|
|
11
|
+
# furnished to do so, subject to the following conditions:
|
|
12
|
+
#
|
|
13
|
+
# The above copyright notice and this permission notice shall be included in all
|
|
14
|
+
# copies or substantial portions of the Software.
|
|
15
|
+
#
|
|
16
|
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
17
|
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
18
|
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
19
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
20
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
21
|
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
22
|
+
# SOFTWARE.
|