oodeel 0.1.1__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of oodeel might be problematic. Click here for more details.

Files changed (47) hide show
  1. oodeel/__init__.py +1 -1
  2. oodeel/datasets/__init__.py +2 -1
  3. oodeel/datasets/data_handler.py +162 -94
  4. oodeel/datasets/deprecated/DEPRECATED_data_handler.py +236 -0
  5. oodeel/datasets/{ooddataset.py → deprecated/DEPRECATED_ooddataset.py} +14 -13
  6. oodeel/datasets/deprecated/DEPRECATED_tf_data_handler.py +671 -0
  7. oodeel/datasets/deprecated/DEPRECATED_torch_data_handler.py +769 -0
  8. oodeel/datasets/deprecated/__init__.py +31 -0
  9. oodeel/datasets/tf_data_handler.py +105 -167
  10. oodeel/datasets/torch_data_handler.py +109 -181
  11. oodeel/eval/metrics.py +7 -2
  12. oodeel/eval/plots/features.py +2 -2
  13. oodeel/eval/plots/plotly.py +2 -2
  14. oodeel/extractor/feature_extractor.py +30 -9
  15. oodeel/extractor/keras_feature_extractor.py +70 -13
  16. oodeel/extractor/torch_feature_extractor.py +120 -33
  17. oodeel/methods/__init__.py +17 -1
  18. oodeel/methods/base.py +103 -17
  19. oodeel/methods/dknn.py +22 -9
  20. oodeel/methods/energy.py +8 -0
  21. oodeel/methods/entropy.py +8 -0
  22. oodeel/methods/gen.py +118 -0
  23. oodeel/methods/gram.py +307 -0
  24. oodeel/methods/mahalanobis.py +14 -12
  25. oodeel/methods/mls.py +8 -0
  26. oodeel/methods/odin.py +8 -0
  27. oodeel/methods/rmds.py +122 -0
  28. oodeel/methods/she.py +197 -0
  29. oodeel/methods/vim.py +5 -5
  30. oodeel/preprocess/__init__.py +31 -0
  31. oodeel/preprocess/tf_preprocess.py +95 -0
  32. oodeel/preprocess/torch_preprocess.py +97 -0
  33. oodeel/utils/operator.py +72 -2
  34. oodeel/utils/tf_operator.py +72 -4
  35. oodeel/utils/tf_training_tools.py +26 -3
  36. oodeel/utils/torch_operator.py +75 -4
  37. oodeel/utils/torch_training_tools.py +31 -2
  38. {oodeel-0.1.1.dist-info → oodeel-0.3.0.dist-info}/METADATA +141 -107
  39. oodeel-0.3.0.dist-info/RECORD +57 -0
  40. {oodeel-0.1.1.dist-info → oodeel-0.3.0.dist-info}/WHEEL +1 -1
  41. tests/tests_tensorflow/tf_methods_utils.py +2 -1
  42. tests/tests_torch/tools_torch.py +9 -9
  43. tests/tests_torch/torch_methods_utils.py +34 -27
  44. tests/tools_operator.py +10 -1
  45. oodeel-0.1.1.dist-info/RECORD +0 -46
  46. {oodeel-0.1.1.dist-info → oodeel-0.3.0.dist-info/licenses}/LICENSE +0 -0
  47. {oodeel-0.1.1.dist-info → oodeel-0.3.0.dist-info}/top_level.txt +0 -0
@@ -46,9 +46,10 @@ from .data_handler import DataHandler
46
46
 
47
47
 
48
48
  def dict_only_ds(ds_handling_method: Callable) -> Callable:
49
- """Decorator to ensure that the dataset is a dict dataset and that the input key
50
- matches one of the feature keys. The signature of decorated functions
51
- must be function(dataset, *args, **kwargs) with feature_key either in kwargs or
49
+ """Decorator to ensure that the dataset is a dict dataset and that the column_name
50
+ given as argument matches one of the column names.
51
+ matches one of the column names. The signature of decorated functions
52
+ must be function(dataset, *args, **kwargs) with column_name either in kwargs or
52
53
  args[0] when relevant.
53
54
 
54
55
 
@@ -64,19 +65,19 @@ def dict_only_ds(ds_handling_method: Callable) -> Callable:
64
65
  dataset, DictDataset
65
66
  ), "Dataset must be an instance of DictDataset"
66
67
 
67
- if "feature_key" in kwargs:
68
- feature_key = kwargs["feature_key"]
68
+ if "column_name" in kwargs:
69
+ column_name = kwargs["column_name"]
69
70
  elif len(args) > 0:
70
- feature_key = args[0]
71
+ column_name = args[0]
71
72
 
72
- # If feature_key is provided, check that it is in the dataset feature keys
73
- if (len(args) > 0) or ("feature_key" in kwargs):
74
- if isinstance(feature_key, str):
75
- feature_key = [feature_key]
76
- for key in feature_key:
73
+ # If column_name is provided, check that it is in the dataset column names
74
+ if (len(args) > 0) or ("column_name" in kwargs):
75
+ if isinstance(column_name, str):
76
+ column_name = [column_name]
77
+ for name in column_name:
77
78
  assert (
78
- key in dataset.output_keys
79
- ), f"The input dataset has no feature names {key}"
79
+ name in dataset.columns
80
+ ), f"The input dataset has no column named {name}"
80
81
  return ds_handling_method(dataset, *args, **kwargs)
81
82
 
82
83
  return wrapper
@@ -108,23 +109,23 @@ class DictDataset(Dataset):
108
109
 
109
110
  Args:
110
111
  dataset (Dataset): Dataset to wrap.
111
- output_keys (output_keys[str]): Keys describing the output tensors.
112
+ columns (columns[str]): Column names describing the output tensors.
112
113
  """
113
114
 
114
115
  def __init__(
115
- self, dataset: Dataset, output_keys: List[str] = ["input", "label"]
116
+ self, dataset: Dataset, columns: List[str] = ["input", "label"]
116
117
  ) -> None:
117
118
  self._dataset = dataset
118
- self._raw_output_keys = output_keys
119
+ self._raw_columns = columns
119
120
  self.map_fns = []
120
121
  self._check_init_args()
121
122
 
122
123
  @property
123
- def output_keys(self) -> list:
124
- """Get the list of keys in a dict-based item from the dataset.
124
+ def columns(self) -> list:
125
+ """Get the list of columns in a dict-based item from the dataset.
125
126
 
126
127
  Returns:
127
- list: feature keys of the dataset.
128
+ list: column names of the dataset.
128
129
  """
129
130
  dummy_item = self[0]
130
131
  return list(dummy_item.keys())
@@ -137,10 +138,10 @@ class DictDataset(Dataset):
137
138
  list: tensor shapes of an dataset item.
138
139
  """
139
140
  dummy_item = self[0]
140
- return [dummy_item[key].shape for key in self.output_keys]
141
+ return [dummy_item[key].shape for key in self.columns]
141
142
 
142
143
  def _check_init_args(self) -> None:
143
- """Check validity of dataset and output keys provided at init"""
144
+ """Check validity of dataset and column names provided at init"""
144
145
  dummy_item = self._dataset[0]
145
146
  assert isinstance(
146
147
  dummy_item, (tuple, dict, list, torch.Tensor)
@@ -148,8 +149,8 @@ class DictDataset(Dataset):
148
149
  if isinstance(dummy_item, torch.Tensor):
149
150
  dummy_item = [dummy_item]
150
151
  assert len(dummy_item) == len(
151
- self._raw_output_keys
152
- ), "Length mismatch between dataset item and provided keys"
152
+ self._raw_columns
153
+ ), "Length mismatch between dataset item and provided column names"
153
154
 
154
155
  def __getitem__(self, index: int) -> dict:
155
156
  """Return a dictionary of tensors corresponding to a specfic index.
@@ -171,9 +172,7 @@ class DictDataset(Dataset):
171
172
  tensors = item
172
173
 
173
174
  # build output dictionary
174
- output_dict = {
175
- key: tensor for (key, tensor) in zip(self._raw_output_keys, tensors)
176
- }
175
+ output_dict = {key: tensor for (key, tensor) in zip(self._raw_columns, tensors)}
177
176
 
178
177
  # apply map functions
179
178
  for map_fn in self.map_fns:
@@ -228,18 +227,16 @@ class DictDataset(Dataset):
228
227
  other_dataset, DictDataset
229
228
  ), "Second dataset should be an instance of DictDataset"
230
229
  assert (
231
- self.output_keys == other_dataset.output_keys
232
- ), "Incompatible dataset elements (different dict keys)"
230
+ self.columns == other_dataset.columns
231
+ ), "Incompatible dataset elements (different column names)"
233
232
  if inplace:
234
233
  dataset_copy = copy.deepcopy(self)
235
- self._raw_output_keys = self.output_keys
234
+ self._raw_columns = self.columns
236
235
  self.map_fns = []
237
236
  self._dataset = ConcatDataset([dataset_copy, other_dataset])
238
237
  dataset = self
239
238
  else:
240
- dataset = DictDataset(
241
- ConcatDataset([self, other_dataset]), self.output_keys
242
- )
239
+ dataset = DictDataset(ConcatDataset([self, other_dataset]), self.columns)
243
240
  return dataset
244
241
 
245
242
  def __len__(self) -> int:
@@ -258,6 +255,11 @@ class TorchDataHandler(DataHandler):
258
255
  torch syntax.
259
256
  """
260
257
 
258
+ def __init__(self) -> None:
259
+ super().__init__()
260
+ self.backend = "torch"
261
+ self.channel_order = "channels_first"
262
+
261
263
  @staticmethod
262
264
  def _default_target_transform(y: Any) -> torch.Tensor:
263
265
  """Format int or float item target as a torch tensor
@@ -277,14 +279,16 @@ class TorchDataHandler(DataHandler):
277
279
  def load_dataset(
278
280
  cls,
279
281
  dataset_id: Union[Dataset, ItemType, str],
280
- keys: Optional[list] = None,
282
+ columns: Optional[list] = None,
281
283
  load_kwargs: dict = {},
282
284
  ) -> DictDataset:
283
285
  """Load dataset from different manners
284
286
 
285
287
  Args:
286
- dataset_id (Union[Dataset, ItemType, str]): dataset identification
287
- keys (list, optional): Features keys. If None, assigned as "input_i"
288
+ dataset_id (Union[Dataset, ItemType, str]): dataset identification.
289
+ Can be the name of a dataset from torchvision, a torch Dataset,
290
+ or a tuple/dict of np.ndarrays/torch tensors.
291
+ columns (list, optional): Column names. If None, assigned as "input_i"
288
292
  for i-th feature. Defaults to None.
289
293
  load_kwargs (dict, optional): Additional loading kwargs. Defaults to {}.
290
294
 
@@ -295,23 +299,23 @@ class TorchDataHandler(DataHandler):
295
299
  assert "root" in load_kwargs.keys()
296
300
  dataset = cls.load_from_torchvision(dataset_id, **load_kwargs)
297
301
  elif isinstance(dataset_id, Dataset):
298
- dataset = cls.load_custom_dataset(dataset_id, keys)
302
+ dataset = cls.load_custom_dataset(dataset_id, columns)
299
303
  elif isinstance(dataset_id, get_args(ItemType)):
300
- dataset = cls.load_dataset_from_arrays(dataset_id, keys)
304
+ dataset = cls.load_dataset_from_arrays(dataset_id, columns)
301
305
  return dataset
302
306
 
303
307
  @staticmethod
304
308
  def load_dataset_from_arrays(
305
309
  dataset_id: ItemType,
306
- keys: Optional[list] = None,
310
+ columns: Optional[list] = None,
307
311
  ) -> DictDataset:
308
312
  """Load a torch.utils.data.Dataset from an array or a tuple/dict of arrays.
309
313
 
310
314
  Args:
311
315
  dataset_id (ItemType):
312
316
  numpy / torch array(s) to load.
313
- keys (list, optional): Features keys. If None, assigned as "input_i"
314
- for i-th feature. Defaults to None.
317
+ columns (list, optional): Column names to assign. If None,
318
+ assigned as "input_i" for i-th feature. Defaults to None.
315
319
 
316
320
  Returns:
317
321
  DictDataset: dataset
@@ -319,47 +323,45 @@ class TorchDataHandler(DataHandler):
319
323
  # If dataset_id is an array
320
324
  if isinstance(dataset_id, get_args(TensorType)):
321
325
  tensors = tuple(to_torch(dataset_id))
322
- output_keys = keys or ["input"]
326
+ columns = columns or ["input"]
323
327
 
324
328
  # If dataset_id is a tuple of arrays
325
329
  elif isinstance(dataset_id, tuple):
326
330
  len_elem = len(dataset_id)
327
- output_keys = keys
328
- if output_keys is None:
331
+ if columns is None:
329
332
  if len_elem == 2:
330
- output_keys = ["input", "label"]
333
+ columns = ["input", "label"]
331
334
  else:
332
- output_keys = [f"input_{i}" for i in range(len_elem - 1)] + [
333
- "label"
334
- ]
335
+ columns = [f"input_{i}" for i in range(len_elem - 1)] + ["label"]
335
336
  print(
336
337
  "Loading torch.utils.data.Dataset with elems as dicts, "
337
338
  'assigning "input_i" key to the i-th tuple dimension and'
338
339
  ' "label" key to the last tuple dimension.'
339
340
  )
340
- assert len(output_keys) == len(dataset_id)
341
+ assert len(columns) == len(dataset_id)
341
342
  tensors = tuple(to_torch(array) for array in dataset_id)
342
343
 
343
344
  # If dataset_id is a dictionary of arrays
344
345
  elif isinstance(dataset_id, dict):
345
- output_keys = keys or list(dataset_id.keys())
346
- assert len(output_keys) == len(dataset_id)
346
+ columns = columns or list(dataset_id.keys())
347
+ assert len(columns) == len(dataset_id)
347
348
  tensors = tuple(to_torch(array) for array in dataset_id.values())
348
349
 
349
- # create torch dictionary dataset from tensors tuple and keys
350
- dataset = DictDataset(TensorDataset(*tensors), output_keys)
350
+ # create torch dictionary dataset from tensors tuple and columns
351
+ dataset = DictDataset(TensorDataset(*tensors), columns)
351
352
  return dataset
352
353
 
353
354
  @staticmethod
354
355
  def load_custom_dataset(
355
- dataset_id: Dataset, keys: Optional[list] = None
356
+ dataset_id: Dataset, columns: Optional[list] = None
356
357
  ) -> DictDataset:
357
358
  """Load a custom Dataset by ensuring it has the correct format (dict-based)
358
359
 
359
360
  Args:
360
361
  dataset_id (Dataset): Dataset
361
- keys (list, optional): Keys to use for features if dataset_id is
362
- tuple based. Defaults to None.
362
+ columns (list, optional): Column names to use for elements if dataset_id is
363
+ tuple based. If None, assigned as "input_i"
364
+ for i-th column. Defaults to None.
363
365
 
364
366
  Returns:
365
367
  DictDataset
@@ -370,20 +372,17 @@ class TorchDataHandler(DataHandler):
370
372
  assert isinstance(
371
373
  dummy_item, (Tuple, torch.Tensor)
372
374
  ), "Custom dataset should be either dictionary based or tuple based"
373
- output_keys = keys
374
- if output_keys is None:
375
+ if columns is None:
375
376
  len_elem = len(dummy_item)
376
377
  if len_elem == 2:
377
- output_keys = ["input", "label"]
378
+ columns = ["input", "label"]
378
379
  else:
379
- output_keys = [f"input_{i}" for i in range(len_elem - 1)] + [
380
- "label"
381
- ]
380
+ columns = [f"input_{i}" for i in range(len_elem - 1)] + ["label"]
382
381
  print(
383
382
  "Feature name not found, assigning 'input_i' "
384
383
  "key to the i-th tensor and 'label' key to the last"
385
384
  )
386
- dataset_id = DictDataset(dataset_id, output_keys)
385
+ dataset_id = DictDataset(dataset_id, columns)
387
386
 
388
387
  dataset = dataset_id
389
388
  return dataset
@@ -428,81 +427,18 @@ class TorchDataHandler(DataHandler):
428
427
  )
429
428
  return cls.load_custom_dataset(dataset)
430
429
 
431
- @staticmethod
432
- def assign_feature_value(
433
- dataset: DictDataset, feature_key: str, value: int
434
- ) -> DictDataset:
435
- """Assign a value to a feature for every sample in a DictDataset
436
-
437
- Args:
438
- dataset (DictDataset): DictDataset to assign the value to
439
- feature_key (str): Feature to assign the value to
440
- value (int): Value to assign
441
-
442
- Returns:
443
- DictDataset
444
- """
445
- assert isinstance(
446
- dataset, DictDataset
447
- ), "Dataset must be an instance of DictDataset"
448
-
449
- def assign_value_to_feature(x):
450
- x[feature_key] = torch.tensor(value)
451
- return x
452
-
453
- dataset = dataset.map(assign_value_to_feature)
454
- return dataset
455
-
456
430
  @staticmethod
457
431
  @dict_only_ds
458
- def get_feature_from_ds(dataset: DictDataset, feature_key: str) -> np.ndarray:
459
- """Get a feature from a DictDataset
460
-
461
- !!! note
462
- This function can be a bit time consuming since it needs to iterate
463
- over the whole dataset.
432
+ def get_ds_column_names(dataset: DictDataset) -> list:
433
+ """Get the column names of a DictDataset
464
434
 
465
435
  Args:
466
- dataset (DictDataset): Dataset to get the feature from
467
- feature_key (str): Feature value to get
436
+ dataset (DictDataset): Dataset to get the column names from
468
437
 
469
438
  Returns:
470
- np.ndarray: Feature values for dataset
439
+ list: List of column names
471
440
  """
472
-
473
- features = dataset.map(lambda x: x[feature_key])
474
- features = np.stack([f.numpy() for f in features])
475
- return features
476
-
477
- @staticmethod
478
- @dict_only_ds
479
- def get_ds_feature_keys(dataset: DictDataset) -> list:
480
- """Get the feature keys of a DictDataset
481
-
482
- Args:
483
- dataset (DictDataset): Dataset to get the feature keys from
484
-
485
- Returns:
486
- list: List of feature keys
487
- """
488
- return dataset.output_keys
489
-
490
- @staticmethod
491
- def has_feature_key(dataset: DictDataset, key: str) -> bool:
492
- """Check if a DictDataset has a feature denoted by key
493
-
494
- Args:
495
- dataset (DictDataset): Dataset to check
496
- key (str): Key to check
497
-
498
- Returns:
499
- bool: If the dataset has a feature denoted by key
500
- """
501
- assert isinstance(
502
- dataset, DictDataset
503
- ), "Dataset must be an instance of DictDataset"
504
-
505
- return key in dataset.output_keys
441
+ return dataset.columns
506
442
 
507
443
  @staticmethod
508
444
  def map_ds(
@@ -522,13 +458,13 @@ class TorchDataHandler(DataHandler):
522
458
 
523
459
  @staticmethod
524
460
  @dict_only_ds
525
- def filter_by_feature_value(
461
+ def filter_by_value(
526
462
  dataset: DictDataset,
527
- feature_key: str,
463
+ column_name: str,
528
464
  values: list,
529
465
  excluded: bool = False,
530
466
  ) -> DictDataset:
531
- """Filter the dataset by checking the value of a feature is in `values`
467
+ """Filter the dataset by checking if the value of a column is in `values`
532
468
 
533
469
  !!! note
534
470
  This function can be a bit of time consuming since it needs to iterate
@@ -536,62 +472,64 @@ class TorchDataHandler(DataHandler):
536
472
 
537
473
  Args:
538
474
  dataset (DictDataset): Dataset to filter
539
- feature_key (str): Feature name to check the value
540
- values (list): Feature_key values to keep
475
+ column_name (str): Column to filter the dataset with
476
+ values (list): Column values to keep
541
477
  excluded (bool, optional): To keep (False) or exclude (True) the samples
542
- with Feature_key value included in Values. Defaults to False.
478
+ with column values included in Values. Defaults to False.
543
479
 
544
480
  Returns:
545
481
  DictDataset: Filtered dataset
546
482
  """
547
483
 
548
- if len(dataset[0][feature_key].shape) > 0:
549
- value_dim = dataset[0][feature_key].shape[-1]
484
+ if len(dataset[0][column_name].shape) > 0:
485
+ value_dim = dataset[0][column_name].shape[-1]
550
486
  values = [
551
487
  F.one_hot(torch.tensor(value).long(), value_dim) for value in values
552
488
  ]
553
489
 
554
490
  def filter_fn(x):
555
- keep = any([torch.all(x[feature_key] == v) for v in values])
491
+ keep = any([torch.all(x[column_name] == v) for v in values])
556
492
  return keep if not excluded else not keep
557
493
 
558
494
  filtered_dataset = dataset.filter(filter_fn)
559
495
  return filtered_dataset
560
496
 
561
497
  @classmethod
562
- def prepare_for_training(
498
+ def prepare(
563
499
  cls,
564
500
  dataset: DictDataset,
565
501
  batch_size: int,
566
- shuffle: bool = False,
567
502
  preprocess_fn: Optional[Callable] = None,
568
503
  augment_fn: Optional[Callable] = None,
569
- output_keys: Optional[list] = None,
570
- dict_based_fns: bool = False,
571
- shuffle_buffer_size: Optional[int] = None,
504
+ columns: Optional[list] = None,
505
+ shuffle: bool = False,
506
+ dict_based_fns: bool = True,
507
+ return_tuple: bool = True,
508
+ num_workers: int = 8,
572
509
  ) -> DataLoader:
573
510
  """Prepare a DataLoader for training
574
511
 
575
512
  Args:
576
513
  dataset (DictDataset): Dataset to prepare
577
514
  batch_size (int): Batch size
578
- shuffle (bool): Wether to shuffle the dataloader or not
579
515
  preprocess_fn (Callable, optional): Preprocessing function to apply to
580
516
  the dataset. Defaults to None.
581
517
  augment_fn (Callable, optional): Augment function to be used (when the
582
518
  returned dataset is to be used for training). Defaults to None.
583
- output_keys (list): List of keys corresponding to the features that will be
584
- returned. Keep all features if None. Defaults to None.
519
+ columns (list, optional): List of column names corresponding to the columns
520
+ that will be returned. Keep all features if None. Defaults to None.
521
+ shuffle (bool, optional): To shuffle the returned dataset or not.
522
+ Defaults to False.
585
523
  dict_based_fns (bool): Whether to use preprocess and DA functions as dict
586
- based (if True) or as tuple based (if False). Defaults to False.
587
- shuffle_buffer_size (int, optional): Size of the shuffle buffer. Not used
588
- in torch because we only rely on Map-Style datasets. Still as argument
589
- for API consistency. Defaults to None.
524
+ based (if True) or as tuple based (if False). Defaults to True.
525
+ return_tuple (bool, optional): Whether to return each dataset item
526
+ as a tuple. Defaults to True.
527
+ num_workers (int, optional): Number of workers to use for the dataloader.
590
528
 
591
529
  Returns:
592
530
  DataLoader: dataloader
593
531
  """
594
- output_keys = output_keys or cls.get_ds_feature_keys(dataset)
532
+ columns = columns or cls.get_ds_column_names(dataset)
595
533
 
596
534
  def collate_fn(batch: List[dict]):
597
535
  if dict_based_fns:
@@ -599,18 +537,20 @@ class TorchDataHandler(DataHandler):
599
537
  preprocess_func = preprocess_fn or (lambda x: x)
600
538
  augment_func = augment_fn or (lambda x: x)
601
539
  batch = [augment_func(preprocess_func(d)) for d in batch]
602
- # to tuple of batchs
603
- return tuple(
604
- default_collate([d[key] for d in batch]) for key in output_keys
605
- )
540
+ # to dict of batchs
541
+ if return_tuple:
542
+ return tuple(
543
+ default_collate([d[key] for d in batch]) for key in columns
544
+ )
545
+ return {
546
+ key: default_collate([d[key] for d in batch]) for key in columns
547
+ }
606
548
  else:
607
549
  # preprocess + DA: List[dict] -> List[tuple]
608
550
  preprocess_func = preprocess_fn or (lambda *x: x)
609
551
  augment_func = augment_fn or (lambda *x: x)
610
552
  batch = [
611
- augment_func(
612
- *preprocess_func(*tuple(d[key] for key in output_keys))
613
- )
553
+ augment_func(*preprocess_func(*tuple(d[key] for key in columns)))
614
554
  for d in batch
615
555
  ]
616
556
  # to tuple of batchs
@@ -621,6 +561,7 @@ class TorchDataHandler(DataHandler):
621
561
  batch_size=batch_size,
622
562
  shuffle=shuffle,
623
563
  collate_fn=collate_fn,
564
+ num_workers=num_workers,
624
565
  )
625
566
  return loader
626
567
 
@@ -697,17 +638,21 @@ class TorchDataHandler(DataHandler):
697
638
  return len(dataset)
698
639
 
699
640
  @staticmethod
700
- def get_feature_shape(dataset: Dataset, feature_key: Union[str, int]) -> tuple:
701
- """Get the shape of a feature of dataset identified by feature_key
641
+ def get_column_elements_shape(
642
+ dataset: Dataset, column_name: Union[str, int]
643
+ ) -> tuple:
644
+ """Get the shape of the elements of a column of dataset identified by
645
+ column_name
702
646
 
703
647
  Args:
704
648
  dataset (Dataset): a Dataset
705
- feature_key (Union[str, int]): The identifier of the feature
649
+ column_name (Union[str, int]): The column name to get
650
+ the element shape from.
706
651
 
707
652
  Returns:
708
- tuple: the shape of feature_id
653
+ tuple: the shape of an element from column_name
709
654
  """
710
- return tuple(dataset[0][feature_key].shape)
655
+ return tuple(dataset[0][column_name].shape)
711
656
 
712
657
  @staticmethod
713
658
  def get_input_from_dataset_item(elem: ItemType) -> Any:
@@ -747,20 +692,3 @@ class TorchDataHandler(DataHandler):
747
692
  if len(label.shape) > 1:
748
693
  label = label.view([label.shape[0]])
749
694
  return label
750
-
751
- @staticmethod
752
- def get_feature(dataset: DictDataset, feature_key: Union[str, int]) -> DictDataset:
753
- """Extract a feature from a dataset
754
-
755
- Args:
756
- dataset (tf.data.Dataset): Dataset to extract the feature from
757
- feature_key (Union[str, int]): feature to extract
758
-
759
- Returns:
760
- tf.data.Dataset: dataset built with the extracted feature only
761
- """
762
-
763
- def _get_feature_item(item):
764
- return item[feature_key]
765
-
766
- return dataset.map(_get_feature_item)
oodeel/eval/metrics.py CHANGED
@@ -54,9 +54,11 @@ def bench_metrics(
54
54
  out_value if different from their default values.
55
55
  Defaults to None.
56
56
  in_value (Optional[int], optional): ood label value for in-distribution data.
57
+ Automatically assigned 0 if it is not the case.
57
58
  Defaults to 0.
58
59
  out_value (Optional[int], optional): ood label value for out-of-distribution
59
- data. Defaults to 1.
60
+ data. Automatically assigned 1 if it is not the case.
61
+ Defaults to 1.
60
62
  metrics (Optional[List[str]], optional): list of metrics to compute. Can pass
61
63
  any metric name from sklearn.metric or among "detect_acc" and
62
64
  "<aaa><XX><bbb>" where <aaa> and <bbb> are in ["fpr", "tpr", "fnr", "tnr"]
@@ -89,7 +91,10 @@ def bench_metrics(
89
91
  for metric in metrics:
90
92
  if isinstance(metric, str):
91
93
  if metric == "auroc":
92
- auroc = -np.trapz(1.0 - fpr, tpr)
94
+ if np.__version__ >= "2.0.0":
95
+ auroc = -np.trapezoid(1.0 - fpr, tpr)
96
+ else:
97
+ auroc = -np.trapz(1.0 - fpr, tpr)
93
98
  metrics_dict["auroc"] = auroc
94
99
 
95
100
  elif metric == "detect_acc":
@@ -168,7 +168,7 @@ def _plot_features(
168
168
  # === extract id features ===
169
169
  # features
170
170
  in_features, _ = feature_extractor.predict(in_dataset)
171
- in_features = op.convert_to_numpy(op.flatten(in_features))[:max_samples]
171
+ in_features = op.convert_to_numpy(op.flatten(in_features[0]))[:max_samples]
172
172
 
173
173
  # labels
174
174
  in_labels = []
@@ -181,7 +181,7 @@ def _plot_features(
181
181
  if out_dataset is not None:
182
182
  # features
183
183
  out_features, _ = feature_extractor.predict(out_dataset)
184
- out_features = op.convert_to_numpy(op.flatten(out_features))[:max_samples]
184
+ out_features = op.convert_to_numpy(op.flatten(out_features[0]))[:max_samples]
185
185
 
186
186
  # labels
187
187
  out_labels_str = np.array(["unknown"] * len(out_features))
@@ -82,7 +82,7 @@ def plotly_3D_features(
82
82
  # === extract id features ===
83
83
  # features
84
84
  in_features, _ = feature_extractor.predict(in_dataset)
85
- in_features = op.convert_to_numpy(op.flatten(in_features))[:max_samples]
85
+ in_features = op.convert_to_numpy(op.flatten(in_features[0]))[:max_samples]
86
86
 
87
87
  # labels
88
88
  in_labels = []
@@ -95,7 +95,7 @@ def plotly_3D_features(
95
95
  if out_dataset is not None:
96
96
  # features
97
97
  out_features, _ = feature_extractor.predict(out_dataset)
98
- out_features = op.convert_to_numpy(op.flatten(out_features))[:max_samples]
98
+ out_features = op.convert_to_numpy(op.flatten(out_features[0]))[:max_samples]
99
99
 
100
100
  # labels
101
101
  out_labels = np.array(["unknown"] * len(out_features))