oodeel 0.1.1__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of oodeel might be problematic. Click here for more details.
- oodeel/__init__.py +1 -1
- oodeel/datasets/__init__.py +2 -1
- oodeel/datasets/data_handler.py +162 -94
- oodeel/datasets/deprecated/DEPRECATED_data_handler.py +236 -0
- oodeel/datasets/{ooddataset.py → deprecated/DEPRECATED_ooddataset.py} +14 -13
- oodeel/datasets/deprecated/DEPRECATED_tf_data_handler.py +671 -0
- oodeel/datasets/deprecated/DEPRECATED_torch_data_handler.py +769 -0
- oodeel/datasets/deprecated/__init__.py +31 -0
- oodeel/datasets/tf_data_handler.py +105 -167
- oodeel/datasets/torch_data_handler.py +109 -181
- oodeel/eval/metrics.py +7 -2
- oodeel/eval/plots/features.py +2 -2
- oodeel/eval/plots/plotly.py +2 -2
- oodeel/extractor/feature_extractor.py +30 -9
- oodeel/extractor/keras_feature_extractor.py +70 -13
- oodeel/extractor/torch_feature_extractor.py +120 -33
- oodeel/methods/__init__.py +17 -1
- oodeel/methods/base.py +103 -17
- oodeel/methods/dknn.py +22 -9
- oodeel/methods/energy.py +8 -0
- oodeel/methods/entropy.py +8 -0
- oodeel/methods/gen.py +118 -0
- oodeel/methods/gram.py +307 -0
- oodeel/methods/mahalanobis.py +14 -12
- oodeel/methods/mls.py +8 -0
- oodeel/methods/odin.py +8 -0
- oodeel/methods/rmds.py +122 -0
- oodeel/methods/she.py +197 -0
- oodeel/methods/vim.py +5 -5
- oodeel/preprocess/__init__.py +31 -0
- oodeel/preprocess/tf_preprocess.py +95 -0
- oodeel/preprocess/torch_preprocess.py +97 -0
- oodeel/utils/operator.py +72 -2
- oodeel/utils/tf_operator.py +72 -4
- oodeel/utils/tf_training_tools.py +26 -3
- oodeel/utils/torch_operator.py +75 -4
- oodeel/utils/torch_training_tools.py +31 -2
- {oodeel-0.1.1.dist-info → oodeel-0.3.0.dist-info}/METADATA +141 -107
- oodeel-0.3.0.dist-info/RECORD +57 -0
- {oodeel-0.1.1.dist-info → oodeel-0.3.0.dist-info}/WHEEL +1 -1
- tests/tests_tensorflow/tf_methods_utils.py +2 -1
- tests/tests_torch/tools_torch.py +9 -9
- tests/tests_torch/torch_methods_utils.py +34 -27
- tests/tools_operator.py +10 -1
- oodeel-0.1.1.dist-info/RECORD +0 -46
- {oodeel-0.1.1.dist-info → oodeel-0.3.0.dist-info/licenses}/LICENSE +0 -0
- {oodeel-0.1.1.dist-info → oodeel-0.3.0.dist-info}/top_level.txt +0 -0
|
@@ -35,6 +35,27 @@ from ..types import Optional
|
|
|
35
35
|
from ..types import Union
|
|
36
36
|
|
|
37
37
|
|
|
38
|
+
def get_toy_mlp(input_shape: tuple, num_classes: int) -> tf.keras.Model:
|
|
39
|
+
"""Basic keras MLP classifier for toy datasets.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
input_shape (tuple): Input data shape.
|
|
43
|
+
num_classes (int): Number of classes for the classification task.
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
tf.keras.Model: model
|
|
47
|
+
"""
|
|
48
|
+
return tf.keras.models.Sequential(
|
|
49
|
+
[
|
|
50
|
+
tf.keras.layers.Input(shape=input_shape),
|
|
51
|
+
tf.keras.layers.Flatten(),
|
|
52
|
+
tf.keras.layers.Dense(300, activation="relu"),
|
|
53
|
+
tf.keras.layers.Dense(150, activation="relu"),
|
|
54
|
+
tf.keras.layers.Dense(num_classes, activation="softmax"),
|
|
55
|
+
]
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
|
|
38
59
|
def get_toy_keras_convnet(num_classes: int) -> tf.keras.Model:
|
|
39
60
|
"""Basic keras convolutional classifier for toy datasets.
|
|
40
61
|
|
|
@@ -60,8 +81,8 @@ def get_toy_keras_convnet(num_classes: int) -> tf.keras.Model:
|
|
|
60
81
|
def train_tf_model(
|
|
61
82
|
train_data: tf.data.Dataset,
|
|
62
83
|
model: Union[tf.keras.Model, str],
|
|
63
|
-
input_shape: tuple,
|
|
64
|
-
num_classes: int,
|
|
84
|
+
input_shape: tuple = None,
|
|
85
|
+
num_classes: int = None,
|
|
65
86
|
batch_size: int = 128,
|
|
66
87
|
epochs: int = 50,
|
|
67
88
|
loss: str = "sparse_categorical_crossentropy",
|
|
@@ -80,7 +101,7 @@ def train_tf_model(
|
|
|
80
101
|
Args:
|
|
81
102
|
train_data (tf.data.Dataset): training dataset.
|
|
82
103
|
model (Union[tf.keras.Model, str]): if a string is provided, must be a model
|
|
83
|
-
from tf.keras.applications or "toy_convnet"
|
|
104
|
+
from tf.keras.applications or "toy_convnet" or "toy_mlp"
|
|
84
105
|
input_shape (tuple): Shape of the input images.
|
|
85
106
|
num_classes (int): Number of output classes.
|
|
86
107
|
batch_size (int, optional): Defaults to 128.
|
|
@@ -120,6 +141,8 @@ def train_tf_model(
|
|
|
120
141
|
elif isinstance(model, str):
|
|
121
142
|
if model == "toy_convnet":
|
|
122
143
|
model = get_toy_keras_convnet(num_classes)
|
|
144
|
+
elif model == "toy_mlp":
|
|
145
|
+
model = get_toy_mlp(input_shape, num_classes)
|
|
123
146
|
else:
|
|
124
147
|
weights = "imagenet" if imagenet_pretrained else None
|
|
125
148
|
backbone = getattr(tf.keras.applications, model)(
|
oodeel/utils/torch_operator.py
CHANGED
|
@@ -69,12 +69,24 @@ class TorchOperator(Operator):
|
|
|
69
69
|
return torch.argmax(tensor, dim=dim)
|
|
70
70
|
|
|
71
71
|
@staticmethod
|
|
72
|
-
def max(
|
|
72
|
+
def max(
|
|
73
|
+
tensor: TensorType, dim: Optional[int] = None, keepdim: Optional[bool] = False
|
|
74
|
+
) -> torch.Tensor:
|
|
73
75
|
"""Max function"""
|
|
74
76
|
if dim is None:
|
|
75
77
|
return torch.max(tensor)
|
|
76
78
|
else:
|
|
77
|
-
return torch.max(tensor, dim)[0]
|
|
79
|
+
return torch.max(tensor, dim, keepdim=keepdim)[0]
|
|
80
|
+
|
|
81
|
+
@staticmethod
|
|
82
|
+
def min(
|
|
83
|
+
tensor: TensorType, dim: Optional[int] = None, keepdim: bool = False
|
|
84
|
+
) -> torch.Tensor:
|
|
85
|
+
"""Min function"""
|
|
86
|
+
if dim is None:
|
|
87
|
+
return torch.min(tensor)
|
|
88
|
+
else:
|
|
89
|
+
return torch.min(tensor, dim, keepdim=keepdim)[0]
|
|
78
90
|
|
|
79
91
|
@staticmethod
|
|
80
92
|
def one_hot(tensor: TensorType, num_classes: int) -> torch.Tensor:
|
|
@@ -167,13 +179,18 @@ class TorchOperator(Operator):
|
|
|
167
179
|
def from_numpy(self, arr: np.ndarray) -> torch.Tensor:
|
|
168
180
|
"Convert a NumPy array to a tensor"
|
|
169
181
|
# TODO change dtype
|
|
170
|
-
return torch.
|
|
182
|
+
return torch.tensor(arr).to(self._device)
|
|
171
183
|
|
|
172
184
|
@staticmethod
|
|
173
|
-
def
|
|
185
|
+
def t(tensor: TensorType) -> torch.Tensor:
|
|
174
186
|
"Transpose function for tensor of rank 2"
|
|
175
187
|
return tensor.t()
|
|
176
188
|
|
|
189
|
+
@staticmethod
|
|
190
|
+
def permute(tensor: TensorType, dims) -> torch.Tensor:
|
|
191
|
+
"Transpose function for tensor of rank 2"
|
|
192
|
+
return torch.permute(tensor, dims)
|
|
193
|
+
|
|
177
194
|
@staticmethod
|
|
178
195
|
def diag(tensor: TensorType) -> torch.Tensor:
|
|
179
196
|
"Diagonal function: return the diagonal of a 2D tensor"
|
|
@@ -215,3 +232,57 @@ class TorchOperator(Operator):
|
|
|
215
232
|
def relu(tensor: TensorType) -> torch.Tensor:
|
|
216
233
|
"Apply relu to a tensor"
|
|
217
234
|
return torch.nn.functional.relu(tensor)
|
|
235
|
+
|
|
236
|
+
@staticmethod
|
|
237
|
+
def einsum(equation: str, *tensors: TensorType) -> torch.Tensor:
|
|
238
|
+
"Computes the einsum between tensors following equation"
|
|
239
|
+
return torch.einsum(equation, *tensors)
|
|
240
|
+
|
|
241
|
+
@staticmethod
|
|
242
|
+
def tril(tensor: TensorType, diagonal: int = 0) -> torch.Tensor:
|
|
243
|
+
"Set the upper triangle of the matrix formed by the last two dimensions of"
|
|
244
|
+
"tensor to zero"
|
|
245
|
+
return torch.tril(tensor, diagonal)
|
|
246
|
+
|
|
247
|
+
@staticmethod
|
|
248
|
+
def sum(tensor: TensorType, dim: Union[tuple, list, int] = None) -> torch.Tensor:
|
|
249
|
+
"sum along dim"
|
|
250
|
+
return torch.sum(tensor, dim)
|
|
251
|
+
|
|
252
|
+
@staticmethod
|
|
253
|
+
def unsqueeze(tensor: TensorType, dim: int) -> torch.Tensor:
|
|
254
|
+
"unsqueeze along dim"
|
|
255
|
+
return torch.unsqueeze(tensor, dim)
|
|
256
|
+
|
|
257
|
+
@staticmethod
|
|
258
|
+
def squeeze(tensor: TensorType, dim: int = None) -> torch.Tensor:
|
|
259
|
+
"squeeze along dim"
|
|
260
|
+
|
|
261
|
+
if dim is None:
|
|
262
|
+
return torch.squeeze(tensor)
|
|
263
|
+
|
|
264
|
+
return torch.squeeze(tensor, dim)
|
|
265
|
+
|
|
266
|
+
@staticmethod
|
|
267
|
+
def abs(tensor: TensorType) -> torch.Tensor:
|
|
268
|
+
"compute absolute value"
|
|
269
|
+
return torch.abs(tensor)
|
|
270
|
+
|
|
271
|
+
@staticmethod
|
|
272
|
+
def where(
|
|
273
|
+
condition: TensorType,
|
|
274
|
+
input: Union[TensorType, float],
|
|
275
|
+
other: Union[TensorType, float],
|
|
276
|
+
) -> torch.Tensor:
|
|
277
|
+
"Applies where function , to condition"
|
|
278
|
+
return torch.where(condition, input, other)
|
|
279
|
+
|
|
280
|
+
@staticmethod
|
|
281
|
+
def avg_pool_2d(tensor: TensorType) -> torch.Tensor:
|
|
282
|
+
"""Perform avg pool in 2d as in torch.nn.functional.adaptive_avg_pool2d"""
|
|
283
|
+
return torch.mean(tensor, dim=(-2, -1))
|
|
284
|
+
|
|
285
|
+
@staticmethod
|
|
286
|
+
def log(tensor: TensorType) -> torch.Tensor:
|
|
287
|
+
"""Perform log"""
|
|
288
|
+
return torch.log(tensor)
|
|
@@ -35,6 +35,31 @@ from ..types import Optional
|
|
|
35
35
|
from ..types import Union
|
|
36
36
|
|
|
37
37
|
|
|
38
|
+
class ToyTorchMLP(nn.Sequential):
|
|
39
|
+
"""Basic torch MLP classifier for toy datasets.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
input_shape (tuple): Input data shape.
|
|
43
|
+
num_classes (int): Number of classes for the classification task.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def __init__(self, input_shape: tuple, num_classes: int):
|
|
47
|
+
self.input_shape = input_shape
|
|
48
|
+
|
|
49
|
+
# build toy mlp
|
|
50
|
+
mlp_modules = OrderedDict(
|
|
51
|
+
[
|
|
52
|
+
("flatten", nn.Flatten()),
|
|
53
|
+
("dense1", nn.Linear(np.prod(input_shape), 300)),
|
|
54
|
+
("relu1", nn.ReLU()),
|
|
55
|
+
("dense2", nn.Linear(300, 150)),
|
|
56
|
+
("relu2", nn.ReLU()),
|
|
57
|
+
("fc1", nn.Linear(150, num_classes)),
|
|
58
|
+
]
|
|
59
|
+
)
|
|
60
|
+
super().__init__(mlp_modules)
|
|
61
|
+
|
|
62
|
+
|
|
38
63
|
class ToyTorchConvnet(nn.Sequential):
|
|
39
64
|
"""Basic torch convolutional classifier for toy datasets.
|
|
40
65
|
|
|
@@ -106,7 +131,7 @@ def train_torch_model(
|
|
|
106
131
|
Args:
|
|
107
132
|
train_data (DataLoader): train dataloader
|
|
108
133
|
model (Union[nn.Module, str]): if a string is provided, must be a model from
|
|
109
|
-
torchvision.models or "toy_convnet".
|
|
134
|
+
torchvision.models or "toy_convnet" or "toy_mlp.
|
|
110
135
|
num_classes (int): Number of output classes.
|
|
111
136
|
epochs (int, optional): Defaults to 50.
|
|
112
137
|
loss (str, optional): Defaults to "CrossEntropyLoss".
|
|
@@ -136,8 +161,12 @@ def train_torch_model(
|
|
|
136
161
|
elif isinstance(model, str):
|
|
137
162
|
if model == "toy_convnet":
|
|
138
163
|
# toy model
|
|
139
|
-
input_shape = next(iter(train_data))[0].shape[1:]
|
|
164
|
+
input_shape = tuple(next(iter(train_data))[0].shape[1:])
|
|
140
165
|
model = ToyTorchConvnet(input_shape, num_classes).to(device)
|
|
166
|
+
elif model == "toy_mlp":
|
|
167
|
+
# toy model
|
|
168
|
+
input_shape = tuple(next(iter(train_data))[0].shape[1:])
|
|
169
|
+
model = ToyTorchMLP(input_shape, num_classes).to(device)
|
|
141
170
|
else:
|
|
142
171
|
# torchvision model
|
|
143
172
|
model = getattr(torchvision.models, model)(
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
2
|
Name: oodeel
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.3.0
|
|
4
4
|
Summary: Simple, compact, and hackable post-hoc deep OOD detection for alreadytrained tensorflow or pytorch image classifiers.
|
|
5
5
|
Author: DEEL Core Team
|
|
6
6
|
Author-email: paul.novello@irt-saintexupery.com
|
|
@@ -12,104 +12,126 @@ Classifier: Programming Language :: Python :: 3.9
|
|
|
12
12
|
Classifier: Programming Language :: Python :: 3.10
|
|
13
13
|
Description-Content-Type: text/markdown
|
|
14
14
|
License-File: LICENSE
|
|
15
|
-
Requires-Dist:
|
|
15
|
+
Requires-Dist: faiss_cpu
|
|
16
16
|
Requires-Dist: numpy
|
|
17
|
-
Requires-Dist:
|
|
17
|
+
Requires-Dist: scikit_learn
|
|
18
18
|
Requires-Dist: scipy
|
|
19
19
|
Requires-Dist: setuptools
|
|
20
20
|
Requires-Dist: matplotlib
|
|
21
21
|
Requires-Dist: pandas
|
|
22
22
|
Requires-Dist: seaborn
|
|
23
23
|
Requires-Dist: plotly
|
|
24
|
+
Requires-Dist: tqdm
|
|
24
25
|
Provides-Extra: dev
|
|
25
|
-
Requires-Dist: mypy
|
|
26
|
-
Requires-Dist: ipywidgets
|
|
27
|
-
Requires-Dist: mkdocs-jupyter
|
|
28
|
-
Requires-Dist: mkdocstrings-python
|
|
29
|
-
Requires-Dist: flake8
|
|
30
|
-
Requires-Dist: setuptools
|
|
31
|
-
Requires-Dist: pre-commit
|
|
32
|
-
Requires-Dist: tox
|
|
33
|
-
Requires-Dist: black
|
|
34
|
-
Requires-Dist:
|
|
35
|
-
Requires-Dist:
|
|
36
|
-
Requires-Dist:
|
|
37
|
-
Requires-Dist:
|
|
38
|
-
Requires-Dist:
|
|
39
|
-
Requires-Dist:
|
|
40
|
-
Requires-Dist:
|
|
41
|
-
Requires-Dist:
|
|
42
|
-
Requires-Dist:
|
|
43
|
-
Requires-Dist:
|
|
44
|
-
Requires-Dist:
|
|
45
|
-
Requires-Dist:
|
|
46
|
-
Requires-Dist:
|
|
47
|
-
Requires-Dist:
|
|
48
|
-
Requires-Dist:
|
|
49
|
-
Requires-Dist:
|
|
50
|
-
Requires-Dist:
|
|
51
|
-
|
|
52
|
-
Requires-Dist:
|
|
53
|
-
Requires-Dist:
|
|
54
|
-
Requires-Dist:
|
|
55
|
-
Requires-Dist: mknotebooks ; extra == 'docs'
|
|
56
|
-
Requires-Dist: ipython ; extra == 'docs'
|
|
57
|
-
Provides-Extra: tensorflow
|
|
58
|
-
Requires-Dist: tensorflow ; extra == 'tensorflow'
|
|
59
|
-
Requires-Dist: tensorflow-datasets ; extra == 'tensorflow'
|
|
60
|
-
Requires-Dist: tensorflow-probability ; extra == 'tensorflow'
|
|
26
|
+
Requires-Dist: mypy; extra == "dev"
|
|
27
|
+
Requires-Dist: ipywidgets; extra == "dev"
|
|
28
|
+
Requires-Dist: mkdocs-jupyter; extra == "dev"
|
|
29
|
+
Requires-Dist: mkdocstrings-python; extra == "dev"
|
|
30
|
+
Requires-Dist: flake8; extra == "dev"
|
|
31
|
+
Requires-Dist: setuptools; extra == "dev"
|
|
32
|
+
Requires-Dist: pre-commit; extra == "dev"
|
|
33
|
+
Requires-Dist: tox; extra == "dev"
|
|
34
|
+
Requires-Dist: black; extra == "dev"
|
|
35
|
+
Requires-Dist: ruff; extra == "dev"
|
|
36
|
+
Requires-Dist: ipython; extra == "dev"
|
|
37
|
+
Requires-Dist: ipykernel; extra == "dev"
|
|
38
|
+
Requires-Dist: pytest; extra == "dev"
|
|
39
|
+
Requires-Dist: pylint; extra == "dev"
|
|
40
|
+
Requires-Dist: mypy; extra == "dev"
|
|
41
|
+
Requires-Dist: mkdocs; extra == "dev"
|
|
42
|
+
Requires-Dist: mkdocs-material; extra == "dev"
|
|
43
|
+
Requires-Dist: mkdocstrings; extra == "dev"
|
|
44
|
+
Requires-Dist: mknotebooks; extra == "dev"
|
|
45
|
+
Requires-Dist: mike; extra == "dev"
|
|
46
|
+
Requires-Dist: bump2version; extra == "dev"
|
|
47
|
+
Requires-Dist: docsig; extra == "dev"
|
|
48
|
+
Requires-Dist: no_implicit_optional; extra == "dev"
|
|
49
|
+
Requires-Dist: numpy==1.26.4; extra == "dev"
|
|
50
|
+
Requires-Dist: tensorflow==2.11.0; extra == "dev"
|
|
51
|
+
Requires-Dist: tensorflow_datasets; extra == "dev"
|
|
52
|
+
Requires-Dist: tensorflow_probability==0.19.0; extra == "dev"
|
|
53
|
+
Requires-Dist: timm; extra == "dev"
|
|
54
|
+
Requires-Dist: torch==1.13.1; extra == "dev"
|
|
55
|
+
Requires-Dist: torchvision==0.14.1; extra == "dev"
|
|
61
56
|
Provides-Extra: tensorflow-dev
|
|
62
|
-
Requires-Dist: mypy
|
|
63
|
-
Requires-Dist: ipywidgets
|
|
64
|
-
Requires-Dist: mkdocs-jupyter
|
|
65
|
-
Requires-Dist: mkdocstrings-python
|
|
66
|
-
Requires-Dist: flake8
|
|
67
|
-
Requires-Dist: setuptools
|
|
68
|
-
Requires-Dist: pre-commit
|
|
69
|
-
Requires-Dist: tox
|
|
70
|
-
Requires-Dist: black
|
|
71
|
-
Requires-Dist:
|
|
72
|
-
Requires-Dist:
|
|
73
|
-
Requires-Dist:
|
|
74
|
-
Requires-Dist:
|
|
75
|
-
Requires-Dist:
|
|
76
|
-
Requires-Dist:
|
|
77
|
-
Requires-Dist:
|
|
78
|
-
Requires-Dist:
|
|
79
|
-
Requires-Dist:
|
|
80
|
-
Requires-Dist:
|
|
81
|
-
Requires-Dist:
|
|
82
|
-
Requires-Dist:
|
|
83
|
-
Requires-Dist:
|
|
84
|
-
Requires-Dist:
|
|
85
|
-
|
|
86
|
-
Requires-Dist:
|
|
87
|
-
Requires-Dist:
|
|
88
|
-
Requires-Dist:
|
|
57
|
+
Requires-Dist: mypy; extra == "tensorflow-dev"
|
|
58
|
+
Requires-Dist: ipywidgets; extra == "tensorflow-dev"
|
|
59
|
+
Requires-Dist: mkdocs-jupyter; extra == "tensorflow-dev"
|
|
60
|
+
Requires-Dist: mkdocstrings-python; extra == "tensorflow-dev"
|
|
61
|
+
Requires-Dist: flake8; extra == "tensorflow-dev"
|
|
62
|
+
Requires-Dist: setuptools; extra == "tensorflow-dev"
|
|
63
|
+
Requires-Dist: pre-commit; extra == "tensorflow-dev"
|
|
64
|
+
Requires-Dist: tox; extra == "tensorflow-dev"
|
|
65
|
+
Requires-Dist: black; extra == "tensorflow-dev"
|
|
66
|
+
Requires-Dist: ruff; extra == "tensorflow-dev"
|
|
67
|
+
Requires-Dist: ipython; extra == "tensorflow-dev"
|
|
68
|
+
Requires-Dist: ipykernel; extra == "tensorflow-dev"
|
|
69
|
+
Requires-Dist: pytest; extra == "tensorflow-dev"
|
|
70
|
+
Requires-Dist: pylint; extra == "tensorflow-dev"
|
|
71
|
+
Requires-Dist: mypy; extra == "tensorflow-dev"
|
|
72
|
+
Requires-Dist: mkdocs; extra == "tensorflow-dev"
|
|
73
|
+
Requires-Dist: mkdocs-material; extra == "tensorflow-dev"
|
|
74
|
+
Requires-Dist: mkdocstrings; extra == "tensorflow-dev"
|
|
75
|
+
Requires-Dist: mknotebooks; extra == "tensorflow-dev"
|
|
76
|
+
Requires-Dist: mike; extra == "tensorflow-dev"
|
|
77
|
+
Requires-Dist: bump2version; extra == "tensorflow-dev"
|
|
78
|
+
Requires-Dist: docsig; extra == "tensorflow-dev"
|
|
79
|
+
Requires-Dist: no_implicit_optional; extra == "tensorflow-dev"
|
|
80
|
+
Requires-Dist: numpy==1.26.4; extra == "tensorflow-dev"
|
|
81
|
+
Requires-Dist: tensorflow==2.11.0; extra == "tensorflow-dev"
|
|
82
|
+
Requires-Dist: tensorflow_datasets; extra == "tensorflow-dev"
|
|
83
|
+
Requires-Dist: tensorflow_probability==0.19.0; extra == "tensorflow-dev"
|
|
89
84
|
Provides-Extra: torch-dev
|
|
90
|
-
Requires-Dist: mypy
|
|
91
|
-
Requires-Dist: ipywidgets
|
|
92
|
-
Requires-Dist: mkdocs-jupyter
|
|
93
|
-
Requires-Dist: mkdocstrings-python
|
|
94
|
-
Requires-Dist: flake8
|
|
95
|
-
Requires-Dist: setuptools
|
|
96
|
-
Requires-Dist: pre-commit
|
|
97
|
-
Requires-Dist: tox
|
|
98
|
-
Requires-Dist: black
|
|
99
|
-
Requires-Dist:
|
|
100
|
-
Requires-Dist:
|
|
101
|
-
Requires-Dist:
|
|
102
|
-
Requires-Dist:
|
|
103
|
-
Requires-Dist:
|
|
104
|
-
Requires-Dist:
|
|
105
|
-
Requires-Dist:
|
|
106
|
-
Requires-Dist:
|
|
107
|
-
Requires-Dist:
|
|
108
|
-
Requires-Dist:
|
|
109
|
-
Requires-Dist:
|
|
110
|
-
Requires-Dist:
|
|
111
|
-
Requires-Dist:
|
|
112
|
-
Requires-Dist:
|
|
85
|
+
Requires-Dist: mypy; extra == "torch-dev"
|
|
86
|
+
Requires-Dist: ipywidgets; extra == "torch-dev"
|
|
87
|
+
Requires-Dist: mkdocs-jupyter; extra == "torch-dev"
|
|
88
|
+
Requires-Dist: mkdocstrings-python; extra == "torch-dev"
|
|
89
|
+
Requires-Dist: flake8; extra == "torch-dev"
|
|
90
|
+
Requires-Dist: setuptools; extra == "torch-dev"
|
|
91
|
+
Requires-Dist: pre-commit; extra == "torch-dev"
|
|
92
|
+
Requires-Dist: tox; extra == "torch-dev"
|
|
93
|
+
Requires-Dist: black; extra == "torch-dev"
|
|
94
|
+
Requires-Dist: ruff; extra == "torch-dev"
|
|
95
|
+
Requires-Dist: ipython; extra == "torch-dev"
|
|
96
|
+
Requires-Dist: ipykernel; extra == "torch-dev"
|
|
97
|
+
Requires-Dist: pytest; extra == "torch-dev"
|
|
98
|
+
Requires-Dist: pylint; extra == "torch-dev"
|
|
99
|
+
Requires-Dist: mypy; extra == "torch-dev"
|
|
100
|
+
Requires-Dist: mkdocs; extra == "torch-dev"
|
|
101
|
+
Requires-Dist: mkdocs-material; extra == "torch-dev"
|
|
102
|
+
Requires-Dist: mkdocstrings; extra == "torch-dev"
|
|
103
|
+
Requires-Dist: mknotebooks; extra == "torch-dev"
|
|
104
|
+
Requires-Dist: mike; extra == "torch-dev"
|
|
105
|
+
Requires-Dist: bump2version; extra == "torch-dev"
|
|
106
|
+
Requires-Dist: docsig; extra == "torch-dev"
|
|
107
|
+
Requires-Dist: no_implicit_optional; extra == "torch-dev"
|
|
108
|
+
Requires-Dist: numpy==1.26.4; extra == "torch-dev"
|
|
109
|
+
Requires-Dist: timm; extra == "torch-dev"
|
|
110
|
+
Requires-Dist: torch==1.13.1; extra == "torch-dev"
|
|
111
|
+
Requires-Dist: torchvision==0.14.1; extra == "torch-dev"
|
|
112
|
+
Provides-Extra: tensorflow
|
|
113
|
+
Requires-Dist: tensorflow==2.11.0; extra == "tensorflow"
|
|
114
|
+
Requires-Dist: tensorflow_datasets; extra == "tensorflow"
|
|
115
|
+
Requires-Dist: tensorflow_probability==0.19.0; extra == "tensorflow"
|
|
116
|
+
Provides-Extra: torch
|
|
117
|
+
Requires-Dist: timm; extra == "torch"
|
|
118
|
+
Requires-Dist: torch==1.13.1; extra == "torch"
|
|
119
|
+
Requires-Dist: torchvision==0.14.1; extra == "torch"
|
|
120
|
+
Provides-Extra: docs
|
|
121
|
+
Requires-Dist: mkdocs; extra == "docs"
|
|
122
|
+
Requires-Dist: mkdocs-material; extra == "docs"
|
|
123
|
+
Requires-Dist: mkdocstrings; extra == "docs"
|
|
124
|
+
Requires-Dist: mknotebooks; extra == "docs"
|
|
125
|
+
Requires-Dist: ipython; extra == "docs"
|
|
126
|
+
Dynamic: author
|
|
127
|
+
Dynamic: author-email
|
|
128
|
+
Dynamic: classifier
|
|
129
|
+
Dynamic: description
|
|
130
|
+
Dynamic: description-content-type
|
|
131
|
+
Dynamic: license-file
|
|
132
|
+
Dynamic: provides-extra
|
|
133
|
+
Dynamic: requires-dist
|
|
134
|
+
Dynamic: summary
|
|
113
135
|
|
|
114
136
|
|
|
115
137
|
<!-- Banner section -->
|
|
@@ -125,29 +147,23 @@ Requires-Dist: torchvision ; extra == 'torch-dev'
|
|
|
125
147
|
<!-- Badge section -->
|
|
126
148
|
<div align="center">
|
|
127
149
|
<a href="#">
|
|
128
|
-
<img src="https://img.shields.io/badge/python-3.8%2B-blue">
|
|
129
|
-
</a>
|
|
150
|
+
<img src="https://img.shields.io/badge/python-3.8%2B-blue"></a>
|
|
130
151
|
<a href="https://github.com/deel-ai/oodeel/actions/workflows/python-linters.yml">
|
|
131
|
-
<img alt="Flake8" src="https://github.com/deel-ai/oodeel/actions/workflows/python-linters.yml/badge.svg">
|
|
132
|
-
</a>
|
|
152
|
+
<img alt="Flake8" src="https://github.com/deel-ai/oodeel/actions/workflows/python-linters.yml/badge.svg"></a>
|
|
133
153
|
<a href="https://github.com/deel-ai/oodeel/actions/workflows/python-tests-tf.yml">
|
|
134
|
-
<img alt="Tests tf" src="https://github.com/deel-ai/oodeel/actions/workflows/python-tests-tf.yml/badge.svg">
|
|
135
|
-
</a>
|
|
154
|
+
<img alt="Tests tf" src="https://github.com/deel-ai/oodeel/actions/workflows/python-tests-tf.yml/badge.svg"></a>
|
|
136
155
|
<a href="https://github.com/deel-ai/oodeel/actions/workflows/python-tests-torch.yml">
|
|
137
|
-
<img alt="Tests torch" src="https://github.com/deel-ai/oodeel/actions/workflows/python-tests-torch.yml/badge.svg">
|
|
138
|
-
</a>
|
|
156
|
+
<img alt="Tests torch" src="https://github.com/deel-ai/oodeel/actions/workflows/python-tests-torch.yml/badge.svg"></a>
|
|
139
157
|
<a href="https://github.com/deel-ai/oodeel/actions/workflows/python-coverage-shield.yml">
|
|
140
|
-
<img alt="Coverage" src="https://github.com/deel-ai/oodeel/raw/gh-shields/coverage.svg">
|
|
141
|
-
</a>
|
|
158
|
+
<img alt="Coverage" src="https://github.com/deel-ai/oodeel/raw/gh-shields/coverage.svg"></a>
|
|
142
159
|
<a href="https://github.com/deel-ai/oodeel/blob/master/LICENSE">
|
|
143
|
-
<img alt="License MIT" src="https://img.shields.io/badge/License-MIT-efefef">
|
|
144
|
-
</a>
|
|
160
|
+
<img alt="License MIT" src="https://img.shields.io/badge/License-MIT-efefef"></a>
|
|
145
161
|
</div>
|
|
146
162
|
<br>
|
|
147
163
|
|
|
148
164
|
<!-- Short description of your library -->
|
|
149
165
|
|
|
150
|
-
<b>Oodeel</b> is a library that performs post-hoc deep OOD detection on already trained neural network image classifiers. The philosophy of the library is to favor quality over quantity and to foster easy adoption. As a result, we provide a simple, compact and easily customizable API and carefully integrate and test each proposed baseline into a coherent framework that is designed to enable their use in tensorflow **and** pytorch. You can find the documentation [here](https://deel-ai.github.io/oodeel/).
|
|
166
|
+
<b>Oodeel</b> is a library that performs post-hoc deep OOD (Out-of-Distribution) detection on already trained neural network image classifiers. The philosophy of the library is to favor quality over quantity and to foster easy adoption. As a result, we provide a simple, compact and easily customizable API and carefully integrate and test each proposed baseline into a coherent framework that is designed to enable their use in tensorflow **and** pytorch. You can find the documentation [here](https://deel-ai.github.io/oodeel/).
|
|
151
167
|
|
|
152
168
|
```python
|
|
153
169
|
from oodeel.methods import MLS
|
|
@@ -167,7 +183,8 @@ scores, info = mls.score(ds) # ds is a tf.data.Dataset or a torch.DataLoader
|
|
|
167
183
|
- [Contributing](#contributing)
|
|
168
184
|
- [See Also](#see-also)
|
|
169
185
|
- [Acknowledgments](#acknowledgments)
|
|
170
|
-
- [
|
|
186
|
+
- [Creators](#creators)
|
|
187
|
+
- [Citation](#citation)
|
|
171
188
|
- [License](#license)
|
|
172
189
|
|
|
173
190
|
# Installation
|
|
@@ -286,15 +303,19 @@ Currently, **oodeel** includes the following baselines:
|
|
|
286
303
|
| MSP | [A Baseline for Detecting Misclassified and Out-of-Distribution Examples in Neural Networks](http://arxiv.org/abs/1610.02136) | ICLR 2017 | avail [tensorflow & torch](docs/pages/getting_started.ipynb)|
|
|
287
304
|
| Mahalanobis | [A Simple Unified Framework for Detecting Out-of-Distribution Samples and Adversarial Attacks](http://arxiv.org/abs/1807.03888) | NeurIPS 2018 | avail [tensorflow](docs/notebooks/tensorflow/demo_mahalanobis_tf.ipynb) or [torch](docs/notebooks/torch/demo_mahalanobis_torch.ipynb)|
|
|
288
305
|
| Energy | [Energy-based Out-of-distribution Detection](http://arxiv.org/abs/2010.03759) | NeurIPS 2020 |avail [tensorflow](docs/notebooks/tensorflow/demo_energy_tf.ipynb) or [torch](docs/notebooks/torch/demo_energy_torch.ipynb) |
|
|
289
|
-
| Odin | [Enhancing The Reliability of Out-of-distribution Image Detection in Neural Networks](http://arxiv.org/abs/1706.02690) | ICLR 2018 |avail [tensorflow](docs/notebooks/tensorflow/demo_odin_tf.ipynb) or [torch](docs/notebooks/torch/demo_odin_torch.ipynb) |
|
|
306
|
+
| Odin | [Enhancing The Reliability of Out-of-distribution Image Detection in Neural Networks](http://arxiv.org/abs/1706.02690) | ICLR 2018 | avail [tensorflow](docs/notebooks/tensorflow/demo_odin_tf.ipynb) or [torch](docs/notebooks/torch/demo_odin_torch.ipynb) |
|
|
290
307
|
| DKNN | [Out-of-Distribution Detection with Deep Nearest Neighbors](http://arxiv.org/abs/2204.06507) | ICML 2022 | avail [tensorflow](docs/notebooks/tensorflow/demo_dknn_tf.ipynb) or [torch](docs/notebooks/torch/demo_dknn_torch.ipynb) |
|
|
291
308
|
| VIM | [ViM: Out-Of-Distribution with Virtual-logit Matching](http://arxiv.org/abs/2203.10807) | CVPR 2022 |avail [tensorflow](docs/notebooks/tensorflow/demo_vim_tf.ipynb) or [torch](docs/notebooks/torch/demo_vim_torch.ipynb) |
|
|
292
309
|
| Entropy | [Likelihood Ratios for Out-of-Distribution Detection](https://proceedings.neurips.cc/paper/2019/hash/1e79596878b2320cac26dd792a6c51c9-Abstract.html) | NeurIPS 2019 |avail [tensorflow](docs/notebooks/tensorflow/demo_entropy_tf.ipynb) or [torch](docs/notebooks/torch/demo_entropy_torch.ipynb) |
|
|
293
310
|
| GODIN | [Generalized ODIN: Detecting Out-of-Distribution Image Without Learning From Out-of-Distribution Data](https://ieeexplore.ieee.org/document/9156473/) | CVPR 2020 | planned |
|
|
294
311
|
| ReAct | [ReAct: Out-of-distribution Detection With Rectified Activations](http://arxiv.org/abs/2111.12797) | NeurIPS 2021 | avail [tensorflow](docs/notebooks/tensorflow/demo_react_tf.ipynb) or [torch](docs/notebooks/torch/demo_react_torch.ipynb) |
|
|
295
312
|
| NMD | [Neural Mean Discrepancy for Efficient Out-of-Distribution Detection](https://openaccess.thecvf.com/content/CVPR2022/html/Dong_Neural_Mean_Discrepancy_for_Efficient_Out-of-Distribution_Detection_CVPR_2022_paper.html) | CVPR 2022 | planned |
|
|
296
|
-
| Gram | [Detecting Out-of-Distribution Examples with Gram Matrices](https://proceedings.mlr.press/v119/sastry20a.html) | ICML 2020 |
|
|
297
|
-
|
|
313
|
+
| Gram | [Detecting Out-of-Distribution Examples with Gram Matrices](https://proceedings.mlr.press/v119/sastry20a.html) | ICML 2020 | avail [tensorflow](docs/notebooks/tensorflow/demo_gram_tf.ipynb) or [torch](docs/notebooks/torch/demo_gram_torch.ipynb) |
|
|
314
|
+
| GEN | [GEN: Pushing the Limits of Softmax-Based Out-of-Distribution Detection](https://openaccess.thecvf.com/content/CVPR2023/html/Liu_GEN_Pushing_the_Limits_of_Softmax-Based_Out-of-Distribution_Detection_CVPR_2023_paper.html) | CVPR 2023 | avail [tensorflow](docs/notebooks/tensorflow/demo_gen_tf.ipynb) or [torch](docs/notebooks/torch/demo_gen_torch.ipynb) |
|
|
315
|
+
| RMDS | [A Simple Fix to Mahalanobis Distance for Improving Near-OOD Detection](https://arxiv.org/abs/2106.09022) | preprint | avail [tensorflow](docs/notebooks/tensorflow/demo_rmds_tf.ipynb) or [torch](docs/notebooks/torch/demo_rmds_torch.ipynb) |
|
|
316
|
+
| SHE | [Out-of-Distribution Detection based on In-Distribution Data Patterns Memorization with Modern Hopfield Energy](https://openreview.net/forum?id=KkazG4lgKL) | ICLR 2023 | avail [tensorflow](docs/notebooks/tensorflow/demo_she_tf.ipynb) or [torch](docs/notebooks/torch/demo_she_torch.ipynb) |
|
|
317
|
+
| ASH | [Extremely Simple Activation Shaping for Out-of-Distribution Detection](http://arxiv.org/abs/2310.00227) | ICLR 2023 | avail [tensorflow](docs/notebooks/tensorflow/demo_ash_tf.ipynb) or [torch](docs/notebooks/torch/demo_ash_torch.ipynb) |
|
|
318
|
+
| SCALE | [Scaling for Training Time and Post-hoc Out-of-distribution Detection Enhancement](https://arxiv.org/abs/2111.12797) | ICLR 2024 | avail [tensorflow](docs/notebooks/tensorflow/demo_scale_tf.ipynb) or [torch](docs/notebooks/torch/demo_scale_torch.ipynb) |
|
|
298
319
|
|
|
299
320
|
|
|
300
321
|
|
|
@@ -344,6 +365,19 @@ This project received funding from the French ”Investing for the Future – PI
|
|
|
344
365
|
|
|
345
366
|
The library was created by Paul Novello to streamline DEEL research on post-hoc deep OOD methods and foster their adoption by DEEL industrial partners. He was soon joined by Yann Pequignot, Yannick Prudent, Corentin Friedrich and Matthieu Le Goff.
|
|
346
367
|
|
|
368
|
+
# Citation
|
|
369
|
+
|
|
370
|
+
If you use OODEEL for your research project, please consider citing:
|
|
371
|
+
```
|
|
372
|
+
@misc{oodeel,
|
|
373
|
+
author = {Novello, Paul and Prudent, Yannick and Friedrich, Corentin and Pequignot, Yann and Le Goff, Matthieu},
|
|
374
|
+
title = {OODEEL, a simple, compact, and hackable post-hoc deep OOD detection for already trained tensorflow or pytorch image classifiers.},
|
|
375
|
+
year = {2023},
|
|
376
|
+
publisher = {GitHub},
|
|
377
|
+
journal = {GitHub repository},
|
|
378
|
+
howpublished = {\url{https://github.com/deel-ai/oodeel}},
|
|
379
|
+
}
|
|
380
|
+
```
|
|
347
381
|
# License
|
|
348
382
|
|
|
349
383
|
The package is released under [MIT license](LICENSE).
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
oodeel/__init__.py,sha256=-3tTNyrjeIvLCDPR0kCXup2TQX0KHn7R0emJDOKUYxI,1343
|
|
2
|
+
oodeel/datasets/__init__.py,sha256=COhGN25e3TsYtIz5rJgc0aNr_k3JjGXE3cCYd5RajTQ,1398
|
|
3
|
+
oodeel/datasets/data_handler.py,sha256=mWg1_Jv19rpz7p8o2DOE859J8evTRrWTiW5U_prw1sQ,10532
|
|
4
|
+
oodeel/datasets/tf_data_handler.py,sha256=H3BN_N9fzT19vtJtjhe0XJG4aA-7BauGrEY4OASxc90,23102
|
|
5
|
+
oodeel/datasets/torch_data_handler.py,sha256=jzQ1DvHC3xM0DOLHWGtV5OX-C63eqhabDyVDKmg11cc,24792
|
|
6
|
+
oodeel/datasets/deprecated/DEPRECATED_data_handler.py,sha256=fGK3_YSbNdHuljBtrjt7xbi6ESfNpnQV_pUSn9Uot2k,7910
|
|
7
|
+
oodeel/datasets/deprecated/DEPRECATED_ooddataset.py,sha256=Ad2otHV_vClK7ZY2D3-gW4QO_B3ir6DzJM1kYeNCJpw,13294
|
|
8
|
+
oodeel/datasets/deprecated/DEPRECATED_tf_data_handler.py,sha256=heK9c1g829LZof8CkAmmxpTvE0EpoLZOLXuYKMrL8b4,24811
|
|
9
|
+
oodeel/datasets/deprecated/DEPRECATED_torch_data_handler.py,sha256=eWCOlSPbblS8KZZAlu-M10QgY2Hv7mOq1uawLcGs0FE,26986
|
|
10
|
+
oodeel/datasets/deprecated/__init__.py,sha256=VYVsYZ-_TgGyd5ls5rudytQGPocm9I7N5_M9rHcV91w,1506
|
|
11
|
+
oodeel/eval/__init__.py,sha256=lQIUQjczeiRtfIqH2uLNJGubKUN6vPM93mTfY1Qz3bc,1297
|
|
12
|
+
oodeel/eval/metrics.py,sha256=9fg7fVTT10_XIKrcsan6HQOlhUdFJhZ3f5VAprTKsjM,8839
|
|
13
|
+
oodeel/eval/plots/__init__.py,sha256=YmcFh8RUGvljc-vCew6EiIFMDn0YA_UOfDj4eAv5_Yk,1487
|
|
14
|
+
oodeel/eval/plots/features.py,sha256=fP4soXDg3BAtkBySvfUIH3Z_E3ofuU3QenH6SHRnzlM,11821
|
|
15
|
+
oodeel/eval/plots/metrics.py,sha256=3QvLqEB1pAggNHeQJzUpLqL_Ro0MOJ_bPrLcLmf_qpk,4189
|
|
16
|
+
oodeel/eval/plots/plotly.py,sha256=XHPFuXARrRIGVrRB7d0Od36iCn3JBP387tDrwRvPhU0,6040
|
|
17
|
+
oodeel/extractor/__init__.py,sha256=Ew8gLh-xanZsWJe_gKvTa_7HciZ5yTZ-cADKuj7-PCg,1538
|
|
18
|
+
oodeel/extractor/feature_extractor.py,sha256=WFNOrZjmCPt2mPr3G4ctcSqiSuvFjMZkozrwazv1zwg,6789
|
|
19
|
+
oodeel/extractor/keras_feature_extractor.py,sha256=bjJfh_wgt2Y0NOGBesZ8uBjtcyiXtHYJzsWCS_tsIFg,12412
|
|
20
|
+
oodeel/extractor/torch_feature_extractor.py,sha256=D_c_AV29Ciu-d_bWXk1ABad66qdX8wdddxrCgLQ-5G4,15779
|
|
21
|
+
oodeel/methods/__init__.py,sha256=BkpwY0-N6SAvJ8iUxe1g0J8On-tbmLCCYTJIFI3M7i0,1721
|
|
22
|
+
oodeel/methods/base.py,sha256=zeizvn6pMvdaP4FnI5TcvgrCdFARmxrwCfyU3D9LS1U,14831
|
|
23
|
+
oodeel/methods/dknn.py,sha256=6eyoHAKzb3s3pkukr6r5v9_UyK8_Rj269iHQxv1PoFI,4982
|
|
24
|
+
oodeel/methods/energy.py,sha256=u3TSoGDll1myCiA0FkuYZZNmGEXTTFPRSpT7-J8Q6Ec,4505
|
|
25
|
+
oodeel/methods/entropy.py,sha256=dcG7oC6velSN1sky9DJAD9eD32EEPqgM4R45_iU-R18,4172
|
|
26
|
+
oodeel/methods/gen.py,sha256=JafCnih09929I2KgOVaDfn8T1O4mbKdZgz3lByKvOJQ,4505
|
|
27
|
+
oodeel/methods/gram.py,sha256=35yH3BYl5CgsC2fzw4GLY21MWfyLNvY0mp3tAfmJkJ4,12220
|
|
28
|
+
oodeel/methods/mahalanobis.py,sha256=C6t2Gj7Ckp5KIhGFiOpk9uWY1RWIkM6h_qfLQkVGETc,7393
|
|
29
|
+
oodeel/methods/mls.py,sha256=FxN1XYlTkOR2uZs07krsD8qCAfJr9UIpdKH6FYEn37Q,4399
|
|
30
|
+
oodeel/methods/odin.py,sha256=Ty1YItWIsvplad72JmGuNZ_MDIBxg9iJVbjAcn33K3U,5548
|
|
31
|
+
oodeel/methods/rmds.py,sha256=ji6r8_AGhIc_75-u174Mj6a8YpHC6EHDm4mFBSOsiFM,4970
|
|
32
|
+
oodeel/methods/she.py,sha256=qJftwTM1UD_dtBvGH66RtejWqpkxR69ttpBy5KbNvHQ,7254
|
|
33
|
+
oodeel/methods/vim.py,sha256=PMy04hMZ08THlbi66ZWg15JL9NNGBG6vndz40sez5Ps,9475
|
|
34
|
+
oodeel/preprocess/__init__.py,sha256=65M9hKYHYzZ6lACA6xo7ODKofaAoqfFZ4aC5ZSJeN2I,1484
|
|
35
|
+
oodeel/preprocess/tf_preprocess.py,sha256=TRaEA7KrVjWFB81vlnNk9hN-G0tGhDa97VvrBEaP9vM,4048
|
|
36
|
+
oodeel/preprocess/torch_preprocess.py,sha256=BTjOyEHPfqx_CSv6Lw3zBi2wKEjAIdPN1DuBptcebA0,3976
|
|
37
|
+
oodeel/types/__init__.py,sha256=9TTXjSBfbaDIVMRnclInHI-CBr4L6VZTi61rCJKcTw8,2484
|
|
38
|
+
oodeel/utils/__init__.py,sha256=Lue9BysqeJf5Ej0szafgC9na8IZZR7hWGInJxoEiHUg,1696
|
|
39
|
+
oodeel/utils/general_utils.py,sha256=xc6e7q19ALgMxdCgS7TIyDiMUIGF4ih-aTK1kSlqWoQ,3292
|
|
40
|
+
oodeel/utils/operator.py,sha256=ETAFJ_oYhiD1Rawjooueq5KDl4SNzJR5fQDUU05uMz8,8262
|
|
41
|
+
oodeel/utils/tf_operator.py,sha256=gHJZlD6SKOBxDtwv0oy6u93oXYiS5gylKv5lbVWlzx4,9385
|
|
42
|
+
oodeel/utils/tf_training_tools.py,sha256=31cPCpXKmO7lMRrefmY93m6cSn0DRaHxVFeYgUiE6kI,8090
|
|
43
|
+
oodeel/utils/torch_operator.py,sha256=lO_Onb4km6UNyqYjRkFPi3OcoT2MUwE1ktzP_JMDmFY,10199
|
|
44
|
+
oodeel/utils/torch_training_tools.py,sha256=ggL_iDwyquTw9CtQAg--IODVdS1BFsBw4U5IOVlAsK8,11192
|
|
45
|
+
oodeel-0.3.0.dist-info/licenses/LICENSE,sha256=XrlZ0uYNVeUAF-iEVX21J3CTJjYPgIZUagYSy3Hf0jk,1265
|
|
46
|
+
tests/__init__.py,sha256=lQIUQjczeiRtfIqH2uLNJGubKUN6vPM93mTfY1Qz3bc,1297
|
|
47
|
+
tests/tools_operator.py,sha256=YTU_FppXZPUyi3nldQsDkCMQ3Bvz1Hn_V7L45EEwftc,5328
|
|
48
|
+
tests/tests_tensorflow/__init__.py,sha256=VuiDSdOBB2jUeobzAW-XTtHJQrcyZpp8i2DyFu0A2RI,1710
|
|
49
|
+
tests/tests_tensorflow/tf_methods_utils.py,sha256=wHjsvKJj1EhcmAQhDgSWfRbVZR_jJwknT8vyez0YByw,5277
|
|
50
|
+
tests/tests_tensorflow/tools_tf.py,sha256=Z_MEzhwCOF7E7uT-tnfRS3Po9hg0MFzbESP7sMAFBkY,3499
|
|
51
|
+
tests/tests_torch/__init__.py,sha256=3mVxix2Ecn2wUo9DxvzyJBmyOAkv4fVWHHp6uLQHoic,1738
|
|
52
|
+
tests/tests_torch/tools_torch.py,sha256=lS5_kctnyiB6eP-NTfVUsbj5Mn-AiRO0Jo75p8djmwg,4975
|
|
53
|
+
tests/tests_torch/torch_methods_utils.py,sha256=PAxBs2smjsDtQjgwrwdbDzvVrboczc0pQOKtNfl82J4,5432
|
|
54
|
+
oodeel-0.3.0.dist-info/METADATA,sha256=YNvMUC4osqrw2FWlqqNMJiapvjBGFwxp834EOkNIYtg,20299
|
|
55
|
+
oodeel-0.3.0.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
|
56
|
+
oodeel-0.3.0.dist-info/top_level.txt,sha256=zkYRty1FGJ1dkpk-5MU_4uFfBFmcxXoqSwej73xELDs,13
|
|
57
|
+
oodeel-0.3.0.dist-info/RECORD,,
|
|
@@ -66,7 +66,8 @@ def load_blobs_data(batch_size=128, num_samples=10000, train_ratio=0.8):
|
|
|
66
66
|
def load_blob_mlp():
|
|
67
67
|
model_path_blobs = tf.keras.utils.get_file(
|
|
68
68
|
"blobs_mlp.h5",
|
|
69
|
-
origin="https://
|
|
69
|
+
origin="https://github.com/deel-ai/oodeel/blob/assets/test_models/"
|
|
70
|
+
+ "blobs_mlp.h5?raw=True",
|
|
70
71
|
cache_dir=model_path,
|
|
71
72
|
cache_subdir="",
|
|
72
73
|
)
|