oodeel 0.1.1__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of oodeel might be problematic. Click here for more details.

Files changed (47) hide show
  1. oodeel/__init__.py +1 -1
  2. oodeel/datasets/__init__.py +2 -1
  3. oodeel/datasets/data_handler.py +162 -94
  4. oodeel/datasets/deprecated/DEPRECATED_data_handler.py +236 -0
  5. oodeel/datasets/{ooddataset.py → deprecated/DEPRECATED_ooddataset.py} +14 -13
  6. oodeel/datasets/deprecated/DEPRECATED_tf_data_handler.py +671 -0
  7. oodeel/datasets/deprecated/DEPRECATED_torch_data_handler.py +769 -0
  8. oodeel/datasets/deprecated/__init__.py +31 -0
  9. oodeel/datasets/tf_data_handler.py +105 -167
  10. oodeel/datasets/torch_data_handler.py +109 -181
  11. oodeel/eval/metrics.py +7 -2
  12. oodeel/eval/plots/features.py +2 -2
  13. oodeel/eval/plots/plotly.py +2 -2
  14. oodeel/extractor/feature_extractor.py +30 -9
  15. oodeel/extractor/keras_feature_extractor.py +70 -13
  16. oodeel/extractor/torch_feature_extractor.py +120 -33
  17. oodeel/methods/__init__.py +17 -1
  18. oodeel/methods/base.py +103 -17
  19. oodeel/methods/dknn.py +22 -9
  20. oodeel/methods/energy.py +8 -0
  21. oodeel/methods/entropy.py +8 -0
  22. oodeel/methods/gen.py +118 -0
  23. oodeel/methods/gram.py +307 -0
  24. oodeel/methods/mahalanobis.py +14 -12
  25. oodeel/methods/mls.py +8 -0
  26. oodeel/methods/odin.py +8 -0
  27. oodeel/methods/rmds.py +122 -0
  28. oodeel/methods/she.py +197 -0
  29. oodeel/methods/vim.py +5 -5
  30. oodeel/preprocess/__init__.py +31 -0
  31. oodeel/preprocess/tf_preprocess.py +95 -0
  32. oodeel/preprocess/torch_preprocess.py +97 -0
  33. oodeel/utils/operator.py +72 -2
  34. oodeel/utils/tf_operator.py +72 -4
  35. oodeel/utils/tf_training_tools.py +26 -3
  36. oodeel/utils/torch_operator.py +75 -4
  37. oodeel/utils/torch_training_tools.py +31 -2
  38. {oodeel-0.1.1.dist-info → oodeel-0.3.0.dist-info}/METADATA +141 -107
  39. oodeel-0.3.0.dist-info/RECORD +57 -0
  40. {oodeel-0.1.1.dist-info → oodeel-0.3.0.dist-info}/WHEEL +1 -1
  41. tests/tests_tensorflow/tf_methods_utils.py +2 -1
  42. tests/tests_torch/tools_torch.py +9 -9
  43. tests/tests_torch/torch_methods_utils.py +34 -27
  44. tests/tools_operator.py +10 -1
  45. oodeel-0.1.1.dist-info/RECORD +0 -46
  46. {oodeel-0.1.1.dist-info → oodeel-0.3.0.dist-info/licenses}/LICENSE +0 -0
  47. {oodeel-0.1.1.dist-info → oodeel-0.3.0.dist-info}/top_level.txt +0 -0
@@ -22,11 +22,11 @@
22
22
  # SOFTWARE.
23
23
  import numpy as np
24
24
 
25
- from ..types import Callable
26
- from ..types import DatasetType
27
- from ..types import Optional
28
- from ..types import Tuple
29
- from ..types import Union
25
+ from ...types import Callable
26
+ from ...types import DatasetType
27
+ from ...types import Optional
28
+ from ...types import Tuple
29
+ from ...types import Union
30
30
 
31
31
 
32
32
  class OODDataset(object):
@@ -61,6 +61,7 @@ class OODDataset(object):
61
61
  load_from_tensorflow_datasets: bool = False,
62
62
  input_key: Optional[str] = None,
63
63
  ):
64
+
64
65
  self.backend = backend
65
66
  self.load_from_tensorflow_datasets = load_from_tensorflow_datasets
66
67
 
@@ -74,19 +75,19 @@ class OODDataset(object):
74
75
  # Set the channel order depending on the backend
75
76
  if self.backend == "torch":
76
77
  if load_from_tensorflow_datasets:
77
- from .tf_data_handler import TFDataHandler
78
+ from .DEPRECATED_tf_data_handler import TFDataHandler
78
79
  import tensorflow as tf
79
80
 
80
81
  tf.config.set_visible_devices([], "GPU")
81
82
  self._data_handler = TFDataHandler()
82
83
  load_kwargs["as_supervised"] = False
83
84
  else:
84
- from .torch_data_handler import TorchDataHandler
85
+ from .DEPRECATED_torch_data_handler import TorchDataHandler
85
86
 
86
87
  self._data_handler = TorchDataHandler()
87
88
  self.channel_order = "channels_first"
88
89
  else:
89
- from .tf_data_handler import TFDataHandler
90
+ from .DEPRECATED_tf_data_handler import TFDataHandler
90
91
 
91
92
  self._data_handler = TFDataHandler()
92
93
  self.channel_order = "channels_last"
@@ -265,7 +266,7 @@ class OODDataset(object):
265
266
  with_ood_labels: bool = False,
266
267
  with_labels: bool = True,
267
268
  shuffle: bool = False,
268
- shuffle_buffer_size: Optional[int] = None,
269
+ **kwargs_prepare,
269
270
  ) -> DatasetType:
270
271
  """Prepare self.data for scoring or training
271
272
 
@@ -282,9 +283,9 @@ class OODDataset(object):
282
283
  Defaults to True.
283
284
  shuffle (bool, optional): To shuffle the returned dataset or not.
284
285
  Defaults to False.
285
- shuffle_buffer_size (int, optional): (TF only) Size of the shuffle buffer.
286
- If None, taken as the number of samples in the dataset.
287
- Defaults to None.
286
+ kwargs_prepare (dict): Additional parameters to be passed to the
287
+ data_handler.prepare_for_training method.
288
+
288
289
 
289
290
  Returns:
290
291
  DatasetType: prepared dataset
@@ -323,7 +324,7 @@ class OODDataset(object):
323
324
  preprocess_fn=preprocess_fn,
324
325
  augment_fn=augment_fn,
325
326
  output_keys=keys,
326
- shuffle_buffer_size=shuffle_buffer_size,
327
+ **kwargs_prepare,
327
328
  )
328
329
 
329
330
  return dataset