oodeel 0.1.1__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of oodeel might be problematic. Click here for more details.
- oodeel/__init__.py +1 -1
- oodeel/datasets/__init__.py +2 -1
- oodeel/datasets/data_handler.py +162 -94
- oodeel/datasets/deprecated/DEPRECATED_data_handler.py +236 -0
- oodeel/datasets/{ooddataset.py → deprecated/DEPRECATED_ooddataset.py} +14 -13
- oodeel/datasets/deprecated/DEPRECATED_tf_data_handler.py +671 -0
- oodeel/datasets/deprecated/DEPRECATED_torch_data_handler.py +769 -0
- oodeel/datasets/deprecated/__init__.py +31 -0
- oodeel/datasets/tf_data_handler.py +105 -167
- oodeel/datasets/torch_data_handler.py +109 -181
- oodeel/eval/metrics.py +7 -2
- oodeel/eval/plots/features.py +2 -2
- oodeel/eval/plots/plotly.py +2 -2
- oodeel/extractor/feature_extractor.py +30 -9
- oodeel/extractor/keras_feature_extractor.py +70 -13
- oodeel/extractor/torch_feature_extractor.py +120 -33
- oodeel/methods/__init__.py +17 -1
- oodeel/methods/base.py +103 -17
- oodeel/methods/dknn.py +22 -9
- oodeel/methods/energy.py +8 -0
- oodeel/methods/entropy.py +8 -0
- oodeel/methods/gen.py +118 -0
- oodeel/methods/gram.py +307 -0
- oodeel/methods/mahalanobis.py +14 -12
- oodeel/methods/mls.py +8 -0
- oodeel/methods/odin.py +8 -0
- oodeel/methods/rmds.py +122 -0
- oodeel/methods/she.py +197 -0
- oodeel/methods/vim.py +5 -5
- oodeel/preprocess/__init__.py +31 -0
- oodeel/preprocess/tf_preprocess.py +95 -0
- oodeel/preprocess/torch_preprocess.py +97 -0
- oodeel/utils/operator.py +72 -2
- oodeel/utils/tf_operator.py +72 -4
- oodeel/utils/tf_training_tools.py +26 -3
- oodeel/utils/torch_operator.py +75 -4
- oodeel/utils/torch_training_tools.py +31 -2
- {oodeel-0.1.1.dist-info → oodeel-0.3.0.dist-info}/METADATA +141 -107
- oodeel-0.3.0.dist-info/RECORD +57 -0
- {oodeel-0.1.1.dist-info → oodeel-0.3.0.dist-info}/WHEEL +1 -1
- tests/tests_tensorflow/tf_methods_utils.py +2 -1
- tests/tests_torch/tools_torch.py +9 -9
- tests/tests_torch/torch_methods_utils.py +34 -27
- tests/tools_operator.py +10 -1
- oodeel-0.1.1.dist-info/RECORD +0 -46
- {oodeel-0.1.1.dist-info → oodeel-0.3.0.dist-info/licenses}/LICENSE +0 -0
- {oodeel-0.1.1.dist-info → oodeel-0.3.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Copyright IRT Antoine de Saint Exupéry et Université Paul Sabatier Toulouse III - All
|
|
3
|
+
# rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
|
|
4
|
+
# CRIAQ and ANITI - https://www.deel.ai/
|
|
5
|
+
#
|
|
6
|
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
7
|
+
# of this software and associated documentation files (the "Software"), to deal
|
|
8
|
+
# in the Software without restriction, including without limitation the rights
|
|
9
|
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
10
|
+
# copies of the Software, and to permit persons to whom the Software is
|
|
11
|
+
# furnished to do so, subject to the following conditions:
|
|
12
|
+
#
|
|
13
|
+
# The above copyright notice and this permission notice shall be included in all
|
|
14
|
+
# copies or substantial portions of the Software.
|
|
15
|
+
#
|
|
16
|
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
17
|
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
18
|
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
19
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
20
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
21
|
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
22
|
+
# SOFTWARE.
|
|
23
|
+
import warnings
|
|
24
|
+
|
|
25
|
+
from .DEPRECATED_ooddataset import OODDataset
|
|
26
|
+
|
|
27
|
+
warnings.warn(
|
|
28
|
+
"The 'OODDataset' object is deprecated and will be removed in a future release.",
|
|
29
|
+
DeprecationWarning,
|
|
30
|
+
stacklevel=2,
|
|
31
|
+
)
|
|
@@ -22,7 +22,6 @@
|
|
|
22
22
|
# SOFTWARE.
|
|
23
23
|
from typing import get_args
|
|
24
24
|
|
|
25
|
-
import numpy as np
|
|
26
25
|
import tensorflow as tf
|
|
27
26
|
import tensorflow_datasets as tfds
|
|
28
27
|
|
|
@@ -36,9 +35,10 @@ from .data_handler import DataHandler
|
|
|
36
35
|
|
|
37
36
|
|
|
38
37
|
def dict_only_ds(ds_handling_method: Callable) -> Callable:
|
|
39
|
-
"""Decorator to ensure that the dataset is a dict dataset and that the
|
|
40
|
-
matches one of the
|
|
41
|
-
|
|
38
|
+
"""Decorator to ensure that the dataset is a dict dataset and that the column_name
|
|
39
|
+
given as argument matches one of the column names.
|
|
40
|
+
matches one of the column names. The signature of decorated functions
|
|
41
|
+
must be function(dataset, *args, **kwargs) with column_name either in kwargs or
|
|
42
42
|
args[0] when relevant.
|
|
43
43
|
|
|
44
44
|
|
|
@@ -52,19 +52,19 @@ def dict_only_ds(ds_handling_method: Callable) -> Callable:
|
|
|
52
52
|
def wrapper(dataset: tf.data.Dataset, *args, **kwargs):
|
|
53
53
|
assert isinstance(dataset.element_spec, dict), "dataset elements must be dicts"
|
|
54
54
|
|
|
55
|
-
if "
|
|
56
|
-
|
|
55
|
+
if "column_name" in kwargs.keys():
|
|
56
|
+
column_name = kwargs["column_name"]
|
|
57
57
|
elif len(args) > 0:
|
|
58
|
-
|
|
58
|
+
column_name = args[0]
|
|
59
59
|
|
|
60
|
-
# If
|
|
61
|
-
if (len(args) > 0) or ("
|
|
62
|
-
if isinstance(
|
|
63
|
-
|
|
64
|
-
for
|
|
60
|
+
# If column_name is provided, check that it is in the dataset column names
|
|
61
|
+
if (len(args) > 0) or ("column_name" in kwargs):
|
|
62
|
+
if isinstance(column_name, str):
|
|
63
|
+
column_name = [column_name]
|
|
64
|
+
for name in column_name:
|
|
65
65
|
assert (
|
|
66
|
-
|
|
67
|
-
), f"The input dataset has no
|
|
66
|
+
name in dataset.element_spec.keys()
|
|
67
|
+
), f"The input dataset has no column named {name}"
|
|
68
68
|
return ds_handling_method(dataset, *args, **kwargs)
|
|
69
69
|
|
|
70
70
|
return wrapper
|
|
@@ -77,45 +77,54 @@ class TFDataHandler(DataHandler):
|
|
|
77
77
|
tensorflow syntax.
|
|
78
78
|
"""
|
|
79
79
|
|
|
80
|
+
def __init__(self) -> None:
|
|
81
|
+
super().__init__()
|
|
82
|
+
self.backend = "tensorflow"
|
|
83
|
+
self.channel_order = "channels_last"
|
|
84
|
+
|
|
80
85
|
@classmethod
|
|
81
86
|
def load_dataset(
|
|
82
87
|
cls,
|
|
83
88
|
dataset_id: Union[tf.data.Dataset, ItemType, str],
|
|
84
|
-
|
|
89
|
+
columns: Optional[list] = None,
|
|
85
90
|
load_kwargs: dict = {},
|
|
86
91
|
) -> tf.data.Dataset:
|
|
87
92
|
"""Load dataset from different manners, ensuring to return a dict based
|
|
88
93
|
tf.data.Dataset.
|
|
89
94
|
|
|
90
95
|
Args:
|
|
91
|
-
dataset_id (
|
|
92
|
-
|
|
93
|
-
|
|
96
|
+
dataset_id (Union[tf.data.Dataset, ItemType, str]): dataset identification.
|
|
97
|
+
Can be the name of a dataset from tensorflow_datasets, a tf.data.Dataset,
|
|
98
|
+
or a tuple/dict of np.ndarrays/tf.Tensors.
|
|
99
|
+
columns (list, optional): Column names. If None, assigned as "input_i"
|
|
100
|
+
for i-th column. Defaults to None.
|
|
94
101
|
load_kwargs (dict, optional): Additional args for loading from
|
|
95
102
|
tensorflow_datasets. Defaults to {}.
|
|
96
103
|
|
|
97
104
|
Returns:
|
|
98
105
|
tf.data.Dataset: A dict based tf.data.Dataset
|
|
99
106
|
"""
|
|
107
|
+
load_kwargs["as_supervised"] = False
|
|
108
|
+
|
|
100
109
|
if isinstance(dataset_id, get_args(ItemType)):
|
|
101
|
-
dataset = cls.load_dataset_from_arrays(dataset_id,
|
|
110
|
+
dataset = cls.load_dataset_from_arrays(dataset_id, columns)
|
|
102
111
|
elif isinstance(dataset_id, tf.data.Dataset):
|
|
103
|
-
dataset = cls.load_custom_dataset(dataset_id,
|
|
112
|
+
dataset = cls.load_custom_dataset(dataset_id, columns)
|
|
104
113
|
elif isinstance(dataset_id, str):
|
|
105
114
|
dataset = cls.load_from_tensorflow_datasets(dataset_id, load_kwargs)
|
|
106
115
|
return dataset
|
|
107
116
|
|
|
108
117
|
@staticmethod
|
|
109
118
|
def load_dataset_from_arrays(
|
|
110
|
-
dataset_id: ItemType,
|
|
119
|
+
dataset_id: ItemType, columns: Optional[list] = None
|
|
111
120
|
) -> tf.data.Dataset:
|
|
112
121
|
"""Load a tf.data.Dataset from a np.ndarray, a tf.Tensor or a tuple/dict
|
|
113
|
-
of np.ndarrays/
|
|
122
|
+
of np.ndarrays/tf.Tensors.
|
|
114
123
|
|
|
115
124
|
Args:
|
|
116
125
|
dataset_id (ItemType): numpy array(s) to load.
|
|
117
|
-
|
|
118
|
-
for i-th
|
|
126
|
+
columns (list, optional): Column names to assign. If None,
|
|
127
|
+
assigned as "input_i" for i-th column. Defaults to None.
|
|
119
128
|
|
|
120
129
|
Returns:
|
|
121
130
|
tf.data.Dataset
|
|
@@ -127,7 +136,7 @@ class TFDataHandler(DataHandler):
|
|
|
127
136
|
# If dataset_id is a tuple, convert it to a dict
|
|
128
137
|
elif isinstance(dataset_id, tuple):
|
|
129
138
|
len_elem = len(dataset_id)
|
|
130
|
-
if
|
|
139
|
+
if columns is None:
|
|
131
140
|
if len_elem == 2:
|
|
132
141
|
dataset_dict = {"input": dataset_id[0], "label": dataset_id[1]}
|
|
133
142
|
else:
|
|
@@ -142,19 +151,19 @@ class TFDataHandler(DataHandler):
|
|
|
142
151
|
)
|
|
143
152
|
else:
|
|
144
153
|
assert (
|
|
145
|
-
len(
|
|
146
|
-
), "Number of
|
|
147
|
-
dataset_dict = {
|
|
154
|
+
len(columns) == len_elem
|
|
155
|
+
), "Number of column names mismatch with the number of columns"
|
|
156
|
+
dataset_dict = {columns[i]: dataset_id[i] for i in range(len_elem)}
|
|
148
157
|
|
|
149
158
|
elif isinstance(dataset_id, dict):
|
|
150
|
-
if
|
|
159
|
+
if columns is not None:
|
|
151
160
|
len_elem = len(dataset_id)
|
|
152
161
|
assert (
|
|
153
|
-
len(
|
|
154
|
-
), "Number of
|
|
155
|
-
|
|
162
|
+
len(columns) == len_elem
|
|
163
|
+
), "Number of column names mismatch with the number of columns"
|
|
164
|
+
original_columns = list(dataset_id.keys())
|
|
156
165
|
dataset_dict = {
|
|
157
|
-
|
|
166
|
+
columns[i]: dataset_id[original_columns[i]] for i in range(len_elem)
|
|
158
167
|
}
|
|
159
168
|
|
|
160
169
|
dataset = tf.data.Dataset.from_tensor_slices(dataset_dict)
|
|
@@ -162,14 +171,15 @@ class TFDataHandler(DataHandler):
|
|
|
162
171
|
|
|
163
172
|
@classmethod
|
|
164
173
|
def load_custom_dataset(
|
|
165
|
-
cls, dataset_id: tf.data.Dataset,
|
|
174
|
+
cls, dataset_id: tf.data.Dataset, columns: Optional[list] = None
|
|
166
175
|
) -> tf.data.Dataset:
|
|
167
176
|
"""Load a custom Dataset by ensuring it has the correct format (dict-based)
|
|
168
177
|
|
|
169
178
|
Args:
|
|
170
179
|
dataset_id (tf.data.Dataset): tf.data.Dataset
|
|
171
|
-
|
|
172
|
-
|
|
180
|
+
columns (list, optional): Column names to use for elements if dataset_id is
|
|
181
|
+
tuple based. If None, assigned as "input_i"
|
|
182
|
+
for i-th column. Defaults to None.
|
|
173
183
|
|
|
174
184
|
Returns:
|
|
175
185
|
tf.data.Dataset
|
|
@@ -177,22 +187,22 @@ class TFDataHandler(DataHandler):
|
|
|
177
187
|
# If dataset_id is a tuple based tf.data.dataset, convert it to a dict
|
|
178
188
|
if not isinstance(dataset_id.element_spec, dict):
|
|
179
189
|
len_elem = len(dataset_id.element_spec)
|
|
180
|
-
if
|
|
190
|
+
if columns is None:
|
|
181
191
|
print(
|
|
182
|
-
"
|
|
192
|
+
"Column name not found, assigning 'input_i' "
|
|
183
193
|
"key to the i-th tensor and 'label' key to the last"
|
|
184
194
|
)
|
|
185
195
|
if len_elem == 2:
|
|
186
|
-
|
|
196
|
+
columns = ["input", "label"]
|
|
187
197
|
else:
|
|
188
|
-
|
|
189
|
-
|
|
198
|
+
columns = [f"input_{i}" for i in range(len_elem)]
|
|
199
|
+
columns[-1] = "label"
|
|
190
200
|
else:
|
|
191
201
|
assert (
|
|
192
|
-
len(
|
|
193
|
-
), "Number of
|
|
202
|
+
len(columns) == len_elem
|
|
203
|
+
), "Number of column names mismatch with the number of columns"
|
|
194
204
|
|
|
195
|
-
dataset_id = cls.tuple_to_dict(dataset_id,
|
|
205
|
+
dataset_id = cls.tuple_to_dict(dataset_id, columns)
|
|
196
206
|
|
|
197
207
|
dataset = dataset_id
|
|
198
208
|
return dataset
|
|
@@ -221,30 +231,30 @@ class TFDataHandler(DataHandler):
|
|
|
221
231
|
@staticmethod
|
|
222
232
|
@dict_only_ds
|
|
223
233
|
def dict_to_tuple(
|
|
224
|
-
dataset: tf.data.Dataset,
|
|
234
|
+
dataset: tf.data.Dataset, columns: Optional[list] = None
|
|
225
235
|
) -> tf.data.Dataset:
|
|
226
236
|
"""Turn a dict based tf.data.Dataset to a tuple based tf.data.Dataset
|
|
227
237
|
|
|
228
238
|
Args:
|
|
229
239
|
dataset (tf.data.Dataset): Dict based tf.data.Dataset
|
|
230
|
-
|
|
231
|
-
tf.data.Dataset. If None, takes all the
|
|
240
|
+
columns (list, optional): Columns to use for the tuples based
|
|
241
|
+
tf.data.Dataset. If None, takes all the columns. Defaults to None.
|
|
232
242
|
|
|
233
243
|
Returns:
|
|
234
244
|
tf.data.Dataset
|
|
235
245
|
"""
|
|
236
|
-
if
|
|
237
|
-
|
|
238
|
-
dataset = dataset.map(lambda x: tuple(x[k] for k in
|
|
246
|
+
if columns is None:
|
|
247
|
+
columns = list(dataset.element_spec.keys())
|
|
248
|
+
dataset = dataset.map(lambda x: tuple(x[k] for k in columns))
|
|
239
249
|
return dataset
|
|
240
250
|
|
|
241
251
|
@staticmethod
|
|
242
|
-
def tuple_to_dict(dataset: tf.data.Dataset,
|
|
252
|
+
def tuple_to_dict(dataset: tf.data.Dataset, columns: list) -> tf.data.Dataset:
|
|
243
253
|
"""Turn a tuple based tf.data.Dataset to a dict based tf.data.Dataset
|
|
244
254
|
|
|
245
255
|
Args:
|
|
246
256
|
dataset (tf.data.Dataset): Tuple based tf.data.Dataset
|
|
247
|
-
|
|
257
|
+
columns (list): Column names to use for the dict based tf.data.Dataset
|
|
248
258
|
|
|
249
259
|
Returns:
|
|
250
260
|
tf.data.Dataset
|
|
@@ -254,86 +264,28 @@ class TFDataHandler(DataHandler):
|
|
|
254
264
|
), "dataset elements must be tuples"
|
|
255
265
|
len_elem = len(dataset.element_spec)
|
|
256
266
|
assert len_elem == len(
|
|
257
|
-
|
|
258
|
-
), "The number of
|
|
267
|
+
columns
|
|
268
|
+
), "The number of columns must be equal to the number of tuple elements"
|
|
259
269
|
|
|
260
270
|
def tuple_to_dict(*inputs):
|
|
261
|
-
return {
|
|
271
|
+
return {columns[i]: inputs[i] for i in range(len_elem)}
|
|
262
272
|
|
|
263
273
|
dataset = dataset.map(tuple_to_dict)
|
|
264
274
|
return dataset
|
|
265
275
|
|
|
266
|
-
@staticmethod
|
|
267
|
-
def assign_feature_value(
|
|
268
|
-
dataset: tf.data.Dataset, feature_key: str, value: int
|
|
269
|
-
) -> tf.data.Dataset:
|
|
270
|
-
"""Assign a value to a feature for every sample in a tf.data.Dataset
|
|
271
|
-
|
|
272
|
-
Args:
|
|
273
|
-
dataset (tf.data.Dataset): tf.data.Dataset to assign the value to
|
|
274
|
-
feature_key (str): Feature to assign the value to
|
|
275
|
-
value (int): Value to assign
|
|
276
|
-
|
|
277
|
-
Returns:
|
|
278
|
-
tf.data.Dataset
|
|
279
|
-
"""
|
|
280
|
-
assert isinstance(dataset.element_spec, dict), "dataset elements must be dicts"
|
|
281
|
-
|
|
282
|
-
def assign_value_to_feature(x):
|
|
283
|
-
x[feature_key] = value
|
|
284
|
-
return x
|
|
285
|
-
|
|
286
|
-
dataset = dataset.map(assign_value_to_feature)
|
|
287
|
-
return dataset
|
|
288
|
-
|
|
289
|
-
@staticmethod
|
|
290
|
-
@dict_only_ds
|
|
291
|
-
def get_feature_from_ds(dataset: tf.data.Dataset, feature_key: str) -> np.ndarray:
|
|
292
|
-
"""Get a feature from a tf.data.Dataset
|
|
293
|
-
|
|
294
|
-
!!! note
|
|
295
|
-
This function can be a bit time consuming since it needs to iterate
|
|
296
|
-
over the whole dataset.
|
|
297
|
-
|
|
298
|
-
Args:
|
|
299
|
-
dataset (tf.data.Dataset): tf.data.Dataset to get the feature from
|
|
300
|
-
feature_key (str): Feature value to get
|
|
301
|
-
|
|
302
|
-
Returns:
|
|
303
|
-
np.ndarray: Feature values for dataset
|
|
304
|
-
"""
|
|
305
|
-
features = dataset.map(lambda x: x[feature_key])
|
|
306
|
-
features = list(features.as_numpy_iterator())
|
|
307
|
-
features = np.array(features)
|
|
308
|
-
return features
|
|
309
|
-
|
|
310
276
|
@staticmethod
|
|
311
277
|
@dict_only_ds
|
|
312
|
-
def
|
|
313
|
-
"""Get the
|
|
278
|
+
def get_ds_column_names(dataset: tf.data.Dataset) -> list:
|
|
279
|
+
"""Get the column names of a tf.data.Dataset
|
|
314
280
|
|
|
315
281
|
Args:
|
|
316
|
-
dataset (tf.data.Dataset): tf.data.Dataset to get the
|
|
282
|
+
dataset (tf.data.Dataset): tf.data.Dataset to get the column names from
|
|
317
283
|
|
|
318
284
|
Returns:
|
|
319
|
-
list: List of
|
|
285
|
+
list: List of column names
|
|
320
286
|
"""
|
|
321
287
|
return list(dataset.element_spec.keys())
|
|
322
288
|
|
|
323
|
-
@staticmethod
|
|
324
|
-
def has_feature_key(dataset: tf.data.Dataset, key: str) -> bool:
|
|
325
|
-
"""Check if a tf.data.Dataset has a feature denoted by key
|
|
326
|
-
|
|
327
|
-
Args:
|
|
328
|
-
dataset (tf.data.Dataset): tf.data.Dataset to check
|
|
329
|
-
key (str): Key to check
|
|
330
|
-
|
|
331
|
-
Returns:
|
|
332
|
-
bool: If the tf.data.Dataset has a feature denoted by key
|
|
333
|
-
"""
|
|
334
|
-
assert isinstance(dataset.element_spec, dict), "dataset elements must be dicts"
|
|
335
|
-
return True if (key in dataset.element_spec.keys()) else False
|
|
336
|
-
|
|
337
289
|
@staticmethod
|
|
338
290
|
def map_ds(
|
|
339
291
|
dataset: tf.data.Dataset,
|
|
@@ -358,35 +310,35 @@ class TFDataHandler(DataHandler):
|
|
|
358
310
|
|
|
359
311
|
@staticmethod
|
|
360
312
|
@dict_only_ds
|
|
361
|
-
def
|
|
313
|
+
def filter_by_value(
|
|
362
314
|
dataset: tf.data.Dataset,
|
|
363
|
-
|
|
315
|
+
column_name: str,
|
|
364
316
|
values: list,
|
|
365
317
|
excluded: bool = False,
|
|
366
318
|
) -> tf.data.Dataset:
|
|
367
|
-
"""Filter a tf.data.Dataset by checking the value of a
|
|
319
|
+
"""Filter a tf.data.Dataset by checking if the value of a column is in 'values'
|
|
368
320
|
|
|
369
321
|
Args:
|
|
370
322
|
dataset (tf.data.Dataset): tf.data.Dataset to filter
|
|
371
|
-
|
|
372
|
-
values (list):
|
|
323
|
+
column_name (str): Column to filter the dataset with
|
|
324
|
+
values (list): Column values to keep (if excluded is False)
|
|
373
325
|
or to exclude
|
|
374
326
|
excluded (bool, optional): To keep (False) or exclude (True) the samples
|
|
375
|
-
with
|
|
327
|
+
with Column values included in Values. Defaults to False.
|
|
376
328
|
|
|
377
329
|
Returns:
|
|
378
330
|
tf.data.Dataset: Filtered dataset
|
|
379
331
|
"""
|
|
380
332
|
# If the labels are one-hot encoded, prepare a function to get the label as int
|
|
381
|
-
if len(dataset.element_spec[
|
|
333
|
+
if len(dataset.element_spec[column_name].shape) > 0:
|
|
382
334
|
|
|
383
335
|
def get_label_int(elem):
|
|
384
|
-
return int(tf.argmax(elem[
|
|
336
|
+
return int(tf.argmax(elem[column_name]))
|
|
385
337
|
|
|
386
338
|
else:
|
|
387
339
|
|
|
388
340
|
def get_label_int(elem):
|
|
389
|
-
return elem[
|
|
341
|
+
return elem[column_name]
|
|
390
342
|
|
|
391
343
|
def filter_fn(elem):
|
|
392
344
|
value = get_label_int(elem)
|
|
@@ -400,15 +352,16 @@ class TFDataHandler(DataHandler):
|
|
|
400
352
|
return dataset_to_filter
|
|
401
353
|
|
|
402
354
|
@classmethod
|
|
403
|
-
def
|
|
355
|
+
def prepare(
|
|
404
356
|
cls,
|
|
405
357
|
dataset: tf.data.Dataset,
|
|
406
358
|
batch_size: int,
|
|
407
|
-
shuffle: bool = False,
|
|
408
359
|
preprocess_fn: Optional[Callable] = None,
|
|
409
360
|
augment_fn: Optional[Callable] = None,
|
|
410
|
-
|
|
411
|
-
|
|
361
|
+
columns: Optional[list] = None,
|
|
362
|
+
shuffle: bool = False,
|
|
363
|
+
dict_based_fns: bool = True,
|
|
364
|
+
return_tuple: bool = True,
|
|
412
365
|
shuffle_buffer_size: Optional[int] = None,
|
|
413
366
|
prefetch_buffer_size: Optional[int] = None,
|
|
414
367
|
drop_remainder: Optional[bool] = False,
|
|
@@ -418,16 +371,18 @@ class TFDataHandler(DataHandler):
|
|
|
418
371
|
Args:
|
|
419
372
|
dataset (tf.data.Dataset): tf.data.Dataset to prepare
|
|
420
373
|
batch_size (int): Batch size
|
|
421
|
-
|
|
422
|
-
Defaults to False.
|
|
423
|
-
preprocess_fn (Callable, optional): Preprocessing function to apply to\
|
|
374
|
+
preprocess_fn (Callable, optional): Preprocessing function to apply to
|
|
424
375
|
the dataset. Defaults to None.
|
|
425
|
-
augment_fn (Callable, optional): Augment function to be used (when the
|
|
376
|
+
augment_fn (Callable, optional): Augment function to be used (when the
|
|
426
377
|
returned dataset is to be used for training). Defaults to None.
|
|
427
|
-
|
|
428
|
-
that will be returned. Keep all
|
|
429
|
-
|
|
430
|
-
|
|
378
|
+
columns (list, optional): List of column names corresponding to the columns
|
|
379
|
+
that will be returned. Keep all columns if None. Defaults to None.
|
|
380
|
+
shuffle (bool, optional): To shuffle the returned dataset or not.
|
|
381
|
+
Defaults to False.
|
|
382
|
+
dict_based_fns (bool): Whether to use preprocess and DA functions as dict
|
|
383
|
+
based (if True) or as tuple based (if False). Defaults to True.
|
|
384
|
+
return_tuple (bool, optional): Whether to return each dataset item
|
|
385
|
+
as a tuple. Defaults to True.
|
|
431
386
|
shuffle_buffer_size (int, optional): Size of the shuffle buffer. If None,
|
|
432
387
|
taken as the number of samples in the dataset. Defaults to None.
|
|
433
388
|
prefetch_buffer_size (Optional[int], optional): Buffer size for prefetch.
|
|
@@ -440,9 +395,9 @@ class TFDataHandler(DataHandler):
|
|
|
440
395
|
tf.data.Dataset: Prepared dataset
|
|
441
396
|
"""
|
|
442
397
|
# dict based to tuple based
|
|
443
|
-
|
|
398
|
+
columns = columns or cls.get_ds_column_names(dataset)
|
|
444
399
|
if not dict_based_fns:
|
|
445
|
-
dataset = cls.dict_to_tuple(dataset,
|
|
400
|
+
dataset = cls.dict_to_tuple(dataset, columns)
|
|
446
401
|
|
|
447
402
|
# preprocess + DA
|
|
448
403
|
if preprocess_fn is not None:
|
|
@@ -450,8 +405,8 @@ class TFDataHandler(DataHandler):
|
|
|
450
405
|
if augment_fn is not None:
|
|
451
406
|
dataset = cls.map_ds(dataset, augment_fn)
|
|
452
407
|
|
|
453
|
-
if dict_based_fns:
|
|
454
|
-
dataset = cls.dict_to_tuple(dataset,
|
|
408
|
+
if dict_based_fns and return_tuple:
|
|
409
|
+
dataset = cls.dict_to_tuple(dataset, columns)
|
|
455
410
|
|
|
456
411
|
dataset = dataset.cache()
|
|
457
412
|
|
|
@@ -598,19 +553,21 @@ class TFDataHandler(DataHandler):
|
|
|
598
553
|
return int(cardinality)
|
|
599
554
|
|
|
600
555
|
@staticmethod
|
|
601
|
-
def
|
|
602
|
-
dataset: tf.data.Dataset,
|
|
556
|
+
def get_column_elements_shape(
|
|
557
|
+
dataset: tf.data.Dataset, column_name: Union[str, int]
|
|
603
558
|
) -> tuple:
|
|
604
|
-
"""Get the shape of a
|
|
559
|
+
"""Get the shape of the elements of a column of dataset identified by
|
|
560
|
+
column_name
|
|
605
561
|
|
|
606
562
|
Args:
|
|
607
563
|
dataset (tf.data.Dataset): a tf.data.dataset
|
|
608
|
-
|
|
564
|
+
column_name (Union[str, int]): The column name to get
|
|
565
|
+
the element shape from.
|
|
609
566
|
|
|
610
567
|
Returns:
|
|
611
|
-
tuple: the shape of
|
|
568
|
+
tuple: the shape of an element from column_name
|
|
612
569
|
"""
|
|
613
|
-
return tuple(dataset.element_spec[
|
|
570
|
+
return tuple(dataset.element_spec[column_name].shape)
|
|
614
571
|
|
|
615
572
|
@staticmethod
|
|
616
573
|
def get_input_from_dataset_item(elem: ItemType) -> TensorType:
|
|
@@ -650,22 +607,3 @@ class TFDataHandler(DataHandler):
|
|
|
650
607
|
if len(label.shape) > 1:
|
|
651
608
|
label = tf.reshape(label, [label.shape[0]])
|
|
652
609
|
return label
|
|
653
|
-
|
|
654
|
-
@staticmethod
|
|
655
|
-
def get_feature(
|
|
656
|
-
dataset: tf.data.Dataset, feature_key: Union[str, int]
|
|
657
|
-
) -> tf.data.Dataset:
|
|
658
|
-
"""Extract a feature from a dataset
|
|
659
|
-
|
|
660
|
-
Args:
|
|
661
|
-
dataset (tf.data.Dataset): Dataset to extract the feature from
|
|
662
|
-
feature_key (Union[str, int]): feature to extract
|
|
663
|
-
|
|
664
|
-
Returns:
|
|
665
|
-
tf.data.Dataset: dataset built with the extracted feature only
|
|
666
|
-
"""
|
|
667
|
-
|
|
668
|
-
def _get_feature_elem(elem):
|
|
669
|
-
return elem[feature_key]
|
|
670
|
-
|
|
671
|
-
return dataset.map(_get_feature_elem)
|