oodeel 0.1.1__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of oodeel might be problematic. Click here for more details.
- oodeel/__init__.py +1 -1
- oodeel/datasets/__init__.py +2 -1
- oodeel/datasets/data_handler.py +162 -94
- oodeel/datasets/deprecated/DEPRECATED_data_handler.py +236 -0
- oodeel/datasets/{ooddataset.py → deprecated/DEPRECATED_ooddataset.py} +14 -13
- oodeel/datasets/deprecated/DEPRECATED_tf_data_handler.py +671 -0
- oodeel/datasets/deprecated/DEPRECATED_torch_data_handler.py +769 -0
- oodeel/datasets/deprecated/__init__.py +31 -0
- oodeel/datasets/tf_data_handler.py +105 -167
- oodeel/datasets/torch_data_handler.py +109 -181
- oodeel/eval/metrics.py +7 -2
- oodeel/eval/plots/features.py +2 -2
- oodeel/eval/plots/plotly.py +2 -2
- oodeel/extractor/feature_extractor.py +30 -9
- oodeel/extractor/keras_feature_extractor.py +70 -13
- oodeel/extractor/torch_feature_extractor.py +120 -33
- oodeel/methods/__init__.py +17 -1
- oodeel/methods/base.py +103 -17
- oodeel/methods/dknn.py +22 -9
- oodeel/methods/energy.py +8 -0
- oodeel/methods/entropy.py +8 -0
- oodeel/methods/gen.py +118 -0
- oodeel/methods/gram.py +307 -0
- oodeel/methods/mahalanobis.py +14 -12
- oodeel/methods/mls.py +8 -0
- oodeel/methods/odin.py +8 -0
- oodeel/methods/rmds.py +122 -0
- oodeel/methods/she.py +197 -0
- oodeel/methods/vim.py +5 -5
- oodeel/preprocess/__init__.py +31 -0
- oodeel/preprocess/tf_preprocess.py +95 -0
- oodeel/preprocess/torch_preprocess.py +97 -0
- oodeel/utils/operator.py +72 -2
- oodeel/utils/tf_operator.py +72 -4
- oodeel/utils/tf_training_tools.py +26 -3
- oodeel/utils/torch_operator.py +75 -4
- oodeel/utils/torch_training_tools.py +31 -2
- {oodeel-0.1.1.dist-info → oodeel-0.3.0.dist-info}/METADATA +141 -107
- oodeel-0.3.0.dist-info/RECORD +57 -0
- {oodeel-0.1.1.dist-info → oodeel-0.3.0.dist-info}/WHEEL +1 -1
- tests/tests_tensorflow/tf_methods_utils.py +2 -1
- tests/tests_torch/tools_torch.py +9 -9
- tests/tests_torch/torch_methods_utils.py +34 -27
- tests/tools_operator.py +10 -1
- oodeel-0.1.1.dist-info/RECORD +0 -46
- {oodeel-0.1.1.dist-info → oodeel-0.3.0.dist-info/licenses}/LICENSE +0 -0
- {oodeel-0.1.1.dist-info → oodeel-0.3.0.dist-info}/top_level.txt +0 -0
oodeel/__init__.py
CHANGED
oodeel/datasets/__init__.py
CHANGED
|
@@ -20,4 +20,5 @@
|
|
|
20
20
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
21
21
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
22
22
|
# SOFTWARE.
|
|
23
|
-
from .
|
|
23
|
+
from .data_handler import load_data_handler
|
|
24
|
+
from .deprecated.DEPRECATED_ooddataset import OODDataset
|
oodeel/datasets/data_handler.py
CHANGED
|
@@ -20,6 +20,7 @@
|
|
|
20
20
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
21
21
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
22
22
|
# SOFTWARE.
|
|
23
|
+
import importlib.util
|
|
23
24
|
from abc import ABC
|
|
24
25
|
from abc import abstractmethod
|
|
25
26
|
|
|
@@ -29,10 +30,45 @@ from ..types import Callable
|
|
|
29
30
|
from ..types import DatasetType
|
|
30
31
|
from ..types import ItemType
|
|
31
32
|
from ..types import Optional
|
|
33
|
+
from ..types import TensorType
|
|
32
34
|
from ..types import Tuple
|
|
33
35
|
from ..types import Union
|
|
34
36
|
|
|
35
37
|
|
|
38
|
+
def get_backend():
|
|
39
|
+
"""Detects whether TensorFlow or PyTorch is available and returns
|
|
40
|
+
the preferred backend."""
|
|
41
|
+
available_backends = []
|
|
42
|
+
if importlib.util.find_spec("tensorflow"):
|
|
43
|
+
available_backends.append("tensorflow")
|
|
44
|
+
if importlib.util.find_spec("torch"):
|
|
45
|
+
available_backends.append("torch")
|
|
46
|
+
|
|
47
|
+
if len(available_backends) == 1:
|
|
48
|
+
return available_backends[0]
|
|
49
|
+
elif len(available_backends) == 0:
|
|
50
|
+
raise ImportError("Neither TensorFlow nor PyTorch is installed.")
|
|
51
|
+
else:
|
|
52
|
+
raise ImportError(
|
|
53
|
+
"Both TensorFlow and PyTorch are installed. Please specify the backend."
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def load_data_handler(backend: str = None):
|
|
58
|
+
if backend is None:
|
|
59
|
+
backend = get_backend()
|
|
60
|
+
|
|
61
|
+
if backend == "tensorflow":
|
|
62
|
+
from .tf_data_handler import TFDataHandler
|
|
63
|
+
|
|
64
|
+
return TFDataHandler()
|
|
65
|
+
|
|
66
|
+
elif backend == "torch":
|
|
67
|
+
from .torch_data_handler import TorchDataHandler
|
|
68
|
+
|
|
69
|
+
return TorchDataHandler()
|
|
70
|
+
|
|
71
|
+
|
|
36
72
|
class DataHandler(ABC):
|
|
37
73
|
"""
|
|
38
74
|
Class to manage Datasets. The aim is to provide a simple interface
|
|
@@ -40,82 +76,122 @@ class DataHandler(ABC):
|
|
|
40
76
|
having to use library-specific syntax.
|
|
41
77
|
"""
|
|
42
78
|
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
79
|
+
def __init__(self):
|
|
80
|
+
self.backend = None
|
|
81
|
+
self.channel_order = None
|
|
82
|
+
|
|
83
|
+
def split_by_class(
|
|
84
|
+
self,
|
|
85
|
+
dataset: DatasetType,
|
|
86
|
+
in_labels: Optional[Union[np.ndarray, list]] = None,
|
|
87
|
+
out_labels: Optional[Union[np.ndarray, list]] = None,
|
|
88
|
+
) -> Optional[Tuple[DatasetType]]:
|
|
89
|
+
"""Filter the dataset by assigning ood labels depending on labels
|
|
90
|
+
value (typically, class id).
|
|
52
91
|
|
|
53
92
|
Args:
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
93
|
+
in_labels (Optional[Union[np.ndarray, list]], optional): set of labels
|
|
94
|
+
to be considered as in-distribution. Defaults to None.
|
|
95
|
+
out_labels (Optional[Union[np.ndarray, list]], optional): set of labels
|
|
96
|
+
to be considered as out-of-distribution. Defaults to None.
|
|
58
97
|
|
|
59
98
|
Returns:
|
|
60
|
-
|
|
99
|
+
Optional[Tuple[OODDataset]]: Tuple of in-distribution and
|
|
100
|
+
out-of-distribution OODDatasets
|
|
61
101
|
"""
|
|
62
|
-
|
|
102
|
+
# Make sure the dataset has labels
|
|
103
|
+
assert (in_labels is not None) or (
|
|
104
|
+
out_labels is not None
|
|
105
|
+
), "specify labels to filter with"
|
|
106
|
+
assert self.get_item_length(dataset) >= 2, "the dataset has no labels"
|
|
63
107
|
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
) -> DatasetType:
|
|
69
|
-
"""Assign a value to a feature for every sample in a Dataset
|
|
108
|
+
# Filter the dataset depending on in_labels and out_labels given
|
|
109
|
+
if (out_labels is not None) and (in_labels is not None):
|
|
110
|
+
in_data = self.filter_by_value(dataset, "label", in_labels)
|
|
111
|
+
out_data = self.filter_by_value(dataset, "label", out_labels)
|
|
70
112
|
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
value (int): Value to assign
|
|
113
|
+
if out_labels is None:
|
|
114
|
+
in_data = self.filter_by_value(dataset, "label", in_labels)
|
|
115
|
+
out_data = self.filter_by_value(dataset, "label", in_labels, excluded=True)
|
|
75
116
|
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
raise NotImplementedError()
|
|
117
|
+
elif in_labels is None:
|
|
118
|
+
in_data = self.filter_by_value(dataset, "label", out_labels, excluded=True)
|
|
119
|
+
out_data = self.filter_by_value(dataset, "label", out_labels)
|
|
80
120
|
|
|
81
|
-
|
|
121
|
+
# Return the filtered OODDatasets
|
|
122
|
+
return in_data, out_data
|
|
123
|
+
|
|
124
|
+
@classmethod
|
|
82
125
|
@abstractmethod
|
|
83
|
-
def
|
|
84
|
-
|
|
126
|
+
def prepare(
|
|
127
|
+
cls,
|
|
128
|
+
dataset: DatasetType,
|
|
129
|
+
batch_size: int,
|
|
130
|
+
preprocess_fn: Optional[Callable] = None,
|
|
131
|
+
augment_fn: Optional[Callable] = None,
|
|
132
|
+
columns: Optional[list] = None,
|
|
133
|
+
shuffle: bool = False,
|
|
134
|
+
dict_based_fns: bool = True,
|
|
135
|
+
return_tuple: bool = True,
|
|
136
|
+
**kwargs_prepare,
|
|
137
|
+
) -> DatasetType:
|
|
138
|
+
"""Prepare dataset for scoring or training
|
|
85
139
|
|
|
86
140
|
Args:
|
|
87
|
-
|
|
88
|
-
|
|
141
|
+
batch_size (int): Batch size
|
|
142
|
+
preprocess_fn (Callable, optional): Preprocessing function to apply to
|
|
143
|
+
the dataset. Defaults to None.
|
|
144
|
+
augment_fn (Callable, optional): Augment function to be used (when the
|
|
145
|
+
returned dataset is to be used for training). Defaults to None.
|
|
146
|
+
columns (list, optional): List of columns
|
|
147
|
+
that will be returned. Keep all columns if None. Defaults to None.
|
|
148
|
+
shuffle (bool, optional): To shuffle the returned dataset or not.
|
|
149
|
+
Defaults to False.
|
|
150
|
+
dict_based_fns (bool): Whether to use preprocess and DA functions as dict
|
|
151
|
+
based (if True) or as tuple based (if False). Defaults to True.
|
|
152
|
+
return_tuple (bool, optional): Whether to return each dataset item
|
|
153
|
+
as a tuple. Defaults to True.
|
|
154
|
+
kwargs_prepare (dict): Additional parameters to be passed to the
|
|
155
|
+
data_handler for backend specific preparation.
|
|
156
|
+
|
|
89
157
|
|
|
90
158
|
Returns:
|
|
91
|
-
|
|
159
|
+
DatasetType: prepared dataset
|
|
92
160
|
"""
|
|
93
161
|
raise NotImplementedError()
|
|
94
162
|
|
|
95
163
|
@staticmethod
|
|
96
164
|
@abstractmethod
|
|
97
|
-
def
|
|
98
|
-
|
|
165
|
+
def load_dataset_from_arrays(
|
|
166
|
+
dataset_id: ItemType, columns: Optional[list] = None
|
|
167
|
+
) -> DatasetType:
|
|
168
|
+
"""Load a DatasetType from a np.ndarray / Tensor
|
|
99
169
|
|
|
100
170
|
Args:
|
|
101
|
-
|
|
171
|
+
dataset_id (ItemType): numpy array(s) to load.
|
|
172
|
+
columns (list, optional): Column names to assign. If None,
|
|
173
|
+
assigned as "input_i" for i-th column. Defaults to None.
|
|
102
174
|
|
|
103
175
|
Returns:
|
|
104
|
-
|
|
176
|
+
DatasetType
|
|
105
177
|
"""
|
|
106
178
|
raise NotImplementedError()
|
|
107
179
|
|
|
108
180
|
@staticmethod
|
|
109
181
|
@abstractmethod
|
|
110
|
-
def
|
|
111
|
-
|
|
182
|
+
def load_custom_dataset(
|
|
183
|
+
dataset_id: DatasetType, columns: Optional[list] = None
|
|
184
|
+
) -> DatasetType:
|
|
185
|
+
"""Load a custom dataset by ensuring it is properly formatted.
|
|
112
186
|
|
|
113
187
|
Args:
|
|
114
|
-
|
|
115
|
-
|
|
188
|
+
dataset_id (DatasetType): dataset
|
|
189
|
+
columns (list, optional): Column names to use for elements if dataset_id is
|
|
190
|
+
tuple based. If None, assigned as "input_i"
|
|
191
|
+
for i-th column. Defaults to None.
|
|
116
192
|
|
|
117
193
|
Returns:
|
|
118
|
-
|
|
194
|
+
A properly formatted dataset.
|
|
119
195
|
"""
|
|
120
196
|
raise NotImplementedError()
|
|
121
197
|
|
|
@@ -135,21 +211,21 @@ class DataHandler(ABC):
|
|
|
135
211
|
|
|
136
212
|
@staticmethod
|
|
137
213
|
@abstractmethod
|
|
138
|
-
def
|
|
214
|
+
def filter_by_value(
|
|
139
215
|
dataset: DatasetType,
|
|
140
|
-
|
|
216
|
+
column_name: str,
|
|
141
217
|
values: list,
|
|
142
218
|
excluded: bool = False,
|
|
143
219
|
) -> DatasetType:
|
|
144
|
-
"""Filter the dataset by checking the value of a
|
|
220
|
+
"""Filter the dataset by checking the value of a column is in `values`
|
|
145
221
|
|
|
146
222
|
Args:
|
|
147
223
|
dataset (Dataset): Dataset to filter
|
|
148
|
-
|
|
149
|
-
values (list):
|
|
224
|
+
column_name (str): Column to filter the dataset with
|
|
225
|
+
values (list): Column values to keep (if excluded is False)
|
|
150
226
|
or to exclude
|
|
151
227
|
excluded (bool, optional): To keep (False) or exclude (True) the samples
|
|
152
|
-
with
|
|
228
|
+
with column value included in Values. Defaults to False.
|
|
153
229
|
|
|
154
230
|
Returns:
|
|
155
231
|
DatasetType: Filtered dataset
|
|
@@ -158,79 +234,71 @@ class DataHandler(ABC):
|
|
|
158
234
|
|
|
159
235
|
@staticmethod
|
|
160
236
|
@abstractmethod
|
|
161
|
-
def
|
|
162
|
-
|
|
163
|
-
ood_dataset: DatasetType,
|
|
164
|
-
resize: Optional[bool] = False,
|
|
165
|
-
shape: Optional[Tuple[int]] = None,
|
|
166
|
-
) -> DatasetType:
|
|
167
|
-
"""Merge two datasets
|
|
237
|
+
def get_item_length(dataset: DatasetType) -> int:
|
|
238
|
+
"""Number of elements in a dataset item
|
|
168
239
|
|
|
169
240
|
Args:
|
|
170
|
-
|
|
171
|
-
ood_dataset (DictDataset): dataset of out-of-distribution data
|
|
172
|
-
resize (Optional[bool], optional): toggles if input tensors of the
|
|
173
|
-
datasets have to be resized to have the same shape. Defaults to True.
|
|
174
|
-
shape (Optional[Tuple[int]], optional): shape to use for resizing input
|
|
175
|
-
tensors. If None, the tensors are resized with the shape of the
|
|
176
|
-
id_dataset input tensors. Defaults to None.
|
|
241
|
+
dataset (DatasetType): Dataset
|
|
177
242
|
|
|
178
243
|
Returns:
|
|
179
|
-
|
|
244
|
+
int: Item length
|
|
180
245
|
"""
|
|
181
246
|
raise NotImplementedError()
|
|
182
247
|
|
|
183
|
-
@
|
|
248
|
+
@staticmethod
|
|
184
249
|
@abstractmethod
|
|
185
|
-
def
|
|
186
|
-
|
|
187
|
-
dataset: DatasetType,
|
|
188
|
-
batch_size: int,
|
|
189
|
-
shuffle: bool = False,
|
|
190
|
-
preprocess_fn: Optional[Callable] = None,
|
|
191
|
-
augment_fn: Optional[Callable] = None,
|
|
192
|
-
output_keys: list = ["input", "label"],
|
|
193
|
-
) -> DatasetType:
|
|
194
|
-
"""Prepare a dataset for training
|
|
250
|
+
def get_dataset_length(dataset: DatasetType) -> int:
|
|
251
|
+
"""Number of items in a dataset
|
|
195
252
|
|
|
196
253
|
Args:
|
|
197
|
-
dataset (
|
|
198
|
-
batch_size (int): Batch size
|
|
199
|
-
shuffle (bool): Wether to shuffle the dataloader or not
|
|
200
|
-
preprocess_fn (Callable, optional): Preprocessing function to apply to
|
|
201
|
-
the dataset. Defaults to None.
|
|
202
|
-
augment_fn (Callable, optional): Augment function to be used (when the
|
|
203
|
-
returned dataset is to be used for training). Defaults to None.
|
|
204
|
-
output_keys (list): List of keys corresponding to the features that will be
|
|
205
|
-
returned. Keep all features if None. Defaults to None.
|
|
254
|
+
dataset (DatasetType): Dataset
|
|
206
255
|
|
|
207
256
|
Returns:
|
|
208
|
-
|
|
257
|
+
int: Dataset length
|
|
209
258
|
"""
|
|
210
259
|
raise NotImplementedError()
|
|
211
260
|
|
|
212
261
|
@staticmethod
|
|
213
262
|
@abstractmethod
|
|
214
|
-
def
|
|
215
|
-
|
|
263
|
+
def get_column_elements_shape(
|
|
264
|
+
dataset: DatasetType, column_name: Union[str, int]
|
|
265
|
+
) -> tuple:
|
|
266
|
+
"""Get the shape of the elements of a column of dataset identified by
|
|
267
|
+
column_name
|
|
216
268
|
|
|
217
269
|
Args:
|
|
218
|
-
dataset (
|
|
270
|
+
dataset (Dataset): a Dataset
|
|
271
|
+
column_name (Union[str, int]): The column name to get
|
|
272
|
+
the element shape from.
|
|
219
273
|
|
|
220
274
|
Returns:
|
|
221
|
-
|
|
275
|
+
tuple: the shape of an element from column_name
|
|
222
276
|
"""
|
|
223
277
|
raise NotImplementedError()
|
|
224
278
|
|
|
225
279
|
@staticmethod
|
|
226
280
|
@abstractmethod
|
|
227
|
-
def
|
|
228
|
-
"""
|
|
281
|
+
def get_input_from_dataset_item(elem: ItemType) -> TensorType:
|
|
282
|
+
"""Get the tensor that is to be feed as input to a model from a dataset element.
|
|
229
283
|
|
|
230
284
|
Args:
|
|
231
|
-
|
|
285
|
+
elem (ItemType): dataset element to extract input from
|
|
232
286
|
|
|
233
287
|
Returns:
|
|
234
|
-
|
|
288
|
+
TensorType: Input tensor
|
|
289
|
+
"""
|
|
290
|
+
raise NotImplementedError()
|
|
291
|
+
|
|
292
|
+
@staticmethod
|
|
293
|
+
@abstractmethod
|
|
294
|
+
def get_label_from_dataset_item(item: ItemType):
|
|
295
|
+
"""Retrieve label tensor from item as a tuple/list. Label must be at index 1
|
|
296
|
+
in the item tuple. If one-hot encoded, labels are converted to single value.
|
|
297
|
+
|
|
298
|
+
Args:
|
|
299
|
+
elem (ItemType): dataset element to extract label from
|
|
300
|
+
|
|
301
|
+
Returns:
|
|
302
|
+
Any: Label tensor
|
|
235
303
|
"""
|
|
236
304
|
raise NotImplementedError()
|
|
@@ -0,0 +1,236 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Copyright IRT Antoine de Saint Exupéry et Université Paul Sabatier Toulouse III - All
|
|
3
|
+
# rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
|
|
4
|
+
# CRIAQ and ANITI - https://www.deel.ai/
|
|
5
|
+
#
|
|
6
|
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
7
|
+
# of this software and associated documentation files (the "Software"), to deal
|
|
8
|
+
# in the Software without restriction, including without limitation the rights
|
|
9
|
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
10
|
+
# copies of the Software, and to permit persons to whom the Software is
|
|
11
|
+
# furnished to do so, subject to the following conditions:
|
|
12
|
+
#
|
|
13
|
+
# The above copyright notice and this permission notice shall be included in all
|
|
14
|
+
# copies or substantial portions of the Software.
|
|
15
|
+
#
|
|
16
|
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
17
|
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
18
|
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
19
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
20
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
21
|
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
22
|
+
# SOFTWARE.
|
|
23
|
+
from abc import ABC
|
|
24
|
+
from abc import abstractmethod
|
|
25
|
+
|
|
26
|
+
import numpy as np
|
|
27
|
+
|
|
28
|
+
from ...types import Callable
|
|
29
|
+
from ...types import DatasetType
|
|
30
|
+
from ...types import ItemType
|
|
31
|
+
from ...types import Optional
|
|
32
|
+
from ...types import Tuple
|
|
33
|
+
from ...types import Union
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class DataHandler(ABC):
|
|
37
|
+
"""
|
|
38
|
+
Class to manage Datasets. The aim is to provide a simple interface
|
|
39
|
+
for working with datasets (torch, tensorflow or other...) and manage them without
|
|
40
|
+
having to use library-specific syntax.
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
@classmethod
|
|
44
|
+
@abstractmethod
|
|
45
|
+
def load_dataset(
|
|
46
|
+
cls,
|
|
47
|
+
dataset_id: Union[ItemType, DatasetType, str],
|
|
48
|
+
keys: Optional[list] = None,
|
|
49
|
+
load_kwargs: dict = {},
|
|
50
|
+
) -> DatasetType:
|
|
51
|
+
"""Load dataset from different manners
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
dataset_id (Union[ItemType, DatasetType, str]): dataset identification
|
|
55
|
+
keys (list, optional): Features keys. If None, assigned as "input_i"
|
|
56
|
+
for i-th feature. Defaults to None.
|
|
57
|
+
load_kwargs (dict, optional): Additional loading kwargs. Defaults to {}.
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
DatasetType: dataset
|
|
61
|
+
"""
|
|
62
|
+
raise NotImplementedError()
|
|
63
|
+
|
|
64
|
+
@staticmethod
|
|
65
|
+
@abstractmethod
|
|
66
|
+
def assign_feature_value(
|
|
67
|
+
dataset: DatasetType, feature_key: str, value: int
|
|
68
|
+
) -> DatasetType:
|
|
69
|
+
"""Assign a value to a feature for every sample in a Dataset
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
dataset (DatasetType): Dataset to assign the value to
|
|
73
|
+
feature_key (str): Feature to assign the value to
|
|
74
|
+
value (int): Value to assign
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
DatasetType: updated dataset
|
|
78
|
+
"""
|
|
79
|
+
raise NotImplementedError()
|
|
80
|
+
|
|
81
|
+
@staticmethod
|
|
82
|
+
@abstractmethod
|
|
83
|
+
def get_feature_from_ds(dataset: DatasetType, feature_key: str) -> np.ndarray:
|
|
84
|
+
"""Get a feature from a Dataset
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
dataset (DatasetType): Dataset to get the feature from
|
|
88
|
+
feature_key (str): Feature value to get
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
np.ndarray: Feature values for dataset
|
|
92
|
+
"""
|
|
93
|
+
raise NotImplementedError()
|
|
94
|
+
|
|
95
|
+
@staticmethod
|
|
96
|
+
@abstractmethod
|
|
97
|
+
def get_ds_feature_keys(dataset: DatasetType) -> list:
|
|
98
|
+
"""Get the feature keys of a Dataset
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
dataset (Dataset): Dataset to get the feature keys from
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
list: List of feature keys
|
|
105
|
+
"""
|
|
106
|
+
raise NotImplementedError()
|
|
107
|
+
|
|
108
|
+
@staticmethod
|
|
109
|
+
@abstractmethod
|
|
110
|
+
def has_feature_key(dataset: DatasetType, key: str) -> bool:
|
|
111
|
+
"""Check if a Dataset has a feature denoted by key
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
dataset (DatasetType): Dataset to check
|
|
115
|
+
key (str): Key to check
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
bool: If the dataset has a feature denoted by key
|
|
119
|
+
"""
|
|
120
|
+
raise NotImplementedError()
|
|
121
|
+
|
|
122
|
+
@staticmethod
|
|
123
|
+
@abstractmethod
|
|
124
|
+
def map_ds(dataset: DatasetType, map_fn: Callable) -> DatasetType:
|
|
125
|
+
"""Map a function to a Dataset
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
dataset (DatasetType): Dataset to map the function to
|
|
129
|
+
map_fn (Callable): Function to map
|
|
130
|
+
|
|
131
|
+
Returns:
|
|
132
|
+
DatasetType: Mapped dataset
|
|
133
|
+
"""
|
|
134
|
+
raise NotImplementedError()
|
|
135
|
+
|
|
136
|
+
@staticmethod
|
|
137
|
+
@abstractmethod
|
|
138
|
+
def filter_by_feature_value(
|
|
139
|
+
dataset: DatasetType,
|
|
140
|
+
feature_key: str,
|
|
141
|
+
values: list,
|
|
142
|
+
excluded: bool = False,
|
|
143
|
+
) -> DatasetType:
|
|
144
|
+
"""Filter the dataset by checking the value of a feature is in `values`
|
|
145
|
+
|
|
146
|
+
Args:
|
|
147
|
+
dataset (Dataset): Dataset to filter
|
|
148
|
+
feature_key (str): Feature name to check the value
|
|
149
|
+
values (list): Feature_key values to keep (if excluded is False)
|
|
150
|
+
or to exclude
|
|
151
|
+
excluded (bool, optional): To keep (False) or exclude (True) the samples
|
|
152
|
+
with Feature_key value included in Values. Defaults to False.
|
|
153
|
+
|
|
154
|
+
Returns:
|
|
155
|
+
DatasetType: Filtered dataset
|
|
156
|
+
"""
|
|
157
|
+
raise NotImplementedError()
|
|
158
|
+
|
|
159
|
+
@staticmethod
|
|
160
|
+
@abstractmethod
|
|
161
|
+
def merge(
|
|
162
|
+
id_dataset: DatasetType,
|
|
163
|
+
ood_dataset: DatasetType,
|
|
164
|
+
resize: Optional[bool] = False,
|
|
165
|
+
shape: Optional[Tuple[int]] = None,
|
|
166
|
+
) -> DatasetType:
|
|
167
|
+
"""Merge two datasets
|
|
168
|
+
|
|
169
|
+
Args:
|
|
170
|
+
id_dataset (Dataset): dataset of in-distribution data
|
|
171
|
+
ood_dataset (DictDataset): dataset of out-of-distribution data
|
|
172
|
+
resize (Optional[bool], optional): toggles if input tensors of the
|
|
173
|
+
datasets have to be resized to have the same shape. Defaults to True.
|
|
174
|
+
shape (Optional[Tuple[int]], optional): shape to use for resizing input
|
|
175
|
+
tensors. If None, the tensors are resized with the shape of the
|
|
176
|
+
id_dataset input tensors. Defaults to None.
|
|
177
|
+
|
|
178
|
+
Returns:
|
|
179
|
+
DatasetType: merged dataset
|
|
180
|
+
"""
|
|
181
|
+
raise NotImplementedError()
|
|
182
|
+
|
|
183
|
+
@classmethod
|
|
184
|
+
@abstractmethod
|
|
185
|
+
def prepare_for_training(
|
|
186
|
+
cls,
|
|
187
|
+
dataset: DatasetType,
|
|
188
|
+
batch_size: int,
|
|
189
|
+
shuffle: bool = False,
|
|
190
|
+
preprocess_fn: Optional[Callable] = None,
|
|
191
|
+
augment_fn: Optional[Callable] = None,
|
|
192
|
+
output_keys: list = ["input", "label"],
|
|
193
|
+
) -> DatasetType:
|
|
194
|
+
"""Prepare a dataset for training
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
dataset (DictDataset): Dataset to prepare
|
|
198
|
+
batch_size (int): Batch size
|
|
199
|
+
shuffle (bool): Wether to shuffle the dataloader or not
|
|
200
|
+
preprocess_fn (Callable, optional): Preprocessing function to apply to
|
|
201
|
+
the dataset. Defaults to None.
|
|
202
|
+
augment_fn (Callable, optional): Augment function to be used (when the
|
|
203
|
+
returned dataset is to be used for training). Defaults to None.
|
|
204
|
+
output_keys (list): List of keys corresponding to the features that will be
|
|
205
|
+
returned. Keep all features if None. Defaults to None.
|
|
206
|
+
|
|
207
|
+
Returns:
|
|
208
|
+
DatasetType: prepared dataset / dataloader
|
|
209
|
+
"""
|
|
210
|
+
raise NotImplementedError()
|
|
211
|
+
|
|
212
|
+
@staticmethod
|
|
213
|
+
@abstractmethod
|
|
214
|
+
def get_item_length(dataset: DatasetType) -> int:
|
|
215
|
+
"""Number of elements in a dataset item
|
|
216
|
+
|
|
217
|
+
Args:
|
|
218
|
+
dataset (DatasetType): Dataset
|
|
219
|
+
|
|
220
|
+
Returns:
|
|
221
|
+
int: Item length
|
|
222
|
+
"""
|
|
223
|
+
raise NotImplementedError()
|
|
224
|
+
|
|
225
|
+
@staticmethod
|
|
226
|
+
@abstractmethod
|
|
227
|
+
def get_dataset_length(dataset: DatasetType) -> int:
|
|
228
|
+
"""Number of items in a dataset
|
|
229
|
+
|
|
230
|
+
Args:
|
|
231
|
+
dataset (DatasetType): Dataset
|
|
232
|
+
|
|
233
|
+
Returns:
|
|
234
|
+
int: Dataset length
|
|
235
|
+
"""
|
|
236
|
+
raise NotImplementedError()
|