oodeel 0.1.1__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of oodeel might be problematic. Click here for more details.
- oodeel/__init__.py +1 -1
- oodeel/datasets/__init__.py +2 -1
- oodeel/datasets/data_handler.py +162 -94
- oodeel/datasets/deprecated/DEPRECATED_data_handler.py +236 -0
- oodeel/datasets/{ooddataset.py → deprecated/DEPRECATED_ooddataset.py} +14 -13
- oodeel/datasets/deprecated/DEPRECATED_tf_data_handler.py +671 -0
- oodeel/datasets/deprecated/DEPRECATED_torch_data_handler.py +769 -0
- oodeel/datasets/deprecated/__init__.py +31 -0
- oodeel/datasets/tf_data_handler.py +105 -167
- oodeel/datasets/torch_data_handler.py +109 -181
- oodeel/eval/metrics.py +7 -2
- oodeel/eval/plots/features.py +2 -2
- oodeel/eval/plots/plotly.py +2 -2
- oodeel/extractor/feature_extractor.py +30 -9
- oodeel/extractor/keras_feature_extractor.py +70 -13
- oodeel/extractor/torch_feature_extractor.py +120 -33
- oodeel/methods/__init__.py +17 -1
- oodeel/methods/base.py +103 -17
- oodeel/methods/dknn.py +22 -9
- oodeel/methods/energy.py +8 -0
- oodeel/methods/entropy.py +8 -0
- oodeel/methods/gen.py +118 -0
- oodeel/methods/gram.py +307 -0
- oodeel/methods/mahalanobis.py +14 -12
- oodeel/methods/mls.py +8 -0
- oodeel/methods/odin.py +8 -0
- oodeel/methods/rmds.py +122 -0
- oodeel/methods/she.py +197 -0
- oodeel/methods/vim.py +5 -5
- oodeel/preprocess/__init__.py +31 -0
- oodeel/preprocess/tf_preprocess.py +95 -0
- oodeel/preprocess/torch_preprocess.py +97 -0
- oodeel/utils/operator.py +72 -2
- oodeel/utils/tf_operator.py +72 -4
- oodeel/utils/tf_training_tools.py +26 -3
- oodeel/utils/torch_operator.py +75 -4
- oodeel/utils/torch_training_tools.py +31 -2
- {oodeel-0.1.1.dist-info → oodeel-0.3.0.dist-info}/METADATA +141 -107
- oodeel-0.3.0.dist-info/RECORD +57 -0
- {oodeel-0.1.1.dist-info → oodeel-0.3.0.dist-info}/WHEEL +1 -1
- tests/tests_tensorflow/tf_methods_utils.py +2 -1
- tests/tests_torch/tools_torch.py +9 -9
- tests/tests_torch/torch_methods_utils.py +34 -27
- tests/tools_operator.py +10 -1
- oodeel-0.1.1.dist-info/RECORD +0 -46
- {oodeel-0.1.1.dist-info → oodeel-0.3.0.dist-info/licenses}/LICENSE +0 -0
- {oodeel-0.1.1.dist-info → oodeel-0.3.0.dist-info}/top_level.txt +0 -0
tests/tests_torch/tools_torch.py
CHANGED
|
@@ -35,14 +35,14 @@ def almost_equal(arr1, arr2, epsilon=1e-6):
|
|
|
35
35
|
|
|
36
36
|
|
|
37
37
|
class Net(nn.Module):
|
|
38
|
-
def __init__(self):
|
|
38
|
+
def __init__(self, num_classes=10):
|
|
39
39
|
super().__init__()
|
|
40
40
|
self.conv1 = nn.Conv2d(3, 6, 5)
|
|
41
41
|
self.pool = nn.MaxPool2d(2, 2)
|
|
42
42
|
self.conv2 = nn.Conv2d(6, 16, 5)
|
|
43
43
|
self.fc1 = nn.Linear(16 * 5 * 5, 120)
|
|
44
44
|
self.fc2 = nn.Linear(120, 84)
|
|
45
|
-
self.fc3 = nn.Linear(84,
|
|
45
|
+
self.fc3 = nn.Linear(84, num_classes)
|
|
46
46
|
|
|
47
47
|
def forward(self, x):
|
|
48
48
|
x = self.pool(F.relu(self.conv1(x)))
|
|
@@ -55,7 +55,7 @@ class Net(nn.Module):
|
|
|
55
55
|
|
|
56
56
|
|
|
57
57
|
class ComplexNet(nn.Module):
|
|
58
|
-
def __init__(self):
|
|
58
|
+
def __init__(self, num_classes=10):
|
|
59
59
|
super().__init__()
|
|
60
60
|
|
|
61
61
|
self.feature_extractor = nn.Sequential(
|
|
@@ -77,7 +77,7 @@ class ComplexNet(nn.Module):
|
|
|
77
77
|
[
|
|
78
78
|
("fc1", nn.Linear(16 * 5 * 5, 120)),
|
|
79
79
|
("fc2", nn.Linear(120, 84)),
|
|
80
|
-
("fc3", nn.Linear(84,
|
|
80
|
+
("fc3", nn.Linear(84, num_classes)),
|
|
81
81
|
]
|
|
82
82
|
)
|
|
83
83
|
)
|
|
@@ -88,7 +88,7 @@ class ComplexNet(nn.Module):
|
|
|
88
88
|
return x
|
|
89
89
|
|
|
90
90
|
|
|
91
|
-
def sequential_model():
|
|
91
|
+
def sequential_model(num_classes=10):
|
|
92
92
|
return nn.Sequential(
|
|
93
93
|
nn.Conv2d(3, 6, 5),
|
|
94
94
|
nn.ReLU(),
|
|
@@ -99,11 +99,11 @@ def sequential_model():
|
|
|
99
99
|
nn.Flatten(),
|
|
100
100
|
nn.Linear(16 * 5 * 5, 120),
|
|
101
101
|
nn.Linear(120, 84),
|
|
102
|
-
nn.Linear(84,
|
|
102
|
+
nn.Linear(84, num_classes),
|
|
103
103
|
)
|
|
104
104
|
|
|
105
105
|
|
|
106
|
-
def named_sequential_model():
|
|
106
|
+
def named_sequential_model(num_classes=10):
|
|
107
107
|
return nn.Sequential(
|
|
108
108
|
OrderedDict(
|
|
109
109
|
[
|
|
@@ -116,13 +116,13 @@ def named_sequential_model():
|
|
|
116
116
|
("flatten", nn.Flatten()),
|
|
117
117
|
("fc1", nn.Linear(16 * 5 * 5, 120)),
|
|
118
118
|
("fc2", nn.Linear(120, 84)),
|
|
119
|
-
("fc3", nn.Linear(84,
|
|
119
|
+
("fc3", nn.Linear(84, num_classes)),
|
|
120
120
|
]
|
|
121
121
|
)
|
|
122
122
|
)
|
|
123
123
|
|
|
124
124
|
|
|
125
|
-
def simplest_mlp(num_features, num_classes):
|
|
125
|
+
def simplest_mlp(num_features, num_classes=10):
|
|
126
126
|
return nn.Sequential(
|
|
127
127
|
nn.Linear(num_features, 64),
|
|
128
128
|
nn.ReLU(),
|
|
@@ -23,12 +23,11 @@
|
|
|
23
23
|
import os
|
|
24
24
|
|
|
25
25
|
import numpy as np
|
|
26
|
-
import requests
|
|
27
26
|
import torch
|
|
28
27
|
from sklearn.datasets import make_blobs
|
|
29
28
|
from sklearn.model_selection import train_test_split
|
|
30
29
|
|
|
31
|
-
from oodeel.datasets import
|
|
30
|
+
from oodeel.datasets import load_data_handler
|
|
32
31
|
from oodeel.eval.metrics import bench_metrics
|
|
33
32
|
|
|
34
33
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
@@ -53,31 +52,26 @@ def load_blobs_data(batch_size=128, num_samples=10000, train_ratio=0.8):
|
|
|
53
52
|
)
|
|
54
53
|
|
|
55
54
|
# === id / ood split ===
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
55
|
+
handler = load_data_handler("torch")
|
|
56
|
+
blobs_train = handler.load_dataset((X_train, y_train))
|
|
57
|
+
blobs_test = handler.load_dataset((X_test, y_test))
|
|
58
|
+
oods_fit, _ = handler.split_by_class(blobs_train, in_labels, out_labels)
|
|
59
|
+
oods_in, oods_out = handler.split_by_class(blobs_test, in_labels, out_labels)
|
|
60
60
|
|
|
61
61
|
# === prepare data (shuffle, batch) => torch dataloaders ===
|
|
62
|
-
ds_fit =
|
|
63
|
-
ds_in =
|
|
64
|
-
ds_out =
|
|
62
|
+
ds_fit = handler.prepare(oods_fit, batch_size=batch_size, shuffle=True)
|
|
63
|
+
ds_in = handler.prepare(oods_in, batch_size=batch_size)
|
|
64
|
+
ds_out = handler.prepare(oods_out, batch_size=batch_size)
|
|
65
65
|
return ds_fit, ds_in, ds_out
|
|
66
66
|
|
|
67
67
|
|
|
68
68
|
def load_blob_mlp():
|
|
69
|
-
model_path_blob = os.path.join(model_path, "blobs_mlp.pt")
|
|
70
|
-
|
|
71
|
-
# if model not in local, download it
|
|
72
|
-
if not os.path.exists(model_path_blob):
|
|
73
|
-
data = requests.get(
|
|
74
|
-
"https://share.deel.ai/s/xcyk3ET8fzfTp8S/download/blobs_mlp.pt"
|
|
75
|
-
)
|
|
76
|
-
with open(model_path_blob, "wb") as file:
|
|
77
|
-
file.write(data.content)
|
|
78
|
-
|
|
79
69
|
# load model
|
|
80
|
-
model = torch.
|
|
70
|
+
model = torch.hub.load_state_dict_from_url(
|
|
71
|
+
"https://github.com/deel-ai/oodeel/blob/assets/"
|
|
72
|
+
+ "test_models/blobs_mlp.pt?raw=True",
|
|
73
|
+
map_location=device,
|
|
74
|
+
)
|
|
81
75
|
model.eval()
|
|
82
76
|
return model
|
|
83
77
|
|
|
@@ -125,14 +119,27 @@ def eval_detector_on_blobs(
|
|
|
125
119
|
|
|
126
120
|
# react specific test
|
|
127
121
|
# /!\ do it at the end of the test because it may affect the detector's behaviour
|
|
122
|
+
|
|
128
123
|
if check_react_clipping:
|
|
129
124
|
assert detector.react_threshold is not None
|
|
125
|
+
|
|
130
126
|
penult_feat_extractor = detector._load_feature_extractor(
|
|
131
|
-
model=model, feature_layers_id=[-
|
|
132
|
-
)
|
|
133
|
-
penult_features = penult_feat_extractor.predict(ds_fit)[0][0]
|
|
134
|
-
assert torch.max(penult_features) <= detector.react_threshold, (
|
|
135
|
-
f"Maximum value of penultimate features ({torch.max(penult_features)})"
|
|
136
|
-
+ " should be less than or equal to the react threshold value"
|
|
137
|
-
+ f" ({detector.react_threshold})"
|
|
127
|
+
model=model, feature_layers_id=[-1]
|
|
138
128
|
)
|
|
129
|
+
# penult_feat_extractor._prepare_ood_handles()
|
|
130
|
+
|
|
131
|
+
def hook(_, input):
|
|
132
|
+
penult_feat_extractor._features["test"] = input[0]
|
|
133
|
+
|
|
134
|
+
penult_feat_extractor.model[-1].register_forward_pre_hook(hook)
|
|
135
|
+
for x, y in ds_fit:
|
|
136
|
+
_ = penult_feat_extractor.predict_tensor(x)
|
|
137
|
+
assert (
|
|
138
|
+
torch.max(penult_feat_extractor._features["test"])
|
|
139
|
+
<= detector.react_threshold
|
|
140
|
+
), (
|
|
141
|
+
"Maximum value of penultimate features"
|
|
142
|
+
+ f" ({torch.max(penult_feat_extractor._features)})"
|
|
143
|
+
+ " should be less than or equal to the react threshold value"
|
|
144
|
+
+ f" ({detector.react_threshold})"
|
|
145
|
+
)
|
tests/tools_operator.py
CHANGED
|
@@ -127,7 +127,7 @@ def check_common_operators(backend):
|
|
|
127
127
|
assert operator.flatten(x).shape == (25, 12 * 6)
|
|
128
128
|
|
|
129
129
|
# Transpose
|
|
130
|
-
assert operator.
|
|
130
|
+
assert operator.t(x[0]).shape == (6, 12)
|
|
131
131
|
|
|
132
132
|
# Diag
|
|
133
133
|
assert operator.diag(x[0]).shape == (6,)
|
|
@@ -142,3 +142,12 @@ def check_common_operators(backend):
|
|
|
142
142
|
|
|
143
143
|
# Pinv
|
|
144
144
|
assert operator.pinv(x[0]).shape == (6, 12)
|
|
145
|
+
|
|
146
|
+
# einsum
|
|
147
|
+
ein = operator.einsum("bij,jk->bik", x, operator.t(x[0]))
|
|
148
|
+
assert ein.shape == (25, 12, 12)
|
|
149
|
+
|
|
150
|
+
# tril
|
|
151
|
+
triangle = operator.tril(ein, diagonal=-1)
|
|
152
|
+
assert triangle.shape == (25, 12, 12)
|
|
153
|
+
assert np.sum([m[0, 0] + m[11, 11] + m[0, 11] for m in triangle]) == 0
|
oodeel-0.1.1.dist-info/RECORD
DELETED
|
@@ -1,46 +0,0 @@
|
|
|
1
|
-
oodeel/__init__.py,sha256=2q31pS-ZP2CnwNQfZZH3G8iwjYBT2qOwS9Fpl9NERC4,1343
|
|
2
|
-
oodeel/datasets/__init__.py,sha256=BdpBDtWoGS3fgJCRU6TTCcYodzVGYANv5q4Abqhp1RU,1332
|
|
3
|
-
oodeel/datasets/data_handler.py,sha256=kxWKmUjdKKtOOWIPXbt96Bhg8SxVcFUhBCD2XEHx5Ac,7904
|
|
4
|
-
oodeel/datasets/ooddataset.py,sha256=tq3ZIjLNe3iWhjFkmpI4jomLJJ4-CfmsuHoTFR25gBY,13360
|
|
5
|
-
oodeel/datasets/tf_data_handler.py,sha256=FmqWoV05yIQ0lHNBdDE8pTgSdma4jbdjT3YvQOS9KnY,24794
|
|
6
|
-
oodeel/datasets/torch_data_handler.py,sha256=SMa8RvMYp2-f0AI3m5MKeZVFh9G6CMVJl_MNZmiLK74,26814
|
|
7
|
-
oodeel/eval/__init__.py,sha256=lQIUQjczeiRtfIqH2uLNJGubKUN6vPM93mTfY1Qz3bc,1297
|
|
8
|
-
oodeel/eval/metrics.py,sha256=ywU9Y4E_PYckdxllAzUd17T1NtCKq4SfRPigrw78PHI,8589
|
|
9
|
-
oodeel/eval/plots/__init__.py,sha256=YmcFh8RUGvljc-vCew6EiIFMDn0YA_UOfDj4eAv5_Yk,1487
|
|
10
|
-
oodeel/eval/plots/features.py,sha256=tEV1j_1zgqVmC7MOZa7-LyaSQMqL6egH8_FjUwqIhVQ,11815
|
|
11
|
-
oodeel/eval/plots/metrics.py,sha256=3QvLqEB1pAggNHeQJzUpLqL_Ro0MOJ_bPrLcLmf_qpk,4189
|
|
12
|
-
oodeel/eval/plots/plotly.py,sha256=lc22a880TJgUf4H8CSYnXslYRl0I0QqjYsFrf1XjZyA,6034
|
|
13
|
-
oodeel/extractor/__init__.py,sha256=Ew8gLh-xanZsWJe_gKvTa_7HciZ5yTZ-cADKuj7-PCg,1538
|
|
14
|
-
oodeel/extractor/feature_extractor.py,sha256=PE6YczZwRhrd3ZAtEwyWhHl1KmDdvhAOiZHqtYIAnpA,5649
|
|
15
|
-
oodeel/extractor/keras_feature_extractor.py,sha256=IzPDYKS6nHMp-7EIur3mtK5FEVAdPBT_4OTOKTSBRc8,9732
|
|
16
|
-
oodeel/extractor/torch_feature_extractor.py,sha256=VYPoTJfauYkrL2IYmxH9hCT_O8aZKYwhMeB2bZ1wpa4,12465
|
|
17
|
-
oodeel/methods/__init__.py,sha256=TPB-l1dbtvS9YpSG64mFN28Squ6hWNwNMHOFpX5inU0,1556
|
|
18
|
-
oodeel/methods/base.py,sha256=MvjIk7flOGlcufH8UmgWdORJbSIwgPHWUzgBEE8pHxE,11075
|
|
19
|
-
oodeel/methods/dknn.py,sha256=Cbv2y80u8FPho1glb7abC6MGNk2UEjHvEzzMuorBmcc,4264
|
|
20
|
-
oodeel/methods/energy.py,sha256=xcSVs0PibzjFT6pftMuKyZlYkmMdqWJA2WQxw60NHHM,4211
|
|
21
|
-
oodeel/methods/entropy.py,sha256=F5w312f6wpvk9DBHi7LDDcndrMOoPY8OuFCO0ddgOU4,3878
|
|
22
|
-
oodeel/methods/mahalanobis.py,sha256=IIiuyAAvlLkuSZRAmEUGoNANMMuM6KJ-YWb2kc42IsA,7237
|
|
23
|
-
oodeel/methods/mls.py,sha256=BJajmhl3x11Kr-8RwYq59kAn-rveA5ydyoko7ajF_60,4105
|
|
24
|
-
oodeel/methods/odin.py,sha256=E4I2olYH9xvnkgau_2rcPnWL3FFWAh-jxG60-CGxlJA,5254
|
|
25
|
-
oodeel/methods/vim.py,sha256=Z6VeAe_Aaza_rvWL56z1JLVBiRBfmVm-k4ifbrFeH_Q,9485
|
|
26
|
-
oodeel/types/__init__.py,sha256=9TTXjSBfbaDIVMRnclInHI-CBr4L6VZTi61rCJKcTw8,2484
|
|
27
|
-
oodeel/utils/__init__.py,sha256=Lue9BysqeJf5Ej0szafgC9na8IZZR7hWGInJxoEiHUg,1696
|
|
28
|
-
oodeel/utils/general_utils.py,sha256=xc6e7q19ALgMxdCgS7TIyDiMUIGF4ih-aTK1kSlqWoQ,3292
|
|
29
|
-
oodeel/utils/operator.py,sha256=2_jwROZbg0BkZzi6j7SyoGh7iulC1iI-QzZHGZQKHNM,6167
|
|
30
|
-
oodeel/utils/tf_operator.py,sha256=kOp5BvXyY7yf6LdmJg2vRLZHPTFV2HLhlR1na_hZfzI,7212
|
|
31
|
-
oodeel/utils/tf_training_tools.py,sha256=Knih0NGIpBe8G25hV0tpbZi1D2UkFTQCBuQBJYlz3-c,7315
|
|
32
|
-
oodeel/utils/torch_operator.py,sha256=ieiiLlUoexPo5AcFrYG91o2Xe3uzxocWuNce-YOjjHM,7958
|
|
33
|
-
oodeel/utils/torch_training_tools.py,sha256=yOSuLwUw8ENEUUrMJL4PFC35V_XJYmguta2cFjuhpy8,10224
|
|
34
|
-
tests/__init__.py,sha256=lQIUQjczeiRtfIqH2uLNJGubKUN6vPM93mTfY1Qz3bc,1297
|
|
35
|
-
tests/tools_operator.py,sha256=bMCNRUC6uIwVPewTBPrSjPASLxFg6GFzEPLwFobnLIw,5047
|
|
36
|
-
tests/tests_tensorflow/__init__.py,sha256=VuiDSdOBB2jUeobzAW-XTtHJQrcyZpp8i2DyFu0A2RI,1710
|
|
37
|
-
tests/tests_tensorflow/tf_methods_utils.py,sha256=JJqreJaHvhDaGEgw1w7wUhy8YZqYBgkleNC__lYg2Xs,5246
|
|
38
|
-
tests/tests_tensorflow/tools_tf.py,sha256=Z_MEzhwCOF7E7uT-tnfRS3Po9hg0MFzbESP7sMAFBkY,3499
|
|
39
|
-
tests/tests_torch/__init__.py,sha256=3mVxix2Ecn2wUo9DxvzyJBmyOAkv4fVWHHp6uLQHoic,1738
|
|
40
|
-
tests/tests_torch/tools_torch.py,sha256=9ParAhU54AFe4CDYuVo3VGVMypLcMN4Nrk5yz93j-SE,4876
|
|
41
|
-
tests/tests_torch/torch_methods_utils.py,sha256=60dkvArDX3mq9gSato2ghp9BCZeXZ_U64WosCjgdGbE,5234
|
|
42
|
-
oodeel-0.1.1.dist-info/LICENSE,sha256=XrlZ0uYNVeUAF-iEVX21J3CTJjYPgIZUagYSy3Hf0jk,1265
|
|
43
|
-
oodeel-0.1.1.dist-info/METADATA,sha256=WHzwlYu1UvHUd2Qh0REjo-HPGmizFR2fFrzmoROEA9k,17490
|
|
44
|
-
oodeel-0.1.1.dist-info/WHEEL,sha256=yQN5g4mg4AybRjkgi-9yy4iQEFibGQmlz78Pik5Or-A,92
|
|
45
|
-
oodeel-0.1.1.dist-info/top_level.txt,sha256=zkYRty1FGJ1dkpk-5MU_4uFfBFmcxXoqSwej73xELDs,13
|
|
46
|
-
oodeel-0.1.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|