oodeel 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. oodeel/__init__.py +28 -0
  2. oodeel/aggregator/__init__.py +26 -0
  3. oodeel/aggregator/base.py +70 -0
  4. oodeel/aggregator/fisher.py +259 -0
  5. oodeel/aggregator/mean.py +72 -0
  6. oodeel/aggregator/std.py +86 -0
  7. oodeel/datasets/__init__.py +24 -0
  8. oodeel/datasets/data_handler.py +334 -0
  9. oodeel/datasets/deprecated/DEPRECATED_data_handler.py +236 -0
  10. oodeel/datasets/deprecated/DEPRECATED_ooddataset.py +330 -0
  11. oodeel/datasets/deprecated/DEPRECATED_tf_data_handler.py +671 -0
  12. oodeel/datasets/deprecated/DEPRECATED_torch_data_handler.py +769 -0
  13. oodeel/datasets/deprecated/__init__.py +31 -0
  14. oodeel/datasets/tf_data_handler.py +600 -0
  15. oodeel/datasets/torch_data_handler.py +672 -0
  16. oodeel/eval/__init__.py +22 -0
  17. oodeel/eval/metrics.py +218 -0
  18. oodeel/eval/plots/__init__.py +27 -0
  19. oodeel/eval/plots/features.py +345 -0
  20. oodeel/eval/plots/metrics.py +118 -0
  21. oodeel/eval/plots/plotly.py +162 -0
  22. oodeel/extractor/__init__.py +35 -0
  23. oodeel/extractor/feature_extractor.py +187 -0
  24. oodeel/extractor/hf_torch_feature_extractor.py +184 -0
  25. oodeel/extractor/keras_feature_extractor.py +409 -0
  26. oodeel/extractor/torch_feature_extractor.py +506 -0
  27. oodeel/methods/__init__.py +47 -0
  28. oodeel/methods/base.py +570 -0
  29. oodeel/methods/dknn.py +185 -0
  30. oodeel/methods/energy.py +119 -0
  31. oodeel/methods/entropy.py +113 -0
  32. oodeel/methods/gen.py +113 -0
  33. oodeel/methods/gram.py +274 -0
  34. oodeel/methods/mahalanobis.py +209 -0
  35. oodeel/methods/mls.py +113 -0
  36. oodeel/methods/odin.py +109 -0
  37. oodeel/methods/rmds.py +172 -0
  38. oodeel/methods/she.py +159 -0
  39. oodeel/methods/vim.py +273 -0
  40. oodeel/preprocess/__init__.py +31 -0
  41. oodeel/preprocess/tf_preprocess.py +95 -0
  42. oodeel/preprocess/torch_preprocess.py +97 -0
  43. oodeel/types/__init__.py +75 -0
  44. oodeel/utils/__init__.py +38 -0
  45. oodeel/utils/general_utils.py +97 -0
  46. oodeel/utils/operator.py +253 -0
  47. oodeel/utils/tf_operator.py +269 -0
  48. oodeel/utils/tf_training_tools.py +219 -0
  49. oodeel/utils/torch_operator.py +292 -0
  50. oodeel/utils/torch_training_tools.py +303 -0
  51. oodeel-0.4.0.dist-info/METADATA +409 -0
  52. oodeel-0.4.0.dist-info/RECORD +63 -0
  53. oodeel-0.4.0.dist-info/WHEEL +5 -0
  54. oodeel-0.4.0.dist-info/licenses/LICENSE +21 -0
  55. oodeel-0.4.0.dist-info/top_level.txt +2 -0
  56. tests/__init__.py +22 -0
  57. tests/tests_tensorflow/__init__.py +37 -0
  58. tests/tests_tensorflow/tf_methods_utils.py +140 -0
  59. tests/tests_tensorflow/tools_tf.py +86 -0
  60. tests/tests_torch/__init__.py +38 -0
  61. tests/tests_torch/tools_torch.py +151 -0
  62. tests/tests_torch/torch_methods_utils.py +148 -0
  63. tests/tools_operator.py +153 -0
@@ -0,0 +1,672 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright IRT Antoine de Saint Exupéry et Université Paul Sabatier Toulouse III - All
3
+ # rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
4
+ # CRIAQ and ANITI - https://www.deel.ai/
5
+ #
6
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
7
+ # of this software and associated documentation files (the "Software"), to deal
8
+ # in the Software without restriction, including without limitation the rights
9
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
+ # copies of the Software, and to permit persons to whom the Software is
11
+ # furnished to do so, subject to the following conditions:
12
+ #
13
+ # The above copyright notice and this permission notice shall be included in all
14
+ # copies or substantial portions of the Software.
15
+ #
16
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
+ # SOFTWARE.
23
+ import copy
24
+ from typing import get_args
25
+
26
+ import numpy as np
27
+ import torch
28
+ import torch.nn.functional as F
29
+ import torchvision
30
+ from datasets import load_dataset as hf_load_dataset
31
+ from torch.utils.data import DataLoader
32
+ from torch.utils.data import Dataset
33
+ from torch.utils.data import Subset
34
+ from torch.utils.data import TensorDataset
35
+ from torch.utils.data.dataloader import default_collate
36
+
37
+ from ..types import Any
38
+ from ..types import Callable
39
+ from ..types import ItemType
40
+ from ..types import List
41
+ from ..types import Optional
42
+ from ..types import TensorType
43
+ from ..types import Tuple
44
+ from ..types import Union
45
+ from .data_handler import DataHandler
46
+
47
+
48
+ def dict_only_ds(ds_handling_method: Callable) -> Callable:
49
+ """Decorator to ensure that the dataset is a dict dataset and that the column_name
50
+ given as argument matches one of the column names.
51
+ matches one of the column names. The signature of decorated functions
52
+ must be function(dataset, *args, **kwargs) with column_name either in kwargs or
53
+ args[0] when relevant.
54
+
55
+
56
+ Args:
57
+ ds_handling_method: method to decorate
58
+
59
+ Returns:
60
+ decorated method
61
+ """
62
+
63
+ def wrapper(dataset: Dataset, *args, **kwargs):
64
+ assert isinstance(
65
+ dataset[0], dict
66
+ ), "Dataset must be an instance of DictDataset"
67
+
68
+ if "column_name" in kwargs:
69
+ column_name = kwargs["column_name"]
70
+ elif len(args) > 0:
71
+ column_name = args[0]
72
+
73
+ # If column_name is provided, check that it is in the dataset column names
74
+ if (len(args) > 0) or ("column_name" in kwargs):
75
+ if isinstance(column_name, str):
76
+ column_name = [column_name]
77
+ for name in column_name:
78
+ assert (
79
+ name in dataset.column_names
80
+ ), f"The input dataset has no column named {name}"
81
+ return ds_handling_method(dataset, *args, **kwargs)
82
+
83
+ return wrapper
84
+
85
+
86
+ def to_torch(array: TensorType) -> torch.Tensor:
87
+ """Convert an array into a torch Tensor
88
+
89
+ Args:
90
+ array (TensorType): array to convert
91
+
92
+ Returns:
93
+ torch.Tensor: converted array
94
+ """
95
+ if isinstance(array, np.ndarray):
96
+ return torch.Tensor(array)
97
+ elif isinstance(array, torch.Tensor):
98
+ return array
99
+ else:
100
+ raise TypeError("Input array must be of numpy or torch type")
101
+
102
+
103
+ class DictDataset(Dataset):
104
+ r"""Dictionary pytorch dataset
105
+
106
+ Wrapper to output a dictionary of tensors at the __getitem__ call of a dataset.
107
+ Some mapping, filtering and concatenation methods are implemented to imitate
108
+ tensorflow datasets features.
109
+
110
+ Args:
111
+ dataset (Dataset): Dataset to wrap.
112
+ columns (columns[str]): Column names describing the output tensors.
113
+ """
114
+
115
+ def __init__(
116
+ self, dataset: Dataset, column_names: List[str] = ["input", "label"]
117
+ ) -> None:
118
+ self._dataset = dataset
119
+ self._raw_columns = column_names
120
+ self.map_fns = []
121
+ self._check_init_args()
122
+
123
+ @property
124
+ def column_names(self) -> list:
125
+ """Get the list of columns in a dict-based item from the dataset.
126
+
127
+ Returns:
128
+ list: column names of the dataset.
129
+ """
130
+ dummy_item = self[0]
131
+ return list(dummy_item.keys())
132
+
133
+ def _check_init_args(self) -> None:
134
+ """Check validity of dataset and column names provided at init"""
135
+ dummy_item = self._dataset[0]
136
+ assert isinstance(
137
+ dummy_item, (tuple, dict, list, torch.Tensor)
138
+ ), "Dataset to be wrapped needs to return tuple, list or dict of tensors"
139
+ if isinstance(dummy_item, torch.Tensor):
140
+ dummy_item = [dummy_item]
141
+ assert len(dummy_item) == len(
142
+ self._raw_columns
143
+ ), "Length mismatch between dataset item and provided column names"
144
+
145
+ def __getitem__(self, index: int) -> dict:
146
+ """Return a dictionary of tensors corresponding to a specfic index.
147
+
148
+ Args:
149
+ index (int): the index of the item to retrieve.
150
+
151
+ Returns:
152
+ dict: tensors for the item at the specific index.
153
+ """
154
+ item = self._dataset[index]
155
+
156
+ # convert item to a list / tuple of tensors
157
+ if isinstance(item, torch.Tensor):
158
+ tensors = [item]
159
+ elif isinstance(item, dict):
160
+ tensors = list(item.values())
161
+ else:
162
+ tensors = item
163
+
164
+ # build output dictionary
165
+ output_dict = {key: tensor for (key, tensor) in zip(self._raw_columns, tensors)}
166
+
167
+ # apply map functions
168
+ for map_fn in self.map_fns:
169
+ output_dict = map_fn(output_dict)
170
+ return output_dict
171
+
172
+ def map(self, map_fn: Callable, inplace: bool = False) -> "DictDataset":
173
+ """Map the dataset
174
+
175
+ Args:
176
+ map_fn (Callable): map function f: dict -> dict
177
+ inplace (bool): if False, applies the mapping on a copied version of\
178
+ the dataset. Defaults to False.
179
+
180
+ Return:
181
+ DictDataset: Mapped dataset
182
+ """
183
+ dataset = self if inplace else copy.deepcopy(self)
184
+ dataset.map_fns.append(map_fn)
185
+ return dataset
186
+
187
+ def filter(self, filter_fn: Callable, inplace: bool = False) -> "DictDataset":
188
+ """Filter the dataset
189
+
190
+ Args:
191
+ filter_fn (Callable): filter function f: dict -> bool
192
+ inplace (bool): if False, applies the filtering on a copied version of\
193
+ the dataset. Defaults to False.
194
+
195
+ Returns:
196
+ DictDataset: Filtered dataset
197
+ """
198
+ indices = [i for i in range(len(self)) if filter_fn(self[i])]
199
+ dataset = self if inplace else copy.deepcopy(self)
200
+ dataset._dataset = Subset(self._dataset, indices)
201
+ return dataset
202
+
203
+ def __len__(self) -> int:
204
+ """Return the length of the dataset, i.e. the number of items.
205
+
206
+ Returns:
207
+ int: length of the dataset.
208
+ """
209
+ return len(self._dataset)
210
+
211
+
212
+ class TorchDataHandler(DataHandler):
213
+ """
214
+ Class to manage torch DictDataset. The aim is to provide a simple interface
215
+ for working with torch datasets and manage them without having to use
216
+ torch syntax.
217
+ """
218
+
219
+ def __init__(self) -> None:
220
+ """
221
+ Initializes the TorchDataHandler.
222
+ Attributes:
223
+ backend (str): The backend framework used, set to "torch".
224
+ channel_order (str): The channel order format, set to "channels_first".
225
+ """
226
+
227
+ super().__init__()
228
+ self.backend = "torch"
229
+ self.channel_order = "channels_first"
230
+
231
+ @staticmethod
232
+ def _default_target_transform(y: Any) -> torch.Tensor:
233
+ """Format int or float item target as a torch tensor
234
+
235
+ Args:
236
+ y (Any): dataset item target
237
+
238
+ Returns:
239
+ torch.Tensor: target as a torch.Tensor
240
+ """
241
+ return torch.tensor(y) if isinstance(y, (float, int)) else y
242
+
243
+ def load_dataset(
244
+ cls,
245
+ dataset_id: Union[Dataset, ItemType, str],
246
+ columns: Optional[list] = None,
247
+ hub: Optional[str] = "torchvision",
248
+ load_kwargs: dict = {},
249
+ ) -> DictDataset:
250
+ """Load dataset from different manners
251
+
252
+ Args:
253
+ dataset_id (Union[Dataset, ItemType, str]): dataset identification.
254
+ Can be the name of a dataset from torchvision, a torch Dataset,
255
+ or a tuple/dict of np.ndarrays/torch tensors.
256
+ columns (list, optional): Column names. If None, assigned as "input_i"
257
+ for i-th feature. Defaults to None.
258
+ load_kwargs (dict, optional): Additional loading kwargs. Defaults to {}.
259
+
260
+ Returns:
261
+ DictDataset: dataset
262
+ """
263
+
264
+ assert hub in {
265
+ "torchvision",
266
+ "huggingface",
267
+ }, "hub must be either 'torchvision' or 'huggingface'"
268
+
269
+ if isinstance(dataset_id, str):
270
+ if hub == "torchvision":
271
+ assert "root" in load_kwargs.keys()
272
+ dataset = cls.load_from_torchvision(dataset_id, load_kwargs)
273
+ elif hub == "huggingface":
274
+ dataset = cls.load_from_huggingface(dataset_id, load_kwargs)
275
+ elif isinstance(dataset_id, Dataset):
276
+ dataset = cls.load_custom_dataset(dataset_id, columns)
277
+ elif isinstance(dataset_id, get_args(ItemType)):
278
+ dataset = cls.load_dataset_from_arrays(dataset_id, columns)
279
+ return dataset
280
+
281
+ @staticmethod
282
+ def load_dataset_from_arrays(
283
+ dataset_id: ItemType,
284
+ columns: Optional[list] = None,
285
+ ) -> DictDataset:
286
+ """Load a torch.utils.data.Dataset from an array or a tuple/dict of arrays.
287
+
288
+ Args:
289
+ dataset_id (ItemType):
290
+ numpy / torch array(s) to load.
291
+ columns (list, optional): Column names to assign. If None,
292
+ assigned as "input_i" for i-th feature. Defaults to None.
293
+
294
+ Returns:
295
+ DictDataset: dataset
296
+ """
297
+ # If dataset_id is an array
298
+ if isinstance(dataset_id, get_args(TensorType)):
299
+ tensors = tuple(to_torch(dataset_id))
300
+ columns = columns or ["input"]
301
+
302
+ # If dataset_id is a tuple of arrays
303
+ elif isinstance(dataset_id, tuple):
304
+ len_elem = len(dataset_id)
305
+ if columns is None:
306
+ if len_elem == 2:
307
+ columns = ["input", "label"]
308
+ else:
309
+ columns = [f"input_{i}" for i in range(len_elem - 1)] + ["label"]
310
+ print(
311
+ "Loading torch.utils.data.Dataset with elems as dicts, "
312
+ 'assigning "input_i" key to the i-th tuple dimension and'
313
+ ' "label" key to the last tuple dimension.'
314
+ )
315
+ assert len(columns) == len(dataset_id)
316
+ tensors = tuple(to_torch(array) for array in dataset_id)
317
+
318
+ # If dataset_id is a dictionary of arrays
319
+ elif isinstance(dataset_id, dict):
320
+ columns = columns or list(dataset_id.keys())
321
+ assert len(columns) == len(dataset_id)
322
+ tensors = tuple(to_torch(array) for array in dataset_id.values())
323
+
324
+ # create torch dictionary dataset from tensors tuple and columns
325
+ dataset = DictDataset(TensorDataset(*tensors), columns)
326
+ return dataset
327
+
328
+ @staticmethod
329
+ def load_custom_dataset(
330
+ dataset_id: Dataset, columns: Optional[list] = None
331
+ ) -> DictDataset:
332
+ """Load a custom Dataset by ensuring it has the correct format (dict-based)
333
+
334
+ Args:
335
+ dataset_id (Dataset): Dataset
336
+ columns (list, optional): Column names to use for elements if dataset_id is
337
+ tuple based. If None, assigned as "input_i"
338
+ for i-th column. Defaults to None.
339
+
340
+ Returns:
341
+ DictDataset
342
+ """
343
+ # If dataset_id is a tuple based Dataset, convert it to a DictDataset
344
+ dummy_item = dataset_id[0]
345
+ if not isinstance(dummy_item, dict):
346
+ assert isinstance(
347
+ dummy_item, (Tuple, torch.Tensor)
348
+ ), "Custom dataset should be either dictionary based or tuple based"
349
+ if columns is None:
350
+ len_elem = len(dummy_item)
351
+ if len_elem == 2:
352
+ columns = ["input", "label"]
353
+ else:
354
+ columns = [f"input_{i}" for i in range(len_elem - 1)] + ["label"]
355
+ print(
356
+ "Feature name not found, assigning 'input_i' "
357
+ "key to the i-th tensor and 'label' key to the last"
358
+ )
359
+ dataset_id = DictDataset(dataset_id, columns)
360
+
361
+ dataset = dataset_id
362
+ return dataset
363
+
364
+ @classmethod
365
+ def load_from_huggingface(
366
+ cls,
367
+ dataset_id: str,
368
+ load_kwargs: dict = {},
369
+ ) -> DictDataset:
370
+ """Load a Dataset from the Hugging Face datasets catalog
371
+
372
+ Args:
373
+ dataset_id (str): Identifier of the dataset
374
+ load_kwargs (dict): Loading kwargs to add to the initialization
375
+ of the dataset.
376
+
377
+ Returns:
378
+ DictDataset: dataset
379
+ """
380
+ if "transform" in load_kwargs.keys():
381
+ transform = load_kwargs["transform"]
382
+ load_kwargs.pop("transform")
383
+ else:
384
+
385
+ def transform(x):
386
+ return x
387
+
388
+ dataset = hf_load_dataset(dataset_id, **load_kwargs)
389
+
390
+ def transform_full(examples):
391
+ examples = transform(examples)
392
+ examples["label"] = [
393
+ cls._default_target_transform(example) for example in examples["label"]
394
+ ]
395
+ return examples
396
+
397
+ dataset = dataset.with_transform(transform_full)
398
+ return dataset # HF datasets are already dict-based
399
+
400
+ @classmethod
401
+ def load_from_torchvision(
402
+ cls,
403
+ dataset_id: str,
404
+ load_kwargs: dict = {},
405
+ ) -> DictDataset:
406
+ """Load a Dataset from the torchvision datasets catalog
407
+
408
+ Args:
409
+ dataset_id (str): Identifier of the dataset
410
+ root (str): Root directory of dataset
411
+ transform (Callable, optional): Transform function to apply to the input.
412
+ Defaults to DEFAULT_TRANSFORM.
413
+ target_transform (Callable, optional): Transform function to apply
414
+ to the target. Defaults to DEFAULT_TARGET_TRANSFORM.
415
+ download (bool): If true, downloads the dataset from the internet and puts
416
+ it in root directory. If dataset is already downloaded, it is not
417
+ downloaded again. Defaults to False.
418
+ load_kwargs (dict): Loading kwargs to add to the initialization
419
+ of dataset.
420
+
421
+ Returns:
422
+ DictDataset: dataset
423
+ """
424
+ assert (
425
+ dataset_id in torchvision.datasets.__all__
426
+ ), "Dataset not available on torchvision datasets catalog"
427
+
428
+ if "transform" not in load_kwargs.keys():
429
+ load_kwargs["transform"] = torchvision.transforms.PILToTensor()
430
+ if "target_transform" not in load_kwargs.keys():
431
+ load_kwargs["target_transform"] = cls._default_target_transform
432
+
433
+ dataset = getattr(torchvision.datasets, dataset_id)(
434
+ **load_kwargs,
435
+ )
436
+ return cls.load_custom_dataset(dataset)
437
+
438
+ @staticmethod
439
+ @dict_only_ds
440
+ def get_ds_column_names(dataset: DictDataset) -> list:
441
+ """Get the column names of a DictDataset
442
+
443
+ Args:
444
+ dataset (DictDataset): Dataset to get the column names from
445
+
446
+ Returns:
447
+ list: List of column names
448
+ """
449
+ return dataset.column_names
450
+
451
+ @staticmethod
452
+ def map_ds(
453
+ dataset: DictDataset,
454
+ map_fn: Callable,
455
+ ) -> DictDataset:
456
+ """Map a function to a DictDataset
457
+
458
+ Args:
459
+ dataset (DictDataset): Dataset to map the function to
460
+ map_fn (Callable): Function to map
461
+
462
+ Returns:
463
+ DictDataset: Mapped dataset
464
+ """
465
+ return dataset.map(map_fn)
466
+
467
+ @staticmethod
468
+ @dict_only_ds
469
+ def filter_by_value(
470
+ dataset: DictDataset,
471
+ column_name: str,
472
+ values: list,
473
+ excluded: bool = False,
474
+ ) -> DictDataset:
475
+ """Filter the dataset by checking if the value of a column is in `values`
476
+
477
+ !!! note
478
+ This function can be a bit of time consuming since it needs to iterate
479
+ over the whole dataset.
480
+
481
+ Args:
482
+ dataset (DictDataset): Dataset to filter
483
+ column_name (str): Column to filter the dataset with
484
+ values (list): Column values to keep
485
+ excluded (bool, optional): To keep (False) or exclude (True) the samples
486
+ with column values included in Values. Defaults to False.
487
+
488
+ Returns:
489
+ DictDataset: Filtered dataset
490
+ """
491
+
492
+ if len(dataset[0][column_name].shape) > 0:
493
+ value_dim = dataset[0][column_name].shape[-1]
494
+ values = [
495
+ F.one_hot(torch.tensor(value).long(), value_dim) for value in values
496
+ ]
497
+
498
+ def filter_fn(x):
499
+ keep = any([torch.all(x[column_name] == v) for v in values])
500
+ return keep if not excluded else not keep
501
+
502
+ filtered_dataset = dataset.filter(filter_fn)
503
+ return filtered_dataset
504
+
505
+ @classmethod
506
+ def prepare(
507
+ cls,
508
+ dataset: DictDataset,
509
+ batch_size: int,
510
+ preprocess_fn: Optional[Callable] = None,
511
+ augment_fn: Optional[Callable] = None,
512
+ columns: Optional[list] = None,
513
+ shuffle: bool = False,
514
+ dict_based_fns: bool = True,
515
+ return_tuple: bool = True,
516
+ num_workers: int = 0,
517
+ ) -> DataLoader:
518
+ """Prepare a DataLoader for training
519
+
520
+ Args:
521
+ dataset (DictDataset): Dataset to prepare
522
+ batch_size (int): Batch size
523
+ preprocess_fn (Callable, optional): Preprocessing function to apply to
524
+ the dataset. Defaults to None.
525
+ augment_fn (Callable, optional): Augment function to be used (when the
526
+ returned dataset is to be used for training). Defaults to None.
527
+ columns (list, optional): List of column names corresponding to the columns
528
+ that will be returned. Keep all features if None. Defaults to None.
529
+ shuffle (bool, optional): To shuffle the returned dataset or not.
530
+ Defaults to False.
531
+ dict_based_fns (bool): Whether to use preprocess and DA functions as dict
532
+ based (if True) or as tuple based (if False). Defaults to True.
533
+ return_tuple (bool, optional): Whether to return each dataset item
534
+ as a tuple. Defaults to True.
535
+ num_workers (int, optional): Number of workers to use for the dataloader.
536
+
537
+ Returns:
538
+ DataLoader: dataloader
539
+ """
540
+ columns = columns or cls.get_ds_column_names(dataset)
541
+
542
+ def collate_fn(batch: List[dict]):
543
+ if dict_based_fns:
544
+ # preprocess + DA: List[dict] -> List[dict]
545
+ preprocess_func = preprocess_fn or (lambda x: x)
546
+ augment_func = augment_fn or (lambda x: x)
547
+ batch = [augment_func(preprocess_func(d)) for d in batch]
548
+ # to dict of batchs
549
+ if return_tuple:
550
+ return tuple(
551
+ default_collate([d[key] for d in batch]) for key in columns
552
+ )
553
+ return {
554
+ key: default_collate([d[key] for d in batch]) for key in columns
555
+ }
556
+ else:
557
+ # preprocess + DA: List[dict] -> List[tuple]
558
+ preprocess_func = preprocess_fn or (lambda *x: x)
559
+ augment_func = augment_fn or (lambda *x: x)
560
+ batch = [
561
+ augment_func(*preprocess_func(*tuple(d[key] for key in columns)))
562
+ for d in batch
563
+ ]
564
+ # to tuple of batchs
565
+ return default_collate(batch)
566
+
567
+ loader = DataLoader(
568
+ dataset,
569
+ batch_size=batch_size,
570
+ shuffle=shuffle,
571
+ collate_fn=collate_fn,
572
+ num_workers=num_workers,
573
+ )
574
+ return loader
575
+
576
+ @staticmethod
577
+ def get_item_length(dataset: Dataset) -> int:
578
+ """Number of elements in a dataset item
579
+
580
+ Args:
581
+ dataset (DictDataset): Dataset
582
+
583
+ Returns:
584
+ int: Item length
585
+ """
586
+ return len(dataset[0])
587
+
588
+ @staticmethod
589
+ def get_dataset_length(dataset: Dataset) -> int:
590
+ """Number of items in a dataset
591
+
592
+ Args:
593
+ dataset (DictDataset): Dataset
594
+
595
+ Returns:
596
+ int: Dataset length
597
+ """
598
+ return len(dataset)
599
+
600
+ @staticmethod
601
+ def get_column_elements_shape(
602
+ dataset: Dataset, column_name: Union[str, int]
603
+ ) -> tuple:
604
+ """Get the shape of the elements of a column of dataset identified by
605
+ column_name
606
+
607
+ Args:
608
+ dataset (Dataset): a Dataset
609
+ column_name (Union[str, int]): The column name to get
610
+ the element shape from.
611
+
612
+ Returns:
613
+ tuple: the shape of an element from column_name
614
+ """
615
+ return tuple(dataset[0][column_name].shape)
616
+
617
+ @staticmethod
618
+ def get_columns_shapes(dataset: Dataset) -> dict:
619
+ """Get the shapes of the elements of all columns of a dataset
620
+
621
+ Args:
622
+ dataset (Dataset): a Dataset
623
+
624
+ Returns:
625
+ dict: dictionary of column names and their corresponding shape
626
+ """
627
+ shapes = {}
628
+ for key in dataset.column_names:
629
+ try:
630
+ shapes[key] = tuple(dataset[0][key].shape)
631
+ except AttributeError:
632
+ pass
633
+ return shapes
634
+
635
+ @staticmethod
636
+ def get_input_from_dataset_item(elem: ItemType) -> Any:
637
+ """Get the tensor that is to be feed as input to a model from a dataset element.
638
+
639
+ Args:
640
+ elem (ItemType): dataset element to extract input from
641
+
642
+ Returns:
643
+ Any: Input tensor
644
+ """
645
+ if isinstance(elem, (tuple, list)):
646
+ tensor = elem[0]
647
+ elif isinstance(elem, dict):
648
+ tensor = elem[list(elem.keys())[0]]
649
+ else:
650
+ tensor = elem
651
+ return tensor
652
+
653
+ @staticmethod
654
+ def get_label_from_dataset_item(item: ItemType):
655
+ """Retrieve label tensor from item as a tuple/list. Label must be at index 1
656
+ in the item tuple. If one-hot encoded, labels are converted to single value.
657
+
658
+ Args:
659
+ elem (ItemType): dataset element to extract label from
660
+
661
+ Returns:
662
+ Any: Label tensor
663
+ """
664
+ label = item[1] # labels must be at index 1 in the batch tuple
665
+ # If labels are one-hot encoded, take the argmax
666
+ if len(label.shape) > 1 and label.shape[1] > 1:
667
+ label = label.view(label.size(0), -1)
668
+ label = torch.argmax(label, dim=1)
669
+ # If labels are in two dimensions, squeeze them
670
+ if len(label.shape) > 1:
671
+ label = label.view([label.shape[0]])
672
+ return label
@@ -0,0 +1,22 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright IRT Antoine de Saint Exupéry et Université Paul Sabatier Toulouse III - All
3
+ # rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
4
+ # CRIAQ and ANITI - https://www.deel.ai/
5
+ #
6
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
7
+ # of this software and associated documentation files (the "Software"), to deal
8
+ # in the Software without restriction, including without limitation the rights
9
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
+ # copies of the Software, and to permit persons to whom the Software is
11
+ # furnished to do so, subject to the following conditions:
12
+ #
13
+ # The above copyright notice and this permission notice shall be included in all
14
+ # copies or substantial portions of the Software.
15
+ #
16
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
+ # SOFTWARE.