oodeel 0.1.1__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of oodeel might be problematic. Click here for more details.

Files changed (47) hide show
  1. oodeel/__init__.py +1 -1
  2. oodeel/datasets/__init__.py +2 -1
  3. oodeel/datasets/data_handler.py +162 -94
  4. oodeel/datasets/deprecated/DEPRECATED_data_handler.py +236 -0
  5. oodeel/datasets/{ooddataset.py → deprecated/DEPRECATED_ooddataset.py} +14 -13
  6. oodeel/datasets/deprecated/DEPRECATED_tf_data_handler.py +671 -0
  7. oodeel/datasets/deprecated/DEPRECATED_torch_data_handler.py +769 -0
  8. oodeel/datasets/deprecated/__init__.py +31 -0
  9. oodeel/datasets/tf_data_handler.py +105 -167
  10. oodeel/datasets/torch_data_handler.py +109 -181
  11. oodeel/eval/metrics.py +7 -2
  12. oodeel/eval/plots/features.py +2 -2
  13. oodeel/eval/plots/plotly.py +2 -2
  14. oodeel/extractor/feature_extractor.py +30 -9
  15. oodeel/extractor/keras_feature_extractor.py +70 -13
  16. oodeel/extractor/torch_feature_extractor.py +120 -33
  17. oodeel/methods/__init__.py +17 -1
  18. oodeel/methods/base.py +103 -17
  19. oodeel/methods/dknn.py +22 -9
  20. oodeel/methods/energy.py +8 -0
  21. oodeel/methods/entropy.py +8 -0
  22. oodeel/methods/gen.py +118 -0
  23. oodeel/methods/gram.py +307 -0
  24. oodeel/methods/mahalanobis.py +14 -12
  25. oodeel/methods/mls.py +8 -0
  26. oodeel/methods/odin.py +8 -0
  27. oodeel/methods/rmds.py +122 -0
  28. oodeel/methods/she.py +197 -0
  29. oodeel/methods/vim.py +5 -5
  30. oodeel/preprocess/__init__.py +31 -0
  31. oodeel/preprocess/tf_preprocess.py +95 -0
  32. oodeel/preprocess/torch_preprocess.py +97 -0
  33. oodeel/utils/operator.py +72 -2
  34. oodeel/utils/tf_operator.py +72 -4
  35. oodeel/utils/tf_training_tools.py +26 -3
  36. oodeel/utils/torch_operator.py +75 -4
  37. oodeel/utils/torch_training_tools.py +31 -2
  38. {oodeel-0.1.1.dist-info → oodeel-0.3.0.dist-info}/METADATA +141 -107
  39. oodeel-0.3.0.dist-info/RECORD +57 -0
  40. {oodeel-0.1.1.dist-info → oodeel-0.3.0.dist-info}/WHEEL +1 -1
  41. tests/tests_tensorflow/tf_methods_utils.py +2 -1
  42. tests/tests_torch/tools_torch.py +9 -9
  43. tests/tests_torch/torch_methods_utils.py +34 -27
  44. tests/tools_operator.py +10 -1
  45. oodeel-0.1.1.dist-info/RECORD +0 -46
  46. {oodeel-0.1.1.dist-info → oodeel-0.3.0.dist-info/licenses}/LICENSE +0 -0
  47. {oodeel-0.1.1.dist-info → oodeel-0.3.0.dist-info}/top_level.txt +0 -0
@@ -35,14 +35,14 @@ def almost_equal(arr1, arr2, epsilon=1e-6):
35
35
 
36
36
 
37
37
  class Net(nn.Module):
38
- def __init__(self):
38
+ def __init__(self, num_classes=10):
39
39
  super().__init__()
40
40
  self.conv1 = nn.Conv2d(3, 6, 5)
41
41
  self.pool = nn.MaxPool2d(2, 2)
42
42
  self.conv2 = nn.Conv2d(6, 16, 5)
43
43
  self.fc1 = nn.Linear(16 * 5 * 5, 120)
44
44
  self.fc2 = nn.Linear(120, 84)
45
- self.fc3 = nn.Linear(84, 10)
45
+ self.fc3 = nn.Linear(84, num_classes)
46
46
 
47
47
  def forward(self, x):
48
48
  x = self.pool(F.relu(self.conv1(x)))
@@ -55,7 +55,7 @@ class Net(nn.Module):
55
55
 
56
56
 
57
57
  class ComplexNet(nn.Module):
58
- def __init__(self):
58
+ def __init__(self, num_classes=10):
59
59
  super().__init__()
60
60
 
61
61
  self.feature_extractor = nn.Sequential(
@@ -77,7 +77,7 @@ class ComplexNet(nn.Module):
77
77
  [
78
78
  ("fc1", nn.Linear(16 * 5 * 5, 120)),
79
79
  ("fc2", nn.Linear(120, 84)),
80
- ("fc3", nn.Linear(84, 10)),
80
+ ("fc3", nn.Linear(84, num_classes)),
81
81
  ]
82
82
  )
83
83
  )
@@ -88,7 +88,7 @@ class ComplexNet(nn.Module):
88
88
  return x
89
89
 
90
90
 
91
- def sequential_model():
91
+ def sequential_model(num_classes=10):
92
92
  return nn.Sequential(
93
93
  nn.Conv2d(3, 6, 5),
94
94
  nn.ReLU(),
@@ -99,11 +99,11 @@ def sequential_model():
99
99
  nn.Flatten(),
100
100
  nn.Linear(16 * 5 * 5, 120),
101
101
  nn.Linear(120, 84),
102
- nn.Linear(84, 10),
102
+ nn.Linear(84, num_classes),
103
103
  )
104
104
 
105
105
 
106
- def named_sequential_model():
106
+ def named_sequential_model(num_classes=10):
107
107
  return nn.Sequential(
108
108
  OrderedDict(
109
109
  [
@@ -116,13 +116,13 @@ def named_sequential_model():
116
116
  ("flatten", nn.Flatten()),
117
117
  ("fc1", nn.Linear(16 * 5 * 5, 120)),
118
118
  ("fc2", nn.Linear(120, 84)),
119
- ("fc3", nn.Linear(84, 10)),
119
+ ("fc3", nn.Linear(84, num_classes)),
120
120
  ]
121
121
  )
122
122
  )
123
123
 
124
124
 
125
- def simplest_mlp(num_features, num_classes):
125
+ def simplest_mlp(num_features, num_classes=10):
126
126
  return nn.Sequential(
127
127
  nn.Linear(num_features, 64),
128
128
  nn.ReLU(),
@@ -23,12 +23,11 @@
23
23
  import os
24
24
 
25
25
  import numpy as np
26
- import requests
27
26
  import torch
28
27
  from sklearn.datasets import make_blobs
29
28
  from sklearn.model_selection import train_test_split
30
29
 
31
- from oodeel.datasets import OODDataset
30
+ from oodeel.datasets import load_data_handler
32
31
  from oodeel.eval.metrics import bench_metrics
33
32
 
34
33
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -53,31 +52,26 @@ def load_blobs_data(batch_size=128, num_samples=10000, train_ratio=0.8):
53
52
  )
54
53
 
55
54
  # === id / ood split ===
56
- blobs_train = OODDataset((X_train, y_train), backend="torch")
57
- blobs_test = OODDataset((X_test, y_test), backend="torch")
58
- oods_fit, _ = blobs_train.split_by_class(in_labels, out_labels)
59
- oods_in, oods_out = blobs_test.split_by_class(in_labels, out_labels)
55
+ handler = load_data_handler("torch")
56
+ blobs_train = handler.load_dataset((X_train, y_train))
57
+ blobs_test = handler.load_dataset((X_test, y_test))
58
+ oods_fit, _ = handler.split_by_class(blobs_train, in_labels, out_labels)
59
+ oods_in, oods_out = handler.split_by_class(blobs_test, in_labels, out_labels)
60
60
 
61
61
  # === prepare data (shuffle, batch) => torch dataloaders ===
62
- ds_fit = oods_fit.prepare(batch_size=batch_size, shuffle=True)
63
- ds_in = oods_in.prepare(batch_size=batch_size)
64
- ds_out = oods_out.prepare(batch_size=batch_size)
62
+ ds_fit = handler.prepare(oods_fit, batch_size=batch_size, shuffle=True)
63
+ ds_in = handler.prepare(oods_in, batch_size=batch_size)
64
+ ds_out = handler.prepare(oods_out, batch_size=batch_size)
65
65
  return ds_fit, ds_in, ds_out
66
66
 
67
67
 
68
68
  def load_blob_mlp():
69
- model_path_blob = os.path.join(model_path, "blobs_mlp.pt")
70
-
71
- # if model not in local, download it
72
- if not os.path.exists(model_path_blob):
73
- data = requests.get(
74
- "https://share.deel.ai/s/xcyk3ET8fzfTp8S/download/blobs_mlp.pt"
75
- )
76
- with open(model_path_blob, "wb") as file:
77
- file.write(data.content)
78
-
79
69
  # load model
80
- model = torch.load(model_path_blob, map_location=device)
70
+ model = torch.hub.load_state_dict_from_url(
71
+ "https://github.com/deel-ai/oodeel/blob/assets/"
72
+ + "test_models/blobs_mlp.pt?raw=True",
73
+ map_location=device,
74
+ )
81
75
  model.eval()
82
76
  return model
83
77
 
@@ -125,14 +119,27 @@ def eval_detector_on_blobs(
125
119
 
126
120
  # react specific test
127
121
  # /!\ do it at the end of the test because it may affect the detector's behaviour
122
+
128
123
  if check_react_clipping:
129
124
  assert detector.react_threshold is not None
125
+
130
126
  penult_feat_extractor = detector._load_feature_extractor(
131
- model=model, feature_layers_id=[-2, -1]
132
- )
133
- penult_features = penult_feat_extractor.predict(ds_fit)[0][0]
134
- assert torch.max(penult_features) <= detector.react_threshold, (
135
- f"Maximum value of penultimate features ({torch.max(penult_features)})"
136
- + " should be less than or equal to the react threshold value"
137
- + f" ({detector.react_threshold})"
127
+ model=model, feature_layers_id=[-1]
138
128
  )
129
+ # penult_feat_extractor._prepare_ood_handles()
130
+
131
+ def hook(_, input):
132
+ penult_feat_extractor._features["test"] = input[0]
133
+
134
+ penult_feat_extractor.model[-1].register_forward_pre_hook(hook)
135
+ for x, y in ds_fit:
136
+ _ = penult_feat_extractor.predict_tensor(x)
137
+ assert (
138
+ torch.max(penult_feat_extractor._features["test"])
139
+ <= detector.react_threshold
140
+ ), (
141
+ "Maximum value of penultimate features"
142
+ + f" ({torch.max(penult_feat_extractor._features)})"
143
+ + " should be less than or equal to the react threshold value"
144
+ + f" ({detector.react_threshold})"
145
+ )
tests/tools_operator.py CHANGED
@@ -127,7 +127,7 @@ def check_common_operators(backend):
127
127
  assert operator.flatten(x).shape == (25, 12 * 6)
128
128
 
129
129
  # Transpose
130
- assert operator.transpose(x[0]).shape == (6, 12)
130
+ assert operator.t(x[0]).shape == (6, 12)
131
131
 
132
132
  # Diag
133
133
  assert operator.diag(x[0]).shape == (6,)
@@ -142,3 +142,12 @@ def check_common_operators(backend):
142
142
 
143
143
  # Pinv
144
144
  assert operator.pinv(x[0]).shape == (6, 12)
145
+
146
+ # einsum
147
+ ein = operator.einsum("bij,jk->bik", x, operator.t(x[0]))
148
+ assert ein.shape == (25, 12, 12)
149
+
150
+ # tril
151
+ triangle = operator.tril(ein, diagonal=-1)
152
+ assert triangle.shape == (25, 12, 12)
153
+ assert np.sum([m[0, 0] + m[11, 11] + m[0, 11] for m in triangle]) == 0
@@ -1,46 +0,0 @@
1
- oodeel/__init__.py,sha256=2q31pS-ZP2CnwNQfZZH3G8iwjYBT2qOwS9Fpl9NERC4,1343
2
- oodeel/datasets/__init__.py,sha256=BdpBDtWoGS3fgJCRU6TTCcYodzVGYANv5q4Abqhp1RU,1332
3
- oodeel/datasets/data_handler.py,sha256=kxWKmUjdKKtOOWIPXbt96Bhg8SxVcFUhBCD2XEHx5Ac,7904
4
- oodeel/datasets/ooddataset.py,sha256=tq3ZIjLNe3iWhjFkmpI4jomLJJ4-CfmsuHoTFR25gBY,13360
5
- oodeel/datasets/tf_data_handler.py,sha256=FmqWoV05yIQ0lHNBdDE8pTgSdma4jbdjT3YvQOS9KnY,24794
6
- oodeel/datasets/torch_data_handler.py,sha256=SMa8RvMYp2-f0AI3m5MKeZVFh9G6CMVJl_MNZmiLK74,26814
7
- oodeel/eval/__init__.py,sha256=lQIUQjczeiRtfIqH2uLNJGubKUN6vPM93mTfY1Qz3bc,1297
8
- oodeel/eval/metrics.py,sha256=ywU9Y4E_PYckdxllAzUd17T1NtCKq4SfRPigrw78PHI,8589
9
- oodeel/eval/plots/__init__.py,sha256=YmcFh8RUGvljc-vCew6EiIFMDn0YA_UOfDj4eAv5_Yk,1487
10
- oodeel/eval/plots/features.py,sha256=tEV1j_1zgqVmC7MOZa7-LyaSQMqL6egH8_FjUwqIhVQ,11815
11
- oodeel/eval/plots/metrics.py,sha256=3QvLqEB1pAggNHeQJzUpLqL_Ro0MOJ_bPrLcLmf_qpk,4189
12
- oodeel/eval/plots/plotly.py,sha256=lc22a880TJgUf4H8CSYnXslYRl0I0QqjYsFrf1XjZyA,6034
13
- oodeel/extractor/__init__.py,sha256=Ew8gLh-xanZsWJe_gKvTa_7HciZ5yTZ-cADKuj7-PCg,1538
14
- oodeel/extractor/feature_extractor.py,sha256=PE6YczZwRhrd3ZAtEwyWhHl1KmDdvhAOiZHqtYIAnpA,5649
15
- oodeel/extractor/keras_feature_extractor.py,sha256=IzPDYKS6nHMp-7EIur3mtK5FEVAdPBT_4OTOKTSBRc8,9732
16
- oodeel/extractor/torch_feature_extractor.py,sha256=VYPoTJfauYkrL2IYmxH9hCT_O8aZKYwhMeB2bZ1wpa4,12465
17
- oodeel/methods/__init__.py,sha256=TPB-l1dbtvS9YpSG64mFN28Squ6hWNwNMHOFpX5inU0,1556
18
- oodeel/methods/base.py,sha256=MvjIk7flOGlcufH8UmgWdORJbSIwgPHWUzgBEE8pHxE,11075
19
- oodeel/methods/dknn.py,sha256=Cbv2y80u8FPho1glb7abC6MGNk2UEjHvEzzMuorBmcc,4264
20
- oodeel/methods/energy.py,sha256=xcSVs0PibzjFT6pftMuKyZlYkmMdqWJA2WQxw60NHHM,4211
21
- oodeel/methods/entropy.py,sha256=F5w312f6wpvk9DBHi7LDDcndrMOoPY8OuFCO0ddgOU4,3878
22
- oodeel/methods/mahalanobis.py,sha256=IIiuyAAvlLkuSZRAmEUGoNANMMuM6KJ-YWb2kc42IsA,7237
23
- oodeel/methods/mls.py,sha256=BJajmhl3x11Kr-8RwYq59kAn-rveA5ydyoko7ajF_60,4105
24
- oodeel/methods/odin.py,sha256=E4I2olYH9xvnkgau_2rcPnWL3FFWAh-jxG60-CGxlJA,5254
25
- oodeel/methods/vim.py,sha256=Z6VeAe_Aaza_rvWL56z1JLVBiRBfmVm-k4ifbrFeH_Q,9485
26
- oodeel/types/__init__.py,sha256=9TTXjSBfbaDIVMRnclInHI-CBr4L6VZTi61rCJKcTw8,2484
27
- oodeel/utils/__init__.py,sha256=Lue9BysqeJf5Ej0szafgC9na8IZZR7hWGInJxoEiHUg,1696
28
- oodeel/utils/general_utils.py,sha256=xc6e7q19ALgMxdCgS7TIyDiMUIGF4ih-aTK1kSlqWoQ,3292
29
- oodeel/utils/operator.py,sha256=2_jwROZbg0BkZzi6j7SyoGh7iulC1iI-QzZHGZQKHNM,6167
30
- oodeel/utils/tf_operator.py,sha256=kOp5BvXyY7yf6LdmJg2vRLZHPTFV2HLhlR1na_hZfzI,7212
31
- oodeel/utils/tf_training_tools.py,sha256=Knih0NGIpBe8G25hV0tpbZi1D2UkFTQCBuQBJYlz3-c,7315
32
- oodeel/utils/torch_operator.py,sha256=ieiiLlUoexPo5AcFrYG91o2Xe3uzxocWuNce-YOjjHM,7958
33
- oodeel/utils/torch_training_tools.py,sha256=yOSuLwUw8ENEUUrMJL4PFC35V_XJYmguta2cFjuhpy8,10224
34
- tests/__init__.py,sha256=lQIUQjczeiRtfIqH2uLNJGubKUN6vPM93mTfY1Qz3bc,1297
35
- tests/tools_operator.py,sha256=bMCNRUC6uIwVPewTBPrSjPASLxFg6GFzEPLwFobnLIw,5047
36
- tests/tests_tensorflow/__init__.py,sha256=VuiDSdOBB2jUeobzAW-XTtHJQrcyZpp8i2DyFu0A2RI,1710
37
- tests/tests_tensorflow/tf_methods_utils.py,sha256=JJqreJaHvhDaGEgw1w7wUhy8YZqYBgkleNC__lYg2Xs,5246
38
- tests/tests_tensorflow/tools_tf.py,sha256=Z_MEzhwCOF7E7uT-tnfRS3Po9hg0MFzbESP7sMAFBkY,3499
39
- tests/tests_torch/__init__.py,sha256=3mVxix2Ecn2wUo9DxvzyJBmyOAkv4fVWHHp6uLQHoic,1738
40
- tests/tests_torch/tools_torch.py,sha256=9ParAhU54AFe4CDYuVo3VGVMypLcMN4Nrk5yz93j-SE,4876
41
- tests/tests_torch/torch_methods_utils.py,sha256=60dkvArDX3mq9gSato2ghp9BCZeXZ_U64WosCjgdGbE,5234
42
- oodeel-0.1.1.dist-info/LICENSE,sha256=XrlZ0uYNVeUAF-iEVX21J3CTJjYPgIZUagYSy3Hf0jk,1265
43
- oodeel-0.1.1.dist-info/METADATA,sha256=WHzwlYu1UvHUd2Qh0REjo-HPGmizFR2fFrzmoROEA9k,17490
44
- oodeel-0.1.1.dist-info/WHEEL,sha256=yQN5g4mg4AybRjkgi-9yy4iQEFibGQmlz78Pik5Or-A,92
45
- oodeel-0.1.1.dist-info/top_level.txt,sha256=zkYRty1FGJ1dkpk-5MU_4uFfBFmcxXoqSwej73xELDs,13
46
- oodeel-0.1.1.dist-info/RECORD,,