oodeel 0.1.1__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of oodeel might be problematic. Click here for more details.
- oodeel/__init__.py +1 -1
- oodeel/datasets/__init__.py +2 -1
- oodeel/datasets/data_handler.py +162 -94
- oodeel/datasets/deprecated/DEPRECATED_data_handler.py +236 -0
- oodeel/datasets/{ooddataset.py → deprecated/DEPRECATED_ooddataset.py} +14 -13
- oodeel/datasets/deprecated/DEPRECATED_tf_data_handler.py +671 -0
- oodeel/datasets/deprecated/DEPRECATED_torch_data_handler.py +769 -0
- oodeel/datasets/deprecated/__init__.py +31 -0
- oodeel/datasets/tf_data_handler.py +105 -167
- oodeel/datasets/torch_data_handler.py +109 -181
- oodeel/eval/metrics.py +7 -2
- oodeel/eval/plots/features.py +2 -2
- oodeel/eval/plots/plotly.py +2 -2
- oodeel/extractor/feature_extractor.py +30 -9
- oodeel/extractor/keras_feature_extractor.py +70 -13
- oodeel/extractor/torch_feature_extractor.py +120 -33
- oodeel/methods/__init__.py +17 -1
- oodeel/methods/base.py +103 -17
- oodeel/methods/dknn.py +22 -9
- oodeel/methods/energy.py +8 -0
- oodeel/methods/entropy.py +8 -0
- oodeel/methods/gen.py +118 -0
- oodeel/methods/gram.py +307 -0
- oodeel/methods/mahalanobis.py +14 -12
- oodeel/methods/mls.py +8 -0
- oodeel/methods/odin.py +8 -0
- oodeel/methods/rmds.py +122 -0
- oodeel/methods/she.py +197 -0
- oodeel/methods/vim.py +5 -5
- oodeel/preprocess/__init__.py +31 -0
- oodeel/preprocess/tf_preprocess.py +95 -0
- oodeel/preprocess/torch_preprocess.py +97 -0
- oodeel/utils/operator.py +72 -2
- oodeel/utils/tf_operator.py +72 -4
- oodeel/utils/tf_training_tools.py +26 -3
- oodeel/utils/torch_operator.py +75 -4
- oodeel/utils/torch_training_tools.py +31 -2
- {oodeel-0.1.1.dist-info → oodeel-0.3.0.dist-info}/METADATA +141 -107
- oodeel-0.3.0.dist-info/RECORD +57 -0
- {oodeel-0.1.1.dist-info → oodeel-0.3.0.dist-info}/WHEEL +1 -1
- tests/tests_tensorflow/tf_methods_utils.py +2 -1
- tests/tests_torch/tools_torch.py +9 -9
- tests/tests_torch/torch_methods_utils.py +34 -27
- tests/tools_operator.py +10 -1
- oodeel-0.1.1.dist-info/RECORD +0 -46
- {oodeel-0.1.1.dist-info → oodeel-0.3.0.dist-info/licenses}/LICENSE +0 -0
- {oodeel-0.1.1.dist-info → oodeel-0.3.0.dist-info}/top_level.txt +0 -0
|
@@ -46,9 +46,10 @@ from .data_handler import DataHandler
|
|
|
46
46
|
|
|
47
47
|
|
|
48
48
|
def dict_only_ds(ds_handling_method: Callable) -> Callable:
|
|
49
|
-
"""Decorator to ensure that the dataset is a dict dataset and that the
|
|
50
|
-
matches one of the
|
|
51
|
-
|
|
49
|
+
"""Decorator to ensure that the dataset is a dict dataset and that the column_name
|
|
50
|
+
given as argument matches one of the column names.
|
|
51
|
+
matches one of the column names. The signature of decorated functions
|
|
52
|
+
must be function(dataset, *args, **kwargs) with column_name either in kwargs or
|
|
52
53
|
args[0] when relevant.
|
|
53
54
|
|
|
54
55
|
|
|
@@ -64,19 +65,19 @@ def dict_only_ds(ds_handling_method: Callable) -> Callable:
|
|
|
64
65
|
dataset, DictDataset
|
|
65
66
|
), "Dataset must be an instance of DictDataset"
|
|
66
67
|
|
|
67
|
-
if "
|
|
68
|
-
|
|
68
|
+
if "column_name" in kwargs:
|
|
69
|
+
column_name = kwargs["column_name"]
|
|
69
70
|
elif len(args) > 0:
|
|
70
|
-
|
|
71
|
+
column_name = args[0]
|
|
71
72
|
|
|
72
|
-
# If
|
|
73
|
-
if (len(args) > 0) or ("
|
|
74
|
-
if isinstance(
|
|
75
|
-
|
|
76
|
-
for
|
|
73
|
+
# If column_name is provided, check that it is in the dataset column names
|
|
74
|
+
if (len(args) > 0) or ("column_name" in kwargs):
|
|
75
|
+
if isinstance(column_name, str):
|
|
76
|
+
column_name = [column_name]
|
|
77
|
+
for name in column_name:
|
|
77
78
|
assert (
|
|
78
|
-
|
|
79
|
-
), f"The input dataset has no
|
|
79
|
+
name in dataset.columns
|
|
80
|
+
), f"The input dataset has no column named {name}"
|
|
80
81
|
return ds_handling_method(dataset, *args, **kwargs)
|
|
81
82
|
|
|
82
83
|
return wrapper
|
|
@@ -108,23 +109,23 @@ class DictDataset(Dataset):
|
|
|
108
109
|
|
|
109
110
|
Args:
|
|
110
111
|
dataset (Dataset): Dataset to wrap.
|
|
111
|
-
|
|
112
|
+
columns (columns[str]): Column names describing the output tensors.
|
|
112
113
|
"""
|
|
113
114
|
|
|
114
115
|
def __init__(
|
|
115
|
-
self, dataset: Dataset,
|
|
116
|
+
self, dataset: Dataset, columns: List[str] = ["input", "label"]
|
|
116
117
|
) -> None:
|
|
117
118
|
self._dataset = dataset
|
|
118
|
-
self.
|
|
119
|
+
self._raw_columns = columns
|
|
119
120
|
self.map_fns = []
|
|
120
121
|
self._check_init_args()
|
|
121
122
|
|
|
122
123
|
@property
|
|
123
|
-
def
|
|
124
|
-
"""Get the list of
|
|
124
|
+
def columns(self) -> list:
|
|
125
|
+
"""Get the list of columns in a dict-based item from the dataset.
|
|
125
126
|
|
|
126
127
|
Returns:
|
|
127
|
-
list:
|
|
128
|
+
list: column names of the dataset.
|
|
128
129
|
"""
|
|
129
130
|
dummy_item = self[0]
|
|
130
131
|
return list(dummy_item.keys())
|
|
@@ -137,10 +138,10 @@ class DictDataset(Dataset):
|
|
|
137
138
|
list: tensor shapes of an dataset item.
|
|
138
139
|
"""
|
|
139
140
|
dummy_item = self[0]
|
|
140
|
-
return [dummy_item[key].shape for key in self.
|
|
141
|
+
return [dummy_item[key].shape for key in self.columns]
|
|
141
142
|
|
|
142
143
|
def _check_init_args(self) -> None:
|
|
143
|
-
"""Check validity of dataset and
|
|
144
|
+
"""Check validity of dataset and column names provided at init"""
|
|
144
145
|
dummy_item = self._dataset[0]
|
|
145
146
|
assert isinstance(
|
|
146
147
|
dummy_item, (tuple, dict, list, torch.Tensor)
|
|
@@ -148,8 +149,8 @@ class DictDataset(Dataset):
|
|
|
148
149
|
if isinstance(dummy_item, torch.Tensor):
|
|
149
150
|
dummy_item = [dummy_item]
|
|
150
151
|
assert len(dummy_item) == len(
|
|
151
|
-
self.
|
|
152
|
-
), "Length mismatch between dataset item and provided
|
|
152
|
+
self._raw_columns
|
|
153
|
+
), "Length mismatch between dataset item and provided column names"
|
|
153
154
|
|
|
154
155
|
def __getitem__(self, index: int) -> dict:
|
|
155
156
|
"""Return a dictionary of tensors corresponding to a specfic index.
|
|
@@ -171,9 +172,7 @@ class DictDataset(Dataset):
|
|
|
171
172
|
tensors = item
|
|
172
173
|
|
|
173
174
|
# build output dictionary
|
|
174
|
-
output_dict = {
|
|
175
|
-
key: tensor for (key, tensor) in zip(self._raw_output_keys, tensors)
|
|
176
|
-
}
|
|
175
|
+
output_dict = {key: tensor for (key, tensor) in zip(self._raw_columns, tensors)}
|
|
177
176
|
|
|
178
177
|
# apply map functions
|
|
179
178
|
for map_fn in self.map_fns:
|
|
@@ -228,18 +227,16 @@ class DictDataset(Dataset):
|
|
|
228
227
|
other_dataset, DictDataset
|
|
229
228
|
), "Second dataset should be an instance of DictDataset"
|
|
230
229
|
assert (
|
|
231
|
-
self.
|
|
232
|
-
), "Incompatible dataset elements (different
|
|
230
|
+
self.columns == other_dataset.columns
|
|
231
|
+
), "Incompatible dataset elements (different column names)"
|
|
233
232
|
if inplace:
|
|
234
233
|
dataset_copy = copy.deepcopy(self)
|
|
235
|
-
self.
|
|
234
|
+
self._raw_columns = self.columns
|
|
236
235
|
self.map_fns = []
|
|
237
236
|
self._dataset = ConcatDataset([dataset_copy, other_dataset])
|
|
238
237
|
dataset = self
|
|
239
238
|
else:
|
|
240
|
-
dataset = DictDataset(
|
|
241
|
-
ConcatDataset([self, other_dataset]), self.output_keys
|
|
242
|
-
)
|
|
239
|
+
dataset = DictDataset(ConcatDataset([self, other_dataset]), self.columns)
|
|
243
240
|
return dataset
|
|
244
241
|
|
|
245
242
|
def __len__(self) -> int:
|
|
@@ -258,6 +255,11 @@ class TorchDataHandler(DataHandler):
|
|
|
258
255
|
torch syntax.
|
|
259
256
|
"""
|
|
260
257
|
|
|
258
|
+
def __init__(self) -> None:
|
|
259
|
+
super().__init__()
|
|
260
|
+
self.backend = "torch"
|
|
261
|
+
self.channel_order = "channels_first"
|
|
262
|
+
|
|
261
263
|
@staticmethod
|
|
262
264
|
def _default_target_transform(y: Any) -> torch.Tensor:
|
|
263
265
|
"""Format int or float item target as a torch tensor
|
|
@@ -277,14 +279,16 @@ class TorchDataHandler(DataHandler):
|
|
|
277
279
|
def load_dataset(
|
|
278
280
|
cls,
|
|
279
281
|
dataset_id: Union[Dataset, ItemType, str],
|
|
280
|
-
|
|
282
|
+
columns: Optional[list] = None,
|
|
281
283
|
load_kwargs: dict = {},
|
|
282
284
|
) -> DictDataset:
|
|
283
285
|
"""Load dataset from different manners
|
|
284
286
|
|
|
285
287
|
Args:
|
|
286
|
-
dataset_id (Union[Dataset, ItemType, str]): dataset identification
|
|
287
|
-
|
|
288
|
+
dataset_id (Union[Dataset, ItemType, str]): dataset identification.
|
|
289
|
+
Can be the name of a dataset from torchvision, a torch Dataset,
|
|
290
|
+
or a tuple/dict of np.ndarrays/torch tensors.
|
|
291
|
+
columns (list, optional): Column names. If None, assigned as "input_i"
|
|
288
292
|
for i-th feature. Defaults to None.
|
|
289
293
|
load_kwargs (dict, optional): Additional loading kwargs. Defaults to {}.
|
|
290
294
|
|
|
@@ -295,23 +299,23 @@ class TorchDataHandler(DataHandler):
|
|
|
295
299
|
assert "root" in load_kwargs.keys()
|
|
296
300
|
dataset = cls.load_from_torchvision(dataset_id, **load_kwargs)
|
|
297
301
|
elif isinstance(dataset_id, Dataset):
|
|
298
|
-
dataset = cls.load_custom_dataset(dataset_id,
|
|
302
|
+
dataset = cls.load_custom_dataset(dataset_id, columns)
|
|
299
303
|
elif isinstance(dataset_id, get_args(ItemType)):
|
|
300
|
-
dataset = cls.load_dataset_from_arrays(dataset_id,
|
|
304
|
+
dataset = cls.load_dataset_from_arrays(dataset_id, columns)
|
|
301
305
|
return dataset
|
|
302
306
|
|
|
303
307
|
@staticmethod
|
|
304
308
|
def load_dataset_from_arrays(
|
|
305
309
|
dataset_id: ItemType,
|
|
306
|
-
|
|
310
|
+
columns: Optional[list] = None,
|
|
307
311
|
) -> DictDataset:
|
|
308
312
|
"""Load a torch.utils.data.Dataset from an array or a tuple/dict of arrays.
|
|
309
313
|
|
|
310
314
|
Args:
|
|
311
315
|
dataset_id (ItemType):
|
|
312
316
|
numpy / torch array(s) to load.
|
|
313
|
-
|
|
314
|
-
for i-th feature. Defaults to None.
|
|
317
|
+
columns (list, optional): Column names to assign. If None,
|
|
318
|
+
assigned as "input_i" for i-th feature. Defaults to None.
|
|
315
319
|
|
|
316
320
|
Returns:
|
|
317
321
|
DictDataset: dataset
|
|
@@ -319,47 +323,45 @@ class TorchDataHandler(DataHandler):
|
|
|
319
323
|
# If dataset_id is an array
|
|
320
324
|
if isinstance(dataset_id, get_args(TensorType)):
|
|
321
325
|
tensors = tuple(to_torch(dataset_id))
|
|
322
|
-
|
|
326
|
+
columns = columns or ["input"]
|
|
323
327
|
|
|
324
328
|
# If dataset_id is a tuple of arrays
|
|
325
329
|
elif isinstance(dataset_id, tuple):
|
|
326
330
|
len_elem = len(dataset_id)
|
|
327
|
-
|
|
328
|
-
if output_keys is None:
|
|
331
|
+
if columns is None:
|
|
329
332
|
if len_elem == 2:
|
|
330
|
-
|
|
333
|
+
columns = ["input", "label"]
|
|
331
334
|
else:
|
|
332
|
-
|
|
333
|
-
"label"
|
|
334
|
-
]
|
|
335
|
+
columns = [f"input_{i}" for i in range(len_elem - 1)] + ["label"]
|
|
335
336
|
print(
|
|
336
337
|
"Loading torch.utils.data.Dataset with elems as dicts, "
|
|
337
338
|
'assigning "input_i" key to the i-th tuple dimension and'
|
|
338
339
|
' "label" key to the last tuple dimension.'
|
|
339
340
|
)
|
|
340
|
-
assert len(
|
|
341
|
+
assert len(columns) == len(dataset_id)
|
|
341
342
|
tensors = tuple(to_torch(array) for array in dataset_id)
|
|
342
343
|
|
|
343
344
|
# If dataset_id is a dictionary of arrays
|
|
344
345
|
elif isinstance(dataset_id, dict):
|
|
345
|
-
|
|
346
|
-
assert len(
|
|
346
|
+
columns = columns or list(dataset_id.keys())
|
|
347
|
+
assert len(columns) == len(dataset_id)
|
|
347
348
|
tensors = tuple(to_torch(array) for array in dataset_id.values())
|
|
348
349
|
|
|
349
|
-
# create torch dictionary dataset from tensors tuple and
|
|
350
|
-
dataset = DictDataset(TensorDataset(*tensors),
|
|
350
|
+
# create torch dictionary dataset from tensors tuple and columns
|
|
351
|
+
dataset = DictDataset(TensorDataset(*tensors), columns)
|
|
351
352
|
return dataset
|
|
352
353
|
|
|
353
354
|
@staticmethod
|
|
354
355
|
def load_custom_dataset(
|
|
355
|
-
dataset_id: Dataset,
|
|
356
|
+
dataset_id: Dataset, columns: Optional[list] = None
|
|
356
357
|
) -> DictDataset:
|
|
357
358
|
"""Load a custom Dataset by ensuring it has the correct format (dict-based)
|
|
358
359
|
|
|
359
360
|
Args:
|
|
360
361
|
dataset_id (Dataset): Dataset
|
|
361
|
-
|
|
362
|
-
tuple based.
|
|
362
|
+
columns (list, optional): Column names to use for elements if dataset_id is
|
|
363
|
+
tuple based. If None, assigned as "input_i"
|
|
364
|
+
for i-th column. Defaults to None.
|
|
363
365
|
|
|
364
366
|
Returns:
|
|
365
367
|
DictDataset
|
|
@@ -370,20 +372,17 @@ class TorchDataHandler(DataHandler):
|
|
|
370
372
|
assert isinstance(
|
|
371
373
|
dummy_item, (Tuple, torch.Tensor)
|
|
372
374
|
), "Custom dataset should be either dictionary based or tuple based"
|
|
373
|
-
|
|
374
|
-
if output_keys is None:
|
|
375
|
+
if columns is None:
|
|
375
376
|
len_elem = len(dummy_item)
|
|
376
377
|
if len_elem == 2:
|
|
377
|
-
|
|
378
|
+
columns = ["input", "label"]
|
|
378
379
|
else:
|
|
379
|
-
|
|
380
|
-
"label"
|
|
381
|
-
]
|
|
380
|
+
columns = [f"input_{i}" for i in range(len_elem - 1)] + ["label"]
|
|
382
381
|
print(
|
|
383
382
|
"Feature name not found, assigning 'input_i' "
|
|
384
383
|
"key to the i-th tensor and 'label' key to the last"
|
|
385
384
|
)
|
|
386
|
-
dataset_id = DictDataset(dataset_id,
|
|
385
|
+
dataset_id = DictDataset(dataset_id, columns)
|
|
387
386
|
|
|
388
387
|
dataset = dataset_id
|
|
389
388
|
return dataset
|
|
@@ -428,81 +427,18 @@ class TorchDataHandler(DataHandler):
|
|
|
428
427
|
)
|
|
429
428
|
return cls.load_custom_dataset(dataset)
|
|
430
429
|
|
|
431
|
-
@staticmethod
|
|
432
|
-
def assign_feature_value(
|
|
433
|
-
dataset: DictDataset, feature_key: str, value: int
|
|
434
|
-
) -> DictDataset:
|
|
435
|
-
"""Assign a value to a feature for every sample in a DictDataset
|
|
436
|
-
|
|
437
|
-
Args:
|
|
438
|
-
dataset (DictDataset): DictDataset to assign the value to
|
|
439
|
-
feature_key (str): Feature to assign the value to
|
|
440
|
-
value (int): Value to assign
|
|
441
|
-
|
|
442
|
-
Returns:
|
|
443
|
-
DictDataset
|
|
444
|
-
"""
|
|
445
|
-
assert isinstance(
|
|
446
|
-
dataset, DictDataset
|
|
447
|
-
), "Dataset must be an instance of DictDataset"
|
|
448
|
-
|
|
449
|
-
def assign_value_to_feature(x):
|
|
450
|
-
x[feature_key] = torch.tensor(value)
|
|
451
|
-
return x
|
|
452
|
-
|
|
453
|
-
dataset = dataset.map(assign_value_to_feature)
|
|
454
|
-
return dataset
|
|
455
|
-
|
|
456
430
|
@staticmethod
|
|
457
431
|
@dict_only_ds
|
|
458
|
-
def
|
|
459
|
-
"""Get
|
|
460
|
-
|
|
461
|
-
!!! note
|
|
462
|
-
This function can be a bit time consuming since it needs to iterate
|
|
463
|
-
over the whole dataset.
|
|
432
|
+
def get_ds_column_names(dataset: DictDataset) -> list:
|
|
433
|
+
"""Get the column names of a DictDataset
|
|
464
434
|
|
|
465
435
|
Args:
|
|
466
|
-
dataset (DictDataset): Dataset to get the
|
|
467
|
-
feature_key (str): Feature value to get
|
|
436
|
+
dataset (DictDataset): Dataset to get the column names from
|
|
468
437
|
|
|
469
438
|
Returns:
|
|
470
|
-
|
|
439
|
+
list: List of column names
|
|
471
440
|
"""
|
|
472
|
-
|
|
473
|
-
features = dataset.map(lambda x: x[feature_key])
|
|
474
|
-
features = np.stack([f.numpy() for f in features])
|
|
475
|
-
return features
|
|
476
|
-
|
|
477
|
-
@staticmethod
|
|
478
|
-
@dict_only_ds
|
|
479
|
-
def get_ds_feature_keys(dataset: DictDataset) -> list:
|
|
480
|
-
"""Get the feature keys of a DictDataset
|
|
481
|
-
|
|
482
|
-
Args:
|
|
483
|
-
dataset (DictDataset): Dataset to get the feature keys from
|
|
484
|
-
|
|
485
|
-
Returns:
|
|
486
|
-
list: List of feature keys
|
|
487
|
-
"""
|
|
488
|
-
return dataset.output_keys
|
|
489
|
-
|
|
490
|
-
@staticmethod
|
|
491
|
-
def has_feature_key(dataset: DictDataset, key: str) -> bool:
|
|
492
|
-
"""Check if a DictDataset has a feature denoted by key
|
|
493
|
-
|
|
494
|
-
Args:
|
|
495
|
-
dataset (DictDataset): Dataset to check
|
|
496
|
-
key (str): Key to check
|
|
497
|
-
|
|
498
|
-
Returns:
|
|
499
|
-
bool: If the dataset has a feature denoted by key
|
|
500
|
-
"""
|
|
501
|
-
assert isinstance(
|
|
502
|
-
dataset, DictDataset
|
|
503
|
-
), "Dataset must be an instance of DictDataset"
|
|
504
|
-
|
|
505
|
-
return key in dataset.output_keys
|
|
441
|
+
return dataset.columns
|
|
506
442
|
|
|
507
443
|
@staticmethod
|
|
508
444
|
def map_ds(
|
|
@@ -522,13 +458,13 @@ class TorchDataHandler(DataHandler):
|
|
|
522
458
|
|
|
523
459
|
@staticmethod
|
|
524
460
|
@dict_only_ds
|
|
525
|
-
def
|
|
461
|
+
def filter_by_value(
|
|
526
462
|
dataset: DictDataset,
|
|
527
|
-
|
|
463
|
+
column_name: str,
|
|
528
464
|
values: list,
|
|
529
465
|
excluded: bool = False,
|
|
530
466
|
) -> DictDataset:
|
|
531
|
-
"""Filter the dataset by checking the value of a
|
|
467
|
+
"""Filter the dataset by checking if the value of a column is in `values`
|
|
532
468
|
|
|
533
469
|
!!! note
|
|
534
470
|
This function can be a bit of time consuming since it needs to iterate
|
|
@@ -536,62 +472,64 @@ class TorchDataHandler(DataHandler):
|
|
|
536
472
|
|
|
537
473
|
Args:
|
|
538
474
|
dataset (DictDataset): Dataset to filter
|
|
539
|
-
|
|
540
|
-
values (list):
|
|
475
|
+
column_name (str): Column to filter the dataset with
|
|
476
|
+
values (list): Column values to keep
|
|
541
477
|
excluded (bool, optional): To keep (False) or exclude (True) the samples
|
|
542
|
-
with
|
|
478
|
+
with column values included in Values. Defaults to False.
|
|
543
479
|
|
|
544
480
|
Returns:
|
|
545
481
|
DictDataset: Filtered dataset
|
|
546
482
|
"""
|
|
547
483
|
|
|
548
|
-
if len(dataset[0][
|
|
549
|
-
value_dim = dataset[0][
|
|
484
|
+
if len(dataset[0][column_name].shape) > 0:
|
|
485
|
+
value_dim = dataset[0][column_name].shape[-1]
|
|
550
486
|
values = [
|
|
551
487
|
F.one_hot(torch.tensor(value).long(), value_dim) for value in values
|
|
552
488
|
]
|
|
553
489
|
|
|
554
490
|
def filter_fn(x):
|
|
555
|
-
keep = any([torch.all(x[
|
|
491
|
+
keep = any([torch.all(x[column_name] == v) for v in values])
|
|
556
492
|
return keep if not excluded else not keep
|
|
557
493
|
|
|
558
494
|
filtered_dataset = dataset.filter(filter_fn)
|
|
559
495
|
return filtered_dataset
|
|
560
496
|
|
|
561
497
|
@classmethod
|
|
562
|
-
def
|
|
498
|
+
def prepare(
|
|
563
499
|
cls,
|
|
564
500
|
dataset: DictDataset,
|
|
565
501
|
batch_size: int,
|
|
566
|
-
shuffle: bool = False,
|
|
567
502
|
preprocess_fn: Optional[Callable] = None,
|
|
568
503
|
augment_fn: Optional[Callable] = None,
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
504
|
+
columns: Optional[list] = None,
|
|
505
|
+
shuffle: bool = False,
|
|
506
|
+
dict_based_fns: bool = True,
|
|
507
|
+
return_tuple: bool = True,
|
|
508
|
+
num_workers: int = 8,
|
|
572
509
|
) -> DataLoader:
|
|
573
510
|
"""Prepare a DataLoader for training
|
|
574
511
|
|
|
575
512
|
Args:
|
|
576
513
|
dataset (DictDataset): Dataset to prepare
|
|
577
514
|
batch_size (int): Batch size
|
|
578
|
-
shuffle (bool): Wether to shuffle the dataloader or not
|
|
579
515
|
preprocess_fn (Callable, optional): Preprocessing function to apply to
|
|
580
516
|
the dataset. Defaults to None.
|
|
581
517
|
augment_fn (Callable, optional): Augment function to be used (when the
|
|
582
518
|
returned dataset is to be used for training). Defaults to None.
|
|
583
|
-
|
|
584
|
-
returned. Keep all features if None. Defaults to None.
|
|
519
|
+
columns (list, optional): List of column names corresponding to the columns
|
|
520
|
+
that will be returned. Keep all features if None. Defaults to None.
|
|
521
|
+
shuffle (bool, optional): To shuffle the returned dataset or not.
|
|
522
|
+
Defaults to False.
|
|
585
523
|
dict_based_fns (bool): Whether to use preprocess and DA functions as dict
|
|
586
|
-
based (if True) or as tuple based (if False). Defaults to
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
524
|
+
based (if True) or as tuple based (if False). Defaults to True.
|
|
525
|
+
return_tuple (bool, optional): Whether to return each dataset item
|
|
526
|
+
as a tuple. Defaults to True.
|
|
527
|
+
num_workers (int, optional): Number of workers to use for the dataloader.
|
|
590
528
|
|
|
591
529
|
Returns:
|
|
592
530
|
DataLoader: dataloader
|
|
593
531
|
"""
|
|
594
|
-
|
|
532
|
+
columns = columns or cls.get_ds_column_names(dataset)
|
|
595
533
|
|
|
596
534
|
def collate_fn(batch: List[dict]):
|
|
597
535
|
if dict_based_fns:
|
|
@@ -599,18 +537,20 @@ class TorchDataHandler(DataHandler):
|
|
|
599
537
|
preprocess_func = preprocess_fn or (lambda x: x)
|
|
600
538
|
augment_func = augment_fn or (lambda x: x)
|
|
601
539
|
batch = [augment_func(preprocess_func(d)) for d in batch]
|
|
602
|
-
# to
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
540
|
+
# to dict of batchs
|
|
541
|
+
if return_tuple:
|
|
542
|
+
return tuple(
|
|
543
|
+
default_collate([d[key] for d in batch]) for key in columns
|
|
544
|
+
)
|
|
545
|
+
return {
|
|
546
|
+
key: default_collate([d[key] for d in batch]) for key in columns
|
|
547
|
+
}
|
|
606
548
|
else:
|
|
607
549
|
# preprocess + DA: List[dict] -> List[tuple]
|
|
608
550
|
preprocess_func = preprocess_fn or (lambda *x: x)
|
|
609
551
|
augment_func = augment_fn or (lambda *x: x)
|
|
610
552
|
batch = [
|
|
611
|
-
augment_func(
|
|
612
|
-
*preprocess_func(*tuple(d[key] for key in output_keys))
|
|
613
|
-
)
|
|
553
|
+
augment_func(*preprocess_func(*tuple(d[key] for key in columns)))
|
|
614
554
|
for d in batch
|
|
615
555
|
]
|
|
616
556
|
# to tuple of batchs
|
|
@@ -621,6 +561,7 @@ class TorchDataHandler(DataHandler):
|
|
|
621
561
|
batch_size=batch_size,
|
|
622
562
|
shuffle=shuffle,
|
|
623
563
|
collate_fn=collate_fn,
|
|
564
|
+
num_workers=num_workers,
|
|
624
565
|
)
|
|
625
566
|
return loader
|
|
626
567
|
|
|
@@ -697,17 +638,21 @@ class TorchDataHandler(DataHandler):
|
|
|
697
638
|
return len(dataset)
|
|
698
639
|
|
|
699
640
|
@staticmethod
|
|
700
|
-
def
|
|
701
|
-
|
|
641
|
+
def get_column_elements_shape(
|
|
642
|
+
dataset: Dataset, column_name: Union[str, int]
|
|
643
|
+
) -> tuple:
|
|
644
|
+
"""Get the shape of the elements of a column of dataset identified by
|
|
645
|
+
column_name
|
|
702
646
|
|
|
703
647
|
Args:
|
|
704
648
|
dataset (Dataset): a Dataset
|
|
705
|
-
|
|
649
|
+
column_name (Union[str, int]): The column name to get
|
|
650
|
+
the element shape from.
|
|
706
651
|
|
|
707
652
|
Returns:
|
|
708
|
-
tuple: the shape of
|
|
653
|
+
tuple: the shape of an element from column_name
|
|
709
654
|
"""
|
|
710
|
-
return tuple(dataset[0][
|
|
655
|
+
return tuple(dataset[0][column_name].shape)
|
|
711
656
|
|
|
712
657
|
@staticmethod
|
|
713
658
|
def get_input_from_dataset_item(elem: ItemType) -> Any:
|
|
@@ -747,20 +692,3 @@ class TorchDataHandler(DataHandler):
|
|
|
747
692
|
if len(label.shape) > 1:
|
|
748
693
|
label = label.view([label.shape[0]])
|
|
749
694
|
return label
|
|
750
|
-
|
|
751
|
-
@staticmethod
|
|
752
|
-
def get_feature(dataset: DictDataset, feature_key: Union[str, int]) -> DictDataset:
|
|
753
|
-
"""Extract a feature from a dataset
|
|
754
|
-
|
|
755
|
-
Args:
|
|
756
|
-
dataset (tf.data.Dataset): Dataset to extract the feature from
|
|
757
|
-
feature_key (Union[str, int]): feature to extract
|
|
758
|
-
|
|
759
|
-
Returns:
|
|
760
|
-
tf.data.Dataset: dataset built with the extracted feature only
|
|
761
|
-
"""
|
|
762
|
-
|
|
763
|
-
def _get_feature_item(item):
|
|
764
|
-
return item[feature_key]
|
|
765
|
-
|
|
766
|
-
return dataset.map(_get_feature_item)
|
oodeel/eval/metrics.py
CHANGED
|
@@ -54,9 +54,11 @@ def bench_metrics(
|
|
|
54
54
|
out_value if different from their default values.
|
|
55
55
|
Defaults to None.
|
|
56
56
|
in_value (Optional[int], optional): ood label value for in-distribution data.
|
|
57
|
+
Automatically assigned 0 if it is not the case.
|
|
57
58
|
Defaults to 0.
|
|
58
59
|
out_value (Optional[int], optional): ood label value for out-of-distribution
|
|
59
|
-
data.
|
|
60
|
+
data. Automatically assigned 1 if it is not the case.
|
|
61
|
+
Defaults to 1.
|
|
60
62
|
metrics (Optional[List[str]], optional): list of metrics to compute. Can pass
|
|
61
63
|
any metric name from sklearn.metric or among "detect_acc" and
|
|
62
64
|
"<aaa><XX><bbb>" where <aaa> and <bbb> are in ["fpr", "tpr", "fnr", "tnr"]
|
|
@@ -89,7 +91,10 @@ def bench_metrics(
|
|
|
89
91
|
for metric in metrics:
|
|
90
92
|
if isinstance(metric, str):
|
|
91
93
|
if metric == "auroc":
|
|
92
|
-
|
|
94
|
+
if np.__version__ >= "2.0.0":
|
|
95
|
+
auroc = -np.trapezoid(1.0 - fpr, tpr)
|
|
96
|
+
else:
|
|
97
|
+
auroc = -np.trapz(1.0 - fpr, tpr)
|
|
93
98
|
metrics_dict["auroc"] = auroc
|
|
94
99
|
|
|
95
100
|
elif metric == "detect_acc":
|
oodeel/eval/plots/features.py
CHANGED
|
@@ -168,7 +168,7 @@ def _plot_features(
|
|
|
168
168
|
# === extract id features ===
|
|
169
169
|
# features
|
|
170
170
|
in_features, _ = feature_extractor.predict(in_dataset)
|
|
171
|
-
in_features = op.convert_to_numpy(op.flatten(in_features))[:max_samples]
|
|
171
|
+
in_features = op.convert_to_numpy(op.flatten(in_features[0]))[:max_samples]
|
|
172
172
|
|
|
173
173
|
# labels
|
|
174
174
|
in_labels = []
|
|
@@ -181,7 +181,7 @@ def _plot_features(
|
|
|
181
181
|
if out_dataset is not None:
|
|
182
182
|
# features
|
|
183
183
|
out_features, _ = feature_extractor.predict(out_dataset)
|
|
184
|
-
out_features = op.convert_to_numpy(op.flatten(out_features))[:max_samples]
|
|
184
|
+
out_features = op.convert_to_numpy(op.flatten(out_features[0]))[:max_samples]
|
|
185
185
|
|
|
186
186
|
# labels
|
|
187
187
|
out_labels_str = np.array(["unknown"] * len(out_features))
|
oodeel/eval/plots/plotly.py
CHANGED
|
@@ -82,7 +82,7 @@ def plotly_3D_features(
|
|
|
82
82
|
# === extract id features ===
|
|
83
83
|
# features
|
|
84
84
|
in_features, _ = feature_extractor.predict(in_dataset)
|
|
85
|
-
in_features = op.convert_to_numpy(op.flatten(in_features))[:max_samples]
|
|
85
|
+
in_features = op.convert_to_numpy(op.flatten(in_features[0]))[:max_samples]
|
|
86
86
|
|
|
87
87
|
# labels
|
|
88
88
|
in_labels = []
|
|
@@ -95,7 +95,7 @@ def plotly_3D_features(
|
|
|
95
95
|
if out_dataset is not None:
|
|
96
96
|
# features
|
|
97
97
|
out_features, _ = feature_extractor.predict(out_dataset)
|
|
98
|
-
out_features = op.convert_to_numpy(op.flatten(out_features))[:max_samples]
|
|
98
|
+
out_features = op.convert_to_numpy(op.flatten(out_features[0]))[:max_samples]
|
|
99
99
|
|
|
100
100
|
# labels
|
|
101
101
|
out_labels = np.array(["unknown"] * len(out_features))
|