oodeel 0.1.1__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of oodeel might be problematic. Click here for more details.
- oodeel/__init__.py +1 -1
- oodeel/datasets/__init__.py +2 -1
- oodeel/datasets/data_handler.py +162 -94
- oodeel/datasets/deprecated/DEPRECATED_data_handler.py +236 -0
- oodeel/datasets/{ooddataset.py → deprecated/DEPRECATED_ooddataset.py} +14 -13
- oodeel/datasets/deprecated/DEPRECATED_tf_data_handler.py +671 -0
- oodeel/datasets/deprecated/DEPRECATED_torch_data_handler.py +769 -0
- oodeel/datasets/deprecated/__init__.py +31 -0
- oodeel/datasets/tf_data_handler.py +105 -167
- oodeel/datasets/torch_data_handler.py +109 -181
- oodeel/eval/metrics.py +7 -2
- oodeel/eval/plots/features.py +2 -2
- oodeel/eval/plots/plotly.py +2 -2
- oodeel/extractor/feature_extractor.py +30 -9
- oodeel/extractor/keras_feature_extractor.py +70 -13
- oodeel/extractor/torch_feature_extractor.py +120 -33
- oodeel/methods/__init__.py +17 -1
- oodeel/methods/base.py +103 -17
- oodeel/methods/dknn.py +22 -9
- oodeel/methods/energy.py +8 -0
- oodeel/methods/entropy.py +8 -0
- oodeel/methods/gen.py +118 -0
- oodeel/methods/gram.py +307 -0
- oodeel/methods/mahalanobis.py +14 -12
- oodeel/methods/mls.py +8 -0
- oodeel/methods/odin.py +8 -0
- oodeel/methods/rmds.py +122 -0
- oodeel/methods/she.py +197 -0
- oodeel/methods/vim.py +5 -5
- oodeel/preprocess/__init__.py +31 -0
- oodeel/preprocess/tf_preprocess.py +95 -0
- oodeel/preprocess/torch_preprocess.py +97 -0
- oodeel/utils/operator.py +72 -2
- oodeel/utils/tf_operator.py +72 -4
- oodeel/utils/tf_training_tools.py +26 -3
- oodeel/utils/torch_operator.py +75 -4
- oodeel/utils/torch_training_tools.py +31 -2
- {oodeel-0.1.1.dist-info → oodeel-0.3.0.dist-info}/METADATA +141 -107
- oodeel-0.3.0.dist-info/RECORD +57 -0
- {oodeel-0.1.1.dist-info → oodeel-0.3.0.dist-info}/WHEEL +1 -1
- tests/tests_tensorflow/tf_methods_utils.py +2 -1
- tests/tests_torch/tools_torch.py +9 -9
- tests/tests_torch/torch_methods_utils.py +34 -27
- tests/tools_operator.py +10 -1
- oodeel-0.1.1.dist-info/RECORD +0 -46
- {oodeel-0.1.1.dist-info → oodeel-0.3.0.dist-info/licenses}/LICENSE +0 -0
- {oodeel-0.1.1.dist-info → oodeel-0.3.0.dist-info}/top_level.txt +0 -0
|
@@ -22,11 +22,11 @@
|
|
|
22
22
|
# SOFTWARE.
|
|
23
23
|
import numpy as np
|
|
24
24
|
|
|
25
|
-
from
|
|
26
|
-
from
|
|
27
|
-
from
|
|
28
|
-
from
|
|
29
|
-
from
|
|
25
|
+
from ...types import Callable
|
|
26
|
+
from ...types import DatasetType
|
|
27
|
+
from ...types import Optional
|
|
28
|
+
from ...types import Tuple
|
|
29
|
+
from ...types import Union
|
|
30
30
|
|
|
31
31
|
|
|
32
32
|
class OODDataset(object):
|
|
@@ -61,6 +61,7 @@ class OODDataset(object):
|
|
|
61
61
|
load_from_tensorflow_datasets: bool = False,
|
|
62
62
|
input_key: Optional[str] = None,
|
|
63
63
|
):
|
|
64
|
+
|
|
64
65
|
self.backend = backend
|
|
65
66
|
self.load_from_tensorflow_datasets = load_from_tensorflow_datasets
|
|
66
67
|
|
|
@@ -74,19 +75,19 @@ class OODDataset(object):
|
|
|
74
75
|
# Set the channel order depending on the backend
|
|
75
76
|
if self.backend == "torch":
|
|
76
77
|
if load_from_tensorflow_datasets:
|
|
77
|
-
from .
|
|
78
|
+
from .DEPRECATED_tf_data_handler import TFDataHandler
|
|
78
79
|
import tensorflow as tf
|
|
79
80
|
|
|
80
81
|
tf.config.set_visible_devices([], "GPU")
|
|
81
82
|
self._data_handler = TFDataHandler()
|
|
82
83
|
load_kwargs["as_supervised"] = False
|
|
83
84
|
else:
|
|
84
|
-
from .
|
|
85
|
+
from .DEPRECATED_torch_data_handler import TorchDataHandler
|
|
85
86
|
|
|
86
87
|
self._data_handler = TorchDataHandler()
|
|
87
88
|
self.channel_order = "channels_first"
|
|
88
89
|
else:
|
|
89
|
-
from .
|
|
90
|
+
from .DEPRECATED_tf_data_handler import TFDataHandler
|
|
90
91
|
|
|
91
92
|
self._data_handler = TFDataHandler()
|
|
92
93
|
self.channel_order = "channels_last"
|
|
@@ -265,7 +266,7 @@ class OODDataset(object):
|
|
|
265
266
|
with_ood_labels: bool = False,
|
|
266
267
|
with_labels: bool = True,
|
|
267
268
|
shuffle: bool = False,
|
|
268
|
-
|
|
269
|
+
**kwargs_prepare,
|
|
269
270
|
) -> DatasetType:
|
|
270
271
|
"""Prepare self.data for scoring or training
|
|
271
272
|
|
|
@@ -282,9 +283,9 @@ class OODDataset(object):
|
|
|
282
283
|
Defaults to True.
|
|
283
284
|
shuffle (bool, optional): To shuffle the returned dataset or not.
|
|
284
285
|
Defaults to False.
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
286
|
+
kwargs_prepare (dict): Additional parameters to be passed to the
|
|
287
|
+
data_handler.prepare_for_training method.
|
|
288
|
+
|
|
288
289
|
|
|
289
290
|
Returns:
|
|
290
291
|
DatasetType: prepared dataset
|
|
@@ -323,7 +324,7 @@ class OODDataset(object):
|
|
|
323
324
|
preprocess_fn=preprocess_fn,
|
|
324
325
|
augment_fn=augment_fn,
|
|
325
326
|
output_keys=keys,
|
|
326
|
-
|
|
327
|
+
**kwargs_prepare,
|
|
327
328
|
)
|
|
328
329
|
|
|
329
330
|
return dataset
|