oodeel 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- oodeel/__init__.py +28 -0
- oodeel/aggregator/__init__.py +26 -0
- oodeel/aggregator/base.py +70 -0
- oodeel/aggregator/fisher.py +259 -0
- oodeel/aggregator/mean.py +72 -0
- oodeel/aggregator/std.py +86 -0
- oodeel/datasets/__init__.py +24 -0
- oodeel/datasets/data_handler.py +334 -0
- oodeel/datasets/deprecated/DEPRECATED_data_handler.py +236 -0
- oodeel/datasets/deprecated/DEPRECATED_ooddataset.py +330 -0
- oodeel/datasets/deprecated/DEPRECATED_tf_data_handler.py +671 -0
- oodeel/datasets/deprecated/DEPRECATED_torch_data_handler.py +769 -0
- oodeel/datasets/deprecated/__init__.py +31 -0
- oodeel/datasets/tf_data_handler.py +600 -0
- oodeel/datasets/torch_data_handler.py +672 -0
- oodeel/eval/__init__.py +22 -0
- oodeel/eval/metrics.py +218 -0
- oodeel/eval/plots/__init__.py +27 -0
- oodeel/eval/plots/features.py +345 -0
- oodeel/eval/plots/metrics.py +118 -0
- oodeel/eval/plots/plotly.py +162 -0
- oodeel/extractor/__init__.py +35 -0
- oodeel/extractor/feature_extractor.py +187 -0
- oodeel/extractor/hf_torch_feature_extractor.py +184 -0
- oodeel/extractor/keras_feature_extractor.py +409 -0
- oodeel/extractor/torch_feature_extractor.py +506 -0
- oodeel/methods/__init__.py +47 -0
- oodeel/methods/base.py +570 -0
- oodeel/methods/dknn.py +185 -0
- oodeel/methods/energy.py +119 -0
- oodeel/methods/entropy.py +113 -0
- oodeel/methods/gen.py +113 -0
- oodeel/methods/gram.py +274 -0
- oodeel/methods/mahalanobis.py +209 -0
- oodeel/methods/mls.py +113 -0
- oodeel/methods/odin.py +109 -0
- oodeel/methods/rmds.py +172 -0
- oodeel/methods/she.py +159 -0
- oodeel/methods/vim.py +273 -0
- oodeel/preprocess/__init__.py +31 -0
- oodeel/preprocess/tf_preprocess.py +95 -0
- oodeel/preprocess/torch_preprocess.py +97 -0
- oodeel/types/__init__.py +75 -0
- oodeel/utils/__init__.py +38 -0
- oodeel/utils/general_utils.py +97 -0
- oodeel/utils/operator.py +253 -0
- oodeel/utils/tf_operator.py +269 -0
- oodeel/utils/tf_training_tools.py +219 -0
- oodeel/utils/torch_operator.py +292 -0
- oodeel/utils/torch_training_tools.py +303 -0
- oodeel-0.4.0.dist-info/METADATA +409 -0
- oodeel-0.4.0.dist-info/RECORD +63 -0
- oodeel-0.4.0.dist-info/WHEEL +5 -0
- oodeel-0.4.0.dist-info/licenses/LICENSE +21 -0
- oodeel-0.4.0.dist-info/top_level.txt +2 -0
- tests/__init__.py +22 -0
- tests/tests_tensorflow/__init__.py +37 -0
- tests/tests_tensorflow/tf_methods_utils.py +140 -0
- tests/tests_tensorflow/tools_tf.py +86 -0
- tests/tests_torch/__init__.py +38 -0
- tests/tests_torch/tools_torch.py +151 -0
- tests/tests_torch/torch_methods_utils.py +148 -0
- tests/tools_operator.py +153 -0
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Copyright IRT Antoine de Saint Exupéry et Université Paul Sabatier Toulouse III - All
|
|
3
|
+
# rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
|
|
4
|
+
# CRIAQ and ANITI - https://www.deel.ai/
|
|
5
|
+
#
|
|
6
|
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
7
|
+
# of this software and associated documentation files (the "Software"), to deal
|
|
8
|
+
# in the Software without restriction, including without limitation the rights
|
|
9
|
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
10
|
+
# copies of the Software, and to permit persons to whom the Software is
|
|
11
|
+
# furnished to do so, subject to the following conditions:
|
|
12
|
+
#
|
|
13
|
+
# The above copyright notice and this permission notice shall be included in all
|
|
14
|
+
# copies or substantial portions of the Software.
|
|
15
|
+
#
|
|
16
|
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
17
|
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
18
|
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
19
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
20
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
21
|
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
22
|
+
# SOFTWARE.
|
|
23
|
+
import matplotlib.pyplot as plt
|
|
24
|
+
import numpy as np
|
|
25
|
+
import seaborn as sns
|
|
26
|
+
|
|
27
|
+
from oodeel.eval.metrics import bench_metrics
|
|
28
|
+
from oodeel.eval.metrics import get_curve
|
|
29
|
+
|
|
30
|
+
sns.set_style("darkgrid")
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def plot_ood_scores(
|
|
34
|
+
scores_in: np.ndarray,
|
|
35
|
+
scores_out: np.ndarray,
|
|
36
|
+
log_scale: bool = False,
|
|
37
|
+
title: str = None,
|
|
38
|
+
):
|
|
39
|
+
"""Plot histograms of OOD detection scores for ID and OOD distribution, using
|
|
40
|
+
matplotlib and seaborn.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
scores_in (np.ndarray): OOD detection scores for ID data.
|
|
44
|
+
scores_out (np.ndarray): OOD detection scores for OOD data.
|
|
45
|
+
log_scale (bool, optional): If True, apply a log scale on x axis. Defaults to
|
|
46
|
+
False.
|
|
47
|
+
title (str, optional): Custom figure title. If None a default one is provided.
|
|
48
|
+
Defaults to None.
|
|
49
|
+
"""
|
|
50
|
+
title = title or "Histograms of OOD detection scores"
|
|
51
|
+
ax1 = sns.histplot(
|
|
52
|
+
data=scores_in,
|
|
53
|
+
alpha=0.5,
|
|
54
|
+
label="ID data",
|
|
55
|
+
stat="density",
|
|
56
|
+
log_scale=log_scale,
|
|
57
|
+
bins=100,
|
|
58
|
+
kde=True,
|
|
59
|
+
)
|
|
60
|
+
ax2 = sns.histplot(
|
|
61
|
+
data=scores_out,
|
|
62
|
+
alpha=0.5,
|
|
63
|
+
label="OOD data",
|
|
64
|
+
stat="density",
|
|
65
|
+
log_scale=log_scale,
|
|
66
|
+
bins=100,
|
|
67
|
+
kde=True,
|
|
68
|
+
)
|
|
69
|
+
ymax = max(ax1.get_ylim()[1], ax2.get_ylim()[1])
|
|
70
|
+
threshold = np.percentile(scores_out, q=5.0)
|
|
71
|
+
plt.vlines(
|
|
72
|
+
x=[threshold],
|
|
73
|
+
ymin=0,
|
|
74
|
+
ymax=ymax,
|
|
75
|
+
colors=["red"],
|
|
76
|
+
linestyles=["dashed"],
|
|
77
|
+
alpha=0.7,
|
|
78
|
+
label="TPR=95%",
|
|
79
|
+
)
|
|
80
|
+
plt.xlabel("OOD score")
|
|
81
|
+
plt.legend()
|
|
82
|
+
plt.title(title, weight="bold").set_fontsize(11)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def plot_roc_curve(scores_in: np.ndarray, scores_out: np.ndarray, title: str = None):
|
|
86
|
+
"""Plot ROC curve for OOD detection task, using matplotlib and seaborn.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
scores_in (np.ndarray): OOD detection scores for ID data.
|
|
90
|
+
scores_out (np.ndarray): OOD detection scores for OOD data.
|
|
91
|
+
title (str, optional): Custom figure title. If None a default one is provided.
|
|
92
|
+
Defaults to None.
|
|
93
|
+
"""
|
|
94
|
+
# compute auroc
|
|
95
|
+
metrics = bench_metrics(
|
|
96
|
+
(scores_in, scores_out),
|
|
97
|
+
metrics=["auroc", "fpr95tpr"],
|
|
98
|
+
)
|
|
99
|
+
auroc, fpr95tpr = metrics["auroc"], metrics["fpr95tpr"]
|
|
100
|
+
|
|
101
|
+
# roc
|
|
102
|
+
fpr, tpr, _, _, _ = get_curve(
|
|
103
|
+
scores=np.concatenate([scores_in, scores_out]),
|
|
104
|
+
labels=np.concatenate([scores_in * 0 + 0, scores_out * 0 + 1]),
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
# plot roc
|
|
108
|
+
title = title or "ROC curve (AuC = {:.3f})".format(auroc)
|
|
109
|
+
plt.plot(fpr, tpr)
|
|
110
|
+
plt.fill_between(fpr, tpr, np.zeros_like(tpr), alpha=0.5)
|
|
111
|
+
plt.plot([fpr95tpr, fpr95tpr, 0], [0, 0.95, 0.95], "--", color="red", alpha=0.7)
|
|
112
|
+
plt.scatter([fpr95tpr], [0.95], marker="o", alpha=0.7, color="red", label="TPR=95%")
|
|
113
|
+
plt.xlabel("FPR")
|
|
114
|
+
plt.ylabel("TPR")
|
|
115
|
+
plt.xlim([-0.01, 1.01])
|
|
116
|
+
plt.ylim([-0.01, 1.01])
|
|
117
|
+
plt.legend()
|
|
118
|
+
plt.title(title, weight="bold").set_fontsize(11)
|
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Copyright IRT Antoine de Saint Exupéry et Université Paul Sabatier Toulouse III - All
|
|
3
|
+
# rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
|
|
4
|
+
# CRIAQ and ANITI - https://www.deel.ai/
|
|
5
|
+
#
|
|
6
|
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
7
|
+
# of this software and associated documentation files (the "Software"), to deal
|
|
8
|
+
# in the Software without restriction, including without limitation the rights
|
|
9
|
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
10
|
+
# copies of the Software, and to permit persons to whom the Software is
|
|
11
|
+
# furnished to do so, subject to the following conditions:
|
|
12
|
+
#
|
|
13
|
+
# The above copyright notice and this permission notice shall be included in all
|
|
14
|
+
# copies or substantial portions of the Software.
|
|
15
|
+
#
|
|
16
|
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
17
|
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
18
|
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
19
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
20
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
21
|
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
22
|
+
# SOFTWARE.
|
|
23
|
+
import numpy as np
|
|
24
|
+
import pandas as pd
|
|
25
|
+
import plotly.express as px
|
|
26
|
+
import seaborn as sns
|
|
27
|
+
from sklearn.decomposition import PCA
|
|
28
|
+
from sklearn.manifold import TSNE
|
|
29
|
+
|
|
30
|
+
from ...types import Callable
|
|
31
|
+
from ...types import DatasetType
|
|
32
|
+
from ...types import Union
|
|
33
|
+
from ...utils import import_backend_specific_stuff
|
|
34
|
+
|
|
35
|
+
sns.set_style("darkgrid")
|
|
36
|
+
|
|
37
|
+
PROJ_DICT = {
|
|
38
|
+
"TSNE": {
|
|
39
|
+
"name": "t-SNE",
|
|
40
|
+
"class": TSNE,
|
|
41
|
+
"default_kwargs": dict(perplexity=30.0, n_iter=800, random_state=0),
|
|
42
|
+
},
|
|
43
|
+
"PCA": {"name": "PCA", "class": PCA, "default_kwargs": dict()},
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def plotly_3D_features(
|
|
48
|
+
model: Callable,
|
|
49
|
+
in_dataset: DatasetType,
|
|
50
|
+
output_layer_id: Union[int, str],
|
|
51
|
+
out_dataset: DatasetType = None,
|
|
52
|
+
proj_method: str = "TSNE",
|
|
53
|
+
max_samples: int = 4000,
|
|
54
|
+
title: str = None,
|
|
55
|
+
**proj_kwargs,
|
|
56
|
+
):
|
|
57
|
+
"""Visualize ID and OOD features of a model on a 3D space using dimensionality
|
|
58
|
+
reduction methods and matplotlib scatter function. Different projection methods are
|
|
59
|
+
available: TSNE, PCA. This function requires the package plotly to be installed to
|
|
60
|
+
run an interactive 3D scatter plot.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
model (Callable): Torch or Keras model.
|
|
64
|
+
in_dataset (DatasetType): In-distribution dataset (torch dataloader or tf
|
|
65
|
+
dataset) that will be projected on the model feature space.
|
|
66
|
+
output_layer_id (Union[int, str]): Identifier for the layer to inspect.
|
|
67
|
+
out_dataset (DatasetType, optional): Out-of-distribution dataset (torch
|
|
68
|
+
dataloader or tf dataset) that will be projected on the model feature space
|
|
69
|
+
if not equal to None. Defaults to None.
|
|
70
|
+
proj_method (str, optional): Projection method for 2d dimensionality reduction.
|
|
71
|
+
Defaults to "TSNE", alternative: "PCA".
|
|
72
|
+
max_samples (int, optional): Max samples to display on the scatter plot.
|
|
73
|
+
Defaults to 4000.
|
|
74
|
+
title (str, optional): Custom figure title. Defaults to None.
|
|
75
|
+
"""
|
|
76
|
+
max_samples = max_samples if out_dataset is None else max_samples // 2
|
|
77
|
+
|
|
78
|
+
# feature extractor
|
|
79
|
+
_, _, op, FeatureExtractorClass = import_backend_specific_stuff(model)
|
|
80
|
+
feature_extractor = FeatureExtractorClass(model, [output_layer_id])
|
|
81
|
+
|
|
82
|
+
# === extract id features ===
|
|
83
|
+
# features
|
|
84
|
+
in_features, _ = feature_extractor.predict(in_dataset)
|
|
85
|
+
in_features = op.convert_to_numpy(op.flatten(in_features[0]))[:max_samples]
|
|
86
|
+
|
|
87
|
+
# labels
|
|
88
|
+
in_labels = []
|
|
89
|
+
for _, batch_y in in_dataset:
|
|
90
|
+
in_labels.append(op.convert_to_numpy(batch_y))
|
|
91
|
+
in_labels = np.concatenate(in_labels)[:max_samples]
|
|
92
|
+
in_labels = list(map(lambda x: f"class {x}", in_labels))
|
|
93
|
+
|
|
94
|
+
# === extract ood features ===
|
|
95
|
+
if out_dataset is not None:
|
|
96
|
+
# features
|
|
97
|
+
out_features, _ = feature_extractor.predict(out_dataset)
|
|
98
|
+
out_features = op.convert_to_numpy(op.flatten(out_features[0]))[:max_samples]
|
|
99
|
+
|
|
100
|
+
# labels
|
|
101
|
+
out_labels = np.array(["unknown"] * len(out_features))
|
|
102
|
+
|
|
103
|
+
# concatenate id and ood items
|
|
104
|
+
features = np.concatenate([out_features, in_features])
|
|
105
|
+
labels = np.concatenate([out_labels, in_labels])
|
|
106
|
+
data_type = np.array(["OOD"] * len(out_labels) + ["ID"] * len(in_labels))
|
|
107
|
+
points_size = np.array([1] * len(out_labels) + [3] * len(in_labels))
|
|
108
|
+
else:
|
|
109
|
+
features = in_features
|
|
110
|
+
labels = in_labels
|
|
111
|
+
data_type = np.array(["ID"] * len(in_labels))
|
|
112
|
+
points_size = np.array([3] * len(in_labels))
|
|
113
|
+
|
|
114
|
+
# === project on 3d space using tsne or pca ===
|
|
115
|
+
proj_class = PROJ_DICT[proj_method]["class"]
|
|
116
|
+
p_kwargs = PROJ_DICT[proj_method]["default_kwargs"]
|
|
117
|
+
p_kwargs.update(proj_kwargs)
|
|
118
|
+
projector = proj_class(
|
|
119
|
+
n_components=3,
|
|
120
|
+
**p_kwargs,
|
|
121
|
+
)
|
|
122
|
+
features_proj = projector.fit_transform(features)
|
|
123
|
+
|
|
124
|
+
# === plot 3d features ===
|
|
125
|
+
features_dim = features.shape[1]
|
|
126
|
+
method_str = PROJ_DICT[proj_method]["name"]
|
|
127
|
+
title = (
|
|
128
|
+
title
|
|
129
|
+
or f"{method_str} 3D projection\n"
|
|
130
|
+
+ f"[layer {output_layer_id}, dim: {features_dim}]"
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
x, y, z = features_proj.T
|
|
134
|
+
df = pd.DataFrame(
|
|
135
|
+
{
|
|
136
|
+
"dim 1": x,
|
|
137
|
+
"dim 2": y,
|
|
138
|
+
"dim 3": z,
|
|
139
|
+
"class": labels,
|
|
140
|
+
"data type": data_type,
|
|
141
|
+
"size": points_size,
|
|
142
|
+
}
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
# 3D projection
|
|
146
|
+
fig = px.scatter_3d(
|
|
147
|
+
data_frame=df,
|
|
148
|
+
x="dim 1",
|
|
149
|
+
y="dim 2",
|
|
150
|
+
z="dim 3",
|
|
151
|
+
color="class",
|
|
152
|
+
symbol="data type",
|
|
153
|
+
size="size",
|
|
154
|
+
opacity=1,
|
|
155
|
+
category_orders={"class": np.unique(df["class"])},
|
|
156
|
+
symbol_map={"OOD": "circle", "ID": "diamond"},
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
fig.update_layout(
|
|
160
|
+
title={"text": title, "y": 0.9, "x": 0.5, "xanchor": "center", "yanchor": "top"}
|
|
161
|
+
)
|
|
162
|
+
fig.show()
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Copyright IRT Antoine de Saint Exupéry et Université Paul Sabatier Toulouse III - All
|
|
3
|
+
# rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
|
|
4
|
+
# CRIAQ and ANITI - https://www.deel.ai/
|
|
5
|
+
#
|
|
6
|
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
7
|
+
# of this software and associated documentation files (the "Software"), to deal
|
|
8
|
+
# in the Software without restriction, including without limitation the rights
|
|
9
|
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
10
|
+
# copies of the Software, and to permit persons to whom the Software is
|
|
11
|
+
# furnished to do so, subject to the following conditions:
|
|
12
|
+
#
|
|
13
|
+
# The above copyright notice and this permission notice shall be included in all
|
|
14
|
+
# copies or substantial portions of the Software.
|
|
15
|
+
#
|
|
16
|
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
17
|
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
18
|
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
19
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
20
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
21
|
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
22
|
+
# SOFTWARE.
|
|
23
|
+
|
|
24
|
+
try:
|
|
25
|
+
import tensorflow as tf
|
|
26
|
+
from .keras_feature_extractor import KerasFeatureExtractor
|
|
27
|
+
except ImportError:
|
|
28
|
+
pass
|
|
29
|
+
|
|
30
|
+
try:
|
|
31
|
+
import torch
|
|
32
|
+
from .torch_feature_extractor import TorchFeatureExtractor
|
|
33
|
+
from .hf_torch_feature_extractor import HFTorchFeatureExtractor
|
|
34
|
+
except ImportError:
|
|
35
|
+
pass
|
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Copyright IRT Antoine de Saint Exupéry et Université Paul Sabatier Toulouse III - All
|
|
3
|
+
# rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
|
|
4
|
+
# CRIAQ and ANITI - https://www.deel.ai/
|
|
5
|
+
#
|
|
6
|
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
7
|
+
# of this software and associated documentation files (the "Software"), to deal
|
|
8
|
+
# in the Software without restriction, including without limitation the rights
|
|
9
|
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
10
|
+
# copies of the Software, and to permit persons to whom the Software is
|
|
11
|
+
# furnished to do so, subject to the following conditions:
|
|
12
|
+
#
|
|
13
|
+
# The above copyright notice and this permission notice shall be included in all
|
|
14
|
+
# copies or substantial portions of the Software.
|
|
15
|
+
#
|
|
16
|
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
17
|
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
18
|
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
19
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
20
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
21
|
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
22
|
+
# SOFTWARE.
|
|
23
|
+
from abc import ABC
|
|
24
|
+
from abc import abstractmethod
|
|
25
|
+
|
|
26
|
+
from ..types import Callable
|
|
27
|
+
from ..types import DatasetType
|
|
28
|
+
from ..types import ItemType
|
|
29
|
+
from ..types import List
|
|
30
|
+
from ..types import Optional
|
|
31
|
+
from ..types import TensorType
|
|
32
|
+
from ..types import Tuple
|
|
33
|
+
from ..types import Union
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class FeatureExtractor(ABC):
|
|
37
|
+
"""
|
|
38
|
+
Feature extractor based on "model" to construct a feature space
|
|
39
|
+
on which OOD detection is performed. The features can be the output
|
|
40
|
+
activation values of internal model layers, or the output of the model
|
|
41
|
+
(softmax/logits).
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
model: model to extract the features from
|
|
45
|
+
feature_layers_id: list of str or int that identify features to output.
|
|
46
|
+
If int, the rank of the layer in the layer list
|
|
47
|
+
If str, the name of the layer.
|
|
48
|
+
Defaults to [].
|
|
49
|
+
head_layer_id (int, str): identifier of the head layer.
|
|
50
|
+
If int, the rank of the layer in the layer list
|
|
51
|
+
If str, the name of the layer.
|
|
52
|
+
Defaults to -1
|
|
53
|
+
input_layer_id: input layer of the feature extractor (to avoid useless forwards
|
|
54
|
+
when working on the feature space without finetuning the bottom of the
|
|
55
|
+
model).
|
|
56
|
+
Defaults to None.
|
|
57
|
+
react_threshold: if not None, penultimate layer activations are clipped under
|
|
58
|
+
this threshold value (useful for ReAct). Defaults to None.
|
|
59
|
+
scale_percentile: if not None, the features are scaled
|
|
60
|
+
following the method of Xu et al., ICLR 2024.
|
|
61
|
+
Defaults to None.
|
|
62
|
+
ash_percentile: if not None, the features are scaled following
|
|
63
|
+
the method of Djurisic et al., ICLR 2023.
|
|
64
|
+
return_penultimate (bool): if True, the penultimate values are returned,
|
|
65
|
+
i.e. the input to the head_layer.
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
def __init__(
|
|
69
|
+
self,
|
|
70
|
+
model: Callable,
|
|
71
|
+
feature_layers_id: List[Union[int, str]] = [],
|
|
72
|
+
head_layer_id: Optional[Union[int, str]] = -1,
|
|
73
|
+
input_layer_id: Optional[Union[int, str]] = [0],
|
|
74
|
+
react_threshold: Optional[float] = None,
|
|
75
|
+
scale_percentile: Optional[float] = None,
|
|
76
|
+
ash_percentile: Optional[float] = None,
|
|
77
|
+
return_penultimate: Optional[bool] = False,
|
|
78
|
+
):
|
|
79
|
+
if not isinstance(feature_layers_id, list):
|
|
80
|
+
feature_layers_id = [feature_layers_id]
|
|
81
|
+
|
|
82
|
+
self.feature_layers_id = feature_layers_id
|
|
83
|
+
self.head_layer_id = head_layer_id
|
|
84
|
+
self.input_layer_id = input_layer_id
|
|
85
|
+
self.react_threshold = react_threshold
|
|
86
|
+
self.scale_percentile = scale_percentile
|
|
87
|
+
self.ash_percentile = ash_percentile
|
|
88
|
+
self.return_penultimate = return_penultimate
|
|
89
|
+
self.model = model
|
|
90
|
+
self.extractor = self.prepare_extractor()
|
|
91
|
+
|
|
92
|
+
@abstractmethod
|
|
93
|
+
def prepare_extractor(self) -> None:
|
|
94
|
+
"""
|
|
95
|
+
prepare FeatureExtractor for feature extraction
|
|
96
|
+
(the way to achieve this depends on the underlying library)
|
|
97
|
+
"""
|
|
98
|
+
raise NotImplementedError()
|
|
99
|
+
|
|
100
|
+
@abstractmethod
|
|
101
|
+
def get_weights(self, layer_id: Union[str, int]) -> List[TensorType]:
|
|
102
|
+
"""
|
|
103
|
+
Get the weights of a layer
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
layer_id (Union[int, str]): layer identifier
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
weights matrix
|
|
110
|
+
"""
|
|
111
|
+
raise NotImplementedError()
|
|
112
|
+
|
|
113
|
+
@abstractmethod
|
|
114
|
+
def predict_tensor(
|
|
115
|
+
self,
|
|
116
|
+
tensor: TensorType,
|
|
117
|
+
postproc_fns: Optional[List[Callable]] = None,
|
|
118
|
+
) -> Tuple[List[TensorType], TensorType]:
|
|
119
|
+
"""Get the projection of tensor in the feature space of self.model
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
tensor (TensorType): input tensor (or dataset elem)
|
|
123
|
+
postproc_fns (Optional[Callable]): postprocessing function to apply to each
|
|
124
|
+
feature immediately after forward. Default to None.
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
Tuple[List[TensorType], TensorType]: features, logits
|
|
128
|
+
"""
|
|
129
|
+
raise NotImplementedError()
|
|
130
|
+
|
|
131
|
+
@abstractmethod
|
|
132
|
+
def predict(
|
|
133
|
+
self,
|
|
134
|
+
dataset: Union[ItemType, DatasetType],
|
|
135
|
+
postproc_fns: Optional[List[Callable]] = None,
|
|
136
|
+
verbose: bool = False,
|
|
137
|
+
**kwargs,
|
|
138
|
+
) -> Tuple[List[TensorType], dict]:
|
|
139
|
+
"""Get the projection of the dataset in the feature space of self.model
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
dataset (Union[ItemType, DatasetType]): input dataset
|
|
143
|
+
postproc_fns (Optional[Callable]): postprocessing function to apply to each
|
|
144
|
+
feature immediately after forward. Default to None.
|
|
145
|
+
verbose (bool): if True, display a progress bar. Defaults to False.
|
|
146
|
+
kwargs (dict): additional arguments not considered for prediction
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
List[TensorType], dict: features and extra information (logits, labels) as a
|
|
150
|
+
dictionary.
|
|
151
|
+
"""
|
|
152
|
+
raise NotImplementedError()
|
|
153
|
+
|
|
154
|
+
@staticmethod
|
|
155
|
+
@abstractmethod
|
|
156
|
+
def find_layer(
|
|
157
|
+
model: Callable,
|
|
158
|
+
layer_id: Union[str, int],
|
|
159
|
+
index_offset: int = 0,
|
|
160
|
+
return_id: bool = False,
|
|
161
|
+
) -> Union[Callable, Tuple[Callable, str]]:
|
|
162
|
+
"""Find a layer in a model either by his name or by his index.
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
model (nn.Module): model whose identified layer will be returned
|
|
166
|
+
layer_id (Union[str, int]): layer identifier
|
|
167
|
+
index_offset (int): index offset to find layers located before (negative
|
|
168
|
+
offset) or after (positive offset) the identified layer
|
|
169
|
+
return_id (bool): if True, the layer will be returned with its id
|
|
170
|
+
|
|
171
|
+
Returns:
|
|
172
|
+
Union[Callable, Tuple[Callable, str]]: the corresponding layer and its id if
|
|
173
|
+
return_id is True.
|
|
174
|
+
"""
|
|
175
|
+
raise NotImplementedError()
|
|
176
|
+
|
|
177
|
+
def __call__(self, inputs: TensorType) -> TensorType:
|
|
178
|
+
"""
|
|
179
|
+
Convenience wrapper for predict_tensor().
|
|
180
|
+
|
|
181
|
+
Args:
|
|
182
|
+
inputs (Union[DatasetType, TensorType]): input tensor
|
|
183
|
+
|
|
184
|
+
Returns:
|
|
185
|
+
TensorType: features
|
|
186
|
+
"""
|
|
187
|
+
return self.predict_tensor(inputs)
|
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Copyright IRT Antoine de Saint Exupéry et Université Paul Sabatier Toulouse III - All
|
|
3
|
+
# rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
|
|
4
|
+
# CRIAQ and ANITI - https://www.deel.ai/
|
|
5
|
+
#
|
|
6
|
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
7
|
+
# of this software and associated documentation files (the "Software"), to deal
|
|
8
|
+
# in the Software without restriction, including without limitation the rights
|
|
9
|
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
10
|
+
# copies of the Software, and to permit persons to whom the Software is
|
|
11
|
+
# furnished to do so, subject to the following conditions:
|
|
12
|
+
#
|
|
13
|
+
# The above copyright notice and this permission notice shall be included in all
|
|
14
|
+
# copies or substantial portions of the Software.
|
|
15
|
+
#
|
|
16
|
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
17
|
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
18
|
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
19
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
20
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
21
|
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
22
|
+
# SOFTWARE.
|
|
23
|
+
from typing import Optional
|
|
24
|
+
|
|
25
|
+
import torch
|
|
26
|
+
from torch import nn
|
|
27
|
+
|
|
28
|
+
from ..types import Callable
|
|
29
|
+
from ..types import List
|
|
30
|
+
from ..types import TensorType
|
|
31
|
+
from ..types import Tuple
|
|
32
|
+
from ..types import Union
|
|
33
|
+
from ..utils.torch_operator import sanitize_input
|
|
34
|
+
from .torch_feature_extractor import TorchFeatureExtractor
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class HFTorchFeatureExtractor(TorchFeatureExtractor):
|
|
38
|
+
"""
|
|
39
|
+
Feature extractor based on "model" to construct a feature space
|
|
40
|
+
on which OOD detection is performed. The features can be the output
|
|
41
|
+
activation values of internal model layers,
|
|
42
|
+
or the output of the model (logits).
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
model: model to extract the features from
|
|
46
|
+
feature_layers_id: list of str or int that identify features to output.
|
|
47
|
+
If int, the rank of the layer in the layer list
|
|
48
|
+
If str, the name of the layer.
|
|
49
|
+
Important: for HFTorchFeatureExtractor, we use features from the
|
|
50
|
+
hidden states returned by model(input, output_hidden_states=True) in
|
|
51
|
+
addition to other features computed like in TorchFeatureExtractor.
|
|
52
|
+
To select the hidden states as feature, identify the layer by hidden_i,
|
|
53
|
+
with i the index of the hidden state.
|
|
54
|
+
Defaults to [].
|
|
55
|
+
head_layer_id (int, str): identifier of the head layer.
|
|
56
|
+
If int, the rank of the layer in the layer list
|
|
57
|
+
If str, the name of the layer.
|
|
58
|
+
We recommend to keep the default value for HFTorchFeatureExtractor unless
|
|
59
|
+
you know what you are doing.
|
|
60
|
+
Defaults to -1
|
|
61
|
+
react_threshold: if not None, penultimate layer activations are clipped under
|
|
62
|
+
this threshold value (useful for ReAct). Defaults to None.
|
|
63
|
+
scale_percentile: if not None, the features are scaled
|
|
64
|
+
following the method of Xu et al., ICLR 2024.
|
|
65
|
+
Defaults to None.
|
|
66
|
+
ash_percentile: if not None, the features are scaled following
|
|
67
|
+
the method of Djurisic et al., ICLR 2023.
|
|
68
|
+
return_penultimate (bool): if True, the penultimate values are returned,
|
|
69
|
+
i.e. the input to the head_layer.
|
|
70
|
+
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
def __init__(
|
|
74
|
+
self,
|
|
75
|
+
model: nn.Module,
|
|
76
|
+
feature_layers_id: List[int] = [],
|
|
77
|
+
head_layer_id: Optional[Union[int, str]] = -1,
|
|
78
|
+
react_threshold: Optional[float] = None,
|
|
79
|
+
scale_percentile: Optional[float] = None,
|
|
80
|
+
ash_percentile: Optional[float] = None,
|
|
81
|
+
return_penultimate: Optional[bool] = False,
|
|
82
|
+
):
|
|
83
|
+
super().__init__(
|
|
84
|
+
model=model,
|
|
85
|
+
feature_layers_id=feature_layers_id,
|
|
86
|
+
head_layer_id=head_layer_id,
|
|
87
|
+
react_threshold=react_threshold,
|
|
88
|
+
scale_percentile=scale_percentile,
|
|
89
|
+
ash_percentile=ash_percentile,
|
|
90
|
+
return_penultimate=return_penultimate,
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
self._features = {layer: torch.empty(0) for layer in self._hook_layers_id}
|
|
94
|
+
|
|
95
|
+
def _parse_hf_hidden_state(self, feature_layer_id) -> Tuple[bool, Union[int, str]]:
|
|
96
|
+
"""Parse the feature_layer_id to check if it is a hidden state from HF model.
|
|
97
|
+
If it is, return True and the index of the hidden state.
|
|
98
|
+
If it is not, return False and the feature_layer_id.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
feature_layer_id (Union[int, str]): feature layer id to parse.
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
Tuple[bool, Union[int, str]]: is_hf_hidden_state, feature_layer_id
|
|
105
|
+
"""
|
|
106
|
+
if (
|
|
107
|
+
isinstance(feature_layer_id, str)
|
|
108
|
+
and len(feature_layer_id) >= 7
|
|
109
|
+
and feature_layer_id[:7] == "hidden_"
|
|
110
|
+
):
|
|
111
|
+
return True, int(feature_layer_id[7:])
|
|
112
|
+
|
|
113
|
+
return False, feature_layer_id
|
|
114
|
+
|
|
115
|
+
@property
|
|
116
|
+
def _hook_layers_id(self) -> List[Union[int, str]]:
|
|
117
|
+
"""Get the list of hook layer ids to be used for feature extraction.
|
|
118
|
+
This list excludes hf_hidden_states because it feature extraction is already
|
|
119
|
+
handled by HF transformers in that case.
|
|
120
|
+
"""
|
|
121
|
+
hook_layer_ids = []
|
|
122
|
+
for feature_layer_id in self.feature_layers_id:
|
|
123
|
+
if (
|
|
124
|
+
isinstance(feature_layer_id, str)
|
|
125
|
+
and len(feature_layer_id) >= 7
|
|
126
|
+
and feature_layer_id[:7] == "hidden_"
|
|
127
|
+
):
|
|
128
|
+
continue
|
|
129
|
+
|
|
130
|
+
hook_layer_ids.append(feature_layer_id)
|
|
131
|
+
|
|
132
|
+
return hook_layer_ids
|
|
133
|
+
|
|
134
|
+
@sanitize_input
|
|
135
|
+
def predict_tensor(
|
|
136
|
+
self,
|
|
137
|
+
x: TensorType,
|
|
138
|
+
postproc_fns: Optional[List[Callable]] = None,
|
|
139
|
+
detach: bool = True,
|
|
140
|
+
) -> Tuple[List[torch.Tensor], torch.Tensor]:
|
|
141
|
+
"""Get the projection of tensor in the feature space of self.model
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
x (TensorType): input tensor (or dataset elem)
|
|
145
|
+
postproc_fns (Optional[List[Callable]]): postprocessing function to apply to
|
|
146
|
+
each feature immediately after forward. Default to None.
|
|
147
|
+
detach (bool): if True, return features detached from the computational
|
|
148
|
+
graph. Defaults to True.
|
|
149
|
+
|
|
150
|
+
Returns:
|
|
151
|
+
List[torch.Tensor], torch.Tensor: features, logits
|
|
152
|
+
"""
|
|
153
|
+
if x.device != self._device:
|
|
154
|
+
x = x.to(self._device)
|
|
155
|
+
outputs = self.model(x, output_hidden_states=True, return_dict=True)
|
|
156
|
+
|
|
157
|
+
features = []
|
|
158
|
+
for feature_layer_id in self.feature_layers_id:
|
|
159
|
+
is_hf_hidden_state, feature_layer_id = self._parse_hf_hidden_state(
|
|
160
|
+
feature_layer_id
|
|
161
|
+
)
|
|
162
|
+
if is_hf_hidden_state:
|
|
163
|
+
features.append(
|
|
164
|
+
outputs["hidden_states"][feature_layer_id].detach()
|
|
165
|
+
if detach
|
|
166
|
+
else outputs["hidden_states"][feature_layer_id]
|
|
167
|
+
)
|
|
168
|
+
else:
|
|
169
|
+
features.append(
|
|
170
|
+
self._features[feature_layer_id].detach()
|
|
171
|
+
if detach
|
|
172
|
+
else self._features[feature_layer_id]
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
logits = outputs.logits.detach() if detach else outputs.logits
|
|
176
|
+
|
|
177
|
+
if postproc_fns is not None:
|
|
178
|
+
features = [
|
|
179
|
+
postproc_fn(feature)
|
|
180
|
+
for feature, postproc_fn in zip(features, postproc_fns)
|
|
181
|
+
]
|
|
182
|
+
|
|
183
|
+
self._last_logits = logits
|
|
184
|
+
return features, logits
|