oodeel 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- oodeel/__init__.py +28 -0
- oodeel/aggregator/__init__.py +26 -0
- oodeel/aggregator/base.py +70 -0
- oodeel/aggregator/fisher.py +259 -0
- oodeel/aggregator/mean.py +72 -0
- oodeel/aggregator/std.py +86 -0
- oodeel/datasets/__init__.py +24 -0
- oodeel/datasets/data_handler.py +334 -0
- oodeel/datasets/deprecated/DEPRECATED_data_handler.py +236 -0
- oodeel/datasets/deprecated/DEPRECATED_ooddataset.py +330 -0
- oodeel/datasets/deprecated/DEPRECATED_tf_data_handler.py +671 -0
- oodeel/datasets/deprecated/DEPRECATED_torch_data_handler.py +769 -0
- oodeel/datasets/deprecated/__init__.py +31 -0
- oodeel/datasets/tf_data_handler.py +600 -0
- oodeel/datasets/torch_data_handler.py +672 -0
- oodeel/eval/__init__.py +22 -0
- oodeel/eval/metrics.py +218 -0
- oodeel/eval/plots/__init__.py +27 -0
- oodeel/eval/plots/features.py +345 -0
- oodeel/eval/plots/metrics.py +118 -0
- oodeel/eval/plots/plotly.py +162 -0
- oodeel/extractor/__init__.py +35 -0
- oodeel/extractor/feature_extractor.py +187 -0
- oodeel/extractor/hf_torch_feature_extractor.py +184 -0
- oodeel/extractor/keras_feature_extractor.py +409 -0
- oodeel/extractor/torch_feature_extractor.py +506 -0
- oodeel/methods/__init__.py +47 -0
- oodeel/methods/base.py +570 -0
- oodeel/methods/dknn.py +185 -0
- oodeel/methods/energy.py +119 -0
- oodeel/methods/entropy.py +113 -0
- oodeel/methods/gen.py +113 -0
- oodeel/methods/gram.py +274 -0
- oodeel/methods/mahalanobis.py +209 -0
- oodeel/methods/mls.py +113 -0
- oodeel/methods/odin.py +109 -0
- oodeel/methods/rmds.py +172 -0
- oodeel/methods/she.py +159 -0
- oodeel/methods/vim.py +273 -0
- oodeel/preprocess/__init__.py +31 -0
- oodeel/preprocess/tf_preprocess.py +95 -0
- oodeel/preprocess/torch_preprocess.py +97 -0
- oodeel/types/__init__.py +75 -0
- oodeel/utils/__init__.py +38 -0
- oodeel/utils/general_utils.py +97 -0
- oodeel/utils/operator.py +253 -0
- oodeel/utils/tf_operator.py +269 -0
- oodeel/utils/tf_training_tools.py +219 -0
- oodeel/utils/torch_operator.py +292 -0
- oodeel/utils/torch_training_tools.py +303 -0
- oodeel-0.4.0.dist-info/METADATA +409 -0
- oodeel-0.4.0.dist-info/RECORD +63 -0
- oodeel-0.4.0.dist-info/WHEEL +5 -0
- oodeel-0.4.0.dist-info/licenses/LICENSE +21 -0
- oodeel-0.4.0.dist-info/top_level.txt +2 -0
- tests/__init__.py +22 -0
- tests/tests_tensorflow/__init__.py +37 -0
- tests/tests_tensorflow/tf_methods_utils.py +140 -0
- tests/tests_tensorflow/tools_tf.py +86 -0
- tests/tests_torch/__init__.py +38 -0
- tests/tests_torch/tools_torch.py +151 -0
- tests/tests_torch/torch_methods_utils.py +148 -0
- tests/tools_operator.py +153 -0
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Copyright IRT Antoine de Saint Exupéry et Université Paul Sabatier Toulouse III - All
|
|
3
|
+
# rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
|
|
4
|
+
# CRIAQ and ANITI - https://www.deel.ai/
|
|
5
|
+
#
|
|
6
|
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
7
|
+
# of this software and associated documentation files (the "Software"), to deal
|
|
8
|
+
# in the Software without restriction, including without limitation the rights
|
|
9
|
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
10
|
+
# copies of the Software, and to permit persons to whom the Software is
|
|
11
|
+
# furnished to do so, subject to the following conditions:
|
|
12
|
+
#
|
|
13
|
+
# The above copyright notice and this permission notice shall be included in all
|
|
14
|
+
# copies or substantial portions of the Software.
|
|
15
|
+
#
|
|
16
|
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
17
|
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
18
|
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
19
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
20
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
21
|
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
22
|
+
# SOFTWARE.
|
|
23
|
+
import warnings
|
|
24
|
+
|
|
25
|
+
from .DEPRECATED_ooddataset import OODDataset
|
|
26
|
+
|
|
27
|
+
warnings.warn(
|
|
28
|
+
"The 'OODDataset' object is deprecated and will be removed in a future release.",
|
|
29
|
+
DeprecationWarning,
|
|
30
|
+
stacklevel=2,
|
|
31
|
+
)
|
|
@@ -0,0 +1,600 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Copyright IRT Antoine de Saint Exupéry et Université Paul Sabatier Toulouse III - All
|
|
3
|
+
# rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
|
|
4
|
+
# CRIAQ and ANITI - https://www.deel.ai/
|
|
5
|
+
#
|
|
6
|
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
7
|
+
# of this software and associated documentation files (the "Software"), to deal
|
|
8
|
+
# in the Software without restriction, including without limitation the rights
|
|
9
|
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
10
|
+
# copies of the Software, and to permit persons to whom the Software is
|
|
11
|
+
# furnished to do so, subject to the following conditions:
|
|
12
|
+
#
|
|
13
|
+
# The above copyright notice and this permission notice shall be included in all
|
|
14
|
+
# copies or substantial portions of the Software.
|
|
15
|
+
#
|
|
16
|
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
17
|
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
18
|
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
19
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
20
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
21
|
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
22
|
+
# SOFTWARE.
|
|
23
|
+
from typing import get_args
|
|
24
|
+
|
|
25
|
+
import tensorflow as tf
|
|
26
|
+
import tensorflow_datasets as tfds
|
|
27
|
+
from datasets import load_dataset as hf_load_dataset
|
|
28
|
+
|
|
29
|
+
from ..types import Callable
|
|
30
|
+
from ..types import ItemType
|
|
31
|
+
from ..types import Optional
|
|
32
|
+
from ..types import TensorType
|
|
33
|
+
from ..types import Union
|
|
34
|
+
from .data_handler import DataHandler
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def dict_only_ds(ds_handling_method: Callable) -> Callable:
|
|
38
|
+
"""Decorator to ensure that the dataset is a dict dataset and that the column_name
|
|
39
|
+
given as argument matches one of the column names.
|
|
40
|
+
matches one of the column names. The signature of decorated functions
|
|
41
|
+
must be function(dataset, *args, **kwargs) with column_name either in kwargs or
|
|
42
|
+
args[0] when relevant.
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
ds_handling_method: method to decorate
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
decorated method
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
def wrapper(dataset: tf.data.Dataset, *args, **kwargs):
|
|
53
|
+
assert isinstance(dataset.element_spec, dict), "dataset elements must be dicts"
|
|
54
|
+
|
|
55
|
+
if "column_name" in kwargs.keys():
|
|
56
|
+
column_name = kwargs["column_name"]
|
|
57
|
+
elif len(args) > 0:
|
|
58
|
+
column_name = args[0]
|
|
59
|
+
|
|
60
|
+
# If column_name is provided, check that it is in the dataset column names
|
|
61
|
+
if (len(args) > 0) or ("column_name" in kwargs):
|
|
62
|
+
if isinstance(column_name, str):
|
|
63
|
+
column_name = [column_name]
|
|
64
|
+
for name in column_name:
|
|
65
|
+
assert (
|
|
66
|
+
name in dataset.element_spec.keys()
|
|
67
|
+
), f"The input dataset has no column named {name}"
|
|
68
|
+
return ds_handling_method(dataset, *args, **kwargs)
|
|
69
|
+
|
|
70
|
+
return wrapper
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class TFDataHandler(DataHandler):
|
|
74
|
+
"""
|
|
75
|
+
Class to manage tf.data.Dataset. The aim is to provide a simple interface for
|
|
76
|
+
working with tf.data.Datasets and manage them without having to use
|
|
77
|
+
tensorflow syntax.
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
def __init__(self) -> None:
|
|
81
|
+
"""
|
|
82
|
+
Initializes the TFDataHandler instance.
|
|
83
|
+
Attributes:
|
|
84
|
+
backend (str): The backend framework used, set to "tensorflow".
|
|
85
|
+
channel_order (str): The channel order format, set to "channels_last".
|
|
86
|
+
"""
|
|
87
|
+
super().__init__()
|
|
88
|
+
self.backend = "tensorflow"
|
|
89
|
+
self.channel_order = "channels_last"
|
|
90
|
+
|
|
91
|
+
@classmethod
|
|
92
|
+
def load_dataset(
|
|
93
|
+
cls,
|
|
94
|
+
dataset_id: Union[tf.data.Dataset, ItemType, str],
|
|
95
|
+
columns: Optional[list] = None,
|
|
96
|
+
hub: Optional[str] = "tensorflow-datasets",
|
|
97
|
+
load_kwargs: dict = {},
|
|
98
|
+
) -> tf.data.Dataset:
|
|
99
|
+
"""Load dataset from different manners, ensuring to return a dict based
|
|
100
|
+
tf.data.Dataset.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
dataset_id (Union[tf.data.Dataset, ItemType, str]): dataset identification.
|
|
104
|
+
Can be the name of a dataset from tensorflow_datasets, a tf.data.Dataset,
|
|
105
|
+
or a tuple/dict of np.ndarrays/tf.Tensors.
|
|
106
|
+
columns (list, optional): Column names. If None, assigned as "input_i"
|
|
107
|
+
for i-th column. Defaults to None.
|
|
108
|
+
hub (str, optional): The hub to load the dataset from. Can be either
|
|
109
|
+
"tensorflow-datasets" or "huggingface".
|
|
110
|
+
Defaults to "tensorflow-datasets".
|
|
111
|
+
load_kwargs (dict, optional): Additional args for loading from
|
|
112
|
+
tensorflow_datasets. Defaults to {}.
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
tf.data.Dataset: A dict based tf.data.Dataset
|
|
116
|
+
"""
|
|
117
|
+
|
|
118
|
+
assert hub in {
|
|
119
|
+
"tensorflow-datasets",
|
|
120
|
+
"huggingface",
|
|
121
|
+
}, "hub must be either 'tensorflow-datasets' or 'huggingface'"
|
|
122
|
+
|
|
123
|
+
if isinstance(dataset_id, get_args(ItemType)):
|
|
124
|
+
dataset = cls.load_dataset_from_arrays(dataset_id, columns)
|
|
125
|
+
elif isinstance(dataset_id, tf.data.Dataset):
|
|
126
|
+
dataset = cls.load_custom_dataset(dataset_id, columns)
|
|
127
|
+
elif isinstance(dataset_id, str):
|
|
128
|
+
if hub == "tensorflow-datasets":
|
|
129
|
+
load_kwargs["as_supervised"] = False
|
|
130
|
+
dataset = cls.load_from_tensorflow_datasets(dataset_id, load_kwargs)
|
|
131
|
+
elif hub == "huggingface":
|
|
132
|
+
dataset = cls.load_from_huggingface(dataset_id, load_kwargs)
|
|
133
|
+
return dataset
|
|
134
|
+
|
|
135
|
+
@staticmethod
|
|
136
|
+
def load_dataset_from_arrays(
|
|
137
|
+
dataset_id: ItemType, columns: Optional[list] = None
|
|
138
|
+
) -> tf.data.Dataset:
|
|
139
|
+
"""Load a tf.data.Dataset from a np.ndarray, a tf.Tensor or a tuple/dict
|
|
140
|
+
of np.ndarrays/tf.Tensors.
|
|
141
|
+
|
|
142
|
+
Args:
|
|
143
|
+
dataset_id (ItemType): numpy array(s) to load.
|
|
144
|
+
columns (list, optional): Column names to assign. If None,
|
|
145
|
+
assigned as "input_i" for i-th column. Defaults to None.
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
tf.data.Dataset
|
|
149
|
+
"""
|
|
150
|
+
# If dataset_id is a numpy array, convert it to a dict
|
|
151
|
+
if isinstance(dataset_id, get_args(TensorType)):
|
|
152
|
+
dataset_dict = {"input": dataset_id}
|
|
153
|
+
|
|
154
|
+
# If dataset_id is a tuple, convert it to a dict
|
|
155
|
+
elif isinstance(dataset_id, tuple):
|
|
156
|
+
len_elem = len(dataset_id)
|
|
157
|
+
if columns is None:
|
|
158
|
+
if len_elem == 2:
|
|
159
|
+
dataset_dict = {"input": dataset_id[0], "label": dataset_id[1]}
|
|
160
|
+
else:
|
|
161
|
+
dataset_dict = {
|
|
162
|
+
f"input_{i}": dataset_id[i] for i in range(len_elem - 1)
|
|
163
|
+
}
|
|
164
|
+
dataset_dict["label"] = dataset_id[-1]
|
|
165
|
+
print(
|
|
166
|
+
'Loading tf.data.Dataset with elems as dicts, assigning "input_i" '
|
|
167
|
+
'key to the i-th tuple dimension and "label" key to the last '
|
|
168
|
+
"tuple dimension."
|
|
169
|
+
)
|
|
170
|
+
else:
|
|
171
|
+
assert (
|
|
172
|
+
len(columns) == len_elem
|
|
173
|
+
), "Number of column names mismatch with the number of columns"
|
|
174
|
+
dataset_dict = {columns[i]: dataset_id[i] for i in range(len_elem)}
|
|
175
|
+
|
|
176
|
+
elif isinstance(dataset_id, dict):
|
|
177
|
+
if columns is not None:
|
|
178
|
+
len_elem = len(dataset_id)
|
|
179
|
+
assert (
|
|
180
|
+
len(columns) == len_elem
|
|
181
|
+
), "Number of column names mismatch with the number of columns"
|
|
182
|
+
original_columns = list(dataset_id.keys())
|
|
183
|
+
dataset_dict = {
|
|
184
|
+
columns[i]: dataset_id[original_columns[i]] for i in range(len_elem)
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
dataset = tf.data.Dataset.from_tensor_slices(dataset_dict)
|
|
188
|
+
return dataset
|
|
189
|
+
|
|
190
|
+
@classmethod
|
|
191
|
+
def load_custom_dataset(
|
|
192
|
+
cls, dataset_id: tf.data.Dataset, columns: Optional[list] = None
|
|
193
|
+
) -> tf.data.Dataset:
|
|
194
|
+
"""Load a custom Dataset by ensuring it has the correct format (dict-based)
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
dataset_id (tf.data.Dataset): tf.data.Dataset
|
|
198
|
+
columns (list, optional): Column names to use for elements if dataset_id is
|
|
199
|
+
tuple based. If None, assigned as "input_i"
|
|
200
|
+
for i-th column. Defaults to None.
|
|
201
|
+
|
|
202
|
+
Returns:
|
|
203
|
+
tf.data.Dataset
|
|
204
|
+
"""
|
|
205
|
+
# If dataset_id is a tuple based tf.data.dataset, convert it to a dict
|
|
206
|
+
if not isinstance(dataset_id.element_spec, dict):
|
|
207
|
+
len_elem = len(dataset_id.element_spec)
|
|
208
|
+
if columns is None:
|
|
209
|
+
print(
|
|
210
|
+
"Column name not found, assigning 'input_i' "
|
|
211
|
+
"key to the i-th tensor and 'label' key to the last"
|
|
212
|
+
)
|
|
213
|
+
if len_elem == 2:
|
|
214
|
+
columns = ["input", "label"]
|
|
215
|
+
else:
|
|
216
|
+
columns = [f"input_{i}" for i in range(len_elem)]
|
|
217
|
+
columns[-1] = "label"
|
|
218
|
+
else:
|
|
219
|
+
assert (
|
|
220
|
+
len(columns) == len_elem
|
|
221
|
+
), "Number of column names mismatch with the number of columns"
|
|
222
|
+
|
|
223
|
+
dataset_id = cls.tuple_to_dict(dataset_id, columns)
|
|
224
|
+
|
|
225
|
+
dataset = dataset_id
|
|
226
|
+
return dataset
|
|
227
|
+
|
|
228
|
+
@staticmethod
|
|
229
|
+
def load_from_huggingface(
|
|
230
|
+
dataset_id: str,
|
|
231
|
+
load_kwargs: dict = {},
|
|
232
|
+
) -> tf.data.Dataset:
|
|
233
|
+
"""Load a Dataset from the Hugging Face datasets catalog
|
|
234
|
+
|
|
235
|
+
Args:
|
|
236
|
+
dataset_id (str): Identifier of the dataset
|
|
237
|
+
load_kwargs (dict): Loading kwargs to add to the initialization
|
|
238
|
+
of the dataset.
|
|
239
|
+
|
|
240
|
+
Returns:
|
|
241
|
+
tf.data.Dataset: dataset
|
|
242
|
+
"""
|
|
243
|
+
dataset = hf_load_dataset(dataset_id, **load_kwargs)
|
|
244
|
+
dataset = dataset.to_tf_dataset()
|
|
245
|
+
return dataset
|
|
246
|
+
|
|
247
|
+
@staticmethod
|
|
248
|
+
def load_from_tensorflow_datasets(
|
|
249
|
+
dataset_id: str,
|
|
250
|
+
load_kwargs: dict = {},
|
|
251
|
+
) -> tf.data.Dataset:
|
|
252
|
+
"""Load a tf.data.Dataset from the tensorflow_datasets catalog
|
|
253
|
+
|
|
254
|
+
Args:
|
|
255
|
+
dataset_id (str): Identifier of the dataset
|
|
256
|
+
load_kwargs (dict, optional): Loading kwargs to add to tfds.load().
|
|
257
|
+
Defaults to {}.
|
|
258
|
+
|
|
259
|
+
Returns:
|
|
260
|
+
tf.data.Dataset
|
|
261
|
+
"""
|
|
262
|
+
assert (
|
|
263
|
+
dataset_id in tfds.list_builders()
|
|
264
|
+
), "Dataset not available on tensorflow datasets catalog"
|
|
265
|
+
dataset = tfds.load(dataset_id, **load_kwargs)
|
|
266
|
+
return dataset
|
|
267
|
+
|
|
268
|
+
@staticmethod
|
|
269
|
+
@dict_only_ds
|
|
270
|
+
def dict_to_tuple(
|
|
271
|
+
dataset: tf.data.Dataset, columns: Optional[list] = None
|
|
272
|
+
) -> tf.data.Dataset:
|
|
273
|
+
"""Turn a dict based tf.data.Dataset to a tuple based tf.data.Dataset
|
|
274
|
+
|
|
275
|
+
Args:
|
|
276
|
+
dataset (tf.data.Dataset): Dict based tf.data.Dataset
|
|
277
|
+
columns (list, optional): Columns to use for the tuples based
|
|
278
|
+
tf.data.Dataset. If None, takes all the columns. Defaults to None.
|
|
279
|
+
|
|
280
|
+
Returns:
|
|
281
|
+
tf.data.Dataset
|
|
282
|
+
"""
|
|
283
|
+
if columns is None:
|
|
284
|
+
columns = list(dataset.element_spec.keys())
|
|
285
|
+
dataset = dataset.map(lambda x: tuple(x[k] for k in columns))
|
|
286
|
+
return dataset
|
|
287
|
+
|
|
288
|
+
@staticmethod
|
|
289
|
+
def tuple_to_dict(dataset: tf.data.Dataset, columns: list) -> tf.data.Dataset:
|
|
290
|
+
"""Turn a tuple based tf.data.Dataset to a dict based tf.data.Dataset
|
|
291
|
+
|
|
292
|
+
Args:
|
|
293
|
+
dataset (tf.data.Dataset): Tuple based tf.data.Dataset
|
|
294
|
+
columns (list): Column names to use for the dict based tf.data.Dataset
|
|
295
|
+
|
|
296
|
+
Returns:
|
|
297
|
+
tf.data.Dataset
|
|
298
|
+
"""
|
|
299
|
+
assert isinstance(
|
|
300
|
+
dataset.element_spec, tuple
|
|
301
|
+
), "dataset elements must be tuples"
|
|
302
|
+
len_elem = len(dataset.element_spec)
|
|
303
|
+
assert len_elem == len(
|
|
304
|
+
columns
|
|
305
|
+
), "The number of columns must be equal to the number of tuple elements"
|
|
306
|
+
|
|
307
|
+
def tuple_to_dict(*inputs):
|
|
308
|
+
return {columns[i]: inputs[i] for i in range(len_elem)}
|
|
309
|
+
|
|
310
|
+
dataset = dataset.map(tuple_to_dict)
|
|
311
|
+
return dataset
|
|
312
|
+
|
|
313
|
+
@staticmethod
|
|
314
|
+
@dict_only_ds
|
|
315
|
+
def get_ds_column_names(dataset: tf.data.Dataset) -> list:
|
|
316
|
+
"""Get the column names of a tf.data.Dataset
|
|
317
|
+
|
|
318
|
+
Args:
|
|
319
|
+
dataset (tf.data.Dataset): tf.data.Dataset to get the column names from
|
|
320
|
+
|
|
321
|
+
Returns:
|
|
322
|
+
list: List of column names
|
|
323
|
+
"""
|
|
324
|
+
return list(dataset.element_spec.keys())
|
|
325
|
+
|
|
326
|
+
@staticmethod
|
|
327
|
+
def map_ds(
|
|
328
|
+
dataset: tf.data.Dataset,
|
|
329
|
+
map_fn: Callable,
|
|
330
|
+
num_parallel_calls: Optional[int] = None,
|
|
331
|
+
) -> tf.data.Dataset:
|
|
332
|
+
"""Map a function to a tf.data.Dataset
|
|
333
|
+
|
|
334
|
+
Args:
|
|
335
|
+
dataset (tf.data.Dataset): tf.data.Dataset to map the function to
|
|
336
|
+
map_fn (Callable): Function to map
|
|
337
|
+
num_parallel_calls (Optional[int], optional): Number of parallel processes
|
|
338
|
+
to use. Defaults to None.
|
|
339
|
+
|
|
340
|
+
Returns:
|
|
341
|
+
tf.data.Dataset: Maped dataset
|
|
342
|
+
"""
|
|
343
|
+
if num_parallel_calls is None:
|
|
344
|
+
num_parallel_calls = tf.data.experimental.AUTOTUNE
|
|
345
|
+
dataset = dataset.map(map_fn, num_parallel_calls=num_parallel_calls)
|
|
346
|
+
return dataset
|
|
347
|
+
|
|
348
|
+
@staticmethod
|
|
349
|
+
@dict_only_ds
|
|
350
|
+
def filter_by_value(
|
|
351
|
+
dataset: tf.data.Dataset,
|
|
352
|
+
column_name: str,
|
|
353
|
+
values: list,
|
|
354
|
+
excluded: bool = False,
|
|
355
|
+
) -> tf.data.Dataset:
|
|
356
|
+
"""Filter a tf.data.Dataset by checking if the value of a column is in 'values'
|
|
357
|
+
|
|
358
|
+
Args:
|
|
359
|
+
dataset (tf.data.Dataset): tf.data.Dataset to filter
|
|
360
|
+
column_name (str): Column to filter the dataset with
|
|
361
|
+
values (list): Column values to keep (if excluded is False)
|
|
362
|
+
or to exclude
|
|
363
|
+
excluded (bool, optional): To keep (False) or exclude (True) the samples
|
|
364
|
+
with Column values included in Values. Defaults to False.
|
|
365
|
+
|
|
366
|
+
Returns:
|
|
367
|
+
tf.data.Dataset: Filtered dataset
|
|
368
|
+
"""
|
|
369
|
+
# If the labels are one-hot encoded, prepare a function to get the label as int
|
|
370
|
+
if len(dataset.element_spec[column_name].shape) > 0:
|
|
371
|
+
|
|
372
|
+
def get_label_int(elem):
|
|
373
|
+
return int(tf.argmax(elem[column_name]))
|
|
374
|
+
|
|
375
|
+
else:
|
|
376
|
+
|
|
377
|
+
def get_label_int(elem):
|
|
378
|
+
return elem[column_name]
|
|
379
|
+
|
|
380
|
+
def filter_fn(elem):
|
|
381
|
+
value = get_label_int(elem)
|
|
382
|
+
if excluded:
|
|
383
|
+
return not tf.reduce_any(tf.equal(value, values))
|
|
384
|
+
else:
|
|
385
|
+
return tf.reduce_any(tf.equal(value, values))
|
|
386
|
+
|
|
387
|
+
dataset_to_filter = dataset
|
|
388
|
+
dataset_to_filter = dataset_to_filter.filter(filter_fn)
|
|
389
|
+
return dataset_to_filter
|
|
390
|
+
|
|
391
|
+
@classmethod
|
|
392
|
+
def prepare(
|
|
393
|
+
cls,
|
|
394
|
+
dataset: tf.data.Dataset,
|
|
395
|
+
batch_size: int,
|
|
396
|
+
preprocess_fn: Optional[Callable] = None,
|
|
397
|
+
augment_fn: Optional[Callable] = None,
|
|
398
|
+
columns: Optional[list] = None,
|
|
399
|
+
shuffle: bool = False,
|
|
400
|
+
dict_based_fns: bool = True,
|
|
401
|
+
return_tuple: bool = True,
|
|
402
|
+
shuffle_buffer_size: Optional[int] = None,
|
|
403
|
+
prefetch_buffer_size: Optional[int] = None,
|
|
404
|
+
drop_remainder: Optional[bool] = False,
|
|
405
|
+
) -> tf.data.Dataset:
|
|
406
|
+
"""Prepare a tf.data.Dataset for training
|
|
407
|
+
|
|
408
|
+
Args:
|
|
409
|
+
dataset (tf.data.Dataset): tf.data.Dataset to prepare
|
|
410
|
+
batch_size (int): Batch size
|
|
411
|
+
preprocess_fn (Callable, optional): Preprocessing function to apply to
|
|
412
|
+
the dataset. Defaults to None.
|
|
413
|
+
augment_fn (Callable, optional): Augment function to be used (when the
|
|
414
|
+
returned dataset is to be used for training). Defaults to None.
|
|
415
|
+
columns (list, optional): List of column names corresponding to the columns
|
|
416
|
+
that will be returned. Keep all columns if None. Defaults to None.
|
|
417
|
+
shuffle (bool, optional): To shuffle the returned dataset or not.
|
|
418
|
+
Defaults to False.
|
|
419
|
+
dict_based_fns (bool): Whether to use preprocess and DA functions as dict
|
|
420
|
+
based (if True) or as tuple based (if False). Defaults to True.
|
|
421
|
+
return_tuple (bool, optional): Whether to return each dataset item
|
|
422
|
+
as a tuple. Defaults to True.
|
|
423
|
+
shuffle_buffer_size (int, optional): Size of the shuffle buffer. If None,
|
|
424
|
+
taken as the number of samples in the dataset. Defaults to None.
|
|
425
|
+
prefetch_buffer_size (Optional[int], optional): Buffer size for prefetch.
|
|
426
|
+
If None, automatically chose using tf.data.experimental.AUTOTUNE.
|
|
427
|
+
Defaults to None.
|
|
428
|
+
drop_remainder (Optional[bool], optional): To drop the last batch when
|
|
429
|
+
its size is lower than batch_size. Defaults to False.
|
|
430
|
+
|
|
431
|
+
Returns:
|
|
432
|
+
tf.data.Dataset: Prepared dataset
|
|
433
|
+
"""
|
|
434
|
+
# dict based to tuple based
|
|
435
|
+
columns = columns or cls.get_ds_column_names(dataset)
|
|
436
|
+
if not dict_based_fns:
|
|
437
|
+
dataset = cls.dict_to_tuple(dataset, columns)
|
|
438
|
+
|
|
439
|
+
# preprocess + DA
|
|
440
|
+
if preprocess_fn is not None:
|
|
441
|
+
dataset = cls.map_ds(dataset, preprocess_fn)
|
|
442
|
+
if augment_fn is not None:
|
|
443
|
+
dataset = cls.map_ds(dataset, augment_fn)
|
|
444
|
+
|
|
445
|
+
if dict_based_fns and return_tuple:
|
|
446
|
+
dataset = cls.dict_to_tuple(dataset, columns)
|
|
447
|
+
|
|
448
|
+
dataset = dataset.cache()
|
|
449
|
+
|
|
450
|
+
# shuffle
|
|
451
|
+
if shuffle:
|
|
452
|
+
num_samples = cls.get_dataset_length(dataset)
|
|
453
|
+
shuffle_buffer_size = (
|
|
454
|
+
num_samples if shuffle_buffer_size is None else shuffle_buffer_size
|
|
455
|
+
)
|
|
456
|
+
dataset = dataset.shuffle(shuffle_buffer_size)
|
|
457
|
+
# batch
|
|
458
|
+
dataset = dataset.batch(batch_size, drop_remainder=drop_remainder)
|
|
459
|
+
# prefetch
|
|
460
|
+
if prefetch_buffer_size is not None:
|
|
461
|
+
prefetch_buffer_size = tf.data.experimental.AUTOTUNE
|
|
462
|
+
dataset = dataset.prefetch(prefetch_buffer_size)
|
|
463
|
+
return dataset
|
|
464
|
+
|
|
465
|
+
@staticmethod
|
|
466
|
+
def make_channel_first(input_key: str, dataset: tf.data.Dataset) -> tf.data.Dataset:
|
|
467
|
+
"""Make a tf.data.Dataset channel first. Make sure that the dataset is not
|
|
468
|
+
already Channel first. If so, the tensor will have the format
|
|
469
|
+
(batch_size, x_size, channel, y_size).
|
|
470
|
+
|
|
471
|
+
Args:
|
|
472
|
+
input_key (str): input key of the dict-based tf.data.Dataset
|
|
473
|
+
dataset (tf.data.Dataset): tf.data.Dataset to make channel first
|
|
474
|
+
|
|
475
|
+
Returns:
|
|
476
|
+
tf.data.Dataset: Channel first dataset
|
|
477
|
+
"""
|
|
478
|
+
|
|
479
|
+
def channel_first(x):
|
|
480
|
+
x[input_key] = tf.transpose(x[input_key], perm=[2, 0, 1])
|
|
481
|
+
return x
|
|
482
|
+
|
|
483
|
+
dataset = dataset.map(channel_first)
|
|
484
|
+
return dataset
|
|
485
|
+
|
|
486
|
+
@staticmethod
|
|
487
|
+
def get_item_length(dataset: tf.data.Dataset) -> int:
|
|
488
|
+
"""Get the length of a dataset element. If an element is a tensor, the length is
|
|
489
|
+
one and if it is a sequence (list or tuple), it is len(elem).
|
|
490
|
+
|
|
491
|
+
Args:
|
|
492
|
+
dataset (tf.data.Dataset): Dataset to process
|
|
493
|
+
|
|
494
|
+
Returns:
|
|
495
|
+
int: length of the dataset elems
|
|
496
|
+
"""
|
|
497
|
+
if isinstance(dataset.element_spec, (tuple, list, dict)):
|
|
498
|
+
return len(dataset.element_spec)
|
|
499
|
+
return 1
|
|
500
|
+
|
|
501
|
+
@staticmethod
|
|
502
|
+
def get_dataset_length(dataset: tf.data.Dataset) -> int:
|
|
503
|
+
"""Get the length of a dataset. Try to access it with len(), and if not
|
|
504
|
+
available, with a reduce op.
|
|
505
|
+
|
|
506
|
+
Args:
|
|
507
|
+
dataset (tf.data.Dataset): Dataset to process
|
|
508
|
+
|
|
509
|
+
Returns:
|
|
510
|
+
int: _description_
|
|
511
|
+
"""
|
|
512
|
+
try:
|
|
513
|
+
return len(dataset)
|
|
514
|
+
except TypeError:
|
|
515
|
+
cardinality = dataset.reduce(0, lambda x, _: x + 1)
|
|
516
|
+
return int(cardinality)
|
|
517
|
+
|
|
518
|
+
@staticmethod
|
|
519
|
+
def get_column_elements_shape(
|
|
520
|
+
dataset: tf.data.Dataset, column_name: Union[str, int]
|
|
521
|
+
) -> tuple:
|
|
522
|
+
"""Get the shape of the elements of a column of dataset identified by
|
|
523
|
+
column_name
|
|
524
|
+
|
|
525
|
+
Args:
|
|
526
|
+
dataset (tf.data.Dataset): a tf.data.dataset
|
|
527
|
+
column_name (Union[str, int]): The column name to get
|
|
528
|
+
the element shape from.
|
|
529
|
+
|
|
530
|
+
Returns:
|
|
531
|
+
tuple: the shape of an element from column_name
|
|
532
|
+
"""
|
|
533
|
+
return tuple(dataset.element_spec[column_name].shape)
|
|
534
|
+
|
|
535
|
+
@staticmethod
|
|
536
|
+
def get_columns_shapes(dataset: tf.data.Dataset) -> dict:
|
|
537
|
+
"""Get the shapes of the elements of all columns of a dataset
|
|
538
|
+
|
|
539
|
+
Args:
|
|
540
|
+
dataset (Dataset): a Dataset
|
|
541
|
+
|
|
542
|
+
Returns:
|
|
543
|
+
dict: dictionary of column names and their corresponding shape
|
|
544
|
+
"""
|
|
545
|
+
|
|
546
|
+
if isinstance(dataset.element_spec, tuple):
|
|
547
|
+
shapes = [None for _ in range(len(dataset.element_spec))]
|
|
548
|
+
for i in range(len(dataset.element_spec)):
|
|
549
|
+
try:
|
|
550
|
+
shapes[i] = tuple(dataset.element_spec[i].shape)
|
|
551
|
+
except AttributeError:
|
|
552
|
+
pass
|
|
553
|
+
shapes = tuple(shapes)
|
|
554
|
+
elif isinstance(dataset.element_spec, dict):
|
|
555
|
+
shapes = {}
|
|
556
|
+
for key in dataset.element_spec.keys():
|
|
557
|
+
try:
|
|
558
|
+
shapes[key] = tuple(dataset.element_spec[key].shape)
|
|
559
|
+
except AttributeError:
|
|
560
|
+
pass
|
|
561
|
+
return shapes
|
|
562
|
+
|
|
563
|
+
@staticmethod
|
|
564
|
+
def get_input_from_dataset_item(elem: ItemType) -> TensorType:
|
|
565
|
+
"""Get the tensor that is to be feed as input to a model from a dataset element.
|
|
566
|
+
|
|
567
|
+
Args:
|
|
568
|
+
elem (ItemType): dataset element to extract input from
|
|
569
|
+
|
|
570
|
+
Returns:
|
|
571
|
+
TensorType: Input tensor
|
|
572
|
+
"""
|
|
573
|
+
if isinstance(elem, (tuple, list)):
|
|
574
|
+
tensor = elem[0]
|
|
575
|
+
elif isinstance(elem, dict):
|
|
576
|
+
tensor = elem[list(elem.keys())[0]]
|
|
577
|
+
else:
|
|
578
|
+
tensor = elem
|
|
579
|
+
return tensor
|
|
580
|
+
|
|
581
|
+
@staticmethod
|
|
582
|
+
def get_label_from_dataset_item(item: ItemType):
|
|
583
|
+
"""Retrieve label tensor from item as a tuple/list. Label must be at index 1
|
|
584
|
+
in the item tuple. If one-hot encoded, labels are converted to single value.
|
|
585
|
+
|
|
586
|
+
Args:
|
|
587
|
+
elem (ItemType): dataset element to extract label from
|
|
588
|
+
|
|
589
|
+
Returns:
|
|
590
|
+
Any: Label tensor
|
|
591
|
+
"""
|
|
592
|
+
label = item[1] # labels must be at index 1 in the item tuple
|
|
593
|
+
# If labels are one-hot encoded, take the argmax
|
|
594
|
+
if tf.rank(label) > 1 and label.shape[1] > 1:
|
|
595
|
+
label = tf.reshape(label, shape=[label.shape[0], -1])
|
|
596
|
+
label = tf.argmax(label, axis=1)
|
|
597
|
+
# If labels are in two dimensions, squeeze them
|
|
598
|
+
if len(label.shape) > 1:
|
|
599
|
+
label = tf.reshape(label, [label.shape[0]])
|
|
600
|
+
return label
|