oodeel 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. oodeel/__init__.py +28 -0
  2. oodeel/aggregator/__init__.py +26 -0
  3. oodeel/aggregator/base.py +70 -0
  4. oodeel/aggregator/fisher.py +259 -0
  5. oodeel/aggregator/mean.py +72 -0
  6. oodeel/aggregator/std.py +86 -0
  7. oodeel/datasets/__init__.py +24 -0
  8. oodeel/datasets/data_handler.py +334 -0
  9. oodeel/datasets/deprecated/DEPRECATED_data_handler.py +236 -0
  10. oodeel/datasets/deprecated/DEPRECATED_ooddataset.py +330 -0
  11. oodeel/datasets/deprecated/DEPRECATED_tf_data_handler.py +671 -0
  12. oodeel/datasets/deprecated/DEPRECATED_torch_data_handler.py +769 -0
  13. oodeel/datasets/deprecated/__init__.py +31 -0
  14. oodeel/datasets/tf_data_handler.py +600 -0
  15. oodeel/datasets/torch_data_handler.py +672 -0
  16. oodeel/eval/__init__.py +22 -0
  17. oodeel/eval/metrics.py +218 -0
  18. oodeel/eval/plots/__init__.py +27 -0
  19. oodeel/eval/plots/features.py +345 -0
  20. oodeel/eval/plots/metrics.py +118 -0
  21. oodeel/eval/plots/plotly.py +162 -0
  22. oodeel/extractor/__init__.py +35 -0
  23. oodeel/extractor/feature_extractor.py +187 -0
  24. oodeel/extractor/hf_torch_feature_extractor.py +184 -0
  25. oodeel/extractor/keras_feature_extractor.py +409 -0
  26. oodeel/extractor/torch_feature_extractor.py +506 -0
  27. oodeel/methods/__init__.py +47 -0
  28. oodeel/methods/base.py +570 -0
  29. oodeel/methods/dknn.py +185 -0
  30. oodeel/methods/energy.py +119 -0
  31. oodeel/methods/entropy.py +113 -0
  32. oodeel/methods/gen.py +113 -0
  33. oodeel/methods/gram.py +274 -0
  34. oodeel/methods/mahalanobis.py +209 -0
  35. oodeel/methods/mls.py +113 -0
  36. oodeel/methods/odin.py +109 -0
  37. oodeel/methods/rmds.py +172 -0
  38. oodeel/methods/she.py +159 -0
  39. oodeel/methods/vim.py +273 -0
  40. oodeel/preprocess/__init__.py +31 -0
  41. oodeel/preprocess/tf_preprocess.py +95 -0
  42. oodeel/preprocess/torch_preprocess.py +97 -0
  43. oodeel/types/__init__.py +75 -0
  44. oodeel/utils/__init__.py +38 -0
  45. oodeel/utils/general_utils.py +97 -0
  46. oodeel/utils/operator.py +253 -0
  47. oodeel/utils/tf_operator.py +269 -0
  48. oodeel/utils/tf_training_tools.py +219 -0
  49. oodeel/utils/torch_operator.py +292 -0
  50. oodeel/utils/torch_training_tools.py +303 -0
  51. oodeel-0.4.0.dist-info/METADATA +409 -0
  52. oodeel-0.4.0.dist-info/RECORD +63 -0
  53. oodeel-0.4.0.dist-info/WHEEL +5 -0
  54. oodeel-0.4.0.dist-info/licenses/LICENSE +21 -0
  55. oodeel-0.4.0.dist-info/top_level.txt +2 -0
  56. tests/__init__.py +22 -0
  57. tests/tests_tensorflow/__init__.py +37 -0
  58. tests/tests_tensorflow/tf_methods_utils.py +140 -0
  59. tests/tests_tensorflow/tools_tf.py +86 -0
  60. tests/tests_torch/__init__.py +38 -0
  61. tests/tests_torch/tools_torch.py +151 -0
  62. tests/tests_torch/torch_methods_utils.py +148 -0
  63. tests/tools_operator.py +153 -0
@@ -0,0 +1,769 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright IRT Antoine de Saint Exupéry et Université Paul Sabatier Toulouse III - All
3
+ # rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
4
+ # CRIAQ and ANITI - https://www.deel.ai/
5
+ #
6
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
7
+ # of this software and associated documentation files (the "Software"), to deal
8
+ # in the Software without restriction, including without limitation the rights
9
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
+ # copies of the Software, and to permit persons to whom the Software is
11
+ # furnished to do so, subject to the following conditions:
12
+ #
13
+ # The above copyright notice and this permission notice shall be included in all
14
+ # copies or substantial portions of the Software.
15
+ #
16
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
+ # SOFTWARE.
23
+ import copy
24
+ from typing import get_args
25
+
26
+ import numpy as np
27
+ import torch
28
+ import torch.nn.functional as F
29
+ import torchvision
30
+ from torch.utils.data import ConcatDataset
31
+ from torch.utils.data import DataLoader
32
+ from torch.utils.data import Dataset
33
+ from torch.utils.data import Subset
34
+ from torch.utils.data import TensorDataset
35
+ from torch.utils.data.dataloader import default_collate
36
+
37
+ from ...types import Any
38
+ from ...types import Callable
39
+ from ...types import ItemType
40
+ from ...types import List
41
+ from ...types import Optional
42
+ from ...types import TensorType
43
+ from ...types import Tuple
44
+ from ...types import Union
45
+ from .DEPRECATED_data_handler import DataHandler
46
+
47
+
48
+ def dict_only_ds(ds_handling_method: Callable) -> Callable:
49
+ """Decorator to ensure that the dataset is a dict dataset and that the input key
50
+ matches one of the feature keys. The signature of decorated functions
51
+ must be function(dataset, *args, **kwargs) with feature_key either in kwargs or
52
+ args[0] when relevant.
53
+
54
+
55
+ Args:
56
+ ds_handling_method: method to decorate
57
+
58
+ Returns:
59
+ decorated method
60
+ """
61
+
62
+ def wrapper(dataset: Dataset, *args, **kwargs):
63
+ assert isinstance(
64
+ dataset, DictDataset
65
+ ), "Dataset must be an instance of DictDataset"
66
+
67
+ if "feature_key" in kwargs:
68
+ feature_key = kwargs["feature_key"]
69
+ elif len(args) > 0:
70
+ feature_key = args[0]
71
+
72
+ # If feature_key is provided, check that it is in the dataset feature keys
73
+ if (len(args) > 0) or ("feature_key" in kwargs):
74
+ if isinstance(feature_key, str):
75
+ feature_key = [feature_key]
76
+ for key in feature_key:
77
+ assert (
78
+ key in dataset.output_keys
79
+ ), f"The input dataset has no feature names {key}"
80
+ return ds_handling_method(dataset, *args, **kwargs)
81
+
82
+ return wrapper
83
+
84
+
85
+ def to_torch(array: TensorType) -> torch.Tensor:
86
+ """Convert an array into a torch Tensor
87
+
88
+ Args:
89
+ array (TensorType): array to convert
90
+
91
+ Returns:
92
+ torch.Tensor: converted array
93
+ """
94
+ if isinstance(array, np.ndarray):
95
+ return torch.Tensor(array)
96
+ elif isinstance(array, torch.Tensor):
97
+ return array
98
+ else:
99
+ raise TypeError("Input array must be of numpy or torch type")
100
+
101
+
102
+ class DictDataset(Dataset):
103
+ r"""Dictionary pytorch dataset
104
+
105
+ Wrapper to output a dictionary of tensors at the __getitem__ call of a dataset.
106
+ Some mapping, filtering and concatenation methods are implemented to imitate
107
+ tensorflow datasets features.
108
+
109
+ Args:
110
+ dataset (Dataset): Dataset to wrap.
111
+ output_keys (output_keys[str]): Keys describing the output tensors.
112
+ """
113
+
114
+ def __init__(
115
+ self, dataset: Dataset, output_keys: List[str] = ["input", "label"]
116
+ ) -> None:
117
+ self._dataset = dataset
118
+ self._raw_output_keys = output_keys
119
+ self.map_fns = []
120
+ self._check_init_args()
121
+
122
+ @property
123
+ def output_keys(self) -> list:
124
+ """Get the list of keys in a dict-based item from the dataset.
125
+
126
+ Returns:
127
+ list: feature keys of the dataset.
128
+ """
129
+ dummy_item = self[0]
130
+ return list(dummy_item.keys())
131
+
132
+ @property
133
+ def output_shapes(self) -> list:
134
+ """Get a list of the tensor shapes in an item from the dataset.
135
+
136
+ Returns:
137
+ list: tensor shapes of an dataset item.
138
+ """
139
+ dummy_item = self[0]
140
+ return [dummy_item[key].shape for key in self.output_keys]
141
+
142
+ def _check_init_args(self) -> None:
143
+ """Check validity of dataset and output keys provided at init"""
144
+ dummy_item = self._dataset[0]
145
+ assert isinstance(
146
+ dummy_item, (tuple, dict, list, torch.Tensor)
147
+ ), "Dataset to be wrapped needs to return tuple, list or dict of tensors"
148
+ if isinstance(dummy_item, torch.Tensor):
149
+ dummy_item = [dummy_item]
150
+ assert len(dummy_item) == len(
151
+ self._raw_output_keys
152
+ ), "Length mismatch between dataset item and provided keys"
153
+
154
+ def __getitem__(self, index: int) -> dict:
155
+ """Return a dictionary of tensors corresponding to a specfic index.
156
+
157
+ Args:
158
+ index (int): the index of the item to retrieve.
159
+
160
+ Returns:
161
+ dict: tensors for the item at the specific index.
162
+ """
163
+ item = self._dataset[index]
164
+
165
+ # convert item to a list / tuple of tensors
166
+ if isinstance(item, torch.Tensor):
167
+ tensors = [item]
168
+ elif isinstance(item, dict):
169
+ tensors = list(item.values())
170
+ else:
171
+ tensors = item
172
+
173
+ # build output dictionary
174
+ output_dict = {
175
+ key: tensor for (key, tensor) in zip(self._raw_output_keys, tensors)
176
+ }
177
+
178
+ # apply map functions
179
+ for map_fn in self.map_fns:
180
+ output_dict = map_fn(output_dict)
181
+ return output_dict
182
+
183
+ def map(self, map_fn: Callable, inplace: bool = False) -> "DictDataset":
184
+ """Map the dataset
185
+
186
+ Args:
187
+ map_fn (Callable): map function f: dict -> dict
188
+ inplace (bool): if False, applies the mapping on a copied version of\
189
+ the dataset. Defaults to False.
190
+
191
+ Return:
192
+ DictDataset: Mapped dataset
193
+ """
194
+ dataset = self if inplace else copy.deepcopy(self)
195
+ dataset.map_fns.append(map_fn)
196
+ return dataset
197
+
198
+ def filter(self, filter_fn: Callable, inplace: bool = False) -> "DictDataset":
199
+ """Filter the dataset
200
+
201
+ Args:
202
+ filter_fn (Callable): filter function f: dict -> bool
203
+ inplace (bool): if False, applies the filtering on a copied version of\
204
+ the dataset. Defaults to False.
205
+
206
+ Returns:
207
+ DictDataset: Filtered dataset
208
+ """
209
+ indices = [i for i in range(len(self)) if filter_fn(self[i])]
210
+ dataset = self if inplace else copy.deepcopy(self)
211
+ dataset._dataset = Subset(self._dataset, indices)
212
+ return dataset
213
+
214
+ def concatenate(
215
+ self, other_dataset: Dataset, inplace: bool = False
216
+ ) -> "DictDataset":
217
+ """Concatenate with another dataset
218
+
219
+ Args:
220
+ other_dataset (DictDataset): Dataset to concatenate with
221
+ inplace (bool): if False, applies the filtering on a copied version of\
222
+ the dataset. Defaults to False.
223
+
224
+ Returns:
225
+ DictDataset: Concatenated dataset
226
+ """
227
+ assert isinstance(
228
+ other_dataset, DictDataset
229
+ ), "Second dataset should be an instance of DictDataset"
230
+ assert (
231
+ self.output_keys == other_dataset.output_keys
232
+ ), "Incompatible dataset elements (different dict keys)"
233
+ if inplace:
234
+ dataset_copy = copy.deepcopy(self)
235
+ self._raw_output_keys = self.output_keys
236
+ self.map_fns = []
237
+ self._dataset = ConcatDataset([dataset_copy, other_dataset])
238
+ dataset = self
239
+ else:
240
+ dataset = DictDataset(
241
+ ConcatDataset([self, other_dataset]), self.output_keys
242
+ )
243
+ return dataset
244
+
245
+ def __len__(self) -> int:
246
+ """Return the length of the dataset, i.e. the number of items.
247
+
248
+ Returns:
249
+ int: length of the dataset.
250
+ """
251
+ return len(self._dataset)
252
+
253
+
254
+ class TorchDataHandler(DataHandler):
255
+ """
256
+ Class to manage torch DictDataset. The aim is to provide a simple interface
257
+ for working with torch datasets and manage them without having to use
258
+ torch syntax.
259
+ """
260
+
261
+ @staticmethod
262
+ def _default_target_transform(y: Any) -> torch.Tensor:
263
+ """Format int or float item target as a torch tensor
264
+
265
+ Args:
266
+ y (Any): dataset item target
267
+
268
+ Returns:
269
+ torch.Tensor: target as a torch.Tensor
270
+ """
271
+ return torch.tensor(y) if isinstance(y, (float, int)) else y
272
+
273
+ DEFAULT_TRANSFORM = torchvision.transforms.PILToTensor()
274
+ DEFAULT_TARGET_TRANSFORM = _default_target_transform.__func__
275
+
276
+ @classmethod
277
+ def load_dataset(
278
+ cls,
279
+ dataset_id: Union[Dataset, ItemType, str],
280
+ keys: Optional[list] = None,
281
+ load_kwargs: dict = {},
282
+ ) -> DictDataset:
283
+ """Load dataset from different manners
284
+
285
+ Args:
286
+ dataset_id (Union[Dataset, ItemType, str]): dataset identification
287
+ keys (list, optional): Features keys. If None, assigned as "input_i"
288
+ for i-th feature. Defaults to None.
289
+ load_kwargs (dict, optional): Additional loading kwargs. Defaults to {}.
290
+
291
+ Returns:
292
+ DictDataset: dataset
293
+ """
294
+ if isinstance(dataset_id, str):
295
+ assert "root" in load_kwargs.keys()
296
+ dataset = cls.load_from_torchvision(dataset_id, **load_kwargs)
297
+ elif isinstance(dataset_id, Dataset):
298
+ dataset = cls.load_custom_dataset(dataset_id, keys)
299
+ elif isinstance(dataset_id, get_args(ItemType)):
300
+ dataset = cls.load_dataset_from_arrays(dataset_id, keys)
301
+ return dataset
302
+
303
+ @staticmethod
304
+ def load_dataset_from_arrays(
305
+ dataset_id: ItemType,
306
+ keys: Optional[list] = None,
307
+ ) -> DictDataset:
308
+ """Load a torch.utils.data.Dataset from an array or a tuple/dict of arrays.
309
+
310
+ Args:
311
+ dataset_id (ItemType):
312
+ numpy / torch array(s) to load.
313
+ keys (list, optional): Features keys. If None, assigned as "input_i"
314
+ for i-th feature. Defaults to None.
315
+
316
+ Returns:
317
+ DictDataset: dataset
318
+ """
319
+ # If dataset_id is an array
320
+ if isinstance(dataset_id, get_args(TensorType)):
321
+ tensors = tuple(to_torch(dataset_id))
322
+ output_keys = keys or ["input"]
323
+
324
+ # If dataset_id is a tuple of arrays
325
+ elif isinstance(dataset_id, tuple):
326
+ len_elem = len(dataset_id)
327
+ output_keys = keys
328
+ if output_keys is None:
329
+ if len_elem == 2:
330
+ output_keys = ["input", "label"]
331
+ else:
332
+ output_keys = [f"input_{i}" for i in range(len_elem - 1)] + [
333
+ "label"
334
+ ]
335
+ print(
336
+ "Loading torch.utils.data.Dataset with elems as dicts, "
337
+ 'assigning "input_i" key to the i-th tuple dimension and'
338
+ ' "label" key to the last tuple dimension.'
339
+ )
340
+ assert len(output_keys) == len(dataset_id)
341
+ tensors = tuple(to_torch(array) for array in dataset_id)
342
+
343
+ # If dataset_id is a dictionary of arrays
344
+ elif isinstance(dataset_id, dict):
345
+ output_keys = keys or list(dataset_id.keys())
346
+ assert len(output_keys) == len(dataset_id)
347
+ tensors = tuple(to_torch(array) for array in dataset_id.values())
348
+
349
+ # create torch dictionary dataset from tensors tuple and keys
350
+ dataset = DictDataset(TensorDataset(*tensors), output_keys)
351
+ return dataset
352
+
353
+ @staticmethod
354
+ def load_custom_dataset(
355
+ dataset_id: Dataset, keys: Optional[list] = None
356
+ ) -> DictDataset:
357
+ """Load a custom Dataset by ensuring it has the correct format (dict-based)
358
+
359
+ Args:
360
+ dataset_id (Dataset): Dataset
361
+ keys (list, optional): Keys to use for features if dataset_id is
362
+ tuple based. Defaults to None.
363
+
364
+ Returns:
365
+ DictDataset
366
+ """
367
+ # If dataset_id is a tuple based Dataset, convert it to a DictDataset
368
+ dummy_item = dataset_id[0]
369
+ if not isinstance(dummy_item, dict):
370
+ assert isinstance(
371
+ dummy_item, (Tuple, torch.Tensor)
372
+ ), "Custom dataset should be either dictionary based or tuple based"
373
+ output_keys = keys
374
+ if output_keys is None:
375
+ len_elem = len(dummy_item)
376
+ if len_elem == 2:
377
+ output_keys = ["input", "label"]
378
+ else:
379
+ output_keys = [f"input_{i}" for i in range(len_elem - 1)] + [
380
+ "label"
381
+ ]
382
+ print(
383
+ "Feature name not found, assigning 'input_i' "
384
+ "key to the i-th tensor and 'label' key to the last"
385
+ )
386
+ dataset_id = DictDataset(dataset_id, output_keys)
387
+
388
+ dataset = dataset_id
389
+ return dataset
390
+
391
+ @classmethod
392
+ def load_from_torchvision(
393
+ cls,
394
+ dataset_id: str,
395
+ root: str,
396
+ transform: Callable = DEFAULT_TRANSFORM,
397
+ target_transform: Callable = DEFAULT_TARGET_TRANSFORM,
398
+ download: bool = False,
399
+ **load_kwargs,
400
+ ) -> DictDataset:
401
+ """Load a Dataset from the torchvision datasets catalog
402
+
403
+ Args:
404
+ dataset_id (str): Identifier of the dataset
405
+ root (str): Root directory of dataset
406
+ transform (Callable, optional): Transform function to apply to the input.
407
+ Defaults to DEFAULT_TRANSFORM.
408
+ target_transform (Callable, optional): Transform function to apply
409
+ to the target. Defaults to DEFAULT_TARGET_TRANSFORM.
410
+ download (bool): If true, downloads the dataset from the internet and puts
411
+ it in root directory. If dataset is already downloaded, it is not
412
+ downloaded again. Defaults to False.
413
+ load_kwargs (dict): Loading kwargs to add to the initialization
414
+ of dataset.
415
+
416
+ Returns:
417
+ DictDataset: dataset
418
+ """
419
+ assert (
420
+ dataset_id in torchvision.datasets.__all__
421
+ ), "Dataset not available on torchvision datasets catalog"
422
+ dataset = getattr(torchvision.datasets, dataset_id)(
423
+ root=root,
424
+ download=download,
425
+ transform=transform,
426
+ target_transform=target_transform,
427
+ **load_kwargs,
428
+ )
429
+ return cls.load_custom_dataset(dataset)
430
+
431
+ @staticmethod
432
+ def assign_feature_value(
433
+ dataset: DictDataset, feature_key: str, value: int
434
+ ) -> DictDataset:
435
+ """Assign a value to a feature for every sample in a DictDataset
436
+
437
+ Args:
438
+ dataset (DictDataset): DictDataset to assign the value to
439
+ feature_key (str): Feature to assign the value to
440
+ value (int): Value to assign
441
+
442
+ Returns:
443
+ DictDataset
444
+ """
445
+ assert isinstance(
446
+ dataset, DictDataset
447
+ ), "Dataset must be an instance of DictDataset"
448
+
449
+ def assign_value_to_feature(x):
450
+ x[feature_key] = torch.tensor(value)
451
+ return x
452
+
453
+ dataset = dataset.map(assign_value_to_feature)
454
+ return dataset
455
+
456
+ @staticmethod
457
+ @dict_only_ds
458
+ def get_feature_from_ds(dataset: DictDataset, feature_key: str) -> np.ndarray:
459
+ """Get a feature from a DictDataset
460
+
461
+ !!! note
462
+ This function can be a bit time consuming since it needs to iterate
463
+ over the whole dataset.
464
+
465
+ Args:
466
+ dataset (DictDataset): Dataset to get the feature from
467
+ feature_key (str): Feature value to get
468
+
469
+ Returns:
470
+ np.ndarray: Feature values for dataset
471
+ """
472
+
473
+ features = dataset.map(lambda x: x[feature_key])
474
+ features = np.stack([f.numpy() for f in features])
475
+ return features
476
+
477
+ @staticmethod
478
+ @dict_only_ds
479
+ def get_ds_feature_keys(dataset: DictDataset) -> list:
480
+ """Get the feature keys of a DictDataset
481
+
482
+ Args:
483
+ dataset (DictDataset): Dataset to get the feature keys from
484
+
485
+ Returns:
486
+ list: List of feature keys
487
+ """
488
+ return dataset.output_keys
489
+
490
+ @staticmethod
491
+ def has_feature_key(dataset: DictDataset, key: str) -> bool:
492
+ """Check if a DictDataset has a feature denoted by key
493
+
494
+ Args:
495
+ dataset (DictDataset): Dataset to check
496
+ key (str): Key to check
497
+
498
+ Returns:
499
+ bool: If the dataset has a feature denoted by key
500
+ """
501
+ assert isinstance(
502
+ dataset, DictDataset
503
+ ), "Dataset must be an instance of DictDataset"
504
+
505
+ return key in dataset.output_keys
506
+
507
+ @staticmethod
508
+ def map_ds(
509
+ dataset: DictDataset,
510
+ map_fn: Callable,
511
+ ) -> DictDataset:
512
+ """Map a function to a DictDataset
513
+
514
+ Args:
515
+ dataset (DictDataset): Dataset to map the function to
516
+ map_fn (Callable): Function to map
517
+
518
+ Returns:
519
+ DictDataset: Mapped dataset
520
+ """
521
+ return dataset.map(map_fn)
522
+
523
+ @staticmethod
524
+ @dict_only_ds
525
+ def filter_by_feature_value(
526
+ dataset: DictDataset,
527
+ feature_key: str,
528
+ values: list,
529
+ excluded: bool = False,
530
+ ) -> DictDataset:
531
+ """Filter the dataset by checking the value of a feature is in `values`
532
+
533
+ !!! note
534
+ This function can be a bit of time consuming since it needs to iterate
535
+ over the whole dataset.
536
+
537
+ Args:
538
+ dataset (DictDataset): Dataset to filter
539
+ feature_key (str): Feature name to check the value
540
+ values (list): Feature_key values to keep
541
+ excluded (bool, optional): To keep (False) or exclude (True) the samples
542
+ with Feature_key value included in Values. Defaults to False.
543
+
544
+ Returns:
545
+ DictDataset: Filtered dataset
546
+ """
547
+
548
+ if len(dataset[0][feature_key].shape) > 0:
549
+ value_dim = dataset[0][feature_key].shape[-1]
550
+ values = [
551
+ F.one_hot(torch.tensor(value).long(), value_dim) for value in values
552
+ ]
553
+
554
+ def filter_fn(x):
555
+ keep = any([torch.all(x[feature_key] == v) for v in values])
556
+ return keep if not excluded else not keep
557
+
558
+ filtered_dataset = dataset.filter(filter_fn)
559
+ return filtered_dataset
560
+
561
+ @classmethod
562
+ def prepare_for_training(
563
+ cls,
564
+ dataset: DictDataset,
565
+ batch_size: int,
566
+ shuffle: bool = False,
567
+ preprocess_fn: Optional[Callable] = None,
568
+ augment_fn: Optional[Callable] = None,
569
+ output_keys: Optional[list] = None,
570
+ dict_based_fns: bool = False,
571
+ shuffle_buffer_size: Optional[int] = None,
572
+ num_workers: int = 8,
573
+ ) -> DataLoader:
574
+ """Prepare a DataLoader for training
575
+
576
+ Args:
577
+ dataset (DictDataset): Dataset to prepare
578
+ batch_size (int): Batch size
579
+ shuffle (bool): Wether to shuffle the dataloader or not
580
+ preprocess_fn (Callable, optional): Preprocessing function to apply to
581
+ the dataset. Defaults to None.
582
+ augment_fn (Callable, optional): Augment function to be used (when the
583
+ returned dataset is to be used for training). Defaults to None.
584
+ output_keys (list): List of keys corresponding to the features that will be
585
+ returned. Keep all features if None. Defaults to None.
586
+ dict_based_fns (bool): Whether to use preprocess and DA functions as dict
587
+ based (if True) or as tuple based (if False). Defaults to False.
588
+ shuffle_buffer_size (int, optional): Size of the shuffle buffer. Not used
589
+ in torch because we only rely on Map-Style datasets. Still as argument
590
+ for API consistency. Defaults to None.
591
+ num_workers (int, optional): Number of workers to use for the dataloader.
592
+
593
+ Returns:
594
+ DataLoader: dataloader
595
+ """
596
+ output_keys = output_keys or cls.get_ds_feature_keys(dataset)
597
+
598
+ def collate_fn(batch: List[dict]):
599
+ if dict_based_fns:
600
+ # preprocess + DA: List[dict] -> List[dict]
601
+ preprocess_func = preprocess_fn or (lambda x: x)
602
+ augment_func = augment_fn or (lambda x: x)
603
+ batch = [augment_func(preprocess_func(d)) for d in batch]
604
+ # to tuple of batchs
605
+ return tuple(
606
+ default_collate([d[key] for d in batch]) for key in output_keys
607
+ )
608
+ else:
609
+ # preprocess + DA: List[dict] -> List[tuple]
610
+ preprocess_func = preprocess_fn or (lambda *x: x)
611
+ augment_func = augment_fn or (lambda *x: x)
612
+ batch = [
613
+ augment_func(
614
+ *preprocess_func(*tuple(d[key] for key in output_keys))
615
+ )
616
+ for d in batch
617
+ ]
618
+ # to tuple of batchs
619
+ return default_collate(batch)
620
+
621
+ loader = DataLoader(
622
+ dataset,
623
+ batch_size=batch_size,
624
+ shuffle=shuffle,
625
+ collate_fn=collate_fn,
626
+ num_workers=num_workers,
627
+ )
628
+ return loader
629
+
630
+ @staticmethod
631
+ def merge(
632
+ id_dataset: DictDataset,
633
+ ood_dataset: DictDataset,
634
+ resize: Optional[bool] = False,
635
+ shape: Optional[Tuple[int]] = None,
636
+ ) -> DictDataset:
637
+ """Merge two instances of DictDataset
638
+
639
+ Args:
640
+ id_dataset (DictDataset): dataset of in-distribution data
641
+ ood_dataset (DictDataset): dataset of out-of-distribution data
642
+ resize (Optional[bool], optional): toggles if input tensors of the
643
+ datasets have to be resized to have the same shape. Defaults to True.
644
+ shape (Optional[Tuple[int]], optional): shape to use for resizing input
645
+ tensors. If None, the tensors are resized with the shape of the
646
+ id_dataset input tensors. Defaults to None.
647
+
648
+ Returns:
649
+ DictDataset: merged dataset
650
+ """
651
+ # If a desired shape is given, triggers the resize
652
+ if shape is not None:
653
+ resize = True
654
+
655
+ # If the shape of the two datasets are different, triggers the resize
656
+ if id_dataset.output_shapes[0] != ood_dataset.output_shapes[0]:
657
+ resize = True
658
+ if shape is None:
659
+ print(
660
+ "Resizing the first item of elem (usually the image)",
661
+ " with the shape of id_dataset",
662
+ )
663
+ shape = id_dataset.output_shapes[0][1:]
664
+
665
+ if resize:
666
+ resize_fn = torchvision.transforms.Resize(shape)
667
+
668
+ def reshape_fn(item_dict):
669
+ item_dict["input"] = resize_fn(item_dict["input"])
670
+ return item_dict
671
+
672
+ id_dataset = id_dataset.map(reshape_fn)
673
+ ood_dataset = ood_dataset.map(reshape_fn)
674
+
675
+ merged_dataset = id_dataset.concatenate(ood_dataset)
676
+ return merged_dataset
677
+
678
+ @staticmethod
679
+ def get_item_length(dataset: Dataset) -> int:
680
+ """Number of elements in a dataset item
681
+
682
+ Args:
683
+ dataset (DictDataset): Dataset
684
+
685
+ Returns:
686
+ int: Item length
687
+ """
688
+ return len(dataset[0])
689
+
690
+ @staticmethod
691
+ def get_dataset_length(dataset: Dataset) -> int:
692
+ """Number of items in a dataset
693
+
694
+ Args:
695
+ dataset (DictDataset): Dataset
696
+
697
+ Returns:
698
+ int: Dataset length
699
+ """
700
+ return len(dataset)
701
+
702
+ @staticmethod
703
+ def get_feature_shape(dataset: Dataset, feature_key: Union[str, int]) -> tuple:
704
+ """Get the shape of a feature of dataset identified by feature_key
705
+
706
+ Args:
707
+ dataset (Dataset): a Dataset
708
+ feature_key (Union[str, int]): The identifier of the feature
709
+
710
+ Returns:
711
+ tuple: the shape of feature_id
712
+ """
713
+ return tuple(dataset[0][feature_key].shape)
714
+
715
+ @staticmethod
716
+ def get_input_from_dataset_item(elem: ItemType) -> Any:
717
+ """Get the tensor that is to be feed as input to a model from a dataset element.
718
+
719
+ Args:
720
+ elem (ItemType): dataset element to extract input from
721
+
722
+ Returns:
723
+ Any: Input tensor
724
+ """
725
+ if isinstance(elem, (tuple, list)):
726
+ tensor = elem[0]
727
+ elif isinstance(elem, dict):
728
+ tensor = elem[list(elem.keys())[0]]
729
+ else:
730
+ tensor = elem
731
+ return tensor
732
+
733
+ @staticmethod
734
+ def get_label_from_dataset_item(item: ItemType):
735
+ """Retrieve label tensor from item as a tuple/list. Label must be at index 1
736
+ in the item tuple. If one-hot encoded, labels are converted to single value.
737
+
738
+ Args:
739
+ elem (ItemType): dataset element to extract label from
740
+
741
+ Returns:
742
+ Any: Label tensor
743
+ """
744
+ label = item[1] # labels must be at index 1 in the batch tuple
745
+ # If labels are one-hot encoded, take the argmax
746
+ if len(label.shape) > 1 and label.shape[1] > 1:
747
+ label = label.view(label.size(0), -1)
748
+ label = torch.argmax(label, dim=1)
749
+ # If labels are in two dimensions, squeeze them
750
+ if len(label.shape) > 1:
751
+ label = label.view([label.shape[0]])
752
+ return label
753
+
754
+ @staticmethod
755
+ def get_feature(dataset: DictDataset, feature_key: Union[str, int]) -> DictDataset:
756
+ """Extract a feature from a dataset
757
+
758
+ Args:
759
+ dataset (tf.data.Dataset): Dataset to extract the feature from
760
+ feature_key (Union[str, int]): feature to extract
761
+
762
+ Returns:
763
+ tf.data.Dataset: dataset built with the extracted feature only
764
+ """
765
+
766
+ def _get_feature_item(item):
767
+ return item[feature_key]
768
+
769
+ return dataset.map(_get_feature_item)