oodeel 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of oodeel might be problematic. Click here for more details.

Files changed (42) hide show
  1. oodeel/__init__.py +1 -1
  2. oodeel/datasets/__init__.py +2 -1
  3. oodeel/datasets/data_handler.py +162 -94
  4. oodeel/datasets/deprecated/DEPRECATED_data_handler.py +236 -0
  5. oodeel/datasets/{ooddataset.py → deprecated/DEPRECATED_ooddataset.py} +14 -13
  6. oodeel/datasets/deprecated/DEPRECATED_tf_data_handler.py +671 -0
  7. oodeel/datasets/deprecated/DEPRECATED_torch_data_handler.py +769 -0
  8. oodeel/datasets/deprecated/__init__.py +31 -0
  9. oodeel/datasets/tf_data_handler.py +105 -167
  10. oodeel/datasets/torch_data_handler.py +109 -181
  11. oodeel/eval/metrics.py +7 -2
  12. oodeel/extractor/feature_extractor.py +11 -0
  13. oodeel/extractor/keras_feature_extractor.py +51 -1
  14. oodeel/extractor/torch_feature_extractor.py +103 -21
  15. oodeel/methods/__init__.py +16 -1
  16. oodeel/methods/base.py +72 -15
  17. oodeel/methods/dknn.py +20 -7
  18. oodeel/methods/energy.py +8 -0
  19. oodeel/methods/entropy.py +8 -0
  20. oodeel/methods/gen.py +118 -0
  21. oodeel/methods/gram.py +15 -4
  22. oodeel/methods/mahalanobis.py +9 -7
  23. oodeel/methods/mls.py +8 -0
  24. oodeel/methods/odin.py +8 -0
  25. oodeel/methods/rmds.py +122 -0
  26. oodeel/methods/she.py +197 -0
  27. oodeel/methods/vim.py +1 -1
  28. oodeel/preprocess/__init__.py +31 -0
  29. oodeel/preprocess/tf_preprocess.py +95 -0
  30. oodeel/preprocess/torch_preprocess.py +97 -0
  31. oodeel/utils/operator.py +17 -0
  32. oodeel/utils/tf_operator.py +15 -0
  33. oodeel/utils/tf_training_tools.py +2 -2
  34. oodeel/utils/torch_operator.py +19 -0
  35. {oodeel-0.2.0.dist-info → oodeel-0.3.0.dist-info}/METADATA +139 -105
  36. oodeel-0.3.0.dist-info/RECORD +57 -0
  37. {oodeel-0.2.0.dist-info → oodeel-0.3.0.dist-info}/WHEEL +1 -1
  38. tests/tests_tensorflow/tf_methods_utils.py +2 -1
  39. tests/tests_torch/torch_methods_utils.py +34 -27
  40. oodeel-0.2.0.dist-info/RECORD +0 -47
  41. {oodeel-0.2.0.dist-info → oodeel-0.3.0.dist-info/licenses}/LICENSE +0 -0
  42. {oodeel-0.2.0.dist-info → oodeel-0.3.0.dist-info}/top_level.txt +0 -0
oodeel/__init__.py CHANGED
@@ -25,4 +25,4 @@ oodeel
25
25
  -------
26
26
  """
27
27
 
28
- __version__ = "0.2.0"
28
+ __version__ = "0.3.0"
@@ -20,4 +20,5 @@
20
20
  # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
21
  # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
22
  # SOFTWARE.
23
- from .ooddataset import OODDataset
23
+ from .data_handler import load_data_handler
24
+ from .deprecated.DEPRECATED_ooddataset import OODDataset
@@ -20,6 +20,7 @@
20
20
  # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
21
  # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
22
  # SOFTWARE.
23
+ import importlib.util
23
24
  from abc import ABC
24
25
  from abc import abstractmethod
25
26
 
@@ -29,10 +30,45 @@ from ..types import Callable
29
30
  from ..types import DatasetType
30
31
  from ..types import ItemType
31
32
  from ..types import Optional
33
+ from ..types import TensorType
32
34
  from ..types import Tuple
33
35
  from ..types import Union
34
36
 
35
37
 
38
+ def get_backend():
39
+ """Detects whether TensorFlow or PyTorch is available and returns
40
+ the preferred backend."""
41
+ available_backends = []
42
+ if importlib.util.find_spec("tensorflow"):
43
+ available_backends.append("tensorflow")
44
+ if importlib.util.find_spec("torch"):
45
+ available_backends.append("torch")
46
+
47
+ if len(available_backends) == 1:
48
+ return available_backends[0]
49
+ elif len(available_backends) == 0:
50
+ raise ImportError("Neither TensorFlow nor PyTorch is installed.")
51
+ else:
52
+ raise ImportError(
53
+ "Both TensorFlow and PyTorch are installed. Please specify the backend."
54
+ )
55
+
56
+
57
+ def load_data_handler(backend: str = None):
58
+ if backend is None:
59
+ backend = get_backend()
60
+
61
+ if backend == "tensorflow":
62
+ from .tf_data_handler import TFDataHandler
63
+
64
+ return TFDataHandler()
65
+
66
+ elif backend == "torch":
67
+ from .torch_data_handler import TorchDataHandler
68
+
69
+ return TorchDataHandler()
70
+
71
+
36
72
  class DataHandler(ABC):
37
73
  """
38
74
  Class to manage Datasets. The aim is to provide a simple interface
@@ -40,82 +76,122 @@ class DataHandler(ABC):
40
76
  having to use library-specific syntax.
41
77
  """
42
78
 
43
- @classmethod
44
- @abstractmethod
45
- def load_dataset(
46
- cls,
47
- dataset_id: Union[ItemType, DatasetType, str],
48
- keys: Optional[list] = None,
49
- load_kwargs: dict = {},
50
- ) -> DatasetType:
51
- """Load dataset from different manners
79
+ def __init__(self):
80
+ self.backend = None
81
+ self.channel_order = None
82
+
83
+ def split_by_class(
84
+ self,
85
+ dataset: DatasetType,
86
+ in_labels: Optional[Union[np.ndarray, list]] = None,
87
+ out_labels: Optional[Union[np.ndarray, list]] = None,
88
+ ) -> Optional[Tuple[DatasetType]]:
89
+ """Filter the dataset by assigning ood labels depending on labels
90
+ value (typically, class id).
52
91
 
53
92
  Args:
54
- dataset_id (Union[ItemType, DatasetType, str]): dataset identification
55
- keys (list, optional): Features keys. If None, assigned as "input_i"
56
- for i-th feature. Defaults to None.
57
- load_kwargs (dict, optional): Additional loading kwargs. Defaults to {}.
93
+ in_labels (Optional[Union[np.ndarray, list]], optional): set of labels
94
+ to be considered as in-distribution. Defaults to None.
95
+ out_labels (Optional[Union[np.ndarray, list]], optional): set of labels
96
+ to be considered as out-of-distribution. Defaults to None.
58
97
 
59
98
  Returns:
60
- DatasetType: dataset
99
+ Optional[Tuple[OODDataset]]: Tuple of in-distribution and
100
+ out-of-distribution OODDatasets
61
101
  """
62
- raise NotImplementedError()
102
+ # Make sure the dataset has labels
103
+ assert (in_labels is not None) or (
104
+ out_labels is not None
105
+ ), "specify labels to filter with"
106
+ assert self.get_item_length(dataset) >= 2, "the dataset has no labels"
63
107
 
64
- @staticmethod
65
- @abstractmethod
66
- def assign_feature_value(
67
- dataset: DatasetType, feature_key: str, value: int
68
- ) -> DatasetType:
69
- """Assign a value to a feature for every sample in a Dataset
108
+ # Filter the dataset depending on in_labels and out_labels given
109
+ if (out_labels is not None) and (in_labels is not None):
110
+ in_data = self.filter_by_value(dataset, "label", in_labels)
111
+ out_data = self.filter_by_value(dataset, "label", out_labels)
70
112
 
71
- Args:
72
- dataset (DatasetType): Dataset to assign the value to
73
- feature_key (str): Feature to assign the value to
74
- value (int): Value to assign
113
+ if out_labels is None:
114
+ in_data = self.filter_by_value(dataset, "label", in_labels)
115
+ out_data = self.filter_by_value(dataset, "label", in_labels, excluded=True)
75
116
 
76
- Returns:
77
- DatasetType: updated dataset
78
- """
79
- raise NotImplementedError()
117
+ elif in_labels is None:
118
+ in_data = self.filter_by_value(dataset, "label", out_labels, excluded=True)
119
+ out_data = self.filter_by_value(dataset, "label", out_labels)
80
120
 
81
- @staticmethod
121
+ # Return the filtered OODDatasets
122
+ return in_data, out_data
123
+
124
+ @classmethod
82
125
  @abstractmethod
83
- def get_feature_from_ds(dataset: DatasetType, feature_key: str) -> np.ndarray:
84
- """Get a feature from a Dataset
126
+ def prepare(
127
+ cls,
128
+ dataset: DatasetType,
129
+ batch_size: int,
130
+ preprocess_fn: Optional[Callable] = None,
131
+ augment_fn: Optional[Callable] = None,
132
+ columns: Optional[list] = None,
133
+ shuffle: bool = False,
134
+ dict_based_fns: bool = True,
135
+ return_tuple: bool = True,
136
+ **kwargs_prepare,
137
+ ) -> DatasetType:
138
+ """Prepare dataset for scoring or training
85
139
 
86
140
  Args:
87
- dataset (DatasetType): Dataset to get the feature from
88
- feature_key (str): Feature value to get
141
+ batch_size (int): Batch size
142
+ preprocess_fn (Callable, optional): Preprocessing function to apply to
143
+ the dataset. Defaults to None.
144
+ augment_fn (Callable, optional): Augment function to be used (when the
145
+ returned dataset is to be used for training). Defaults to None.
146
+ columns (list, optional): List of columns
147
+ that will be returned. Keep all columns if None. Defaults to None.
148
+ shuffle (bool, optional): To shuffle the returned dataset or not.
149
+ Defaults to False.
150
+ dict_based_fns (bool): Whether to use preprocess and DA functions as dict
151
+ based (if True) or as tuple based (if False). Defaults to True.
152
+ return_tuple (bool, optional): Whether to return each dataset item
153
+ as a tuple. Defaults to True.
154
+ kwargs_prepare (dict): Additional parameters to be passed to the
155
+ data_handler for backend specific preparation.
156
+
89
157
 
90
158
  Returns:
91
- np.ndarray: Feature values for dataset
159
+ DatasetType: prepared dataset
92
160
  """
93
161
  raise NotImplementedError()
94
162
 
95
163
  @staticmethod
96
164
  @abstractmethod
97
- def get_ds_feature_keys(dataset: DatasetType) -> list:
98
- """Get the feature keys of a Dataset
165
+ def load_dataset_from_arrays(
166
+ dataset_id: ItemType, columns: Optional[list] = None
167
+ ) -> DatasetType:
168
+ """Load a DatasetType from a np.ndarray / Tensor
99
169
 
100
170
  Args:
101
- dataset (Dataset): Dataset to get the feature keys from
171
+ dataset_id (ItemType): numpy array(s) to load.
172
+ columns (list, optional): Column names to assign. If None,
173
+ assigned as "input_i" for i-th column. Defaults to None.
102
174
 
103
175
  Returns:
104
- list: List of feature keys
176
+ DatasetType
105
177
  """
106
178
  raise NotImplementedError()
107
179
 
108
180
  @staticmethod
109
181
  @abstractmethod
110
- def has_feature_key(dataset: DatasetType, key: str) -> bool:
111
- """Check if a Dataset has a feature denoted by key
182
+ def load_custom_dataset(
183
+ dataset_id: DatasetType, columns: Optional[list] = None
184
+ ) -> DatasetType:
185
+ """Load a custom dataset by ensuring it is properly formatted.
112
186
 
113
187
  Args:
114
- dataset (DatasetType): Dataset to check
115
- key (str): Key to check
188
+ dataset_id (DatasetType): dataset
189
+ columns (list, optional): Column names to use for elements if dataset_id is
190
+ tuple based. If None, assigned as "input_i"
191
+ for i-th column. Defaults to None.
116
192
 
117
193
  Returns:
118
- bool: If the dataset has a feature denoted by key
194
+ A properly formatted dataset.
119
195
  """
120
196
  raise NotImplementedError()
121
197
 
@@ -135,21 +211,21 @@ class DataHandler(ABC):
135
211
 
136
212
  @staticmethod
137
213
  @abstractmethod
138
- def filter_by_feature_value(
214
+ def filter_by_value(
139
215
  dataset: DatasetType,
140
- feature_key: str,
216
+ column_name: str,
141
217
  values: list,
142
218
  excluded: bool = False,
143
219
  ) -> DatasetType:
144
- """Filter the dataset by checking the value of a feature is in `values`
220
+ """Filter the dataset by checking the value of a column is in `values`
145
221
 
146
222
  Args:
147
223
  dataset (Dataset): Dataset to filter
148
- feature_key (str): Feature name to check the value
149
- values (list): Feature_key values to keep (if excluded is False)
224
+ column_name (str): Column to filter the dataset with
225
+ values (list): Column values to keep (if excluded is False)
150
226
  or to exclude
151
227
  excluded (bool, optional): To keep (False) or exclude (True) the samples
152
- with Feature_key value included in Values. Defaults to False.
228
+ with column value included in Values. Defaults to False.
153
229
 
154
230
  Returns:
155
231
  DatasetType: Filtered dataset
@@ -158,79 +234,71 @@ class DataHandler(ABC):
158
234
 
159
235
  @staticmethod
160
236
  @abstractmethod
161
- def merge(
162
- id_dataset: DatasetType,
163
- ood_dataset: DatasetType,
164
- resize: Optional[bool] = False,
165
- shape: Optional[Tuple[int]] = None,
166
- ) -> DatasetType:
167
- """Merge two datasets
237
+ def get_item_length(dataset: DatasetType) -> int:
238
+ """Number of elements in a dataset item
168
239
 
169
240
  Args:
170
- id_dataset (Dataset): dataset of in-distribution data
171
- ood_dataset (DictDataset): dataset of out-of-distribution data
172
- resize (Optional[bool], optional): toggles if input tensors of the
173
- datasets have to be resized to have the same shape. Defaults to True.
174
- shape (Optional[Tuple[int]], optional): shape to use for resizing input
175
- tensors. If None, the tensors are resized with the shape of the
176
- id_dataset input tensors. Defaults to None.
241
+ dataset (DatasetType): Dataset
177
242
 
178
243
  Returns:
179
- DatasetType: merged dataset
244
+ int: Item length
180
245
  """
181
246
  raise NotImplementedError()
182
247
 
183
- @classmethod
248
+ @staticmethod
184
249
  @abstractmethod
185
- def prepare_for_training(
186
- cls,
187
- dataset: DatasetType,
188
- batch_size: int,
189
- shuffle: bool = False,
190
- preprocess_fn: Optional[Callable] = None,
191
- augment_fn: Optional[Callable] = None,
192
- output_keys: list = ["input", "label"],
193
- ) -> DatasetType:
194
- """Prepare a dataset for training
250
+ def get_dataset_length(dataset: DatasetType) -> int:
251
+ """Number of items in a dataset
195
252
 
196
253
  Args:
197
- dataset (DictDataset): Dataset to prepare
198
- batch_size (int): Batch size
199
- shuffle (bool): Wether to shuffle the dataloader or not
200
- preprocess_fn (Callable, optional): Preprocessing function to apply to
201
- the dataset. Defaults to None.
202
- augment_fn (Callable, optional): Augment function to be used (when the
203
- returned dataset is to be used for training). Defaults to None.
204
- output_keys (list): List of keys corresponding to the features that will be
205
- returned. Keep all features if None. Defaults to None.
254
+ dataset (DatasetType): Dataset
206
255
 
207
256
  Returns:
208
- DatasetType: prepared dataset / dataloader
257
+ int: Dataset length
209
258
  """
210
259
  raise NotImplementedError()
211
260
 
212
261
  @staticmethod
213
262
  @abstractmethod
214
- def get_item_length(dataset: DatasetType) -> int:
215
- """Number of elements in a dataset item
263
+ def get_column_elements_shape(
264
+ dataset: DatasetType, column_name: Union[str, int]
265
+ ) -> tuple:
266
+ """Get the shape of the elements of a column of dataset identified by
267
+ column_name
216
268
 
217
269
  Args:
218
- dataset (DatasetType): Dataset
270
+ dataset (Dataset): a Dataset
271
+ column_name (Union[str, int]): The column name to get
272
+ the element shape from.
219
273
 
220
274
  Returns:
221
- int: Item length
275
+ tuple: the shape of an element from column_name
222
276
  """
223
277
  raise NotImplementedError()
224
278
 
225
279
  @staticmethod
226
280
  @abstractmethod
227
- def get_dataset_length(dataset: DatasetType) -> int:
228
- """Number of items in a dataset
281
+ def get_input_from_dataset_item(elem: ItemType) -> TensorType:
282
+ """Get the tensor that is to be feed as input to a model from a dataset element.
229
283
 
230
284
  Args:
231
- dataset (DatasetType): Dataset
285
+ elem (ItemType): dataset element to extract input from
232
286
 
233
287
  Returns:
234
- int: Dataset length
288
+ TensorType: Input tensor
289
+ """
290
+ raise NotImplementedError()
291
+
292
+ @staticmethod
293
+ @abstractmethod
294
+ def get_label_from_dataset_item(item: ItemType):
295
+ """Retrieve label tensor from item as a tuple/list. Label must be at index 1
296
+ in the item tuple. If one-hot encoded, labels are converted to single value.
297
+
298
+ Args:
299
+ elem (ItemType): dataset element to extract label from
300
+
301
+ Returns:
302
+ Any: Label tensor
235
303
  """
236
304
  raise NotImplementedError()
@@ -0,0 +1,236 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright IRT Antoine de Saint Exupéry et Université Paul Sabatier Toulouse III - All
3
+ # rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
4
+ # CRIAQ and ANITI - https://www.deel.ai/
5
+ #
6
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
7
+ # of this software and associated documentation files (the "Software"), to deal
8
+ # in the Software without restriction, including without limitation the rights
9
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
+ # copies of the Software, and to permit persons to whom the Software is
11
+ # furnished to do so, subject to the following conditions:
12
+ #
13
+ # The above copyright notice and this permission notice shall be included in all
14
+ # copies or substantial portions of the Software.
15
+ #
16
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
+ # SOFTWARE.
23
+ from abc import ABC
24
+ from abc import abstractmethod
25
+
26
+ import numpy as np
27
+
28
+ from ...types import Callable
29
+ from ...types import DatasetType
30
+ from ...types import ItemType
31
+ from ...types import Optional
32
+ from ...types import Tuple
33
+ from ...types import Union
34
+
35
+
36
+ class DataHandler(ABC):
37
+ """
38
+ Class to manage Datasets. The aim is to provide a simple interface
39
+ for working with datasets (torch, tensorflow or other...) and manage them without
40
+ having to use library-specific syntax.
41
+ """
42
+
43
+ @classmethod
44
+ @abstractmethod
45
+ def load_dataset(
46
+ cls,
47
+ dataset_id: Union[ItemType, DatasetType, str],
48
+ keys: Optional[list] = None,
49
+ load_kwargs: dict = {},
50
+ ) -> DatasetType:
51
+ """Load dataset from different manners
52
+
53
+ Args:
54
+ dataset_id (Union[ItemType, DatasetType, str]): dataset identification
55
+ keys (list, optional): Features keys. If None, assigned as "input_i"
56
+ for i-th feature. Defaults to None.
57
+ load_kwargs (dict, optional): Additional loading kwargs. Defaults to {}.
58
+
59
+ Returns:
60
+ DatasetType: dataset
61
+ """
62
+ raise NotImplementedError()
63
+
64
+ @staticmethod
65
+ @abstractmethod
66
+ def assign_feature_value(
67
+ dataset: DatasetType, feature_key: str, value: int
68
+ ) -> DatasetType:
69
+ """Assign a value to a feature for every sample in a Dataset
70
+
71
+ Args:
72
+ dataset (DatasetType): Dataset to assign the value to
73
+ feature_key (str): Feature to assign the value to
74
+ value (int): Value to assign
75
+
76
+ Returns:
77
+ DatasetType: updated dataset
78
+ """
79
+ raise NotImplementedError()
80
+
81
+ @staticmethod
82
+ @abstractmethod
83
+ def get_feature_from_ds(dataset: DatasetType, feature_key: str) -> np.ndarray:
84
+ """Get a feature from a Dataset
85
+
86
+ Args:
87
+ dataset (DatasetType): Dataset to get the feature from
88
+ feature_key (str): Feature value to get
89
+
90
+ Returns:
91
+ np.ndarray: Feature values for dataset
92
+ """
93
+ raise NotImplementedError()
94
+
95
+ @staticmethod
96
+ @abstractmethod
97
+ def get_ds_feature_keys(dataset: DatasetType) -> list:
98
+ """Get the feature keys of a Dataset
99
+
100
+ Args:
101
+ dataset (Dataset): Dataset to get the feature keys from
102
+
103
+ Returns:
104
+ list: List of feature keys
105
+ """
106
+ raise NotImplementedError()
107
+
108
+ @staticmethod
109
+ @abstractmethod
110
+ def has_feature_key(dataset: DatasetType, key: str) -> bool:
111
+ """Check if a Dataset has a feature denoted by key
112
+
113
+ Args:
114
+ dataset (DatasetType): Dataset to check
115
+ key (str): Key to check
116
+
117
+ Returns:
118
+ bool: If the dataset has a feature denoted by key
119
+ """
120
+ raise NotImplementedError()
121
+
122
+ @staticmethod
123
+ @abstractmethod
124
+ def map_ds(dataset: DatasetType, map_fn: Callable) -> DatasetType:
125
+ """Map a function to a Dataset
126
+
127
+ Args:
128
+ dataset (DatasetType): Dataset to map the function to
129
+ map_fn (Callable): Function to map
130
+
131
+ Returns:
132
+ DatasetType: Mapped dataset
133
+ """
134
+ raise NotImplementedError()
135
+
136
+ @staticmethod
137
+ @abstractmethod
138
+ def filter_by_feature_value(
139
+ dataset: DatasetType,
140
+ feature_key: str,
141
+ values: list,
142
+ excluded: bool = False,
143
+ ) -> DatasetType:
144
+ """Filter the dataset by checking the value of a feature is in `values`
145
+
146
+ Args:
147
+ dataset (Dataset): Dataset to filter
148
+ feature_key (str): Feature name to check the value
149
+ values (list): Feature_key values to keep (if excluded is False)
150
+ or to exclude
151
+ excluded (bool, optional): To keep (False) or exclude (True) the samples
152
+ with Feature_key value included in Values. Defaults to False.
153
+
154
+ Returns:
155
+ DatasetType: Filtered dataset
156
+ """
157
+ raise NotImplementedError()
158
+
159
+ @staticmethod
160
+ @abstractmethod
161
+ def merge(
162
+ id_dataset: DatasetType,
163
+ ood_dataset: DatasetType,
164
+ resize: Optional[bool] = False,
165
+ shape: Optional[Tuple[int]] = None,
166
+ ) -> DatasetType:
167
+ """Merge two datasets
168
+
169
+ Args:
170
+ id_dataset (Dataset): dataset of in-distribution data
171
+ ood_dataset (DictDataset): dataset of out-of-distribution data
172
+ resize (Optional[bool], optional): toggles if input tensors of the
173
+ datasets have to be resized to have the same shape. Defaults to True.
174
+ shape (Optional[Tuple[int]], optional): shape to use for resizing input
175
+ tensors. If None, the tensors are resized with the shape of the
176
+ id_dataset input tensors. Defaults to None.
177
+
178
+ Returns:
179
+ DatasetType: merged dataset
180
+ """
181
+ raise NotImplementedError()
182
+
183
+ @classmethod
184
+ @abstractmethod
185
+ def prepare_for_training(
186
+ cls,
187
+ dataset: DatasetType,
188
+ batch_size: int,
189
+ shuffle: bool = False,
190
+ preprocess_fn: Optional[Callable] = None,
191
+ augment_fn: Optional[Callable] = None,
192
+ output_keys: list = ["input", "label"],
193
+ ) -> DatasetType:
194
+ """Prepare a dataset for training
195
+
196
+ Args:
197
+ dataset (DictDataset): Dataset to prepare
198
+ batch_size (int): Batch size
199
+ shuffle (bool): Wether to shuffle the dataloader or not
200
+ preprocess_fn (Callable, optional): Preprocessing function to apply to
201
+ the dataset. Defaults to None.
202
+ augment_fn (Callable, optional): Augment function to be used (when the
203
+ returned dataset is to be used for training). Defaults to None.
204
+ output_keys (list): List of keys corresponding to the features that will be
205
+ returned. Keep all features if None. Defaults to None.
206
+
207
+ Returns:
208
+ DatasetType: prepared dataset / dataloader
209
+ """
210
+ raise NotImplementedError()
211
+
212
+ @staticmethod
213
+ @abstractmethod
214
+ def get_item_length(dataset: DatasetType) -> int:
215
+ """Number of elements in a dataset item
216
+
217
+ Args:
218
+ dataset (DatasetType): Dataset
219
+
220
+ Returns:
221
+ int: Item length
222
+ """
223
+ raise NotImplementedError()
224
+
225
+ @staticmethod
226
+ @abstractmethod
227
+ def get_dataset_length(dataset: DatasetType) -> int:
228
+ """Number of items in a dataset
229
+
230
+ Args:
231
+ dataset (DatasetType): Dataset
232
+
233
+ Returns:
234
+ int: Dataset length
235
+ """
236
+ raise NotImplementedError()