oodeel 0.1.1__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of oodeel might be problematic. Click here for more details.

Files changed (47) hide show
  1. oodeel/__init__.py +1 -1
  2. oodeel/datasets/__init__.py +2 -1
  3. oodeel/datasets/data_handler.py +162 -94
  4. oodeel/datasets/deprecated/DEPRECATED_data_handler.py +236 -0
  5. oodeel/datasets/{ooddataset.py → deprecated/DEPRECATED_ooddataset.py} +14 -13
  6. oodeel/datasets/deprecated/DEPRECATED_tf_data_handler.py +671 -0
  7. oodeel/datasets/deprecated/DEPRECATED_torch_data_handler.py +769 -0
  8. oodeel/datasets/deprecated/__init__.py +31 -0
  9. oodeel/datasets/tf_data_handler.py +105 -167
  10. oodeel/datasets/torch_data_handler.py +109 -181
  11. oodeel/eval/metrics.py +7 -2
  12. oodeel/eval/plots/features.py +2 -2
  13. oodeel/eval/plots/plotly.py +2 -2
  14. oodeel/extractor/feature_extractor.py +30 -9
  15. oodeel/extractor/keras_feature_extractor.py +70 -13
  16. oodeel/extractor/torch_feature_extractor.py +120 -33
  17. oodeel/methods/__init__.py +17 -1
  18. oodeel/methods/base.py +103 -17
  19. oodeel/methods/dknn.py +22 -9
  20. oodeel/methods/energy.py +8 -0
  21. oodeel/methods/entropy.py +8 -0
  22. oodeel/methods/gen.py +118 -0
  23. oodeel/methods/gram.py +307 -0
  24. oodeel/methods/mahalanobis.py +14 -12
  25. oodeel/methods/mls.py +8 -0
  26. oodeel/methods/odin.py +8 -0
  27. oodeel/methods/rmds.py +122 -0
  28. oodeel/methods/she.py +197 -0
  29. oodeel/methods/vim.py +5 -5
  30. oodeel/preprocess/__init__.py +31 -0
  31. oodeel/preprocess/tf_preprocess.py +95 -0
  32. oodeel/preprocess/torch_preprocess.py +97 -0
  33. oodeel/utils/operator.py +72 -2
  34. oodeel/utils/tf_operator.py +72 -4
  35. oodeel/utils/tf_training_tools.py +26 -3
  36. oodeel/utils/torch_operator.py +75 -4
  37. oodeel/utils/torch_training_tools.py +31 -2
  38. {oodeel-0.1.1.dist-info → oodeel-0.3.0.dist-info}/METADATA +141 -107
  39. oodeel-0.3.0.dist-info/RECORD +57 -0
  40. {oodeel-0.1.1.dist-info → oodeel-0.3.0.dist-info}/WHEEL +1 -1
  41. tests/tests_tensorflow/tf_methods_utils.py +2 -1
  42. tests/tests_torch/tools_torch.py +9 -9
  43. tests/tests_torch/torch_methods_utils.py +34 -27
  44. tests/tools_operator.py +10 -1
  45. oodeel-0.1.1.dist-info/RECORD +0 -46
  46. {oodeel-0.1.1.dist-info → oodeel-0.3.0.dist-info/licenses}/LICENSE +0 -0
  47. {oodeel-0.1.1.dist-info → oodeel-0.3.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,31 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright IRT Antoine de Saint Exupéry et Université Paul Sabatier Toulouse III - All
3
+ # rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
4
+ # CRIAQ and ANITI - https://www.deel.ai/
5
+ #
6
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
7
+ # of this software and associated documentation files (the "Software"), to deal
8
+ # in the Software without restriction, including without limitation the rights
9
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
+ # copies of the Software, and to permit persons to whom the Software is
11
+ # furnished to do so, subject to the following conditions:
12
+ #
13
+ # The above copyright notice and this permission notice shall be included in all
14
+ # copies or substantial portions of the Software.
15
+ #
16
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
+ # SOFTWARE.
23
+ import warnings
24
+
25
+ from .DEPRECATED_ooddataset import OODDataset
26
+
27
+ warnings.warn(
28
+ "The 'OODDataset' object is deprecated and will be removed in a future release.",
29
+ DeprecationWarning,
30
+ stacklevel=2,
31
+ )
@@ -22,7 +22,6 @@
22
22
  # SOFTWARE.
23
23
  from typing import get_args
24
24
 
25
- import numpy as np
26
25
  import tensorflow as tf
27
26
  import tensorflow_datasets as tfds
28
27
 
@@ -36,9 +35,10 @@ from .data_handler import DataHandler
36
35
 
37
36
 
38
37
  def dict_only_ds(ds_handling_method: Callable) -> Callable:
39
- """Decorator to ensure that the dataset is a dict dataset and that the input key
40
- matches one of the feature keys. The signature of decorated functions
41
- must be function(dataset, *args, **kwargs) with feature_key either in kwargs or
38
+ """Decorator to ensure that the dataset is a dict dataset and that the column_name
39
+ given as argument matches one of the column names.
40
+ matches one of the column names. The signature of decorated functions
41
+ must be function(dataset, *args, **kwargs) with column_name either in kwargs or
42
42
  args[0] when relevant.
43
43
 
44
44
 
@@ -52,19 +52,19 @@ def dict_only_ds(ds_handling_method: Callable) -> Callable:
52
52
  def wrapper(dataset: tf.data.Dataset, *args, **kwargs):
53
53
  assert isinstance(dataset.element_spec, dict), "dataset elements must be dicts"
54
54
 
55
- if "feature_key" in kwargs.keys():
56
- feature_key = kwargs["feature_key"]
55
+ if "column_name" in kwargs.keys():
56
+ column_name = kwargs["column_name"]
57
57
  elif len(args) > 0:
58
- feature_key = args[0]
58
+ column_name = args[0]
59
59
 
60
- # If feature_key is provided, check that it is in the dataset feature keys
61
- if (len(args) > 0) or ("feature_key" in kwargs):
62
- if isinstance(feature_key, str):
63
- feature_key = [feature_key]
64
- for key in feature_key:
60
+ # If column_name is provided, check that it is in the dataset column names
61
+ if (len(args) > 0) or ("column_name" in kwargs):
62
+ if isinstance(column_name, str):
63
+ column_name = [column_name]
64
+ for name in column_name:
65
65
  assert (
66
- key in dataset.element_spec.keys()
67
- ), f"The input dataset has no feature names {key}"
66
+ name in dataset.element_spec.keys()
67
+ ), f"The input dataset has no column named {name}"
68
68
  return ds_handling_method(dataset, *args, **kwargs)
69
69
 
70
70
  return wrapper
@@ -77,45 +77,54 @@ class TFDataHandler(DataHandler):
77
77
  tensorflow syntax.
78
78
  """
79
79
 
80
+ def __init__(self) -> None:
81
+ super().__init__()
82
+ self.backend = "tensorflow"
83
+ self.channel_order = "channels_last"
84
+
80
85
  @classmethod
81
86
  def load_dataset(
82
87
  cls,
83
88
  dataset_id: Union[tf.data.Dataset, ItemType, str],
84
- keys: Optional[list] = None,
89
+ columns: Optional[list] = None,
85
90
  load_kwargs: dict = {},
86
91
  ) -> tf.data.Dataset:
87
92
  """Load dataset from different manners, ensuring to return a dict based
88
93
  tf.data.Dataset.
89
94
 
90
95
  Args:
91
- dataset_id (Any): dataset identification
92
- keys (list, optional): Features keys. If None, assigned as "input_i"
93
- for i-th feature. Defaults to None.
96
+ dataset_id (Union[tf.data.Dataset, ItemType, str]): dataset identification.
97
+ Can be the name of a dataset from tensorflow_datasets, a tf.data.Dataset,
98
+ or a tuple/dict of np.ndarrays/tf.Tensors.
99
+ columns (list, optional): Column names. If None, assigned as "input_i"
100
+ for i-th column. Defaults to None.
94
101
  load_kwargs (dict, optional): Additional args for loading from
95
102
  tensorflow_datasets. Defaults to {}.
96
103
 
97
104
  Returns:
98
105
  tf.data.Dataset: A dict based tf.data.Dataset
99
106
  """
107
+ load_kwargs["as_supervised"] = False
108
+
100
109
  if isinstance(dataset_id, get_args(ItemType)):
101
- dataset = cls.load_dataset_from_arrays(dataset_id, keys)
110
+ dataset = cls.load_dataset_from_arrays(dataset_id, columns)
102
111
  elif isinstance(dataset_id, tf.data.Dataset):
103
- dataset = cls.load_custom_dataset(dataset_id, keys)
112
+ dataset = cls.load_custom_dataset(dataset_id, columns)
104
113
  elif isinstance(dataset_id, str):
105
114
  dataset = cls.load_from_tensorflow_datasets(dataset_id, load_kwargs)
106
115
  return dataset
107
116
 
108
117
  @staticmethod
109
118
  def load_dataset_from_arrays(
110
- dataset_id: ItemType, keys: Optional[list] = None
119
+ dataset_id: ItemType, columns: Optional[list] = None
111
120
  ) -> tf.data.Dataset:
112
121
  """Load a tf.data.Dataset from a np.ndarray, a tf.Tensor or a tuple/dict
113
- of np.ndarrays/td.Tensors.
122
+ of np.ndarrays/tf.Tensors.
114
123
 
115
124
  Args:
116
125
  dataset_id (ItemType): numpy array(s) to load.
117
- keys (list, optional): Features keys. If None, assigned as "input_i"
118
- for i-th feature. Defaults to None.
126
+ columns (list, optional): Column names to assign. If None,
127
+ assigned as "input_i" for i-th column. Defaults to None.
119
128
 
120
129
  Returns:
121
130
  tf.data.Dataset
@@ -127,7 +136,7 @@ class TFDataHandler(DataHandler):
127
136
  # If dataset_id is a tuple, convert it to a dict
128
137
  elif isinstance(dataset_id, tuple):
129
138
  len_elem = len(dataset_id)
130
- if keys is None:
139
+ if columns is None:
131
140
  if len_elem == 2:
132
141
  dataset_dict = {"input": dataset_id[0], "label": dataset_id[1]}
133
142
  else:
@@ -142,19 +151,19 @@ class TFDataHandler(DataHandler):
142
151
  )
143
152
  else:
144
153
  assert (
145
- len(keys) == len_elem
146
- ), "Number of keys mismatch with the number of features"
147
- dataset_dict = {keys[i]: dataset_id[i] for i in range(len_elem)}
154
+ len(columns) == len_elem
155
+ ), "Number of column names mismatch with the number of columns"
156
+ dataset_dict = {columns[i]: dataset_id[i] for i in range(len_elem)}
148
157
 
149
158
  elif isinstance(dataset_id, dict):
150
- if keys is not None:
159
+ if columns is not None:
151
160
  len_elem = len(dataset_id)
152
161
  assert (
153
- len(keys) == len_elem
154
- ), "Number of keys mismatch with the number of features"
155
- original_keys = list(dataset_id.keys())
162
+ len(columns) == len_elem
163
+ ), "Number of column names mismatch with the number of columns"
164
+ original_columns = list(dataset_id.keys())
156
165
  dataset_dict = {
157
- keys[i]: dataset_id[original_keys[i]] for i in range(len_elem)
166
+ columns[i]: dataset_id[original_columns[i]] for i in range(len_elem)
158
167
  }
159
168
 
160
169
  dataset = tf.data.Dataset.from_tensor_slices(dataset_dict)
@@ -162,14 +171,15 @@ class TFDataHandler(DataHandler):
162
171
 
163
172
  @classmethod
164
173
  def load_custom_dataset(
165
- cls, dataset_id: tf.data.Dataset, keys: Optional[list] = None
174
+ cls, dataset_id: tf.data.Dataset, columns: Optional[list] = None
166
175
  ) -> tf.data.Dataset:
167
176
  """Load a custom Dataset by ensuring it has the correct format (dict-based)
168
177
 
169
178
  Args:
170
179
  dataset_id (tf.data.Dataset): tf.data.Dataset
171
- keys (list, optional): Features keys. If None, assigned as "input_i"
172
- for i-th feature. Defaults to None.
180
+ columns (list, optional): Column names to use for elements if dataset_id is
181
+ tuple based. If None, assigned as "input_i"
182
+ for i-th column. Defaults to None.
173
183
 
174
184
  Returns:
175
185
  tf.data.Dataset
@@ -177,22 +187,22 @@ class TFDataHandler(DataHandler):
177
187
  # If dataset_id is a tuple based tf.data.dataset, convert it to a dict
178
188
  if not isinstance(dataset_id.element_spec, dict):
179
189
  len_elem = len(dataset_id.element_spec)
180
- if keys is None:
190
+ if columns is None:
181
191
  print(
182
- "Feature name not found, assigning 'input_i' "
192
+ "Column name not found, assigning 'input_i' "
183
193
  "key to the i-th tensor and 'label' key to the last"
184
194
  )
185
195
  if len_elem == 2:
186
- keys = ["input", "label"]
196
+ columns = ["input", "label"]
187
197
  else:
188
- keys = [f"input_{i}" for i in range(len_elem)]
189
- keys[-1] = "label"
198
+ columns = [f"input_{i}" for i in range(len_elem)]
199
+ columns[-1] = "label"
190
200
  else:
191
201
  assert (
192
- len(keys) == len_elem
193
- ), "Number of keys mismatch with the number of features"
202
+ len(columns) == len_elem
203
+ ), "Number of column names mismatch with the number of columns"
194
204
 
195
- dataset_id = cls.tuple_to_dict(dataset_id, keys)
205
+ dataset_id = cls.tuple_to_dict(dataset_id, columns)
196
206
 
197
207
  dataset = dataset_id
198
208
  return dataset
@@ -221,30 +231,30 @@ class TFDataHandler(DataHandler):
221
231
  @staticmethod
222
232
  @dict_only_ds
223
233
  def dict_to_tuple(
224
- dataset: tf.data.Dataset, keys: Optional[list] = None
234
+ dataset: tf.data.Dataset, columns: Optional[list] = None
225
235
  ) -> tf.data.Dataset:
226
236
  """Turn a dict based tf.data.Dataset to a tuple based tf.data.Dataset
227
237
 
228
238
  Args:
229
239
  dataset (tf.data.Dataset): Dict based tf.data.Dataset
230
- keys (list, optional): Features to use for the tuples based
231
- tf.data.Dataset. If None, takes all the features. Defaults to None.
240
+ columns (list, optional): Columns to use for the tuples based
241
+ tf.data.Dataset. If None, takes all the columns. Defaults to None.
232
242
 
233
243
  Returns:
234
244
  tf.data.Dataset
235
245
  """
236
- if keys is None:
237
- keys = list(dataset.element_spec.keys())
238
- dataset = dataset.map(lambda x: tuple(x[k] for k in keys))
246
+ if columns is None:
247
+ columns = list(dataset.element_spec.keys())
248
+ dataset = dataset.map(lambda x: tuple(x[k] for k in columns))
239
249
  return dataset
240
250
 
241
251
  @staticmethod
242
- def tuple_to_dict(dataset: tf.data.Dataset, keys: list) -> tf.data.Dataset:
252
+ def tuple_to_dict(dataset: tf.data.Dataset, columns: list) -> tf.data.Dataset:
243
253
  """Turn a tuple based tf.data.Dataset to a dict based tf.data.Dataset
244
254
 
245
255
  Args:
246
256
  dataset (tf.data.Dataset): Tuple based tf.data.Dataset
247
- keys (list): Keys to use for the dict based tf.data.Dataset
257
+ columns (list): Column names to use for the dict based tf.data.Dataset
248
258
 
249
259
  Returns:
250
260
  tf.data.Dataset
@@ -254,86 +264,28 @@ class TFDataHandler(DataHandler):
254
264
  ), "dataset elements must be tuples"
255
265
  len_elem = len(dataset.element_spec)
256
266
  assert len_elem == len(
257
- keys
258
- ), "The number of keys must be equal to the number of tuple elements"
267
+ columns
268
+ ), "The number of columns must be equal to the number of tuple elements"
259
269
 
260
270
  def tuple_to_dict(*inputs):
261
- return {keys[i]: inputs[i] for i in range(len_elem)}
271
+ return {columns[i]: inputs[i] for i in range(len_elem)}
262
272
 
263
273
  dataset = dataset.map(tuple_to_dict)
264
274
  return dataset
265
275
 
266
- @staticmethod
267
- def assign_feature_value(
268
- dataset: tf.data.Dataset, feature_key: str, value: int
269
- ) -> tf.data.Dataset:
270
- """Assign a value to a feature for every sample in a tf.data.Dataset
271
-
272
- Args:
273
- dataset (tf.data.Dataset): tf.data.Dataset to assign the value to
274
- feature_key (str): Feature to assign the value to
275
- value (int): Value to assign
276
-
277
- Returns:
278
- tf.data.Dataset
279
- """
280
- assert isinstance(dataset.element_spec, dict), "dataset elements must be dicts"
281
-
282
- def assign_value_to_feature(x):
283
- x[feature_key] = value
284
- return x
285
-
286
- dataset = dataset.map(assign_value_to_feature)
287
- return dataset
288
-
289
- @staticmethod
290
- @dict_only_ds
291
- def get_feature_from_ds(dataset: tf.data.Dataset, feature_key: str) -> np.ndarray:
292
- """Get a feature from a tf.data.Dataset
293
-
294
- !!! note
295
- This function can be a bit time consuming since it needs to iterate
296
- over the whole dataset.
297
-
298
- Args:
299
- dataset (tf.data.Dataset): tf.data.Dataset to get the feature from
300
- feature_key (str): Feature value to get
301
-
302
- Returns:
303
- np.ndarray: Feature values for dataset
304
- """
305
- features = dataset.map(lambda x: x[feature_key])
306
- features = list(features.as_numpy_iterator())
307
- features = np.array(features)
308
- return features
309
-
310
276
  @staticmethod
311
277
  @dict_only_ds
312
- def get_ds_feature_keys(dataset: tf.data.Dataset) -> list:
313
- """Get the feature keys of a tf.data.Dataset
278
+ def get_ds_column_names(dataset: tf.data.Dataset) -> list:
279
+ """Get the column names of a tf.data.Dataset
314
280
 
315
281
  Args:
316
- dataset (tf.data.Dataset): tf.data.Dataset to get the feature keys from
282
+ dataset (tf.data.Dataset): tf.data.Dataset to get the column names from
317
283
 
318
284
  Returns:
319
- list: List of feature keys
285
+ list: List of column names
320
286
  """
321
287
  return list(dataset.element_spec.keys())
322
288
 
323
- @staticmethod
324
- def has_feature_key(dataset: tf.data.Dataset, key: str) -> bool:
325
- """Check if a tf.data.Dataset has a feature denoted by key
326
-
327
- Args:
328
- dataset (tf.data.Dataset): tf.data.Dataset to check
329
- key (str): Key to check
330
-
331
- Returns:
332
- bool: If the tf.data.Dataset has a feature denoted by key
333
- """
334
- assert isinstance(dataset.element_spec, dict), "dataset elements must be dicts"
335
- return True if (key in dataset.element_spec.keys()) else False
336
-
337
289
  @staticmethod
338
290
  def map_ds(
339
291
  dataset: tf.data.Dataset,
@@ -358,35 +310,35 @@ class TFDataHandler(DataHandler):
358
310
 
359
311
  @staticmethod
360
312
  @dict_only_ds
361
- def filter_by_feature_value(
313
+ def filter_by_value(
362
314
  dataset: tf.data.Dataset,
363
- feature_key: str,
315
+ column_name: str,
364
316
  values: list,
365
317
  excluded: bool = False,
366
318
  ) -> tf.data.Dataset:
367
- """Filter a tf.data.Dataset by checking the value of a feature is in 'values'
319
+ """Filter a tf.data.Dataset by checking if the value of a column is in 'values'
368
320
 
369
321
  Args:
370
322
  dataset (tf.data.Dataset): tf.data.Dataset to filter
371
- feature_key (str): Feature name to check the value
372
- values (list): Feature_key values to keep (if excluded is False)
323
+ column_name (str): Column to filter the dataset with
324
+ values (list): Column values to keep (if excluded is False)
373
325
  or to exclude
374
326
  excluded (bool, optional): To keep (False) or exclude (True) the samples
375
- with Feature_key value included in Values. Defaults to False.
327
+ with Column values included in Values. Defaults to False.
376
328
 
377
329
  Returns:
378
330
  tf.data.Dataset: Filtered dataset
379
331
  """
380
332
  # If the labels are one-hot encoded, prepare a function to get the label as int
381
- if len(dataset.element_spec[feature_key].shape) > 0:
333
+ if len(dataset.element_spec[column_name].shape) > 0:
382
334
 
383
335
  def get_label_int(elem):
384
- return int(tf.argmax(elem[feature_key]))
336
+ return int(tf.argmax(elem[column_name]))
385
337
 
386
338
  else:
387
339
 
388
340
  def get_label_int(elem):
389
- return elem[feature_key]
341
+ return elem[column_name]
390
342
 
391
343
  def filter_fn(elem):
392
344
  value = get_label_int(elem)
@@ -400,15 +352,16 @@ class TFDataHandler(DataHandler):
400
352
  return dataset_to_filter
401
353
 
402
354
  @classmethod
403
- def prepare_for_training(
355
+ def prepare(
404
356
  cls,
405
357
  dataset: tf.data.Dataset,
406
358
  batch_size: int,
407
- shuffle: bool = False,
408
359
  preprocess_fn: Optional[Callable] = None,
409
360
  augment_fn: Optional[Callable] = None,
410
- output_keys: Optional[list] = None,
411
- dict_based_fns: bool = False,
361
+ columns: Optional[list] = None,
362
+ shuffle: bool = False,
363
+ dict_based_fns: bool = True,
364
+ return_tuple: bool = True,
412
365
  shuffle_buffer_size: Optional[int] = None,
413
366
  prefetch_buffer_size: Optional[int] = None,
414
367
  drop_remainder: Optional[bool] = False,
@@ -418,16 +371,18 @@ class TFDataHandler(DataHandler):
418
371
  Args:
419
372
  dataset (tf.data.Dataset): tf.data.Dataset to prepare
420
373
  batch_size (int): Batch size
421
- shuffle (bool, optional): To shuffle the returned dataset or not.
422
- Defaults to False.
423
- preprocess_fn (Callable, optional): Preprocessing function to apply to\
374
+ preprocess_fn (Callable, optional): Preprocessing function to apply to
424
375
  the dataset. Defaults to None.
425
- augment_fn (Callable, optional): Augment function to be used (when the\
376
+ augment_fn (Callable, optional): Augment function to be used (when the
426
377
  returned dataset is to be used for training). Defaults to None.
427
- output_keys (list, optional): List of keys corresponding to the features
428
- that will be returned. Keep all features if None. Defaults to None.
429
- dict_based_fns (bool, optional): If the augment and preprocess functions are
430
- dict based or not. Defaults to False.
378
+ columns (list, optional): List of column names corresponding to the columns
379
+ that will be returned. Keep all columns if None. Defaults to None.
380
+ shuffle (bool, optional): To shuffle the returned dataset or not.
381
+ Defaults to False.
382
+ dict_based_fns (bool): Whether to use preprocess and DA functions as dict
383
+ based (if True) or as tuple based (if False). Defaults to True.
384
+ return_tuple (bool, optional): Whether to return each dataset item
385
+ as a tuple. Defaults to True.
431
386
  shuffle_buffer_size (int, optional): Size of the shuffle buffer. If None,
432
387
  taken as the number of samples in the dataset. Defaults to None.
433
388
  prefetch_buffer_size (Optional[int], optional): Buffer size for prefetch.
@@ -440,9 +395,9 @@ class TFDataHandler(DataHandler):
440
395
  tf.data.Dataset: Prepared dataset
441
396
  """
442
397
  # dict based to tuple based
443
- output_keys = output_keys or cls.get_ds_feature_keys(dataset)
398
+ columns = columns or cls.get_ds_column_names(dataset)
444
399
  if not dict_based_fns:
445
- dataset = cls.dict_to_tuple(dataset, output_keys)
400
+ dataset = cls.dict_to_tuple(dataset, columns)
446
401
 
447
402
  # preprocess + DA
448
403
  if preprocess_fn is not None:
@@ -450,8 +405,8 @@ class TFDataHandler(DataHandler):
450
405
  if augment_fn is not None:
451
406
  dataset = cls.map_ds(dataset, augment_fn)
452
407
 
453
- if dict_based_fns:
454
- dataset = cls.dict_to_tuple(dataset, output_keys)
408
+ if dict_based_fns and return_tuple:
409
+ dataset = cls.dict_to_tuple(dataset, columns)
455
410
 
456
411
  dataset = dataset.cache()
457
412
 
@@ -598,19 +553,21 @@ class TFDataHandler(DataHandler):
598
553
  return int(cardinality)
599
554
 
600
555
  @staticmethod
601
- def get_feature_shape(
602
- dataset: tf.data.Dataset, feature_key: Union[str, int]
556
+ def get_column_elements_shape(
557
+ dataset: tf.data.Dataset, column_name: Union[str, int]
603
558
  ) -> tuple:
604
- """Get the shape of a feature of dataset identified by feature_key
559
+ """Get the shape of the elements of a column of dataset identified by
560
+ column_name
605
561
 
606
562
  Args:
607
563
  dataset (tf.data.Dataset): a tf.data.dataset
608
- feature_key (Union[str, int]): The identifier of the feature
564
+ column_name (Union[str, int]): The column name to get
565
+ the element shape from.
609
566
 
610
567
  Returns:
611
- tuple: the shape of feature_id
568
+ tuple: the shape of an element from column_name
612
569
  """
613
- return tuple(dataset.element_spec[feature_key].shape)
570
+ return tuple(dataset.element_spec[column_name].shape)
614
571
 
615
572
  @staticmethod
616
573
  def get_input_from_dataset_item(elem: ItemType) -> TensorType:
@@ -650,22 +607,3 @@ class TFDataHandler(DataHandler):
650
607
  if len(label.shape) > 1:
651
608
  label = tf.reshape(label, [label.shape[0]])
652
609
  return label
653
-
654
- @staticmethod
655
- def get_feature(
656
- dataset: tf.data.Dataset, feature_key: Union[str, int]
657
- ) -> tf.data.Dataset:
658
- """Extract a feature from a dataset
659
-
660
- Args:
661
- dataset (tf.data.Dataset): Dataset to extract the feature from
662
- feature_key (Union[str, int]): feature to extract
663
-
664
- Returns:
665
- tf.data.Dataset: dataset built with the extracted feature only
666
- """
667
-
668
- def _get_feature_elem(elem):
669
- return elem[feature_key]
670
-
671
- return dataset.map(_get_feature_elem)