oodeel 0.1.1__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of oodeel might be problematic. Click here for more details.

Files changed (47) hide show
  1. oodeel/__init__.py +1 -1
  2. oodeel/datasets/__init__.py +2 -1
  3. oodeel/datasets/data_handler.py +162 -94
  4. oodeel/datasets/deprecated/DEPRECATED_data_handler.py +236 -0
  5. oodeel/datasets/{ooddataset.py → deprecated/DEPRECATED_ooddataset.py} +14 -13
  6. oodeel/datasets/deprecated/DEPRECATED_tf_data_handler.py +671 -0
  7. oodeel/datasets/deprecated/DEPRECATED_torch_data_handler.py +769 -0
  8. oodeel/datasets/deprecated/__init__.py +31 -0
  9. oodeel/datasets/tf_data_handler.py +105 -167
  10. oodeel/datasets/torch_data_handler.py +109 -181
  11. oodeel/eval/metrics.py +7 -2
  12. oodeel/eval/plots/features.py +2 -2
  13. oodeel/eval/plots/plotly.py +2 -2
  14. oodeel/extractor/feature_extractor.py +30 -9
  15. oodeel/extractor/keras_feature_extractor.py +70 -13
  16. oodeel/extractor/torch_feature_extractor.py +120 -33
  17. oodeel/methods/__init__.py +17 -1
  18. oodeel/methods/base.py +103 -17
  19. oodeel/methods/dknn.py +22 -9
  20. oodeel/methods/energy.py +8 -0
  21. oodeel/methods/entropy.py +8 -0
  22. oodeel/methods/gen.py +118 -0
  23. oodeel/methods/gram.py +307 -0
  24. oodeel/methods/mahalanobis.py +14 -12
  25. oodeel/methods/mls.py +8 -0
  26. oodeel/methods/odin.py +8 -0
  27. oodeel/methods/rmds.py +122 -0
  28. oodeel/methods/she.py +197 -0
  29. oodeel/methods/vim.py +5 -5
  30. oodeel/preprocess/__init__.py +31 -0
  31. oodeel/preprocess/tf_preprocess.py +95 -0
  32. oodeel/preprocess/torch_preprocess.py +97 -0
  33. oodeel/utils/operator.py +72 -2
  34. oodeel/utils/tf_operator.py +72 -4
  35. oodeel/utils/tf_training_tools.py +26 -3
  36. oodeel/utils/torch_operator.py +75 -4
  37. oodeel/utils/torch_training_tools.py +31 -2
  38. {oodeel-0.1.1.dist-info → oodeel-0.3.0.dist-info}/METADATA +141 -107
  39. oodeel-0.3.0.dist-info/RECORD +57 -0
  40. {oodeel-0.1.1.dist-info → oodeel-0.3.0.dist-info}/WHEEL +1 -1
  41. tests/tests_tensorflow/tf_methods_utils.py +2 -1
  42. tests/tests_torch/tools_torch.py +9 -9
  43. tests/tests_torch/torch_methods_utils.py +34 -27
  44. tests/tools_operator.py +10 -1
  45. oodeel-0.1.1.dist-info/RECORD +0 -46
  46. {oodeel-0.1.1.dist-info → oodeel-0.3.0.dist-info/licenses}/LICENSE +0 -0
  47. {oodeel-0.1.1.dist-info → oodeel-0.3.0.dist-info}/top_level.txt +0 -0
@@ -35,6 +35,27 @@ from ..types import Optional
35
35
  from ..types import Union
36
36
 
37
37
 
38
+ def get_toy_mlp(input_shape: tuple, num_classes: int) -> tf.keras.Model:
39
+ """Basic keras MLP classifier for toy datasets.
40
+
41
+ Args:
42
+ input_shape (tuple): Input data shape.
43
+ num_classes (int): Number of classes for the classification task.
44
+
45
+ Returns:
46
+ tf.keras.Model: model
47
+ """
48
+ return tf.keras.models.Sequential(
49
+ [
50
+ tf.keras.layers.Input(shape=input_shape),
51
+ tf.keras.layers.Flatten(),
52
+ tf.keras.layers.Dense(300, activation="relu"),
53
+ tf.keras.layers.Dense(150, activation="relu"),
54
+ tf.keras.layers.Dense(num_classes, activation="softmax"),
55
+ ]
56
+ )
57
+
58
+
38
59
  def get_toy_keras_convnet(num_classes: int) -> tf.keras.Model:
39
60
  """Basic keras convolutional classifier for toy datasets.
40
61
 
@@ -60,8 +81,8 @@ def get_toy_keras_convnet(num_classes: int) -> tf.keras.Model:
60
81
  def train_tf_model(
61
82
  train_data: tf.data.Dataset,
62
83
  model: Union[tf.keras.Model, str],
63
- input_shape: tuple,
64
- num_classes: int,
84
+ input_shape: tuple = None,
85
+ num_classes: int = None,
65
86
  batch_size: int = 128,
66
87
  epochs: int = 50,
67
88
  loss: str = "sparse_categorical_crossentropy",
@@ -80,7 +101,7 @@ def train_tf_model(
80
101
  Args:
81
102
  train_data (tf.data.Dataset): training dataset.
82
103
  model (Union[tf.keras.Model, str]): if a string is provided, must be a model
83
- from tf.keras.applications or "toy_convnet"
104
+ from tf.keras.applications or "toy_convnet" or "toy_mlp"
84
105
  input_shape (tuple): Shape of the input images.
85
106
  num_classes (int): Number of output classes.
86
107
  batch_size (int, optional): Defaults to 128.
@@ -120,6 +141,8 @@ def train_tf_model(
120
141
  elif isinstance(model, str):
121
142
  if model == "toy_convnet":
122
143
  model = get_toy_keras_convnet(num_classes)
144
+ elif model == "toy_mlp":
145
+ model = get_toy_mlp(input_shape, num_classes)
123
146
  else:
124
147
  weights = "imagenet" if imagenet_pretrained else None
125
148
  backbone = getattr(tf.keras.applications, model)(
@@ -69,12 +69,24 @@ class TorchOperator(Operator):
69
69
  return torch.argmax(tensor, dim=dim)
70
70
 
71
71
  @staticmethod
72
- def max(tensor: TensorType, dim: Optional[int] = None) -> torch.Tensor:
72
+ def max(
73
+ tensor: TensorType, dim: Optional[int] = None, keepdim: Optional[bool] = False
74
+ ) -> torch.Tensor:
73
75
  """Max function"""
74
76
  if dim is None:
75
77
  return torch.max(tensor)
76
78
  else:
77
- return torch.max(tensor, dim)[0]
79
+ return torch.max(tensor, dim, keepdim=keepdim)[0]
80
+
81
+ @staticmethod
82
+ def min(
83
+ tensor: TensorType, dim: Optional[int] = None, keepdim: bool = False
84
+ ) -> torch.Tensor:
85
+ """Min function"""
86
+ if dim is None:
87
+ return torch.min(tensor)
88
+ else:
89
+ return torch.min(tensor, dim, keepdim=keepdim)[0]
78
90
 
79
91
  @staticmethod
80
92
  def one_hot(tensor: TensorType, num_classes: int) -> torch.Tensor:
@@ -167,13 +179,18 @@ class TorchOperator(Operator):
167
179
  def from_numpy(self, arr: np.ndarray) -> torch.Tensor:
168
180
  "Convert a NumPy array to a tensor"
169
181
  # TODO change dtype
170
- return torch.from_numpy(arr).double().to(self._device)
182
+ return torch.tensor(arr).to(self._device)
171
183
 
172
184
  @staticmethod
173
- def transpose(tensor: TensorType) -> torch.Tensor:
185
+ def t(tensor: TensorType) -> torch.Tensor:
174
186
  "Transpose function for tensor of rank 2"
175
187
  return tensor.t()
176
188
 
189
+ @staticmethod
190
+ def permute(tensor: TensorType, dims) -> torch.Tensor:
191
+ "Transpose function for tensor of rank 2"
192
+ return torch.permute(tensor, dims)
193
+
177
194
  @staticmethod
178
195
  def diag(tensor: TensorType) -> torch.Tensor:
179
196
  "Diagonal function: return the diagonal of a 2D tensor"
@@ -215,3 +232,57 @@ class TorchOperator(Operator):
215
232
  def relu(tensor: TensorType) -> torch.Tensor:
216
233
  "Apply relu to a tensor"
217
234
  return torch.nn.functional.relu(tensor)
235
+
236
+ @staticmethod
237
+ def einsum(equation: str, *tensors: TensorType) -> torch.Tensor:
238
+ "Computes the einsum between tensors following equation"
239
+ return torch.einsum(equation, *tensors)
240
+
241
+ @staticmethod
242
+ def tril(tensor: TensorType, diagonal: int = 0) -> torch.Tensor:
243
+ "Set the upper triangle of the matrix formed by the last two dimensions of"
244
+ "tensor to zero"
245
+ return torch.tril(tensor, diagonal)
246
+
247
+ @staticmethod
248
+ def sum(tensor: TensorType, dim: Union[tuple, list, int] = None) -> torch.Tensor:
249
+ "sum along dim"
250
+ return torch.sum(tensor, dim)
251
+
252
+ @staticmethod
253
+ def unsqueeze(tensor: TensorType, dim: int) -> torch.Tensor:
254
+ "unsqueeze along dim"
255
+ return torch.unsqueeze(tensor, dim)
256
+
257
+ @staticmethod
258
+ def squeeze(tensor: TensorType, dim: int = None) -> torch.Tensor:
259
+ "squeeze along dim"
260
+
261
+ if dim is None:
262
+ return torch.squeeze(tensor)
263
+
264
+ return torch.squeeze(tensor, dim)
265
+
266
+ @staticmethod
267
+ def abs(tensor: TensorType) -> torch.Tensor:
268
+ "compute absolute value"
269
+ return torch.abs(tensor)
270
+
271
+ @staticmethod
272
+ def where(
273
+ condition: TensorType,
274
+ input: Union[TensorType, float],
275
+ other: Union[TensorType, float],
276
+ ) -> torch.Tensor:
277
+ "Applies where function , to condition"
278
+ return torch.where(condition, input, other)
279
+
280
+ @staticmethod
281
+ def avg_pool_2d(tensor: TensorType) -> torch.Tensor:
282
+ """Perform avg pool in 2d as in torch.nn.functional.adaptive_avg_pool2d"""
283
+ return torch.mean(tensor, dim=(-2, -1))
284
+
285
+ @staticmethod
286
+ def log(tensor: TensorType) -> torch.Tensor:
287
+ """Perform log"""
288
+ return torch.log(tensor)
@@ -35,6 +35,31 @@ from ..types import Optional
35
35
  from ..types import Union
36
36
 
37
37
 
38
+ class ToyTorchMLP(nn.Sequential):
39
+ """Basic torch MLP classifier for toy datasets.
40
+
41
+ Args:
42
+ input_shape (tuple): Input data shape.
43
+ num_classes (int): Number of classes for the classification task.
44
+ """
45
+
46
+ def __init__(self, input_shape: tuple, num_classes: int):
47
+ self.input_shape = input_shape
48
+
49
+ # build toy mlp
50
+ mlp_modules = OrderedDict(
51
+ [
52
+ ("flatten", nn.Flatten()),
53
+ ("dense1", nn.Linear(np.prod(input_shape), 300)),
54
+ ("relu1", nn.ReLU()),
55
+ ("dense2", nn.Linear(300, 150)),
56
+ ("relu2", nn.ReLU()),
57
+ ("fc1", nn.Linear(150, num_classes)),
58
+ ]
59
+ )
60
+ super().__init__(mlp_modules)
61
+
62
+
38
63
  class ToyTorchConvnet(nn.Sequential):
39
64
  """Basic torch convolutional classifier for toy datasets.
40
65
 
@@ -106,7 +131,7 @@ def train_torch_model(
106
131
  Args:
107
132
  train_data (DataLoader): train dataloader
108
133
  model (Union[nn.Module, str]): if a string is provided, must be a model from
109
- torchvision.models or "toy_convnet".
134
+ torchvision.models or "toy_convnet" or "toy_mlp.
110
135
  num_classes (int): Number of output classes.
111
136
  epochs (int, optional): Defaults to 50.
112
137
  loss (str, optional): Defaults to "CrossEntropyLoss".
@@ -136,8 +161,12 @@ def train_torch_model(
136
161
  elif isinstance(model, str):
137
162
  if model == "toy_convnet":
138
163
  # toy model
139
- input_shape = next(iter(train_data))[0].shape[1:]
164
+ input_shape = tuple(next(iter(train_data))[0].shape[1:])
140
165
  model = ToyTorchConvnet(input_shape, num_classes).to(device)
166
+ elif model == "toy_mlp":
167
+ # toy model
168
+ input_shape = tuple(next(iter(train_data))[0].shape[1:])
169
+ model = ToyTorchMLP(input_shape, num_classes).to(device)
141
170
  else:
142
171
  # torchvision model
143
172
  model = getattr(torchvision.models, model)(
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.4
2
2
  Name: oodeel
3
- Version: 0.1.1
3
+ Version: 0.3.0
4
4
  Summary: Simple, compact, and hackable post-hoc deep OOD detection for alreadytrained tensorflow or pytorch image classifiers.
5
5
  Author: DEEL Core Team
6
6
  Author-email: paul.novello@irt-saintexupery.com
@@ -12,104 +12,126 @@ Classifier: Programming Language :: Python :: 3.9
12
12
  Classifier: Programming Language :: Python :: 3.10
13
13
  Description-Content-Type: text/markdown
14
14
  License-File: LICENSE
15
- Requires-Dist: faiss-cpu
15
+ Requires-Dist: faiss_cpu
16
16
  Requires-Dist: numpy
17
- Requires-Dist: scikit-learn
17
+ Requires-Dist: scikit_learn
18
18
  Requires-Dist: scipy
19
19
  Requires-Dist: setuptools
20
20
  Requires-Dist: matplotlib
21
21
  Requires-Dist: pandas
22
22
  Requires-Dist: seaborn
23
23
  Requires-Dist: plotly
24
+ Requires-Dist: tqdm
24
25
  Provides-Extra: dev
25
- Requires-Dist: mypy ; extra == 'dev'
26
- Requires-Dist: ipywidgets ; extra == 'dev'
27
- Requires-Dist: mkdocs-jupyter ; extra == 'dev'
28
- Requires-Dist: mkdocstrings-python ; extra == 'dev'
29
- Requires-Dist: flake8 ; extra == 'dev'
30
- Requires-Dist: setuptools ; extra == 'dev'
31
- Requires-Dist: pre-commit ; extra == 'dev'
32
- Requires-Dist: tox ; extra == 'dev'
33
- Requires-Dist: black ; extra == 'dev'
34
- Requires-Dist: ipython ; extra == 'dev'
35
- Requires-Dist: ipykernel ; extra == 'dev'
36
- Requires-Dist: pytest ; extra == 'dev'
37
- Requires-Dist: pylint ; extra == 'dev'
38
- Requires-Dist: mkdocs ; extra == 'dev'
39
- Requires-Dist: mkdocs-material ; extra == 'dev'
40
- Requires-Dist: mkdocstrings ; extra == 'dev'
41
- Requires-Dist: mknotebooks ; extra == 'dev'
42
- Requires-Dist: bump2version ; extra == 'dev'
43
- Requires-Dist: docsig ; extra == 'dev'
44
- Requires-Dist: no-implicit-optional ; extra == 'dev'
45
- Requires-Dist: tensorflow ; extra == 'dev'
46
- Requires-Dist: tensorflow-datasets ; extra == 'dev'
47
- Requires-Dist: tensorflow-probability ; extra == 'dev'
48
- Requires-Dist: timm ; extra == 'dev'
49
- Requires-Dist: torch ; extra == 'dev'
50
- Requires-Dist: torchvision ; extra == 'dev'
51
- Provides-Extra: docs
52
- Requires-Dist: mkdocs ; extra == 'docs'
53
- Requires-Dist: mkdocs-material ; extra == 'docs'
54
- Requires-Dist: mkdocstrings ; extra == 'docs'
55
- Requires-Dist: mknotebooks ; extra == 'docs'
56
- Requires-Dist: ipython ; extra == 'docs'
57
- Provides-Extra: tensorflow
58
- Requires-Dist: tensorflow ; extra == 'tensorflow'
59
- Requires-Dist: tensorflow-datasets ; extra == 'tensorflow'
60
- Requires-Dist: tensorflow-probability ; extra == 'tensorflow'
26
+ Requires-Dist: mypy; extra == "dev"
27
+ Requires-Dist: ipywidgets; extra == "dev"
28
+ Requires-Dist: mkdocs-jupyter; extra == "dev"
29
+ Requires-Dist: mkdocstrings-python; extra == "dev"
30
+ Requires-Dist: flake8; extra == "dev"
31
+ Requires-Dist: setuptools; extra == "dev"
32
+ Requires-Dist: pre-commit; extra == "dev"
33
+ Requires-Dist: tox; extra == "dev"
34
+ Requires-Dist: black; extra == "dev"
35
+ Requires-Dist: ruff; extra == "dev"
36
+ Requires-Dist: ipython; extra == "dev"
37
+ Requires-Dist: ipykernel; extra == "dev"
38
+ Requires-Dist: pytest; extra == "dev"
39
+ Requires-Dist: pylint; extra == "dev"
40
+ Requires-Dist: mypy; extra == "dev"
41
+ Requires-Dist: mkdocs; extra == "dev"
42
+ Requires-Dist: mkdocs-material; extra == "dev"
43
+ Requires-Dist: mkdocstrings; extra == "dev"
44
+ Requires-Dist: mknotebooks; extra == "dev"
45
+ Requires-Dist: mike; extra == "dev"
46
+ Requires-Dist: bump2version; extra == "dev"
47
+ Requires-Dist: docsig; extra == "dev"
48
+ Requires-Dist: no_implicit_optional; extra == "dev"
49
+ Requires-Dist: numpy==1.26.4; extra == "dev"
50
+ Requires-Dist: tensorflow==2.11.0; extra == "dev"
51
+ Requires-Dist: tensorflow_datasets; extra == "dev"
52
+ Requires-Dist: tensorflow_probability==0.19.0; extra == "dev"
53
+ Requires-Dist: timm; extra == "dev"
54
+ Requires-Dist: torch==1.13.1; extra == "dev"
55
+ Requires-Dist: torchvision==0.14.1; extra == "dev"
61
56
  Provides-Extra: tensorflow-dev
62
- Requires-Dist: mypy ; extra == 'tensorflow-dev'
63
- Requires-Dist: ipywidgets ; extra == 'tensorflow-dev'
64
- Requires-Dist: mkdocs-jupyter ; extra == 'tensorflow-dev'
65
- Requires-Dist: mkdocstrings-python ; extra == 'tensorflow-dev'
66
- Requires-Dist: flake8 ; extra == 'tensorflow-dev'
67
- Requires-Dist: setuptools ; extra == 'tensorflow-dev'
68
- Requires-Dist: pre-commit ; extra == 'tensorflow-dev'
69
- Requires-Dist: tox ; extra == 'tensorflow-dev'
70
- Requires-Dist: black ; extra == 'tensorflow-dev'
71
- Requires-Dist: ipython ; extra == 'tensorflow-dev'
72
- Requires-Dist: ipykernel ; extra == 'tensorflow-dev'
73
- Requires-Dist: pytest ; extra == 'tensorflow-dev'
74
- Requires-Dist: pylint ; extra == 'tensorflow-dev'
75
- Requires-Dist: mkdocs ; extra == 'tensorflow-dev'
76
- Requires-Dist: mkdocs-material ; extra == 'tensorflow-dev'
77
- Requires-Dist: mkdocstrings ; extra == 'tensorflow-dev'
78
- Requires-Dist: mknotebooks ; extra == 'tensorflow-dev'
79
- Requires-Dist: bump2version ; extra == 'tensorflow-dev'
80
- Requires-Dist: docsig ; extra == 'tensorflow-dev'
81
- Requires-Dist: no-implicit-optional ; extra == 'tensorflow-dev'
82
- Requires-Dist: tensorflow ; extra == 'tensorflow-dev'
83
- Requires-Dist: tensorflow-datasets ; extra == 'tensorflow-dev'
84
- Requires-Dist: tensorflow-probability ; extra == 'tensorflow-dev'
85
- Provides-Extra: torch
86
- Requires-Dist: timm ; extra == 'torch'
87
- Requires-Dist: torch ; extra == 'torch'
88
- Requires-Dist: torchvision ; extra == 'torch'
57
+ Requires-Dist: mypy; extra == "tensorflow-dev"
58
+ Requires-Dist: ipywidgets; extra == "tensorflow-dev"
59
+ Requires-Dist: mkdocs-jupyter; extra == "tensorflow-dev"
60
+ Requires-Dist: mkdocstrings-python; extra == "tensorflow-dev"
61
+ Requires-Dist: flake8; extra == "tensorflow-dev"
62
+ Requires-Dist: setuptools; extra == "tensorflow-dev"
63
+ Requires-Dist: pre-commit; extra == "tensorflow-dev"
64
+ Requires-Dist: tox; extra == "tensorflow-dev"
65
+ Requires-Dist: black; extra == "tensorflow-dev"
66
+ Requires-Dist: ruff; extra == "tensorflow-dev"
67
+ Requires-Dist: ipython; extra == "tensorflow-dev"
68
+ Requires-Dist: ipykernel; extra == "tensorflow-dev"
69
+ Requires-Dist: pytest; extra == "tensorflow-dev"
70
+ Requires-Dist: pylint; extra == "tensorflow-dev"
71
+ Requires-Dist: mypy; extra == "tensorflow-dev"
72
+ Requires-Dist: mkdocs; extra == "tensorflow-dev"
73
+ Requires-Dist: mkdocs-material; extra == "tensorflow-dev"
74
+ Requires-Dist: mkdocstrings; extra == "tensorflow-dev"
75
+ Requires-Dist: mknotebooks; extra == "tensorflow-dev"
76
+ Requires-Dist: mike; extra == "tensorflow-dev"
77
+ Requires-Dist: bump2version; extra == "tensorflow-dev"
78
+ Requires-Dist: docsig; extra == "tensorflow-dev"
79
+ Requires-Dist: no_implicit_optional; extra == "tensorflow-dev"
80
+ Requires-Dist: numpy==1.26.4; extra == "tensorflow-dev"
81
+ Requires-Dist: tensorflow==2.11.0; extra == "tensorflow-dev"
82
+ Requires-Dist: tensorflow_datasets; extra == "tensorflow-dev"
83
+ Requires-Dist: tensorflow_probability==0.19.0; extra == "tensorflow-dev"
89
84
  Provides-Extra: torch-dev
90
- Requires-Dist: mypy ; extra == 'torch-dev'
91
- Requires-Dist: ipywidgets ; extra == 'torch-dev'
92
- Requires-Dist: mkdocs-jupyter ; extra == 'torch-dev'
93
- Requires-Dist: mkdocstrings-python ; extra == 'torch-dev'
94
- Requires-Dist: flake8 ; extra == 'torch-dev'
95
- Requires-Dist: setuptools ; extra == 'torch-dev'
96
- Requires-Dist: pre-commit ; extra == 'torch-dev'
97
- Requires-Dist: tox ; extra == 'torch-dev'
98
- Requires-Dist: black ; extra == 'torch-dev'
99
- Requires-Dist: ipython ; extra == 'torch-dev'
100
- Requires-Dist: ipykernel ; extra == 'torch-dev'
101
- Requires-Dist: pytest ; extra == 'torch-dev'
102
- Requires-Dist: pylint ; extra == 'torch-dev'
103
- Requires-Dist: mkdocs ; extra == 'torch-dev'
104
- Requires-Dist: mkdocs-material ; extra == 'torch-dev'
105
- Requires-Dist: mkdocstrings ; extra == 'torch-dev'
106
- Requires-Dist: mknotebooks ; extra == 'torch-dev'
107
- Requires-Dist: bump2version ; extra == 'torch-dev'
108
- Requires-Dist: docsig ; extra == 'torch-dev'
109
- Requires-Dist: no-implicit-optional ; extra == 'torch-dev'
110
- Requires-Dist: timm ; extra == 'torch-dev'
111
- Requires-Dist: torch ; extra == 'torch-dev'
112
- Requires-Dist: torchvision ; extra == 'torch-dev'
85
+ Requires-Dist: mypy; extra == "torch-dev"
86
+ Requires-Dist: ipywidgets; extra == "torch-dev"
87
+ Requires-Dist: mkdocs-jupyter; extra == "torch-dev"
88
+ Requires-Dist: mkdocstrings-python; extra == "torch-dev"
89
+ Requires-Dist: flake8; extra == "torch-dev"
90
+ Requires-Dist: setuptools; extra == "torch-dev"
91
+ Requires-Dist: pre-commit; extra == "torch-dev"
92
+ Requires-Dist: tox; extra == "torch-dev"
93
+ Requires-Dist: black; extra == "torch-dev"
94
+ Requires-Dist: ruff; extra == "torch-dev"
95
+ Requires-Dist: ipython; extra == "torch-dev"
96
+ Requires-Dist: ipykernel; extra == "torch-dev"
97
+ Requires-Dist: pytest; extra == "torch-dev"
98
+ Requires-Dist: pylint; extra == "torch-dev"
99
+ Requires-Dist: mypy; extra == "torch-dev"
100
+ Requires-Dist: mkdocs; extra == "torch-dev"
101
+ Requires-Dist: mkdocs-material; extra == "torch-dev"
102
+ Requires-Dist: mkdocstrings; extra == "torch-dev"
103
+ Requires-Dist: mknotebooks; extra == "torch-dev"
104
+ Requires-Dist: mike; extra == "torch-dev"
105
+ Requires-Dist: bump2version; extra == "torch-dev"
106
+ Requires-Dist: docsig; extra == "torch-dev"
107
+ Requires-Dist: no_implicit_optional; extra == "torch-dev"
108
+ Requires-Dist: numpy==1.26.4; extra == "torch-dev"
109
+ Requires-Dist: timm; extra == "torch-dev"
110
+ Requires-Dist: torch==1.13.1; extra == "torch-dev"
111
+ Requires-Dist: torchvision==0.14.1; extra == "torch-dev"
112
+ Provides-Extra: tensorflow
113
+ Requires-Dist: tensorflow==2.11.0; extra == "tensorflow"
114
+ Requires-Dist: tensorflow_datasets; extra == "tensorflow"
115
+ Requires-Dist: tensorflow_probability==0.19.0; extra == "tensorflow"
116
+ Provides-Extra: torch
117
+ Requires-Dist: timm; extra == "torch"
118
+ Requires-Dist: torch==1.13.1; extra == "torch"
119
+ Requires-Dist: torchvision==0.14.1; extra == "torch"
120
+ Provides-Extra: docs
121
+ Requires-Dist: mkdocs; extra == "docs"
122
+ Requires-Dist: mkdocs-material; extra == "docs"
123
+ Requires-Dist: mkdocstrings; extra == "docs"
124
+ Requires-Dist: mknotebooks; extra == "docs"
125
+ Requires-Dist: ipython; extra == "docs"
126
+ Dynamic: author
127
+ Dynamic: author-email
128
+ Dynamic: classifier
129
+ Dynamic: description
130
+ Dynamic: description-content-type
131
+ Dynamic: license-file
132
+ Dynamic: provides-extra
133
+ Dynamic: requires-dist
134
+ Dynamic: summary
113
135
 
114
136
 
115
137
  <!-- Banner section -->
@@ -125,29 +147,23 @@ Requires-Dist: torchvision ; extra == 'torch-dev'
125
147
  <!-- Badge section -->
126
148
  <div align="center">
127
149
  <a href="#">
128
- <img src="https://img.shields.io/badge/python-3.8%2B-blue">
129
- </a>
150
+ <img src="https://img.shields.io/badge/python-3.8%2B-blue"></a>
130
151
  <a href="https://github.com/deel-ai/oodeel/actions/workflows/python-linters.yml">
131
- <img alt="Flake8" src="https://github.com/deel-ai/oodeel/actions/workflows/python-linters.yml/badge.svg">
132
- </a>
152
+ <img alt="Flake8" src="https://github.com/deel-ai/oodeel/actions/workflows/python-linters.yml/badge.svg"></a>
133
153
  <a href="https://github.com/deel-ai/oodeel/actions/workflows/python-tests-tf.yml">
134
- <img alt="Tests tf" src="https://github.com/deel-ai/oodeel/actions/workflows/python-tests-tf.yml/badge.svg">
135
- </a>
154
+ <img alt="Tests tf" src="https://github.com/deel-ai/oodeel/actions/workflows/python-tests-tf.yml/badge.svg"></a>
136
155
  <a href="https://github.com/deel-ai/oodeel/actions/workflows/python-tests-torch.yml">
137
- <img alt="Tests torch" src="https://github.com/deel-ai/oodeel/actions/workflows/python-tests-torch.yml/badge.svg">
138
- </a>
156
+ <img alt="Tests torch" src="https://github.com/deel-ai/oodeel/actions/workflows/python-tests-torch.yml/badge.svg"></a>
139
157
  <a href="https://github.com/deel-ai/oodeel/actions/workflows/python-coverage-shield.yml">
140
- <img alt="Coverage" src="https://github.com/deel-ai/oodeel/raw/gh-shields/coverage.svg">
141
- </a>
158
+ <img alt="Coverage" src="https://github.com/deel-ai/oodeel/raw/gh-shields/coverage.svg"></a>
142
159
  <a href="https://github.com/deel-ai/oodeel/blob/master/LICENSE">
143
- <img alt="License MIT" src="https://img.shields.io/badge/License-MIT-efefef">
144
- </a>
160
+ <img alt="License MIT" src="https://img.shields.io/badge/License-MIT-efefef"></a>
145
161
  </div>
146
162
  <br>
147
163
 
148
164
  <!-- Short description of your library -->
149
165
 
150
- <b>Oodeel</b> is a library that performs post-hoc deep OOD detection on already trained neural network image classifiers. The philosophy of the library is to favor quality over quantity and to foster easy adoption. As a result, we provide a simple, compact and easily customizable API and carefully integrate and test each proposed baseline into a coherent framework that is designed to enable their use in tensorflow **and** pytorch. You can find the documentation [here](https://deel-ai.github.io/oodeel/).
166
+ <b>Oodeel</b> is a library that performs post-hoc deep OOD (Out-of-Distribution) detection on already trained neural network image classifiers. The philosophy of the library is to favor quality over quantity and to foster easy adoption. As a result, we provide a simple, compact and easily customizable API and carefully integrate and test each proposed baseline into a coherent framework that is designed to enable their use in tensorflow **and** pytorch. You can find the documentation [here](https://deel-ai.github.io/oodeel/).
151
167
 
152
168
  ```python
153
169
  from oodeel.methods import MLS
@@ -167,7 +183,8 @@ scores, info = mls.score(ds) # ds is a tf.data.Dataset or a torch.DataLoader
167
183
  - [Contributing](#contributing)
168
184
  - [See Also](#see-also)
169
185
  - [Acknowledgments](#acknowledgments)
170
- - [Creator](#creator)
186
+ - [Creators](#creators)
187
+ - [Citation](#citation)
171
188
  - [License](#license)
172
189
 
173
190
  # Installation
@@ -286,15 +303,19 @@ Currently, **oodeel** includes the following baselines:
286
303
  | MSP | [A Baseline for Detecting Misclassified and Out-of-Distribution Examples in Neural Networks](http://arxiv.org/abs/1610.02136) | ICLR 2017 | avail [tensorflow & torch](docs/pages/getting_started.ipynb)|
287
304
  | Mahalanobis | [A Simple Unified Framework for Detecting Out-of-Distribution Samples and Adversarial Attacks](http://arxiv.org/abs/1807.03888) | NeurIPS 2018 | avail [tensorflow](docs/notebooks/tensorflow/demo_mahalanobis_tf.ipynb) or [torch](docs/notebooks/torch/demo_mahalanobis_torch.ipynb)|
288
305
  | Energy | [Energy-based Out-of-distribution Detection](http://arxiv.org/abs/2010.03759) | NeurIPS 2020 |avail [tensorflow](docs/notebooks/tensorflow/demo_energy_tf.ipynb) or [torch](docs/notebooks/torch/demo_energy_torch.ipynb) |
289
- | Odin | [Enhancing The Reliability of Out-of-distribution Image Detection in Neural Networks](http://arxiv.org/abs/1706.02690) | ICLR 2018 |avail [tensorflow](docs/notebooks/tensorflow/demo_odin_tf.ipynb) or [torch](docs/notebooks/torch/demo_odin_torch.ipynb) |
306
+ | Odin | [Enhancing The Reliability of Out-of-distribution Image Detection in Neural Networks](http://arxiv.org/abs/1706.02690) | ICLR 2018 | avail [tensorflow](docs/notebooks/tensorflow/demo_odin_tf.ipynb) or [torch](docs/notebooks/torch/demo_odin_torch.ipynb) |
290
307
  | DKNN | [Out-of-Distribution Detection with Deep Nearest Neighbors](http://arxiv.org/abs/2204.06507) | ICML 2022 | avail [tensorflow](docs/notebooks/tensorflow/demo_dknn_tf.ipynb) or [torch](docs/notebooks/torch/demo_dknn_torch.ipynb) |
291
308
  | VIM | [ViM: Out-Of-Distribution with Virtual-logit Matching](http://arxiv.org/abs/2203.10807) | CVPR 2022 |avail [tensorflow](docs/notebooks/tensorflow/demo_vim_tf.ipynb) or [torch](docs/notebooks/torch/demo_vim_torch.ipynb) |
292
309
  | Entropy | [Likelihood Ratios for Out-of-Distribution Detection](https://proceedings.neurips.cc/paper/2019/hash/1e79596878b2320cac26dd792a6c51c9-Abstract.html) | NeurIPS 2019 |avail [tensorflow](docs/notebooks/tensorflow/demo_entropy_tf.ipynb) or [torch](docs/notebooks/torch/demo_entropy_torch.ipynb) |
293
310
  | GODIN | [Generalized ODIN: Detecting Out-of-Distribution Image Without Learning From Out-of-Distribution Data](https://ieeexplore.ieee.org/document/9156473/) | CVPR 2020 | planned |
294
311
  | ReAct | [ReAct: Out-of-distribution Detection With Rectified Activations](http://arxiv.org/abs/2111.12797) | NeurIPS 2021 | avail [tensorflow](docs/notebooks/tensorflow/demo_react_tf.ipynb) or [torch](docs/notebooks/torch/demo_react_torch.ipynb) |
295
312
  | NMD | [Neural Mean Discrepancy for Efficient Out-of-Distribution Detection](https://openaccess.thecvf.com/content/CVPR2022/html/Dong_Neural_Mean_Discrepancy_for_Efficient_Out-of-Distribution_Detection_CVPR_2022_paper.html) | CVPR 2022 | planned |
296
- | Gram | [Detecting Out-of-Distribution Examples with Gram Matrices](https://proceedings.mlr.press/v119/sastry20a.html) | ICML 2020 | planned |
297
-
313
+ | Gram | [Detecting Out-of-Distribution Examples with Gram Matrices](https://proceedings.mlr.press/v119/sastry20a.html) | ICML 2020 | avail [tensorflow](docs/notebooks/tensorflow/demo_gram_tf.ipynb) or [torch](docs/notebooks/torch/demo_gram_torch.ipynb) |
314
+ | GEN | [GEN: Pushing the Limits of Softmax-Based Out-of-Distribution Detection](https://openaccess.thecvf.com/content/CVPR2023/html/Liu_GEN_Pushing_the_Limits_of_Softmax-Based_Out-of-Distribution_Detection_CVPR_2023_paper.html) | CVPR 2023 | avail [tensorflow](docs/notebooks/tensorflow/demo_gen_tf.ipynb) or [torch](docs/notebooks/torch/demo_gen_torch.ipynb) |
315
+ | RMDS | [A Simple Fix to Mahalanobis Distance for Improving Near-OOD Detection](https://arxiv.org/abs/2106.09022) | preprint | avail [tensorflow](docs/notebooks/tensorflow/demo_rmds_tf.ipynb) or [torch](docs/notebooks/torch/demo_rmds_torch.ipynb) |
316
+ | SHE | [Out-of-Distribution Detection based on In-Distribution Data Patterns Memorization with Modern Hopfield Energy](https://openreview.net/forum?id=KkazG4lgKL) | ICLR 2023 | avail [tensorflow](docs/notebooks/tensorflow/demo_she_tf.ipynb) or [torch](docs/notebooks/torch/demo_she_torch.ipynb) |
317
+ | ASH | [Extremely Simple Activation Shaping for Out-of-Distribution Detection](http://arxiv.org/abs/2310.00227) | ICLR 2023 | avail [tensorflow](docs/notebooks/tensorflow/demo_ash_tf.ipynb) or [torch](docs/notebooks/torch/demo_ash_torch.ipynb) |
318
+ | SCALE | [Scaling for Training Time and Post-hoc Out-of-distribution Detection Enhancement](https://arxiv.org/abs/2111.12797) | ICLR 2024 | avail [tensorflow](docs/notebooks/tensorflow/demo_scale_tf.ipynb) or [torch](docs/notebooks/torch/demo_scale_torch.ipynb) |
298
319
 
299
320
 
300
321
 
@@ -344,6 +365,19 @@ This project received funding from the French ”Investing for the Future – PI
344
365
 
345
366
  The library was created by Paul Novello to streamline DEEL research on post-hoc deep OOD methods and foster their adoption by DEEL industrial partners. He was soon joined by Yann Pequignot, Yannick Prudent, Corentin Friedrich and Matthieu Le Goff.
346
367
 
368
+ # Citation
369
+
370
+ If you use OODEEL for your research project, please consider citing:
371
+ ```
372
+ @misc{oodeel,
373
+ author = {Novello, Paul and Prudent, Yannick and Friedrich, Corentin and Pequignot, Yann and Le Goff, Matthieu},
374
+ title = {OODEEL, a simple, compact, and hackable post-hoc deep OOD detection for already trained tensorflow or pytorch image classifiers.},
375
+ year = {2023},
376
+ publisher = {GitHub},
377
+ journal = {GitHub repository},
378
+ howpublished = {\url{https://github.com/deel-ai/oodeel}},
379
+ }
380
+ ```
347
381
  # License
348
382
 
349
383
  The package is released under [MIT license](LICENSE).
@@ -0,0 +1,57 @@
1
+ oodeel/__init__.py,sha256=-3tTNyrjeIvLCDPR0kCXup2TQX0KHn7R0emJDOKUYxI,1343
2
+ oodeel/datasets/__init__.py,sha256=COhGN25e3TsYtIz5rJgc0aNr_k3JjGXE3cCYd5RajTQ,1398
3
+ oodeel/datasets/data_handler.py,sha256=mWg1_Jv19rpz7p8o2DOE859J8evTRrWTiW5U_prw1sQ,10532
4
+ oodeel/datasets/tf_data_handler.py,sha256=H3BN_N9fzT19vtJtjhe0XJG4aA-7BauGrEY4OASxc90,23102
5
+ oodeel/datasets/torch_data_handler.py,sha256=jzQ1DvHC3xM0DOLHWGtV5OX-C63eqhabDyVDKmg11cc,24792
6
+ oodeel/datasets/deprecated/DEPRECATED_data_handler.py,sha256=fGK3_YSbNdHuljBtrjt7xbi6ESfNpnQV_pUSn9Uot2k,7910
7
+ oodeel/datasets/deprecated/DEPRECATED_ooddataset.py,sha256=Ad2otHV_vClK7ZY2D3-gW4QO_B3ir6DzJM1kYeNCJpw,13294
8
+ oodeel/datasets/deprecated/DEPRECATED_tf_data_handler.py,sha256=heK9c1g829LZof8CkAmmxpTvE0EpoLZOLXuYKMrL8b4,24811
9
+ oodeel/datasets/deprecated/DEPRECATED_torch_data_handler.py,sha256=eWCOlSPbblS8KZZAlu-M10QgY2Hv7mOq1uawLcGs0FE,26986
10
+ oodeel/datasets/deprecated/__init__.py,sha256=VYVsYZ-_TgGyd5ls5rudytQGPocm9I7N5_M9rHcV91w,1506
11
+ oodeel/eval/__init__.py,sha256=lQIUQjczeiRtfIqH2uLNJGubKUN6vPM93mTfY1Qz3bc,1297
12
+ oodeel/eval/metrics.py,sha256=9fg7fVTT10_XIKrcsan6HQOlhUdFJhZ3f5VAprTKsjM,8839
13
+ oodeel/eval/plots/__init__.py,sha256=YmcFh8RUGvljc-vCew6EiIFMDn0YA_UOfDj4eAv5_Yk,1487
14
+ oodeel/eval/plots/features.py,sha256=fP4soXDg3BAtkBySvfUIH3Z_E3ofuU3QenH6SHRnzlM,11821
15
+ oodeel/eval/plots/metrics.py,sha256=3QvLqEB1pAggNHeQJzUpLqL_Ro0MOJ_bPrLcLmf_qpk,4189
16
+ oodeel/eval/plots/plotly.py,sha256=XHPFuXARrRIGVrRB7d0Od36iCn3JBP387tDrwRvPhU0,6040
17
+ oodeel/extractor/__init__.py,sha256=Ew8gLh-xanZsWJe_gKvTa_7HciZ5yTZ-cADKuj7-PCg,1538
18
+ oodeel/extractor/feature_extractor.py,sha256=WFNOrZjmCPt2mPr3G4ctcSqiSuvFjMZkozrwazv1zwg,6789
19
+ oodeel/extractor/keras_feature_extractor.py,sha256=bjJfh_wgt2Y0NOGBesZ8uBjtcyiXtHYJzsWCS_tsIFg,12412
20
+ oodeel/extractor/torch_feature_extractor.py,sha256=D_c_AV29Ciu-d_bWXk1ABad66qdX8wdddxrCgLQ-5G4,15779
21
+ oodeel/methods/__init__.py,sha256=BkpwY0-N6SAvJ8iUxe1g0J8On-tbmLCCYTJIFI3M7i0,1721
22
+ oodeel/methods/base.py,sha256=zeizvn6pMvdaP4FnI5TcvgrCdFARmxrwCfyU3D9LS1U,14831
23
+ oodeel/methods/dknn.py,sha256=6eyoHAKzb3s3pkukr6r5v9_UyK8_Rj269iHQxv1PoFI,4982
24
+ oodeel/methods/energy.py,sha256=u3TSoGDll1myCiA0FkuYZZNmGEXTTFPRSpT7-J8Q6Ec,4505
25
+ oodeel/methods/entropy.py,sha256=dcG7oC6velSN1sky9DJAD9eD32EEPqgM4R45_iU-R18,4172
26
+ oodeel/methods/gen.py,sha256=JafCnih09929I2KgOVaDfn8T1O4mbKdZgz3lByKvOJQ,4505
27
+ oodeel/methods/gram.py,sha256=35yH3BYl5CgsC2fzw4GLY21MWfyLNvY0mp3tAfmJkJ4,12220
28
+ oodeel/methods/mahalanobis.py,sha256=C6t2Gj7Ckp5KIhGFiOpk9uWY1RWIkM6h_qfLQkVGETc,7393
29
+ oodeel/methods/mls.py,sha256=FxN1XYlTkOR2uZs07krsD8qCAfJr9UIpdKH6FYEn37Q,4399
30
+ oodeel/methods/odin.py,sha256=Ty1YItWIsvplad72JmGuNZ_MDIBxg9iJVbjAcn33K3U,5548
31
+ oodeel/methods/rmds.py,sha256=ji6r8_AGhIc_75-u174Mj6a8YpHC6EHDm4mFBSOsiFM,4970
32
+ oodeel/methods/she.py,sha256=qJftwTM1UD_dtBvGH66RtejWqpkxR69ttpBy5KbNvHQ,7254
33
+ oodeel/methods/vim.py,sha256=PMy04hMZ08THlbi66ZWg15JL9NNGBG6vndz40sez5Ps,9475
34
+ oodeel/preprocess/__init__.py,sha256=65M9hKYHYzZ6lACA6xo7ODKofaAoqfFZ4aC5ZSJeN2I,1484
35
+ oodeel/preprocess/tf_preprocess.py,sha256=TRaEA7KrVjWFB81vlnNk9hN-G0tGhDa97VvrBEaP9vM,4048
36
+ oodeel/preprocess/torch_preprocess.py,sha256=BTjOyEHPfqx_CSv6Lw3zBi2wKEjAIdPN1DuBptcebA0,3976
37
+ oodeel/types/__init__.py,sha256=9TTXjSBfbaDIVMRnclInHI-CBr4L6VZTi61rCJKcTw8,2484
38
+ oodeel/utils/__init__.py,sha256=Lue9BysqeJf5Ej0szafgC9na8IZZR7hWGInJxoEiHUg,1696
39
+ oodeel/utils/general_utils.py,sha256=xc6e7q19ALgMxdCgS7TIyDiMUIGF4ih-aTK1kSlqWoQ,3292
40
+ oodeel/utils/operator.py,sha256=ETAFJ_oYhiD1Rawjooueq5KDl4SNzJR5fQDUU05uMz8,8262
41
+ oodeel/utils/tf_operator.py,sha256=gHJZlD6SKOBxDtwv0oy6u93oXYiS5gylKv5lbVWlzx4,9385
42
+ oodeel/utils/tf_training_tools.py,sha256=31cPCpXKmO7lMRrefmY93m6cSn0DRaHxVFeYgUiE6kI,8090
43
+ oodeel/utils/torch_operator.py,sha256=lO_Onb4km6UNyqYjRkFPi3OcoT2MUwE1ktzP_JMDmFY,10199
44
+ oodeel/utils/torch_training_tools.py,sha256=ggL_iDwyquTw9CtQAg--IODVdS1BFsBw4U5IOVlAsK8,11192
45
+ oodeel-0.3.0.dist-info/licenses/LICENSE,sha256=XrlZ0uYNVeUAF-iEVX21J3CTJjYPgIZUagYSy3Hf0jk,1265
46
+ tests/__init__.py,sha256=lQIUQjczeiRtfIqH2uLNJGubKUN6vPM93mTfY1Qz3bc,1297
47
+ tests/tools_operator.py,sha256=YTU_FppXZPUyi3nldQsDkCMQ3Bvz1Hn_V7L45EEwftc,5328
48
+ tests/tests_tensorflow/__init__.py,sha256=VuiDSdOBB2jUeobzAW-XTtHJQrcyZpp8i2DyFu0A2RI,1710
49
+ tests/tests_tensorflow/tf_methods_utils.py,sha256=wHjsvKJj1EhcmAQhDgSWfRbVZR_jJwknT8vyez0YByw,5277
50
+ tests/tests_tensorflow/tools_tf.py,sha256=Z_MEzhwCOF7E7uT-tnfRS3Po9hg0MFzbESP7sMAFBkY,3499
51
+ tests/tests_torch/__init__.py,sha256=3mVxix2Ecn2wUo9DxvzyJBmyOAkv4fVWHHp6uLQHoic,1738
52
+ tests/tests_torch/tools_torch.py,sha256=lS5_kctnyiB6eP-NTfVUsbj5Mn-AiRO0Jo75p8djmwg,4975
53
+ tests/tests_torch/torch_methods_utils.py,sha256=PAxBs2smjsDtQjgwrwdbDzvVrboczc0pQOKtNfl82J4,5432
54
+ oodeel-0.3.0.dist-info/METADATA,sha256=YNvMUC4osqrw2FWlqqNMJiapvjBGFwxp834EOkNIYtg,20299
55
+ oodeel-0.3.0.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
56
+ oodeel-0.3.0.dist-info/top_level.txt,sha256=zkYRty1FGJ1dkpk-5MU_4uFfBFmcxXoqSwej73xELDs,13
57
+ oodeel-0.3.0.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: bdist_wheel (0.41.2)
2
+ Generator: setuptools (78.1.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -66,7 +66,8 @@ def load_blobs_data(batch_size=128, num_samples=10000, train_ratio=0.8):
66
66
  def load_blob_mlp():
67
67
  model_path_blobs = tf.keras.utils.get_file(
68
68
  "blobs_mlp.h5",
69
- origin="https://share.deel.ai/s/bc5jx9HQAGYya9m/download/blobs_mlp.h5",
69
+ origin="https://github.com/deel-ai/oodeel/blob/assets/test_models/"
70
+ + "blobs_mlp.h5?raw=True",
70
71
  cache_dir=model_path,
71
72
  cache_subdir="",
72
73
  )