oodeel 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. oodeel/__init__.py +28 -0
  2. oodeel/aggregator/__init__.py +26 -0
  3. oodeel/aggregator/base.py +70 -0
  4. oodeel/aggregator/fisher.py +259 -0
  5. oodeel/aggregator/mean.py +72 -0
  6. oodeel/aggregator/std.py +86 -0
  7. oodeel/datasets/__init__.py +24 -0
  8. oodeel/datasets/data_handler.py +334 -0
  9. oodeel/datasets/deprecated/DEPRECATED_data_handler.py +236 -0
  10. oodeel/datasets/deprecated/DEPRECATED_ooddataset.py +330 -0
  11. oodeel/datasets/deprecated/DEPRECATED_tf_data_handler.py +671 -0
  12. oodeel/datasets/deprecated/DEPRECATED_torch_data_handler.py +769 -0
  13. oodeel/datasets/deprecated/__init__.py +31 -0
  14. oodeel/datasets/tf_data_handler.py +600 -0
  15. oodeel/datasets/torch_data_handler.py +672 -0
  16. oodeel/eval/__init__.py +22 -0
  17. oodeel/eval/metrics.py +218 -0
  18. oodeel/eval/plots/__init__.py +27 -0
  19. oodeel/eval/plots/features.py +345 -0
  20. oodeel/eval/plots/metrics.py +118 -0
  21. oodeel/eval/plots/plotly.py +162 -0
  22. oodeel/extractor/__init__.py +35 -0
  23. oodeel/extractor/feature_extractor.py +187 -0
  24. oodeel/extractor/hf_torch_feature_extractor.py +184 -0
  25. oodeel/extractor/keras_feature_extractor.py +409 -0
  26. oodeel/extractor/torch_feature_extractor.py +506 -0
  27. oodeel/methods/__init__.py +47 -0
  28. oodeel/methods/base.py +570 -0
  29. oodeel/methods/dknn.py +185 -0
  30. oodeel/methods/energy.py +119 -0
  31. oodeel/methods/entropy.py +113 -0
  32. oodeel/methods/gen.py +113 -0
  33. oodeel/methods/gram.py +274 -0
  34. oodeel/methods/mahalanobis.py +209 -0
  35. oodeel/methods/mls.py +113 -0
  36. oodeel/methods/odin.py +109 -0
  37. oodeel/methods/rmds.py +172 -0
  38. oodeel/methods/she.py +159 -0
  39. oodeel/methods/vim.py +273 -0
  40. oodeel/preprocess/__init__.py +31 -0
  41. oodeel/preprocess/tf_preprocess.py +95 -0
  42. oodeel/preprocess/torch_preprocess.py +97 -0
  43. oodeel/types/__init__.py +75 -0
  44. oodeel/utils/__init__.py +38 -0
  45. oodeel/utils/general_utils.py +97 -0
  46. oodeel/utils/operator.py +253 -0
  47. oodeel/utils/tf_operator.py +269 -0
  48. oodeel/utils/tf_training_tools.py +219 -0
  49. oodeel/utils/torch_operator.py +292 -0
  50. oodeel/utils/torch_training_tools.py +303 -0
  51. oodeel-0.4.0.dist-info/METADATA +409 -0
  52. oodeel-0.4.0.dist-info/RECORD +63 -0
  53. oodeel-0.4.0.dist-info/WHEEL +5 -0
  54. oodeel-0.4.0.dist-info/licenses/LICENSE +21 -0
  55. oodeel-0.4.0.dist-info/top_level.txt +2 -0
  56. tests/__init__.py +22 -0
  57. tests/tests_tensorflow/__init__.py +37 -0
  58. tests/tests_tensorflow/tf_methods_utils.py +140 -0
  59. tests/tests_tensorflow/tools_tf.py +86 -0
  60. tests/tests_torch/__init__.py +38 -0
  61. tests/tests_torch/tools_torch.py +151 -0
  62. tests/tests_torch/torch_methods_utils.py +148 -0
  63. tests/tools_operator.py +153 -0
@@ -0,0 +1,31 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright IRT Antoine de Saint Exupéry et Université Paul Sabatier Toulouse III - All
3
+ # rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
4
+ # CRIAQ and ANITI - https://www.deel.ai/
5
+ #
6
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
7
+ # of this software and associated documentation files (the "Software"), to deal
8
+ # in the Software without restriction, including without limitation the rights
9
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
+ # copies of the Software, and to permit persons to whom the Software is
11
+ # furnished to do so, subject to the following conditions:
12
+ #
13
+ # The above copyright notice and this permission notice shall be included in all
14
+ # copies or substantial portions of the Software.
15
+ #
16
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
+ # SOFTWARE.
23
+ import warnings
24
+
25
+ from .DEPRECATED_ooddataset import OODDataset
26
+
27
+ warnings.warn(
28
+ "The 'OODDataset' object is deprecated and will be removed in a future release.",
29
+ DeprecationWarning,
30
+ stacklevel=2,
31
+ )
@@ -0,0 +1,600 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright IRT Antoine de Saint Exupéry et Université Paul Sabatier Toulouse III - All
3
+ # rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
4
+ # CRIAQ and ANITI - https://www.deel.ai/
5
+ #
6
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
7
+ # of this software and associated documentation files (the "Software"), to deal
8
+ # in the Software without restriction, including without limitation the rights
9
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
+ # copies of the Software, and to permit persons to whom the Software is
11
+ # furnished to do so, subject to the following conditions:
12
+ #
13
+ # The above copyright notice and this permission notice shall be included in all
14
+ # copies or substantial portions of the Software.
15
+ #
16
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
+ # SOFTWARE.
23
+ from typing import get_args
24
+
25
+ import tensorflow as tf
26
+ import tensorflow_datasets as tfds
27
+ from datasets import load_dataset as hf_load_dataset
28
+
29
+ from ..types import Callable
30
+ from ..types import ItemType
31
+ from ..types import Optional
32
+ from ..types import TensorType
33
+ from ..types import Union
34
+ from .data_handler import DataHandler
35
+
36
+
37
+ def dict_only_ds(ds_handling_method: Callable) -> Callable:
38
+ """Decorator to ensure that the dataset is a dict dataset and that the column_name
39
+ given as argument matches one of the column names.
40
+ matches one of the column names. The signature of decorated functions
41
+ must be function(dataset, *args, **kwargs) with column_name either in kwargs or
42
+ args[0] when relevant.
43
+
44
+
45
+ Args:
46
+ ds_handling_method: method to decorate
47
+
48
+ Returns:
49
+ decorated method
50
+ """
51
+
52
+ def wrapper(dataset: tf.data.Dataset, *args, **kwargs):
53
+ assert isinstance(dataset.element_spec, dict), "dataset elements must be dicts"
54
+
55
+ if "column_name" in kwargs.keys():
56
+ column_name = kwargs["column_name"]
57
+ elif len(args) > 0:
58
+ column_name = args[0]
59
+
60
+ # If column_name is provided, check that it is in the dataset column names
61
+ if (len(args) > 0) or ("column_name" in kwargs):
62
+ if isinstance(column_name, str):
63
+ column_name = [column_name]
64
+ for name in column_name:
65
+ assert (
66
+ name in dataset.element_spec.keys()
67
+ ), f"The input dataset has no column named {name}"
68
+ return ds_handling_method(dataset, *args, **kwargs)
69
+
70
+ return wrapper
71
+
72
+
73
+ class TFDataHandler(DataHandler):
74
+ """
75
+ Class to manage tf.data.Dataset. The aim is to provide a simple interface for
76
+ working with tf.data.Datasets and manage them without having to use
77
+ tensorflow syntax.
78
+ """
79
+
80
+ def __init__(self) -> None:
81
+ """
82
+ Initializes the TFDataHandler instance.
83
+ Attributes:
84
+ backend (str): The backend framework used, set to "tensorflow".
85
+ channel_order (str): The channel order format, set to "channels_last".
86
+ """
87
+ super().__init__()
88
+ self.backend = "tensorflow"
89
+ self.channel_order = "channels_last"
90
+
91
+ @classmethod
92
+ def load_dataset(
93
+ cls,
94
+ dataset_id: Union[tf.data.Dataset, ItemType, str],
95
+ columns: Optional[list] = None,
96
+ hub: Optional[str] = "tensorflow-datasets",
97
+ load_kwargs: dict = {},
98
+ ) -> tf.data.Dataset:
99
+ """Load dataset from different manners, ensuring to return a dict based
100
+ tf.data.Dataset.
101
+
102
+ Args:
103
+ dataset_id (Union[tf.data.Dataset, ItemType, str]): dataset identification.
104
+ Can be the name of a dataset from tensorflow_datasets, a tf.data.Dataset,
105
+ or a tuple/dict of np.ndarrays/tf.Tensors.
106
+ columns (list, optional): Column names. If None, assigned as "input_i"
107
+ for i-th column. Defaults to None.
108
+ hub (str, optional): The hub to load the dataset from. Can be either
109
+ "tensorflow-datasets" or "huggingface".
110
+ Defaults to "tensorflow-datasets".
111
+ load_kwargs (dict, optional): Additional args for loading from
112
+ tensorflow_datasets. Defaults to {}.
113
+
114
+ Returns:
115
+ tf.data.Dataset: A dict based tf.data.Dataset
116
+ """
117
+
118
+ assert hub in {
119
+ "tensorflow-datasets",
120
+ "huggingface",
121
+ }, "hub must be either 'tensorflow-datasets' or 'huggingface'"
122
+
123
+ if isinstance(dataset_id, get_args(ItemType)):
124
+ dataset = cls.load_dataset_from_arrays(dataset_id, columns)
125
+ elif isinstance(dataset_id, tf.data.Dataset):
126
+ dataset = cls.load_custom_dataset(dataset_id, columns)
127
+ elif isinstance(dataset_id, str):
128
+ if hub == "tensorflow-datasets":
129
+ load_kwargs["as_supervised"] = False
130
+ dataset = cls.load_from_tensorflow_datasets(dataset_id, load_kwargs)
131
+ elif hub == "huggingface":
132
+ dataset = cls.load_from_huggingface(dataset_id, load_kwargs)
133
+ return dataset
134
+
135
+ @staticmethod
136
+ def load_dataset_from_arrays(
137
+ dataset_id: ItemType, columns: Optional[list] = None
138
+ ) -> tf.data.Dataset:
139
+ """Load a tf.data.Dataset from a np.ndarray, a tf.Tensor or a tuple/dict
140
+ of np.ndarrays/tf.Tensors.
141
+
142
+ Args:
143
+ dataset_id (ItemType): numpy array(s) to load.
144
+ columns (list, optional): Column names to assign. If None,
145
+ assigned as "input_i" for i-th column. Defaults to None.
146
+
147
+ Returns:
148
+ tf.data.Dataset
149
+ """
150
+ # If dataset_id is a numpy array, convert it to a dict
151
+ if isinstance(dataset_id, get_args(TensorType)):
152
+ dataset_dict = {"input": dataset_id}
153
+
154
+ # If dataset_id is a tuple, convert it to a dict
155
+ elif isinstance(dataset_id, tuple):
156
+ len_elem = len(dataset_id)
157
+ if columns is None:
158
+ if len_elem == 2:
159
+ dataset_dict = {"input": dataset_id[0], "label": dataset_id[1]}
160
+ else:
161
+ dataset_dict = {
162
+ f"input_{i}": dataset_id[i] for i in range(len_elem - 1)
163
+ }
164
+ dataset_dict["label"] = dataset_id[-1]
165
+ print(
166
+ 'Loading tf.data.Dataset with elems as dicts, assigning "input_i" '
167
+ 'key to the i-th tuple dimension and "label" key to the last '
168
+ "tuple dimension."
169
+ )
170
+ else:
171
+ assert (
172
+ len(columns) == len_elem
173
+ ), "Number of column names mismatch with the number of columns"
174
+ dataset_dict = {columns[i]: dataset_id[i] for i in range(len_elem)}
175
+
176
+ elif isinstance(dataset_id, dict):
177
+ if columns is not None:
178
+ len_elem = len(dataset_id)
179
+ assert (
180
+ len(columns) == len_elem
181
+ ), "Number of column names mismatch with the number of columns"
182
+ original_columns = list(dataset_id.keys())
183
+ dataset_dict = {
184
+ columns[i]: dataset_id[original_columns[i]] for i in range(len_elem)
185
+ }
186
+
187
+ dataset = tf.data.Dataset.from_tensor_slices(dataset_dict)
188
+ return dataset
189
+
190
+ @classmethod
191
+ def load_custom_dataset(
192
+ cls, dataset_id: tf.data.Dataset, columns: Optional[list] = None
193
+ ) -> tf.data.Dataset:
194
+ """Load a custom Dataset by ensuring it has the correct format (dict-based)
195
+
196
+ Args:
197
+ dataset_id (tf.data.Dataset): tf.data.Dataset
198
+ columns (list, optional): Column names to use for elements if dataset_id is
199
+ tuple based. If None, assigned as "input_i"
200
+ for i-th column. Defaults to None.
201
+
202
+ Returns:
203
+ tf.data.Dataset
204
+ """
205
+ # If dataset_id is a tuple based tf.data.dataset, convert it to a dict
206
+ if not isinstance(dataset_id.element_spec, dict):
207
+ len_elem = len(dataset_id.element_spec)
208
+ if columns is None:
209
+ print(
210
+ "Column name not found, assigning 'input_i' "
211
+ "key to the i-th tensor and 'label' key to the last"
212
+ )
213
+ if len_elem == 2:
214
+ columns = ["input", "label"]
215
+ else:
216
+ columns = [f"input_{i}" for i in range(len_elem)]
217
+ columns[-1] = "label"
218
+ else:
219
+ assert (
220
+ len(columns) == len_elem
221
+ ), "Number of column names mismatch with the number of columns"
222
+
223
+ dataset_id = cls.tuple_to_dict(dataset_id, columns)
224
+
225
+ dataset = dataset_id
226
+ return dataset
227
+
228
+ @staticmethod
229
+ def load_from_huggingface(
230
+ dataset_id: str,
231
+ load_kwargs: dict = {},
232
+ ) -> tf.data.Dataset:
233
+ """Load a Dataset from the Hugging Face datasets catalog
234
+
235
+ Args:
236
+ dataset_id (str): Identifier of the dataset
237
+ load_kwargs (dict): Loading kwargs to add to the initialization
238
+ of the dataset.
239
+
240
+ Returns:
241
+ tf.data.Dataset: dataset
242
+ """
243
+ dataset = hf_load_dataset(dataset_id, **load_kwargs)
244
+ dataset = dataset.to_tf_dataset()
245
+ return dataset
246
+
247
+ @staticmethod
248
+ def load_from_tensorflow_datasets(
249
+ dataset_id: str,
250
+ load_kwargs: dict = {},
251
+ ) -> tf.data.Dataset:
252
+ """Load a tf.data.Dataset from the tensorflow_datasets catalog
253
+
254
+ Args:
255
+ dataset_id (str): Identifier of the dataset
256
+ load_kwargs (dict, optional): Loading kwargs to add to tfds.load().
257
+ Defaults to {}.
258
+
259
+ Returns:
260
+ tf.data.Dataset
261
+ """
262
+ assert (
263
+ dataset_id in tfds.list_builders()
264
+ ), "Dataset not available on tensorflow datasets catalog"
265
+ dataset = tfds.load(dataset_id, **load_kwargs)
266
+ return dataset
267
+
268
+ @staticmethod
269
+ @dict_only_ds
270
+ def dict_to_tuple(
271
+ dataset: tf.data.Dataset, columns: Optional[list] = None
272
+ ) -> tf.data.Dataset:
273
+ """Turn a dict based tf.data.Dataset to a tuple based tf.data.Dataset
274
+
275
+ Args:
276
+ dataset (tf.data.Dataset): Dict based tf.data.Dataset
277
+ columns (list, optional): Columns to use for the tuples based
278
+ tf.data.Dataset. If None, takes all the columns. Defaults to None.
279
+
280
+ Returns:
281
+ tf.data.Dataset
282
+ """
283
+ if columns is None:
284
+ columns = list(dataset.element_spec.keys())
285
+ dataset = dataset.map(lambda x: tuple(x[k] for k in columns))
286
+ return dataset
287
+
288
+ @staticmethod
289
+ def tuple_to_dict(dataset: tf.data.Dataset, columns: list) -> tf.data.Dataset:
290
+ """Turn a tuple based tf.data.Dataset to a dict based tf.data.Dataset
291
+
292
+ Args:
293
+ dataset (tf.data.Dataset): Tuple based tf.data.Dataset
294
+ columns (list): Column names to use for the dict based tf.data.Dataset
295
+
296
+ Returns:
297
+ tf.data.Dataset
298
+ """
299
+ assert isinstance(
300
+ dataset.element_spec, tuple
301
+ ), "dataset elements must be tuples"
302
+ len_elem = len(dataset.element_spec)
303
+ assert len_elem == len(
304
+ columns
305
+ ), "The number of columns must be equal to the number of tuple elements"
306
+
307
+ def tuple_to_dict(*inputs):
308
+ return {columns[i]: inputs[i] for i in range(len_elem)}
309
+
310
+ dataset = dataset.map(tuple_to_dict)
311
+ return dataset
312
+
313
+ @staticmethod
314
+ @dict_only_ds
315
+ def get_ds_column_names(dataset: tf.data.Dataset) -> list:
316
+ """Get the column names of a tf.data.Dataset
317
+
318
+ Args:
319
+ dataset (tf.data.Dataset): tf.data.Dataset to get the column names from
320
+
321
+ Returns:
322
+ list: List of column names
323
+ """
324
+ return list(dataset.element_spec.keys())
325
+
326
+ @staticmethod
327
+ def map_ds(
328
+ dataset: tf.data.Dataset,
329
+ map_fn: Callable,
330
+ num_parallel_calls: Optional[int] = None,
331
+ ) -> tf.data.Dataset:
332
+ """Map a function to a tf.data.Dataset
333
+
334
+ Args:
335
+ dataset (tf.data.Dataset): tf.data.Dataset to map the function to
336
+ map_fn (Callable): Function to map
337
+ num_parallel_calls (Optional[int], optional): Number of parallel processes
338
+ to use. Defaults to None.
339
+
340
+ Returns:
341
+ tf.data.Dataset: Maped dataset
342
+ """
343
+ if num_parallel_calls is None:
344
+ num_parallel_calls = tf.data.experimental.AUTOTUNE
345
+ dataset = dataset.map(map_fn, num_parallel_calls=num_parallel_calls)
346
+ return dataset
347
+
348
+ @staticmethod
349
+ @dict_only_ds
350
+ def filter_by_value(
351
+ dataset: tf.data.Dataset,
352
+ column_name: str,
353
+ values: list,
354
+ excluded: bool = False,
355
+ ) -> tf.data.Dataset:
356
+ """Filter a tf.data.Dataset by checking if the value of a column is in 'values'
357
+
358
+ Args:
359
+ dataset (tf.data.Dataset): tf.data.Dataset to filter
360
+ column_name (str): Column to filter the dataset with
361
+ values (list): Column values to keep (if excluded is False)
362
+ or to exclude
363
+ excluded (bool, optional): To keep (False) or exclude (True) the samples
364
+ with Column values included in Values. Defaults to False.
365
+
366
+ Returns:
367
+ tf.data.Dataset: Filtered dataset
368
+ """
369
+ # If the labels are one-hot encoded, prepare a function to get the label as int
370
+ if len(dataset.element_spec[column_name].shape) > 0:
371
+
372
+ def get_label_int(elem):
373
+ return int(tf.argmax(elem[column_name]))
374
+
375
+ else:
376
+
377
+ def get_label_int(elem):
378
+ return elem[column_name]
379
+
380
+ def filter_fn(elem):
381
+ value = get_label_int(elem)
382
+ if excluded:
383
+ return not tf.reduce_any(tf.equal(value, values))
384
+ else:
385
+ return tf.reduce_any(tf.equal(value, values))
386
+
387
+ dataset_to_filter = dataset
388
+ dataset_to_filter = dataset_to_filter.filter(filter_fn)
389
+ return dataset_to_filter
390
+
391
+ @classmethod
392
+ def prepare(
393
+ cls,
394
+ dataset: tf.data.Dataset,
395
+ batch_size: int,
396
+ preprocess_fn: Optional[Callable] = None,
397
+ augment_fn: Optional[Callable] = None,
398
+ columns: Optional[list] = None,
399
+ shuffle: bool = False,
400
+ dict_based_fns: bool = True,
401
+ return_tuple: bool = True,
402
+ shuffle_buffer_size: Optional[int] = None,
403
+ prefetch_buffer_size: Optional[int] = None,
404
+ drop_remainder: Optional[bool] = False,
405
+ ) -> tf.data.Dataset:
406
+ """Prepare a tf.data.Dataset for training
407
+
408
+ Args:
409
+ dataset (tf.data.Dataset): tf.data.Dataset to prepare
410
+ batch_size (int): Batch size
411
+ preprocess_fn (Callable, optional): Preprocessing function to apply to
412
+ the dataset. Defaults to None.
413
+ augment_fn (Callable, optional): Augment function to be used (when the
414
+ returned dataset is to be used for training). Defaults to None.
415
+ columns (list, optional): List of column names corresponding to the columns
416
+ that will be returned. Keep all columns if None. Defaults to None.
417
+ shuffle (bool, optional): To shuffle the returned dataset or not.
418
+ Defaults to False.
419
+ dict_based_fns (bool): Whether to use preprocess and DA functions as dict
420
+ based (if True) or as tuple based (if False). Defaults to True.
421
+ return_tuple (bool, optional): Whether to return each dataset item
422
+ as a tuple. Defaults to True.
423
+ shuffle_buffer_size (int, optional): Size of the shuffle buffer. If None,
424
+ taken as the number of samples in the dataset. Defaults to None.
425
+ prefetch_buffer_size (Optional[int], optional): Buffer size for prefetch.
426
+ If None, automatically chose using tf.data.experimental.AUTOTUNE.
427
+ Defaults to None.
428
+ drop_remainder (Optional[bool], optional): To drop the last batch when
429
+ its size is lower than batch_size. Defaults to False.
430
+
431
+ Returns:
432
+ tf.data.Dataset: Prepared dataset
433
+ """
434
+ # dict based to tuple based
435
+ columns = columns or cls.get_ds_column_names(dataset)
436
+ if not dict_based_fns:
437
+ dataset = cls.dict_to_tuple(dataset, columns)
438
+
439
+ # preprocess + DA
440
+ if preprocess_fn is not None:
441
+ dataset = cls.map_ds(dataset, preprocess_fn)
442
+ if augment_fn is not None:
443
+ dataset = cls.map_ds(dataset, augment_fn)
444
+
445
+ if dict_based_fns and return_tuple:
446
+ dataset = cls.dict_to_tuple(dataset, columns)
447
+
448
+ dataset = dataset.cache()
449
+
450
+ # shuffle
451
+ if shuffle:
452
+ num_samples = cls.get_dataset_length(dataset)
453
+ shuffle_buffer_size = (
454
+ num_samples if shuffle_buffer_size is None else shuffle_buffer_size
455
+ )
456
+ dataset = dataset.shuffle(shuffle_buffer_size)
457
+ # batch
458
+ dataset = dataset.batch(batch_size, drop_remainder=drop_remainder)
459
+ # prefetch
460
+ if prefetch_buffer_size is not None:
461
+ prefetch_buffer_size = tf.data.experimental.AUTOTUNE
462
+ dataset = dataset.prefetch(prefetch_buffer_size)
463
+ return dataset
464
+
465
+ @staticmethod
466
+ def make_channel_first(input_key: str, dataset: tf.data.Dataset) -> tf.data.Dataset:
467
+ """Make a tf.data.Dataset channel first. Make sure that the dataset is not
468
+ already Channel first. If so, the tensor will have the format
469
+ (batch_size, x_size, channel, y_size).
470
+
471
+ Args:
472
+ input_key (str): input key of the dict-based tf.data.Dataset
473
+ dataset (tf.data.Dataset): tf.data.Dataset to make channel first
474
+
475
+ Returns:
476
+ tf.data.Dataset: Channel first dataset
477
+ """
478
+
479
+ def channel_first(x):
480
+ x[input_key] = tf.transpose(x[input_key], perm=[2, 0, 1])
481
+ return x
482
+
483
+ dataset = dataset.map(channel_first)
484
+ return dataset
485
+
486
+ @staticmethod
487
+ def get_item_length(dataset: tf.data.Dataset) -> int:
488
+ """Get the length of a dataset element. If an element is a tensor, the length is
489
+ one and if it is a sequence (list or tuple), it is len(elem).
490
+
491
+ Args:
492
+ dataset (tf.data.Dataset): Dataset to process
493
+
494
+ Returns:
495
+ int: length of the dataset elems
496
+ """
497
+ if isinstance(dataset.element_spec, (tuple, list, dict)):
498
+ return len(dataset.element_spec)
499
+ return 1
500
+
501
+ @staticmethod
502
+ def get_dataset_length(dataset: tf.data.Dataset) -> int:
503
+ """Get the length of a dataset. Try to access it with len(), and if not
504
+ available, with a reduce op.
505
+
506
+ Args:
507
+ dataset (tf.data.Dataset): Dataset to process
508
+
509
+ Returns:
510
+ int: _description_
511
+ """
512
+ try:
513
+ return len(dataset)
514
+ except TypeError:
515
+ cardinality = dataset.reduce(0, lambda x, _: x + 1)
516
+ return int(cardinality)
517
+
518
+ @staticmethod
519
+ def get_column_elements_shape(
520
+ dataset: tf.data.Dataset, column_name: Union[str, int]
521
+ ) -> tuple:
522
+ """Get the shape of the elements of a column of dataset identified by
523
+ column_name
524
+
525
+ Args:
526
+ dataset (tf.data.Dataset): a tf.data.dataset
527
+ column_name (Union[str, int]): The column name to get
528
+ the element shape from.
529
+
530
+ Returns:
531
+ tuple: the shape of an element from column_name
532
+ """
533
+ return tuple(dataset.element_spec[column_name].shape)
534
+
535
+ @staticmethod
536
+ def get_columns_shapes(dataset: tf.data.Dataset) -> dict:
537
+ """Get the shapes of the elements of all columns of a dataset
538
+
539
+ Args:
540
+ dataset (Dataset): a Dataset
541
+
542
+ Returns:
543
+ dict: dictionary of column names and their corresponding shape
544
+ """
545
+
546
+ if isinstance(dataset.element_spec, tuple):
547
+ shapes = [None for _ in range(len(dataset.element_spec))]
548
+ for i in range(len(dataset.element_spec)):
549
+ try:
550
+ shapes[i] = tuple(dataset.element_spec[i].shape)
551
+ except AttributeError:
552
+ pass
553
+ shapes = tuple(shapes)
554
+ elif isinstance(dataset.element_spec, dict):
555
+ shapes = {}
556
+ for key in dataset.element_spec.keys():
557
+ try:
558
+ shapes[key] = tuple(dataset.element_spec[key].shape)
559
+ except AttributeError:
560
+ pass
561
+ return shapes
562
+
563
+ @staticmethod
564
+ def get_input_from_dataset_item(elem: ItemType) -> TensorType:
565
+ """Get the tensor that is to be feed as input to a model from a dataset element.
566
+
567
+ Args:
568
+ elem (ItemType): dataset element to extract input from
569
+
570
+ Returns:
571
+ TensorType: Input tensor
572
+ """
573
+ if isinstance(elem, (tuple, list)):
574
+ tensor = elem[0]
575
+ elif isinstance(elem, dict):
576
+ tensor = elem[list(elem.keys())[0]]
577
+ else:
578
+ tensor = elem
579
+ return tensor
580
+
581
+ @staticmethod
582
+ def get_label_from_dataset_item(item: ItemType):
583
+ """Retrieve label tensor from item as a tuple/list. Label must be at index 1
584
+ in the item tuple. If one-hot encoded, labels are converted to single value.
585
+
586
+ Args:
587
+ elem (ItemType): dataset element to extract label from
588
+
589
+ Returns:
590
+ Any: Label tensor
591
+ """
592
+ label = item[1] # labels must be at index 1 in the item tuple
593
+ # If labels are one-hot encoded, take the argmax
594
+ if tf.rank(label) > 1 and label.shape[1] > 1:
595
+ label = tf.reshape(label, shape=[label.shape[0], -1])
596
+ label = tf.argmax(label, axis=1)
597
+ # If labels are in two dimensions, squeeze them
598
+ if len(label.shape) > 1:
599
+ label = tf.reshape(label, [label.shape[0]])
600
+ return label