oodeel 0.1.1__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of oodeel might be problematic. Click here for more details.

Files changed (47) hide show
  1. oodeel/__init__.py +1 -1
  2. oodeel/datasets/__init__.py +2 -1
  3. oodeel/datasets/data_handler.py +162 -94
  4. oodeel/datasets/deprecated/DEPRECATED_data_handler.py +236 -0
  5. oodeel/datasets/{ooddataset.py → deprecated/DEPRECATED_ooddataset.py} +14 -13
  6. oodeel/datasets/deprecated/DEPRECATED_tf_data_handler.py +671 -0
  7. oodeel/datasets/deprecated/DEPRECATED_torch_data_handler.py +769 -0
  8. oodeel/datasets/deprecated/__init__.py +31 -0
  9. oodeel/datasets/tf_data_handler.py +105 -167
  10. oodeel/datasets/torch_data_handler.py +109 -181
  11. oodeel/eval/metrics.py +7 -2
  12. oodeel/eval/plots/features.py +2 -2
  13. oodeel/eval/plots/plotly.py +2 -2
  14. oodeel/extractor/feature_extractor.py +30 -9
  15. oodeel/extractor/keras_feature_extractor.py +70 -13
  16. oodeel/extractor/torch_feature_extractor.py +120 -33
  17. oodeel/methods/__init__.py +17 -1
  18. oodeel/methods/base.py +103 -17
  19. oodeel/methods/dknn.py +22 -9
  20. oodeel/methods/energy.py +8 -0
  21. oodeel/methods/entropy.py +8 -0
  22. oodeel/methods/gen.py +118 -0
  23. oodeel/methods/gram.py +307 -0
  24. oodeel/methods/mahalanobis.py +14 -12
  25. oodeel/methods/mls.py +8 -0
  26. oodeel/methods/odin.py +8 -0
  27. oodeel/methods/rmds.py +122 -0
  28. oodeel/methods/she.py +197 -0
  29. oodeel/methods/vim.py +5 -5
  30. oodeel/preprocess/__init__.py +31 -0
  31. oodeel/preprocess/tf_preprocess.py +95 -0
  32. oodeel/preprocess/torch_preprocess.py +97 -0
  33. oodeel/utils/operator.py +72 -2
  34. oodeel/utils/tf_operator.py +72 -4
  35. oodeel/utils/tf_training_tools.py +26 -3
  36. oodeel/utils/torch_operator.py +75 -4
  37. oodeel/utils/torch_training_tools.py +31 -2
  38. {oodeel-0.1.1.dist-info → oodeel-0.3.0.dist-info}/METADATA +141 -107
  39. oodeel-0.3.0.dist-info/RECORD +57 -0
  40. {oodeel-0.1.1.dist-info → oodeel-0.3.0.dist-info}/WHEEL +1 -1
  41. tests/tests_tensorflow/tf_methods_utils.py +2 -1
  42. tests/tests_torch/tools_torch.py +9 -9
  43. tests/tests_torch/torch_methods_utils.py +34 -27
  44. tests/tools_operator.py +10 -1
  45. oodeel-0.1.1.dist-info/RECORD +0 -46
  46. {oodeel-0.1.1.dist-info → oodeel-0.3.0.dist-info/licenses}/LICENSE +0 -0
  47. {oodeel-0.1.1.dist-info → oodeel-0.3.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,671 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright IRT Antoine de Saint Exupéry et Université Paul Sabatier Toulouse III - All
3
+ # rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
4
+ # CRIAQ and ANITI - https://www.deel.ai/
5
+ #
6
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
7
+ # of this software and associated documentation files (the "Software"), to deal
8
+ # in the Software without restriction, including without limitation the rights
9
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
+ # copies of the Software, and to permit persons to whom the Software is
11
+ # furnished to do so, subject to the following conditions:
12
+ #
13
+ # The above copyright notice and this permission notice shall be included in all
14
+ # copies or substantial portions of the Software.
15
+ #
16
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
+ # SOFTWARE.
23
+ from typing import get_args
24
+
25
+ import numpy as np
26
+ import tensorflow as tf
27
+ import tensorflow_datasets as tfds
28
+
29
+ from ...types import Callable
30
+ from ...types import ItemType
31
+ from ...types import Optional
32
+ from ...types import TensorType
33
+ from ...types import Tuple
34
+ from ...types import Union
35
+ from .DEPRECATED_data_handler import DataHandler
36
+
37
+
38
+ def dict_only_ds(ds_handling_method: Callable) -> Callable:
39
+ """Decorator to ensure that the dataset is a dict dataset and that the input key
40
+ matches one of the feature keys. The signature of decorated functions
41
+ must be function(dataset, *args, **kwargs) with feature_key either in kwargs or
42
+ args[0] when relevant.
43
+
44
+
45
+ Args:
46
+ ds_handling_method: method to decorate
47
+
48
+ Returns:
49
+ decorated method
50
+ """
51
+
52
+ def wrapper(dataset: tf.data.Dataset, *args, **kwargs):
53
+ assert isinstance(dataset.element_spec, dict), "dataset elements must be dicts"
54
+
55
+ if "feature_key" in kwargs.keys():
56
+ feature_key = kwargs["feature_key"]
57
+ elif len(args) > 0:
58
+ feature_key = args[0]
59
+
60
+ # If feature_key is provided, check that it is in the dataset feature keys
61
+ if (len(args) > 0) or ("feature_key" in kwargs):
62
+ if isinstance(feature_key, str):
63
+ feature_key = [feature_key]
64
+ for key in feature_key:
65
+ assert (
66
+ key in dataset.element_spec.keys()
67
+ ), f"The input dataset has no feature names {key}"
68
+ return ds_handling_method(dataset, *args, **kwargs)
69
+
70
+ return wrapper
71
+
72
+
73
+ class TFDataHandler(DataHandler):
74
+ """
75
+ Class to manage tf.data.Dataset. The aim is to provide a simple interface for
76
+ working with tf.data.Datasets and manage them without having to use
77
+ tensorflow syntax.
78
+ """
79
+
80
+ @classmethod
81
+ def load_dataset(
82
+ cls,
83
+ dataset_id: Union[tf.data.Dataset, ItemType, str],
84
+ keys: Optional[list] = None,
85
+ load_kwargs: dict = {},
86
+ ) -> tf.data.Dataset:
87
+ """Load dataset from different manners, ensuring to return a dict based
88
+ tf.data.Dataset.
89
+
90
+ Args:
91
+ dataset_id (Any): dataset identification
92
+ keys (list, optional): Features keys. If None, assigned as "input_i"
93
+ for i-th feature. Defaults to None.
94
+ load_kwargs (dict, optional): Additional args for loading from
95
+ tensorflow_datasets. Defaults to {}.
96
+
97
+ Returns:
98
+ tf.data.Dataset: A dict based tf.data.Dataset
99
+ """
100
+ if isinstance(dataset_id, get_args(ItemType)):
101
+ dataset = cls.load_dataset_from_arrays(dataset_id, keys)
102
+ elif isinstance(dataset_id, tf.data.Dataset):
103
+ dataset = cls.load_custom_dataset(dataset_id, keys)
104
+ elif isinstance(dataset_id, str):
105
+ dataset = cls.load_from_tensorflow_datasets(dataset_id, load_kwargs)
106
+ return dataset
107
+
108
+ @staticmethod
109
+ def load_dataset_from_arrays(
110
+ dataset_id: ItemType, keys: Optional[list] = None
111
+ ) -> tf.data.Dataset:
112
+ """Load a tf.data.Dataset from a np.ndarray, a tf.Tensor or a tuple/dict
113
+ of np.ndarrays/td.Tensors.
114
+
115
+ Args:
116
+ dataset_id (ItemType): numpy array(s) to load.
117
+ keys (list, optional): Features keys. If None, assigned as "input_i"
118
+ for i-th feature. Defaults to None.
119
+
120
+ Returns:
121
+ tf.data.Dataset
122
+ """
123
+ # If dataset_id is a numpy array, convert it to a dict
124
+ if isinstance(dataset_id, get_args(TensorType)):
125
+ dataset_dict = {"input": dataset_id}
126
+
127
+ # If dataset_id is a tuple, convert it to a dict
128
+ elif isinstance(dataset_id, tuple):
129
+ len_elem = len(dataset_id)
130
+ if keys is None:
131
+ if len_elem == 2:
132
+ dataset_dict = {"input": dataset_id[0], "label": dataset_id[1]}
133
+ else:
134
+ dataset_dict = {
135
+ f"input_{i}": dataset_id[i] for i in range(len_elem - 1)
136
+ }
137
+ dataset_dict["label"] = dataset_id[-1]
138
+ print(
139
+ 'Loading tf.data.Dataset with elems as dicts, assigning "input_i" '
140
+ 'key to the i-th tuple dimension and "label" key to the last '
141
+ "tuple dimension."
142
+ )
143
+ else:
144
+ assert (
145
+ len(keys) == len_elem
146
+ ), "Number of keys mismatch with the number of features"
147
+ dataset_dict = {keys[i]: dataset_id[i] for i in range(len_elem)}
148
+
149
+ elif isinstance(dataset_id, dict):
150
+ if keys is not None:
151
+ len_elem = len(dataset_id)
152
+ assert (
153
+ len(keys) == len_elem
154
+ ), "Number of keys mismatch with the number of features"
155
+ original_keys = list(dataset_id.keys())
156
+ dataset_dict = {
157
+ keys[i]: dataset_id[original_keys[i]] for i in range(len_elem)
158
+ }
159
+
160
+ dataset = tf.data.Dataset.from_tensor_slices(dataset_dict)
161
+ return dataset
162
+
163
+ @classmethod
164
+ def load_custom_dataset(
165
+ cls, dataset_id: tf.data.Dataset, keys: Optional[list] = None
166
+ ) -> tf.data.Dataset:
167
+ """Load a custom Dataset by ensuring it has the correct format (dict-based)
168
+
169
+ Args:
170
+ dataset_id (tf.data.Dataset): tf.data.Dataset
171
+ keys (list, optional): Features keys. If None, assigned as "input_i"
172
+ for i-th feature. Defaults to None.
173
+
174
+ Returns:
175
+ tf.data.Dataset
176
+ """
177
+ # If dataset_id is a tuple based tf.data.dataset, convert it to a dict
178
+ if not isinstance(dataset_id.element_spec, dict):
179
+ len_elem = len(dataset_id.element_spec)
180
+ if keys is None:
181
+ print(
182
+ "Feature name not found, assigning 'input_i' "
183
+ "key to the i-th tensor and 'label' key to the last"
184
+ )
185
+ if len_elem == 2:
186
+ keys = ["input", "label"]
187
+ else:
188
+ keys = [f"input_{i}" for i in range(len_elem)]
189
+ keys[-1] = "label"
190
+ else:
191
+ assert (
192
+ len(keys) == len_elem
193
+ ), "Number of keys mismatch with the number of features"
194
+
195
+ dataset_id = cls.tuple_to_dict(dataset_id, keys)
196
+
197
+ dataset = dataset_id
198
+ return dataset
199
+
200
+ @staticmethod
201
+ def load_from_tensorflow_datasets(
202
+ dataset_id: str,
203
+ load_kwargs: dict = {},
204
+ ) -> tf.data.Dataset:
205
+ """Load a tf.data.Dataset from the tensorflow_datasets catalog
206
+
207
+ Args:
208
+ dataset_id (str): Identifier of the dataset
209
+ load_kwargs (dict, optional): Loading kwargs to add to tfds.load().
210
+ Defaults to {}.
211
+
212
+ Returns:
213
+ tf.data.Dataset
214
+ """
215
+ assert (
216
+ dataset_id in tfds.list_builders()
217
+ ), "Dataset not available on tensorflow datasets catalog"
218
+ dataset = tfds.load(dataset_id, **load_kwargs)
219
+ return dataset
220
+
221
+ @staticmethod
222
+ @dict_only_ds
223
+ def dict_to_tuple(
224
+ dataset: tf.data.Dataset, keys: Optional[list] = None
225
+ ) -> tf.data.Dataset:
226
+ """Turn a dict based tf.data.Dataset to a tuple based tf.data.Dataset
227
+
228
+ Args:
229
+ dataset (tf.data.Dataset): Dict based tf.data.Dataset
230
+ keys (list, optional): Features to use for the tuples based
231
+ tf.data.Dataset. If None, takes all the features. Defaults to None.
232
+
233
+ Returns:
234
+ tf.data.Dataset
235
+ """
236
+ if keys is None:
237
+ keys = list(dataset.element_spec.keys())
238
+ dataset = dataset.map(lambda x: tuple(x[k] for k in keys))
239
+ return dataset
240
+
241
+ @staticmethod
242
+ def tuple_to_dict(dataset: tf.data.Dataset, keys: list) -> tf.data.Dataset:
243
+ """Turn a tuple based tf.data.Dataset to a dict based tf.data.Dataset
244
+
245
+ Args:
246
+ dataset (tf.data.Dataset): Tuple based tf.data.Dataset
247
+ keys (list): Keys to use for the dict based tf.data.Dataset
248
+
249
+ Returns:
250
+ tf.data.Dataset
251
+ """
252
+ assert isinstance(
253
+ dataset.element_spec, tuple
254
+ ), "dataset elements must be tuples"
255
+ len_elem = len(dataset.element_spec)
256
+ assert len_elem == len(
257
+ keys
258
+ ), "The number of keys must be equal to the number of tuple elements"
259
+
260
+ def tuple_to_dict(*inputs):
261
+ return {keys[i]: inputs[i] for i in range(len_elem)}
262
+
263
+ dataset = dataset.map(tuple_to_dict)
264
+ return dataset
265
+
266
+ @staticmethod
267
+ def assign_feature_value(
268
+ dataset: tf.data.Dataset, feature_key: str, value: int
269
+ ) -> tf.data.Dataset:
270
+ """Assign a value to a feature for every sample in a tf.data.Dataset
271
+
272
+ Args:
273
+ dataset (tf.data.Dataset): tf.data.Dataset to assign the value to
274
+ feature_key (str): Feature to assign the value to
275
+ value (int): Value to assign
276
+
277
+ Returns:
278
+ tf.data.Dataset
279
+ """
280
+ assert isinstance(dataset.element_spec, dict), "dataset elements must be dicts"
281
+
282
+ def assign_value_to_feature(x):
283
+ x[feature_key] = value
284
+ return x
285
+
286
+ dataset = dataset.map(assign_value_to_feature)
287
+ return dataset
288
+
289
+ @staticmethod
290
+ @dict_only_ds
291
+ def get_feature_from_ds(dataset: tf.data.Dataset, feature_key: str) -> np.ndarray:
292
+ """Get a feature from a tf.data.Dataset
293
+
294
+ !!! note
295
+ This function can be a bit time consuming since it needs to iterate
296
+ over the whole dataset.
297
+
298
+ Args:
299
+ dataset (tf.data.Dataset): tf.data.Dataset to get the feature from
300
+ feature_key (str): Feature value to get
301
+
302
+ Returns:
303
+ np.ndarray: Feature values for dataset
304
+ """
305
+ features = dataset.map(lambda x: x[feature_key])
306
+ features = list(features.as_numpy_iterator())
307
+ features = np.array(features)
308
+ return features
309
+
310
+ @staticmethod
311
+ @dict_only_ds
312
+ def get_ds_feature_keys(dataset: tf.data.Dataset) -> list:
313
+ """Get the feature keys of a tf.data.Dataset
314
+
315
+ Args:
316
+ dataset (tf.data.Dataset): tf.data.Dataset to get the feature keys from
317
+
318
+ Returns:
319
+ list: List of feature keys
320
+ """
321
+ return list(dataset.element_spec.keys())
322
+
323
+ @staticmethod
324
+ def has_feature_key(dataset: tf.data.Dataset, key: str) -> bool:
325
+ """Check if a tf.data.Dataset has a feature denoted by key
326
+
327
+ Args:
328
+ dataset (tf.data.Dataset): tf.data.Dataset to check
329
+ key (str): Key to check
330
+
331
+ Returns:
332
+ bool: If the tf.data.Dataset has a feature denoted by key
333
+ """
334
+ assert isinstance(dataset.element_spec, dict), "dataset elements must be dicts"
335
+ return True if (key in dataset.element_spec.keys()) else False
336
+
337
+ @staticmethod
338
+ def map_ds(
339
+ dataset: tf.data.Dataset,
340
+ map_fn: Callable,
341
+ num_parallel_calls: Optional[int] = None,
342
+ ) -> tf.data.Dataset:
343
+ """Map a function to a tf.data.Dataset
344
+
345
+ Args:
346
+ dataset (tf.data.Dataset): tf.data.Dataset to map the function to
347
+ map_fn (Callable): Function to map
348
+ num_parallel_calls (Optional[int], optional): Number of parallel processes
349
+ to use. Defaults to None.
350
+
351
+ Returns:
352
+ tf.data.Dataset: Maped dataset
353
+ """
354
+ if num_parallel_calls is None:
355
+ num_parallel_calls = tf.data.experimental.AUTOTUNE
356
+ dataset = dataset.map(map_fn, num_parallel_calls=num_parallel_calls)
357
+ return dataset
358
+
359
+ @staticmethod
360
+ @dict_only_ds
361
+ def filter_by_feature_value(
362
+ dataset: tf.data.Dataset,
363
+ feature_key: str,
364
+ values: list,
365
+ excluded: bool = False,
366
+ ) -> tf.data.Dataset:
367
+ """Filter a tf.data.Dataset by checking the value of a feature is in 'values'
368
+
369
+ Args:
370
+ dataset (tf.data.Dataset): tf.data.Dataset to filter
371
+ feature_key (str): Feature name to check the value
372
+ values (list): Feature_key values to keep (if excluded is False)
373
+ or to exclude
374
+ excluded (bool, optional): To keep (False) or exclude (True) the samples
375
+ with Feature_key value included in Values. Defaults to False.
376
+
377
+ Returns:
378
+ tf.data.Dataset: Filtered dataset
379
+ """
380
+ # If the labels are one-hot encoded, prepare a function to get the label as int
381
+ if len(dataset.element_spec[feature_key].shape) > 0:
382
+
383
+ def get_label_int(elem):
384
+ return int(tf.argmax(elem[feature_key]))
385
+
386
+ else:
387
+
388
+ def get_label_int(elem):
389
+ return elem[feature_key]
390
+
391
+ def filter_fn(elem):
392
+ value = get_label_int(elem)
393
+ if excluded:
394
+ return not tf.reduce_any(tf.equal(value, values))
395
+ else:
396
+ return tf.reduce_any(tf.equal(value, values))
397
+
398
+ dataset_to_filter = dataset
399
+ dataset_to_filter = dataset_to_filter.filter(filter_fn)
400
+ return dataset_to_filter
401
+
402
+ @classmethod
403
+ def prepare_for_training(
404
+ cls,
405
+ dataset: tf.data.Dataset,
406
+ batch_size: int,
407
+ shuffle: bool = False,
408
+ preprocess_fn: Optional[Callable] = None,
409
+ augment_fn: Optional[Callable] = None,
410
+ output_keys: Optional[list] = None,
411
+ dict_based_fns: bool = False,
412
+ shuffle_buffer_size: Optional[int] = None,
413
+ prefetch_buffer_size: Optional[int] = None,
414
+ drop_remainder: Optional[bool] = False,
415
+ ) -> tf.data.Dataset:
416
+ """Prepare a tf.data.Dataset for training
417
+
418
+ Args:
419
+ dataset (tf.data.Dataset): tf.data.Dataset to prepare
420
+ batch_size (int): Batch size
421
+ shuffle (bool, optional): To shuffle the returned dataset or not.
422
+ Defaults to False.
423
+ preprocess_fn (Callable, optional): Preprocessing function to apply to\
424
+ the dataset. Defaults to None.
425
+ augment_fn (Callable, optional): Augment function to be used (when the\
426
+ returned dataset is to be used for training). Defaults to None.
427
+ output_keys (list, optional): List of keys corresponding to the features
428
+ that will be returned. Keep all features if None. Defaults to None.
429
+ dict_based_fns (bool, optional): If the augment and preprocess functions are
430
+ dict based or not. Defaults to False.
431
+ shuffle_buffer_size (int, optional): Size of the shuffle buffer. If None,
432
+ taken as the number of samples in the dataset. Defaults to None.
433
+ prefetch_buffer_size (Optional[int], optional): Buffer size for prefetch.
434
+ If None, automatically chose using tf.data.experimental.AUTOTUNE.
435
+ Defaults to None.
436
+ drop_remainder (Optional[bool], optional): To drop the last batch when
437
+ its size is lower than batch_size. Defaults to False.
438
+
439
+ Returns:
440
+ tf.data.Dataset: Prepared dataset
441
+ """
442
+ # dict based to tuple based
443
+ output_keys = output_keys or cls.get_ds_feature_keys(dataset)
444
+ if not dict_based_fns:
445
+ dataset = cls.dict_to_tuple(dataset, output_keys)
446
+
447
+ # preprocess + DA
448
+ if preprocess_fn is not None:
449
+ dataset = cls.map_ds(dataset, preprocess_fn)
450
+ if augment_fn is not None:
451
+ dataset = cls.map_ds(dataset, augment_fn)
452
+
453
+ if dict_based_fns:
454
+ dataset = cls.dict_to_tuple(dataset, output_keys)
455
+
456
+ dataset = dataset.cache()
457
+
458
+ # shuffle
459
+ if shuffle:
460
+ num_samples = cls.get_dataset_length(dataset)
461
+ shuffle_buffer_size = (
462
+ num_samples if shuffle_buffer_size is None else shuffle_buffer_size
463
+ )
464
+ dataset = dataset.shuffle(shuffle_buffer_size)
465
+ # batch
466
+ dataset = dataset.batch(batch_size, drop_remainder=drop_remainder)
467
+ # prefetch
468
+ if prefetch_buffer_size is not None:
469
+ prefetch_buffer_size = tf.data.experimental.AUTOTUNE
470
+ dataset = dataset.prefetch(prefetch_buffer_size)
471
+ return dataset
472
+
473
+ @staticmethod
474
+ def make_channel_first(input_key: str, dataset: tf.data.Dataset) -> tf.data.Dataset:
475
+ """Make a tf.data.Dataset channel first. Make sure that the dataset is not
476
+ already Channel first. If so, the tensor will have the format
477
+ (batch_size, x_size, channel, y_size).
478
+
479
+ Args:
480
+ input_key (str): input key of the dict-based tf.data.Dataset
481
+ dataset (tf.data.Dataset): tf.data.Dataset to make channel first
482
+
483
+ Returns:
484
+ tf.data.Dataset: Channel first dataset
485
+ """
486
+
487
+ def channel_first(x):
488
+ x[input_key] = tf.transpose(x[input_key], perm=[2, 0, 1])
489
+ return x
490
+
491
+ dataset = dataset.map(channel_first)
492
+ return dataset
493
+
494
+ @classmethod
495
+ def merge(
496
+ cls,
497
+ id_dataset: tf.data.Dataset,
498
+ ood_dataset: tf.data.Dataset,
499
+ resize: Optional[bool] = False,
500
+ shape: Optional[Tuple[int]] = None,
501
+ channel_order: Optional[str] = "channels_last",
502
+ ) -> tf.data.Dataset:
503
+ """Merge two tf.data.Datasets
504
+
505
+ Args:
506
+ id_dataset (tf.data.Dataset): dataset of in-distribution data
507
+ ood_dataset (tf.data.Dataset): dataset of out-of-distribution data
508
+ resize (Optional[bool], optional): toggles if input tensors of the
509
+ datasets have to be resized to have the same shape. Defaults to True.
510
+ shape (Optional[Tuple[int]], optional): shape to use for resizing input
511
+ tensors. If None, the tensors are resized with the shape of the
512
+ id_dataset input tensors. Defaults to None.
513
+ channel_order (Optional[str], optional): channel order of the input
514
+
515
+ Returns:
516
+ tf.data.Dataset: merged dataset
517
+ """
518
+ len_elem_id = cls.get_item_length(id_dataset)
519
+ len_elem_ood = cls.get_item_length(ood_dataset)
520
+ assert (
521
+ len_elem_id == len_elem_ood
522
+ ), "incompatible dataset elements (different elem dict length)"
523
+
524
+ # If a desired shape is given, triggers the resize
525
+ if shape is not None:
526
+ resize = True
527
+
528
+ id_elem_spec = id_dataset.element_spec
529
+ ood_elem_spec = ood_dataset.element_spec
530
+ assert isinstance(id_elem_spec, dict), "dataset elements must be dicts"
531
+ assert isinstance(ood_elem_spec, dict), "dataset elements must be dicts"
532
+
533
+ input_key_id = list(id_elem_spec.keys())[0]
534
+ input_key_ood = list(ood_elem_spec.keys())[0]
535
+ shape_id = id_dataset.element_spec[input_key_id].shape
536
+ shape_ood = ood_dataset.element_spec[input_key_ood].shape
537
+
538
+ # If the shape of the two datasets are different, triggers the resize
539
+ if shape_id != shape_ood:
540
+ resize = True
541
+
542
+ if shape is None:
543
+ print(
544
+ "Resizing the first item of elem (usually the image)",
545
+ " with the shape of id_dataset",
546
+ )
547
+ if channel_order == "channels_first":
548
+ shape = shape_id[1:]
549
+ else:
550
+ shape = shape_id[:2]
551
+
552
+ if resize:
553
+
554
+ def reshape_im_id(elem):
555
+ elem[input_key_id] = tf.image.resize(elem[input_key_id], shape)
556
+ return elem
557
+
558
+ def reshape_im_ood(elem):
559
+ elem[input_key_ood] = tf.image.resize(elem[input_key_ood], shape)
560
+ return elem
561
+
562
+ id_dataset = id_dataset.map(reshape_im_id)
563
+ ood_dataset = ood_dataset.map(reshape_im_ood)
564
+
565
+ merged_dataset = id_dataset.concatenate(ood_dataset)
566
+ return merged_dataset
567
+
568
+ @staticmethod
569
+ def get_item_length(dataset: tf.data.Dataset) -> int:
570
+ """Get the length of a dataset element. If an element is a tensor, the length is
571
+ one and if it is a sequence (list or tuple), it is len(elem).
572
+
573
+ Args:
574
+ dataset (tf.data.Dataset): Dataset to process
575
+
576
+ Returns:
577
+ int: length of the dataset elems
578
+ """
579
+ if isinstance(dataset.element_spec, (tuple, list, dict)):
580
+ return len(dataset.element_spec)
581
+ return 1
582
+
583
+ @staticmethod
584
+ def get_dataset_length(dataset: tf.data.Dataset) -> int:
585
+ """Get the length of a dataset. Try to access it with len(), and if not
586
+ available, with a reduce op.
587
+
588
+ Args:
589
+ dataset (tf.data.Dataset): Dataset to process
590
+
591
+ Returns:
592
+ int: _description_
593
+ """
594
+ try:
595
+ return len(dataset)
596
+ except TypeError:
597
+ cardinality = dataset.reduce(0, lambda x, _: x + 1)
598
+ return int(cardinality)
599
+
600
+ @staticmethod
601
+ def get_feature_shape(
602
+ dataset: tf.data.Dataset, feature_key: Union[str, int]
603
+ ) -> tuple:
604
+ """Get the shape of a feature of dataset identified by feature_key
605
+
606
+ Args:
607
+ dataset (tf.data.Dataset): a tf.data.dataset
608
+ feature_key (Union[str, int]): The identifier of the feature
609
+
610
+ Returns:
611
+ tuple: the shape of feature_id
612
+ """
613
+ return tuple(dataset.element_spec[feature_key].shape)
614
+
615
+ @staticmethod
616
+ def get_input_from_dataset_item(elem: ItemType) -> TensorType:
617
+ """Get the tensor that is to be feed as input to a model from a dataset element.
618
+
619
+ Args:
620
+ elem (ItemType): dataset element to extract input from
621
+
622
+ Returns:
623
+ TensorType: Input tensor
624
+ """
625
+ if isinstance(elem, (tuple, list)):
626
+ tensor = elem[0]
627
+ elif isinstance(elem, dict):
628
+ tensor = elem[list(elem.keys())[0]]
629
+ else:
630
+ tensor = elem
631
+ return tensor
632
+
633
+ @staticmethod
634
+ def get_label_from_dataset_item(item: ItemType):
635
+ """Retrieve label tensor from item as a tuple/list. Label must be at index 1
636
+ in the item tuple. If one-hot encoded, labels are converted to single value.
637
+
638
+ Args:
639
+ elem (ItemType): dataset element to extract label from
640
+
641
+ Returns:
642
+ Any: Label tensor
643
+ """
644
+ label = item[1] # labels must be at index 1 in the item tuple
645
+ # If labels are one-hot encoded, take the argmax
646
+ if tf.rank(label) > 1 and label.shape[1] > 1:
647
+ label = tf.reshape(label, shape=[label.shape[0], -1])
648
+ label = tf.argmax(label, axis=1)
649
+ # If labels are in two dimensions, squeeze them
650
+ if len(label.shape) > 1:
651
+ label = tf.reshape(label, [label.shape[0]])
652
+ return label
653
+
654
+ @staticmethod
655
+ def get_feature(
656
+ dataset: tf.data.Dataset, feature_key: Union[str, int]
657
+ ) -> tf.data.Dataset:
658
+ """Extract a feature from a dataset
659
+
660
+ Args:
661
+ dataset (tf.data.Dataset): Dataset to extract the feature from
662
+ feature_key (Union[str, int]): feature to extract
663
+
664
+ Returns:
665
+ tf.data.Dataset: dataset built with the extracted feature only
666
+ """
667
+
668
+ def _get_feature_elem(elem):
669
+ return elem[feature_key]
670
+
671
+ return dataset.map(_get_feature_elem)