oodeel 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. oodeel/__init__.py +28 -0
  2. oodeel/aggregator/__init__.py +26 -0
  3. oodeel/aggregator/base.py +70 -0
  4. oodeel/aggregator/fisher.py +259 -0
  5. oodeel/aggregator/mean.py +72 -0
  6. oodeel/aggregator/std.py +86 -0
  7. oodeel/datasets/__init__.py +24 -0
  8. oodeel/datasets/data_handler.py +334 -0
  9. oodeel/datasets/deprecated/DEPRECATED_data_handler.py +236 -0
  10. oodeel/datasets/deprecated/DEPRECATED_ooddataset.py +330 -0
  11. oodeel/datasets/deprecated/DEPRECATED_tf_data_handler.py +671 -0
  12. oodeel/datasets/deprecated/DEPRECATED_torch_data_handler.py +769 -0
  13. oodeel/datasets/deprecated/__init__.py +31 -0
  14. oodeel/datasets/tf_data_handler.py +600 -0
  15. oodeel/datasets/torch_data_handler.py +672 -0
  16. oodeel/eval/__init__.py +22 -0
  17. oodeel/eval/metrics.py +218 -0
  18. oodeel/eval/plots/__init__.py +27 -0
  19. oodeel/eval/plots/features.py +345 -0
  20. oodeel/eval/plots/metrics.py +118 -0
  21. oodeel/eval/plots/plotly.py +162 -0
  22. oodeel/extractor/__init__.py +35 -0
  23. oodeel/extractor/feature_extractor.py +187 -0
  24. oodeel/extractor/hf_torch_feature_extractor.py +184 -0
  25. oodeel/extractor/keras_feature_extractor.py +409 -0
  26. oodeel/extractor/torch_feature_extractor.py +506 -0
  27. oodeel/methods/__init__.py +47 -0
  28. oodeel/methods/base.py +570 -0
  29. oodeel/methods/dknn.py +185 -0
  30. oodeel/methods/energy.py +119 -0
  31. oodeel/methods/entropy.py +113 -0
  32. oodeel/methods/gen.py +113 -0
  33. oodeel/methods/gram.py +274 -0
  34. oodeel/methods/mahalanobis.py +209 -0
  35. oodeel/methods/mls.py +113 -0
  36. oodeel/methods/odin.py +109 -0
  37. oodeel/methods/rmds.py +172 -0
  38. oodeel/methods/she.py +159 -0
  39. oodeel/methods/vim.py +273 -0
  40. oodeel/preprocess/__init__.py +31 -0
  41. oodeel/preprocess/tf_preprocess.py +95 -0
  42. oodeel/preprocess/torch_preprocess.py +97 -0
  43. oodeel/types/__init__.py +75 -0
  44. oodeel/utils/__init__.py +38 -0
  45. oodeel/utils/general_utils.py +97 -0
  46. oodeel/utils/operator.py +253 -0
  47. oodeel/utils/tf_operator.py +269 -0
  48. oodeel/utils/tf_training_tools.py +219 -0
  49. oodeel/utils/torch_operator.py +292 -0
  50. oodeel/utils/torch_training_tools.py +303 -0
  51. oodeel-0.4.0.dist-info/METADATA +409 -0
  52. oodeel-0.4.0.dist-info/RECORD +63 -0
  53. oodeel-0.4.0.dist-info/WHEEL +5 -0
  54. oodeel-0.4.0.dist-info/licenses/LICENSE +21 -0
  55. oodeel-0.4.0.dist-info/top_level.txt +2 -0
  56. tests/__init__.py +22 -0
  57. tests/tests_tensorflow/__init__.py +37 -0
  58. tests/tests_tensorflow/tf_methods_utils.py +140 -0
  59. tests/tests_tensorflow/tools_tf.py +86 -0
  60. tests/tests_torch/__init__.py +38 -0
  61. tests/tests_torch/tools_torch.py +151 -0
  62. tests/tests_torch/torch_methods_utils.py +148 -0
  63. tests/tools_operator.py +153 -0
@@ -0,0 +1,118 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright IRT Antoine de Saint Exupéry et Université Paul Sabatier Toulouse III - All
3
+ # rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
4
+ # CRIAQ and ANITI - https://www.deel.ai/
5
+ #
6
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
7
+ # of this software and associated documentation files (the "Software"), to deal
8
+ # in the Software without restriction, including without limitation the rights
9
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
+ # copies of the Software, and to permit persons to whom the Software is
11
+ # furnished to do so, subject to the following conditions:
12
+ #
13
+ # The above copyright notice and this permission notice shall be included in all
14
+ # copies or substantial portions of the Software.
15
+ #
16
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
+ # SOFTWARE.
23
+ import matplotlib.pyplot as plt
24
+ import numpy as np
25
+ import seaborn as sns
26
+
27
+ from oodeel.eval.metrics import bench_metrics
28
+ from oodeel.eval.metrics import get_curve
29
+
30
+ sns.set_style("darkgrid")
31
+
32
+
33
+ def plot_ood_scores(
34
+ scores_in: np.ndarray,
35
+ scores_out: np.ndarray,
36
+ log_scale: bool = False,
37
+ title: str = None,
38
+ ):
39
+ """Plot histograms of OOD detection scores for ID and OOD distribution, using
40
+ matplotlib and seaborn.
41
+
42
+ Args:
43
+ scores_in (np.ndarray): OOD detection scores for ID data.
44
+ scores_out (np.ndarray): OOD detection scores for OOD data.
45
+ log_scale (bool, optional): If True, apply a log scale on x axis. Defaults to
46
+ False.
47
+ title (str, optional): Custom figure title. If None a default one is provided.
48
+ Defaults to None.
49
+ """
50
+ title = title or "Histograms of OOD detection scores"
51
+ ax1 = sns.histplot(
52
+ data=scores_in,
53
+ alpha=0.5,
54
+ label="ID data",
55
+ stat="density",
56
+ log_scale=log_scale,
57
+ bins=100,
58
+ kde=True,
59
+ )
60
+ ax2 = sns.histplot(
61
+ data=scores_out,
62
+ alpha=0.5,
63
+ label="OOD data",
64
+ stat="density",
65
+ log_scale=log_scale,
66
+ bins=100,
67
+ kde=True,
68
+ )
69
+ ymax = max(ax1.get_ylim()[1], ax2.get_ylim()[1])
70
+ threshold = np.percentile(scores_out, q=5.0)
71
+ plt.vlines(
72
+ x=[threshold],
73
+ ymin=0,
74
+ ymax=ymax,
75
+ colors=["red"],
76
+ linestyles=["dashed"],
77
+ alpha=0.7,
78
+ label="TPR=95%",
79
+ )
80
+ plt.xlabel("OOD score")
81
+ plt.legend()
82
+ plt.title(title, weight="bold").set_fontsize(11)
83
+
84
+
85
+ def plot_roc_curve(scores_in: np.ndarray, scores_out: np.ndarray, title: str = None):
86
+ """Plot ROC curve for OOD detection task, using matplotlib and seaborn.
87
+
88
+ Args:
89
+ scores_in (np.ndarray): OOD detection scores for ID data.
90
+ scores_out (np.ndarray): OOD detection scores for OOD data.
91
+ title (str, optional): Custom figure title. If None a default one is provided.
92
+ Defaults to None.
93
+ """
94
+ # compute auroc
95
+ metrics = bench_metrics(
96
+ (scores_in, scores_out),
97
+ metrics=["auroc", "fpr95tpr"],
98
+ )
99
+ auroc, fpr95tpr = metrics["auroc"], metrics["fpr95tpr"]
100
+
101
+ # roc
102
+ fpr, tpr, _, _, _ = get_curve(
103
+ scores=np.concatenate([scores_in, scores_out]),
104
+ labels=np.concatenate([scores_in * 0 + 0, scores_out * 0 + 1]),
105
+ )
106
+
107
+ # plot roc
108
+ title = title or "ROC curve (AuC = {:.3f})".format(auroc)
109
+ plt.plot(fpr, tpr)
110
+ plt.fill_between(fpr, tpr, np.zeros_like(tpr), alpha=0.5)
111
+ plt.plot([fpr95tpr, fpr95tpr, 0], [0, 0.95, 0.95], "--", color="red", alpha=0.7)
112
+ plt.scatter([fpr95tpr], [0.95], marker="o", alpha=0.7, color="red", label="TPR=95%")
113
+ plt.xlabel("FPR")
114
+ plt.ylabel("TPR")
115
+ plt.xlim([-0.01, 1.01])
116
+ plt.ylim([-0.01, 1.01])
117
+ plt.legend()
118
+ plt.title(title, weight="bold").set_fontsize(11)
@@ -0,0 +1,162 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright IRT Antoine de Saint Exupéry et Université Paul Sabatier Toulouse III - All
3
+ # rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
4
+ # CRIAQ and ANITI - https://www.deel.ai/
5
+ #
6
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
7
+ # of this software and associated documentation files (the "Software"), to deal
8
+ # in the Software without restriction, including without limitation the rights
9
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
+ # copies of the Software, and to permit persons to whom the Software is
11
+ # furnished to do so, subject to the following conditions:
12
+ #
13
+ # The above copyright notice and this permission notice shall be included in all
14
+ # copies or substantial portions of the Software.
15
+ #
16
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
+ # SOFTWARE.
23
+ import numpy as np
24
+ import pandas as pd
25
+ import plotly.express as px
26
+ import seaborn as sns
27
+ from sklearn.decomposition import PCA
28
+ from sklearn.manifold import TSNE
29
+
30
+ from ...types import Callable
31
+ from ...types import DatasetType
32
+ from ...types import Union
33
+ from ...utils import import_backend_specific_stuff
34
+
35
+ sns.set_style("darkgrid")
36
+
37
+ PROJ_DICT = {
38
+ "TSNE": {
39
+ "name": "t-SNE",
40
+ "class": TSNE,
41
+ "default_kwargs": dict(perplexity=30.0, n_iter=800, random_state=0),
42
+ },
43
+ "PCA": {"name": "PCA", "class": PCA, "default_kwargs": dict()},
44
+ }
45
+
46
+
47
+ def plotly_3D_features(
48
+ model: Callable,
49
+ in_dataset: DatasetType,
50
+ output_layer_id: Union[int, str],
51
+ out_dataset: DatasetType = None,
52
+ proj_method: str = "TSNE",
53
+ max_samples: int = 4000,
54
+ title: str = None,
55
+ **proj_kwargs,
56
+ ):
57
+ """Visualize ID and OOD features of a model on a 3D space using dimensionality
58
+ reduction methods and matplotlib scatter function. Different projection methods are
59
+ available: TSNE, PCA. This function requires the package plotly to be installed to
60
+ run an interactive 3D scatter plot.
61
+
62
+ Args:
63
+ model (Callable): Torch or Keras model.
64
+ in_dataset (DatasetType): In-distribution dataset (torch dataloader or tf
65
+ dataset) that will be projected on the model feature space.
66
+ output_layer_id (Union[int, str]): Identifier for the layer to inspect.
67
+ out_dataset (DatasetType, optional): Out-of-distribution dataset (torch
68
+ dataloader or tf dataset) that will be projected on the model feature space
69
+ if not equal to None. Defaults to None.
70
+ proj_method (str, optional): Projection method for 2d dimensionality reduction.
71
+ Defaults to "TSNE", alternative: "PCA".
72
+ max_samples (int, optional): Max samples to display on the scatter plot.
73
+ Defaults to 4000.
74
+ title (str, optional): Custom figure title. Defaults to None.
75
+ """
76
+ max_samples = max_samples if out_dataset is None else max_samples // 2
77
+
78
+ # feature extractor
79
+ _, _, op, FeatureExtractorClass = import_backend_specific_stuff(model)
80
+ feature_extractor = FeatureExtractorClass(model, [output_layer_id])
81
+
82
+ # === extract id features ===
83
+ # features
84
+ in_features, _ = feature_extractor.predict(in_dataset)
85
+ in_features = op.convert_to_numpy(op.flatten(in_features[0]))[:max_samples]
86
+
87
+ # labels
88
+ in_labels = []
89
+ for _, batch_y in in_dataset:
90
+ in_labels.append(op.convert_to_numpy(batch_y))
91
+ in_labels = np.concatenate(in_labels)[:max_samples]
92
+ in_labels = list(map(lambda x: f"class {x}", in_labels))
93
+
94
+ # === extract ood features ===
95
+ if out_dataset is not None:
96
+ # features
97
+ out_features, _ = feature_extractor.predict(out_dataset)
98
+ out_features = op.convert_to_numpy(op.flatten(out_features[0]))[:max_samples]
99
+
100
+ # labels
101
+ out_labels = np.array(["unknown"] * len(out_features))
102
+
103
+ # concatenate id and ood items
104
+ features = np.concatenate([out_features, in_features])
105
+ labels = np.concatenate([out_labels, in_labels])
106
+ data_type = np.array(["OOD"] * len(out_labels) + ["ID"] * len(in_labels))
107
+ points_size = np.array([1] * len(out_labels) + [3] * len(in_labels))
108
+ else:
109
+ features = in_features
110
+ labels = in_labels
111
+ data_type = np.array(["ID"] * len(in_labels))
112
+ points_size = np.array([3] * len(in_labels))
113
+
114
+ # === project on 3d space using tsne or pca ===
115
+ proj_class = PROJ_DICT[proj_method]["class"]
116
+ p_kwargs = PROJ_DICT[proj_method]["default_kwargs"]
117
+ p_kwargs.update(proj_kwargs)
118
+ projector = proj_class(
119
+ n_components=3,
120
+ **p_kwargs,
121
+ )
122
+ features_proj = projector.fit_transform(features)
123
+
124
+ # === plot 3d features ===
125
+ features_dim = features.shape[1]
126
+ method_str = PROJ_DICT[proj_method]["name"]
127
+ title = (
128
+ title
129
+ or f"{method_str} 3D projection\n"
130
+ + f"[layer {output_layer_id}, dim: {features_dim}]"
131
+ )
132
+
133
+ x, y, z = features_proj.T
134
+ df = pd.DataFrame(
135
+ {
136
+ "dim 1": x,
137
+ "dim 2": y,
138
+ "dim 3": z,
139
+ "class": labels,
140
+ "data type": data_type,
141
+ "size": points_size,
142
+ }
143
+ )
144
+
145
+ # 3D projection
146
+ fig = px.scatter_3d(
147
+ data_frame=df,
148
+ x="dim 1",
149
+ y="dim 2",
150
+ z="dim 3",
151
+ color="class",
152
+ symbol="data type",
153
+ size="size",
154
+ opacity=1,
155
+ category_orders={"class": np.unique(df["class"])},
156
+ symbol_map={"OOD": "circle", "ID": "diamond"},
157
+ )
158
+
159
+ fig.update_layout(
160
+ title={"text": title, "y": 0.9, "x": 0.5, "xanchor": "center", "yanchor": "top"}
161
+ )
162
+ fig.show()
@@ -0,0 +1,35 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright IRT Antoine de Saint Exupéry et Université Paul Sabatier Toulouse III - All
3
+ # rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
4
+ # CRIAQ and ANITI - https://www.deel.ai/
5
+ #
6
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
7
+ # of this software and associated documentation files (the "Software"), to deal
8
+ # in the Software without restriction, including without limitation the rights
9
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
+ # copies of the Software, and to permit persons to whom the Software is
11
+ # furnished to do so, subject to the following conditions:
12
+ #
13
+ # The above copyright notice and this permission notice shall be included in all
14
+ # copies or substantial portions of the Software.
15
+ #
16
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
+ # SOFTWARE.
23
+
24
+ try:
25
+ import tensorflow as tf
26
+ from .keras_feature_extractor import KerasFeatureExtractor
27
+ except ImportError:
28
+ pass
29
+
30
+ try:
31
+ import torch
32
+ from .torch_feature_extractor import TorchFeatureExtractor
33
+ from .hf_torch_feature_extractor import HFTorchFeatureExtractor
34
+ except ImportError:
35
+ pass
@@ -0,0 +1,187 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright IRT Antoine de Saint Exupéry et Université Paul Sabatier Toulouse III - All
3
+ # rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
4
+ # CRIAQ and ANITI - https://www.deel.ai/
5
+ #
6
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
7
+ # of this software and associated documentation files (the "Software"), to deal
8
+ # in the Software without restriction, including without limitation the rights
9
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
+ # copies of the Software, and to permit persons to whom the Software is
11
+ # furnished to do so, subject to the following conditions:
12
+ #
13
+ # The above copyright notice and this permission notice shall be included in all
14
+ # copies or substantial portions of the Software.
15
+ #
16
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
+ # SOFTWARE.
23
+ from abc import ABC
24
+ from abc import abstractmethod
25
+
26
+ from ..types import Callable
27
+ from ..types import DatasetType
28
+ from ..types import ItemType
29
+ from ..types import List
30
+ from ..types import Optional
31
+ from ..types import TensorType
32
+ from ..types import Tuple
33
+ from ..types import Union
34
+
35
+
36
+ class FeatureExtractor(ABC):
37
+ """
38
+ Feature extractor based on "model" to construct a feature space
39
+ on which OOD detection is performed. The features can be the output
40
+ activation values of internal model layers, or the output of the model
41
+ (softmax/logits).
42
+
43
+ Args:
44
+ model: model to extract the features from
45
+ feature_layers_id: list of str or int that identify features to output.
46
+ If int, the rank of the layer in the layer list
47
+ If str, the name of the layer.
48
+ Defaults to [].
49
+ head_layer_id (int, str): identifier of the head layer.
50
+ If int, the rank of the layer in the layer list
51
+ If str, the name of the layer.
52
+ Defaults to -1
53
+ input_layer_id: input layer of the feature extractor (to avoid useless forwards
54
+ when working on the feature space without finetuning the bottom of the
55
+ model).
56
+ Defaults to None.
57
+ react_threshold: if not None, penultimate layer activations are clipped under
58
+ this threshold value (useful for ReAct). Defaults to None.
59
+ scale_percentile: if not None, the features are scaled
60
+ following the method of Xu et al., ICLR 2024.
61
+ Defaults to None.
62
+ ash_percentile: if not None, the features are scaled following
63
+ the method of Djurisic et al., ICLR 2023.
64
+ return_penultimate (bool): if True, the penultimate values are returned,
65
+ i.e. the input to the head_layer.
66
+ """
67
+
68
+ def __init__(
69
+ self,
70
+ model: Callable,
71
+ feature_layers_id: List[Union[int, str]] = [],
72
+ head_layer_id: Optional[Union[int, str]] = -1,
73
+ input_layer_id: Optional[Union[int, str]] = [0],
74
+ react_threshold: Optional[float] = None,
75
+ scale_percentile: Optional[float] = None,
76
+ ash_percentile: Optional[float] = None,
77
+ return_penultimate: Optional[bool] = False,
78
+ ):
79
+ if not isinstance(feature_layers_id, list):
80
+ feature_layers_id = [feature_layers_id]
81
+
82
+ self.feature_layers_id = feature_layers_id
83
+ self.head_layer_id = head_layer_id
84
+ self.input_layer_id = input_layer_id
85
+ self.react_threshold = react_threshold
86
+ self.scale_percentile = scale_percentile
87
+ self.ash_percentile = ash_percentile
88
+ self.return_penultimate = return_penultimate
89
+ self.model = model
90
+ self.extractor = self.prepare_extractor()
91
+
92
+ @abstractmethod
93
+ def prepare_extractor(self) -> None:
94
+ """
95
+ prepare FeatureExtractor for feature extraction
96
+ (the way to achieve this depends on the underlying library)
97
+ """
98
+ raise NotImplementedError()
99
+
100
+ @abstractmethod
101
+ def get_weights(self, layer_id: Union[str, int]) -> List[TensorType]:
102
+ """
103
+ Get the weights of a layer
104
+
105
+ Args:
106
+ layer_id (Union[int, str]): layer identifier
107
+
108
+ Returns:
109
+ weights matrix
110
+ """
111
+ raise NotImplementedError()
112
+
113
+ @abstractmethod
114
+ def predict_tensor(
115
+ self,
116
+ tensor: TensorType,
117
+ postproc_fns: Optional[List[Callable]] = None,
118
+ ) -> Tuple[List[TensorType], TensorType]:
119
+ """Get the projection of tensor in the feature space of self.model
120
+
121
+ Args:
122
+ tensor (TensorType): input tensor (or dataset elem)
123
+ postproc_fns (Optional[Callable]): postprocessing function to apply to each
124
+ feature immediately after forward. Default to None.
125
+
126
+ Returns:
127
+ Tuple[List[TensorType], TensorType]: features, logits
128
+ """
129
+ raise NotImplementedError()
130
+
131
+ @abstractmethod
132
+ def predict(
133
+ self,
134
+ dataset: Union[ItemType, DatasetType],
135
+ postproc_fns: Optional[List[Callable]] = None,
136
+ verbose: bool = False,
137
+ **kwargs,
138
+ ) -> Tuple[List[TensorType], dict]:
139
+ """Get the projection of the dataset in the feature space of self.model
140
+
141
+ Args:
142
+ dataset (Union[ItemType, DatasetType]): input dataset
143
+ postproc_fns (Optional[Callable]): postprocessing function to apply to each
144
+ feature immediately after forward. Default to None.
145
+ verbose (bool): if True, display a progress bar. Defaults to False.
146
+ kwargs (dict): additional arguments not considered for prediction
147
+
148
+ Returns:
149
+ List[TensorType], dict: features and extra information (logits, labels) as a
150
+ dictionary.
151
+ """
152
+ raise NotImplementedError()
153
+
154
+ @staticmethod
155
+ @abstractmethod
156
+ def find_layer(
157
+ model: Callable,
158
+ layer_id: Union[str, int],
159
+ index_offset: int = 0,
160
+ return_id: bool = False,
161
+ ) -> Union[Callable, Tuple[Callable, str]]:
162
+ """Find a layer in a model either by his name or by his index.
163
+
164
+ Args:
165
+ model (nn.Module): model whose identified layer will be returned
166
+ layer_id (Union[str, int]): layer identifier
167
+ index_offset (int): index offset to find layers located before (negative
168
+ offset) or after (positive offset) the identified layer
169
+ return_id (bool): if True, the layer will be returned with its id
170
+
171
+ Returns:
172
+ Union[Callable, Tuple[Callable, str]]: the corresponding layer and its id if
173
+ return_id is True.
174
+ """
175
+ raise NotImplementedError()
176
+
177
+ def __call__(self, inputs: TensorType) -> TensorType:
178
+ """
179
+ Convenience wrapper for predict_tensor().
180
+
181
+ Args:
182
+ inputs (Union[DatasetType, TensorType]): input tensor
183
+
184
+ Returns:
185
+ TensorType: features
186
+ """
187
+ return self.predict_tensor(inputs)
@@ -0,0 +1,184 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright IRT Antoine de Saint Exupéry et Université Paul Sabatier Toulouse III - All
3
+ # rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
4
+ # CRIAQ and ANITI - https://www.deel.ai/
5
+ #
6
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
7
+ # of this software and associated documentation files (the "Software"), to deal
8
+ # in the Software without restriction, including without limitation the rights
9
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
+ # copies of the Software, and to permit persons to whom the Software is
11
+ # furnished to do so, subject to the following conditions:
12
+ #
13
+ # The above copyright notice and this permission notice shall be included in all
14
+ # copies or substantial portions of the Software.
15
+ #
16
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
+ # SOFTWARE.
23
+ from typing import Optional
24
+
25
+ import torch
26
+ from torch import nn
27
+
28
+ from ..types import Callable
29
+ from ..types import List
30
+ from ..types import TensorType
31
+ from ..types import Tuple
32
+ from ..types import Union
33
+ from ..utils.torch_operator import sanitize_input
34
+ from .torch_feature_extractor import TorchFeatureExtractor
35
+
36
+
37
+ class HFTorchFeatureExtractor(TorchFeatureExtractor):
38
+ """
39
+ Feature extractor based on "model" to construct a feature space
40
+ on which OOD detection is performed. The features can be the output
41
+ activation values of internal model layers,
42
+ or the output of the model (logits).
43
+
44
+ Args:
45
+ model: model to extract the features from
46
+ feature_layers_id: list of str or int that identify features to output.
47
+ If int, the rank of the layer in the layer list
48
+ If str, the name of the layer.
49
+ Important: for HFTorchFeatureExtractor, we use features from the
50
+ hidden states returned by model(input, output_hidden_states=True) in
51
+ addition to other features computed like in TorchFeatureExtractor.
52
+ To select the hidden states as feature, identify the layer by hidden_i,
53
+ with i the index of the hidden state.
54
+ Defaults to [].
55
+ head_layer_id (int, str): identifier of the head layer.
56
+ If int, the rank of the layer in the layer list
57
+ If str, the name of the layer.
58
+ We recommend to keep the default value for HFTorchFeatureExtractor unless
59
+ you know what you are doing.
60
+ Defaults to -1
61
+ react_threshold: if not None, penultimate layer activations are clipped under
62
+ this threshold value (useful for ReAct). Defaults to None.
63
+ scale_percentile: if not None, the features are scaled
64
+ following the method of Xu et al., ICLR 2024.
65
+ Defaults to None.
66
+ ash_percentile: if not None, the features are scaled following
67
+ the method of Djurisic et al., ICLR 2023.
68
+ return_penultimate (bool): if True, the penultimate values are returned,
69
+ i.e. the input to the head_layer.
70
+
71
+ """
72
+
73
+ def __init__(
74
+ self,
75
+ model: nn.Module,
76
+ feature_layers_id: List[int] = [],
77
+ head_layer_id: Optional[Union[int, str]] = -1,
78
+ react_threshold: Optional[float] = None,
79
+ scale_percentile: Optional[float] = None,
80
+ ash_percentile: Optional[float] = None,
81
+ return_penultimate: Optional[bool] = False,
82
+ ):
83
+ super().__init__(
84
+ model=model,
85
+ feature_layers_id=feature_layers_id,
86
+ head_layer_id=head_layer_id,
87
+ react_threshold=react_threshold,
88
+ scale_percentile=scale_percentile,
89
+ ash_percentile=ash_percentile,
90
+ return_penultimate=return_penultimate,
91
+ )
92
+
93
+ self._features = {layer: torch.empty(0) for layer in self._hook_layers_id}
94
+
95
+ def _parse_hf_hidden_state(self, feature_layer_id) -> Tuple[bool, Union[int, str]]:
96
+ """Parse the feature_layer_id to check if it is a hidden state from HF model.
97
+ If it is, return True and the index of the hidden state.
98
+ If it is not, return False and the feature_layer_id.
99
+
100
+ Args:
101
+ feature_layer_id (Union[int, str]): feature layer id to parse.
102
+
103
+ Returns:
104
+ Tuple[bool, Union[int, str]]: is_hf_hidden_state, feature_layer_id
105
+ """
106
+ if (
107
+ isinstance(feature_layer_id, str)
108
+ and len(feature_layer_id) >= 7
109
+ and feature_layer_id[:7] == "hidden_"
110
+ ):
111
+ return True, int(feature_layer_id[7:])
112
+
113
+ return False, feature_layer_id
114
+
115
+ @property
116
+ def _hook_layers_id(self) -> List[Union[int, str]]:
117
+ """Get the list of hook layer ids to be used for feature extraction.
118
+ This list excludes hf_hidden_states because it feature extraction is already
119
+ handled by HF transformers in that case.
120
+ """
121
+ hook_layer_ids = []
122
+ for feature_layer_id in self.feature_layers_id:
123
+ if (
124
+ isinstance(feature_layer_id, str)
125
+ and len(feature_layer_id) >= 7
126
+ and feature_layer_id[:7] == "hidden_"
127
+ ):
128
+ continue
129
+
130
+ hook_layer_ids.append(feature_layer_id)
131
+
132
+ return hook_layer_ids
133
+
134
+ @sanitize_input
135
+ def predict_tensor(
136
+ self,
137
+ x: TensorType,
138
+ postproc_fns: Optional[List[Callable]] = None,
139
+ detach: bool = True,
140
+ ) -> Tuple[List[torch.Tensor], torch.Tensor]:
141
+ """Get the projection of tensor in the feature space of self.model
142
+
143
+ Args:
144
+ x (TensorType): input tensor (or dataset elem)
145
+ postproc_fns (Optional[List[Callable]]): postprocessing function to apply to
146
+ each feature immediately after forward. Default to None.
147
+ detach (bool): if True, return features detached from the computational
148
+ graph. Defaults to True.
149
+
150
+ Returns:
151
+ List[torch.Tensor], torch.Tensor: features, logits
152
+ """
153
+ if x.device != self._device:
154
+ x = x.to(self._device)
155
+ outputs = self.model(x, output_hidden_states=True, return_dict=True)
156
+
157
+ features = []
158
+ for feature_layer_id in self.feature_layers_id:
159
+ is_hf_hidden_state, feature_layer_id = self._parse_hf_hidden_state(
160
+ feature_layer_id
161
+ )
162
+ if is_hf_hidden_state:
163
+ features.append(
164
+ outputs["hidden_states"][feature_layer_id].detach()
165
+ if detach
166
+ else outputs["hidden_states"][feature_layer_id]
167
+ )
168
+ else:
169
+ features.append(
170
+ self._features[feature_layer_id].detach()
171
+ if detach
172
+ else self._features[feature_layer_id]
173
+ )
174
+
175
+ logits = outputs.logits.detach() if detach else outputs.logits
176
+
177
+ if postproc_fns is not None:
178
+ features = [
179
+ postproc_fn(feature)
180
+ for feature, postproc_fn in zip(features, postproc_fns)
181
+ ]
182
+
183
+ self._last_logits = logits
184
+ return features, logits