oodeel 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- oodeel/__init__.py +28 -0
- oodeel/aggregator/__init__.py +26 -0
- oodeel/aggregator/base.py +70 -0
- oodeel/aggregator/fisher.py +259 -0
- oodeel/aggregator/mean.py +72 -0
- oodeel/aggregator/std.py +86 -0
- oodeel/datasets/__init__.py +24 -0
- oodeel/datasets/data_handler.py +334 -0
- oodeel/datasets/deprecated/DEPRECATED_data_handler.py +236 -0
- oodeel/datasets/deprecated/DEPRECATED_ooddataset.py +330 -0
- oodeel/datasets/deprecated/DEPRECATED_tf_data_handler.py +671 -0
- oodeel/datasets/deprecated/DEPRECATED_torch_data_handler.py +769 -0
- oodeel/datasets/deprecated/__init__.py +31 -0
- oodeel/datasets/tf_data_handler.py +600 -0
- oodeel/datasets/torch_data_handler.py +672 -0
- oodeel/eval/__init__.py +22 -0
- oodeel/eval/metrics.py +218 -0
- oodeel/eval/plots/__init__.py +27 -0
- oodeel/eval/plots/features.py +345 -0
- oodeel/eval/plots/metrics.py +118 -0
- oodeel/eval/plots/plotly.py +162 -0
- oodeel/extractor/__init__.py +35 -0
- oodeel/extractor/feature_extractor.py +187 -0
- oodeel/extractor/hf_torch_feature_extractor.py +184 -0
- oodeel/extractor/keras_feature_extractor.py +409 -0
- oodeel/extractor/torch_feature_extractor.py +506 -0
- oodeel/methods/__init__.py +47 -0
- oodeel/methods/base.py +570 -0
- oodeel/methods/dknn.py +185 -0
- oodeel/methods/energy.py +119 -0
- oodeel/methods/entropy.py +113 -0
- oodeel/methods/gen.py +113 -0
- oodeel/methods/gram.py +274 -0
- oodeel/methods/mahalanobis.py +209 -0
- oodeel/methods/mls.py +113 -0
- oodeel/methods/odin.py +109 -0
- oodeel/methods/rmds.py +172 -0
- oodeel/methods/she.py +159 -0
- oodeel/methods/vim.py +273 -0
- oodeel/preprocess/__init__.py +31 -0
- oodeel/preprocess/tf_preprocess.py +95 -0
- oodeel/preprocess/torch_preprocess.py +97 -0
- oodeel/types/__init__.py +75 -0
- oodeel/utils/__init__.py +38 -0
- oodeel/utils/general_utils.py +97 -0
- oodeel/utils/operator.py +253 -0
- oodeel/utils/tf_operator.py +269 -0
- oodeel/utils/tf_training_tools.py +219 -0
- oodeel/utils/torch_operator.py +292 -0
- oodeel/utils/torch_training_tools.py +303 -0
- oodeel-0.4.0.dist-info/METADATA +409 -0
- oodeel-0.4.0.dist-info/RECORD +63 -0
- oodeel-0.4.0.dist-info/WHEEL +5 -0
- oodeel-0.4.0.dist-info/licenses/LICENSE +21 -0
- oodeel-0.4.0.dist-info/top_level.txt +2 -0
- tests/__init__.py +22 -0
- tests/tests_tensorflow/__init__.py +37 -0
- tests/tests_tensorflow/tf_methods_utils.py +140 -0
- tests/tests_tensorflow/tools_tf.py +86 -0
- tests/tests_torch/__init__.py +38 -0
- tests/tests_torch/tools_torch.py +151 -0
- tests/tests_torch/torch_methods_utils.py +148 -0
- tests/tools_operator.py +153 -0
|
@@ -0,0 +1,409 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Copyright IRT Antoine de Saint Exupéry et Université Paul Sabatier Toulouse III - All
|
|
3
|
+
# rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
|
|
4
|
+
# CRIAQ and ANITI - https://www.deel.ai/
|
|
5
|
+
#
|
|
6
|
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
7
|
+
# of this software and associated documentation files (the "Software"), to deal
|
|
8
|
+
# in the Software without restriction, including without limitation the rights
|
|
9
|
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
10
|
+
# copies of the Software, and to permit persons to whom the Software is
|
|
11
|
+
# furnished to do so, subject to the following conditions:
|
|
12
|
+
#
|
|
13
|
+
# The above copyright notice and this permission notice shall be included in all
|
|
14
|
+
# copies or substantial portions of the Software.
|
|
15
|
+
#
|
|
16
|
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
17
|
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
18
|
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
19
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
20
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
21
|
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
22
|
+
# SOFTWARE.
|
|
23
|
+
from typing import get_args
|
|
24
|
+
from typing import Optional
|
|
25
|
+
|
|
26
|
+
import numpy as np
|
|
27
|
+
import tensorflow as tf
|
|
28
|
+
import tensorflow_probability as tfp
|
|
29
|
+
from tqdm import tqdm
|
|
30
|
+
|
|
31
|
+
from ..datasets.tf_data_handler import TFDataHandler
|
|
32
|
+
from ..types import ItemType
|
|
33
|
+
from ..types import List
|
|
34
|
+
from ..types import TensorType
|
|
35
|
+
from ..types import Tuple
|
|
36
|
+
from ..types import Union
|
|
37
|
+
from ..utils.tf_operator import sanitize_input
|
|
38
|
+
from .feature_extractor import FeatureExtractor
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class KerasFeatureExtractor(FeatureExtractor):
|
|
42
|
+
"""
|
|
43
|
+
Feature extractor based on "model" to construct a feature space
|
|
44
|
+
on which OOD detection is performed. The features can be the output
|
|
45
|
+
activation values of internal model layers, or the output of the model
|
|
46
|
+
(logits).
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
model: model to extract the features from
|
|
50
|
+
feature_layers_id: list of str or int that identify features to output.
|
|
51
|
+
If int, the rank of the layer in the layer list
|
|
52
|
+
If str, the name of the layer. Defaults to [].
|
|
53
|
+
head_layer_id (int, str): identifier of the head layer.
|
|
54
|
+
If int, the rank of the layer in the layer list
|
|
55
|
+
If str, the name of the layer.
|
|
56
|
+
Defaults to -1
|
|
57
|
+
input_layer_id: input layer of the feature extractor (to avoid useless forwards
|
|
58
|
+
when working on the feature space without finetuning the bottom of the
|
|
59
|
+
model).
|
|
60
|
+
Defaults to None.
|
|
61
|
+
react_threshold: if not None, penultimate layer activations are clipped under
|
|
62
|
+
this threshold value (useful for ReAct). Defaults to None.
|
|
63
|
+
scale_percentile: if not None, the features are scaled
|
|
64
|
+
following the method of Xu et al., ICLR 2024.
|
|
65
|
+
Defaults to None.
|
|
66
|
+
ash_percentile: if not None, the features are scaled following
|
|
67
|
+
the method of Djurisic et al., ICLR 2023.
|
|
68
|
+
return_penultimate (bool): if True, the penultimate values are returned,
|
|
69
|
+
i.e. the input to the head_layer.
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
def __init__(
|
|
73
|
+
self,
|
|
74
|
+
model: tf.keras.Model,
|
|
75
|
+
feature_layers_id: List[Union[int, str]] = [],
|
|
76
|
+
head_layer_id: Optional[Union[int, str]] = -1,
|
|
77
|
+
input_layer_id: Optional[Union[int, str]] = None,
|
|
78
|
+
react_threshold: Optional[float] = None,
|
|
79
|
+
scale_percentile: Optional[float] = None,
|
|
80
|
+
ash_percentile: Optional[float] = None,
|
|
81
|
+
return_penultimate: Optional[bool] = False,
|
|
82
|
+
):
|
|
83
|
+
if input_layer_id is None:
|
|
84
|
+
input_layer_id = 0
|
|
85
|
+
|
|
86
|
+
if return_penultimate:
|
|
87
|
+
if isinstance(head_layer_id, str):
|
|
88
|
+
head_layer_id = self.get_layer_index_by_name(model, head_layer_id)
|
|
89
|
+
|
|
90
|
+
feature_layers_id.append(head_layer_id - 1)
|
|
91
|
+
|
|
92
|
+
super().__init__(
|
|
93
|
+
model=model,
|
|
94
|
+
feature_layers_id=feature_layers_id,
|
|
95
|
+
input_layer_id=input_layer_id,
|
|
96
|
+
head_layer_id=head_layer_id,
|
|
97
|
+
react_threshold=react_threshold,
|
|
98
|
+
scale_percentile=scale_percentile,
|
|
99
|
+
ash_percentile=ash_percentile,
|
|
100
|
+
return_penultimate=return_penultimate,
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
self.backend = "tensorflow"
|
|
104
|
+
self.model.layers[-1].activation = getattr(tf.keras.activations, "linear")
|
|
105
|
+
self._last_logits = None
|
|
106
|
+
|
|
107
|
+
@staticmethod
|
|
108
|
+
def find_layer(
|
|
109
|
+
model: tf.keras.Model,
|
|
110
|
+
layer_id: Union[str, int],
|
|
111
|
+
index_offset: int = 0,
|
|
112
|
+
return_id: bool = False,
|
|
113
|
+
) -> Union[tf.keras.layers.Layer, Tuple[tf.keras.layers.Layer, str]]:
|
|
114
|
+
"""Find a layer in a model either by his name or by his index.
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
model (tf.keras.Model): model whose identified layer will be returned
|
|
118
|
+
layer_id (Union[str, int]): layer identifier
|
|
119
|
+
index_offset (int): index offset to find layers located before (negative
|
|
120
|
+
offset) or after (positive offset) the identified layer
|
|
121
|
+
return_id (bool): if True, the layer will be returned with its id
|
|
122
|
+
|
|
123
|
+
Raises:
|
|
124
|
+
ValueError: if the layer is not found
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
Union[tf.keras.layers.Layer, Tuple[tf.keras.layers.Layer, str]]:
|
|
128
|
+
the corresponding layer and its id if return_id is True.
|
|
129
|
+
"""
|
|
130
|
+
if isinstance(layer_id, str):
|
|
131
|
+
layers_names = [layer.name for layer in model.layers]
|
|
132
|
+
layer_id = layers_names.index(layer_id)
|
|
133
|
+
if isinstance(layer_id, int):
|
|
134
|
+
layer_id += index_offset
|
|
135
|
+
layer = model.get_layer(index=layer_id)
|
|
136
|
+
else:
|
|
137
|
+
raise ValueError(f"Could not find any layer {layer_id}.")
|
|
138
|
+
|
|
139
|
+
if return_id:
|
|
140
|
+
return layer, layer_id
|
|
141
|
+
else:
|
|
142
|
+
return layer
|
|
143
|
+
|
|
144
|
+
@staticmethod
|
|
145
|
+
def get_layer_index_by_name(model: tf.keras.Model, layer_id: str) -> int:
|
|
146
|
+
"""Get the index of a layer by its name.
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
model (tf.keras.Model): The model containing the layers.
|
|
150
|
+
layer_id (str): The name of the layer.
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
int: The index of the layer.
|
|
154
|
+
|
|
155
|
+
Raises:
|
|
156
|
+
ValueError: If the layer with the given name is not found.
|
|
157
|
+
"""
|
|
158
|
+
layers_names = [layer.name for layer in model.layers]
|
|
159
|
+
if layer_id not in layers_names:
|
|
160
|
+
raise ValueError(f"Layer with name '{layer_id}' not found in the model.")
|
|
161
|
+
return layers_names.index(layer_id)
|
|
162
|
+
|
|
163
|
+
# @tf.function
|
|
164
|
+
# TODO check with Thomas about @tf.function
|
|
165
|
+
def prepare_extractor(self) -> tf.keras.models.Model:
|
|
166
|
+
"""Constructs the feature extractor model
|
|
167
|
+
|
|
168
|
+
Returns:
|
|
169
|
+
tf.keras.models.Model: truncated model (extractor)
|
|
170
|
+
"""
|
|
171
|
+
input_layer = self.find_layer(self.model, self.input_layer_id)
|
|
172
|
+
new_input = tf.keras.layers.Input(tensor=input_layer.input)
|
|
173
|
+
output_tensors = [
|
|
174
|
+
self.find_layer(self.model, id).output for id in self.feature_layers_id
|
|
175
|
+
]
|
|
176
|
+
|
|
177
|
+
# === If react method, clip activations from penultimate layer ===
|
|
178
|
+
if self.react_threshold is not None:
|
|
179
|
+
penultimate_layer = self.find_layer(
|
|
180
|
+
self.model, self.head_layer_id, index_offset=-1
|
|
181
|
+
)
|
|
182
|
+
penult_extractor = tf.keras.models.Model(
|
|
183
|
+
new_input, penultimate_layer.output
|
|
184
|
+
)
|
|
185
|
+
last_layer = self.find_layer(self.model, self.head_layer_id)
|
|
186
|
+
|
|
187
|
+
# clip penultimate activations
|
|
188
|
+
x = tf.clip_by_value(
|
|
189
|
+
penult_extractor(new_input),
|
|
190
|
+
clip_value_min=tf.float32.min,
|
|
191
|
+
clip_value_max=self.react_threshold,
|
|
192
|
+
)
|
|
193
|
+
# apply ultimate layer on clipped activations
|
|
194
|
+
output_tensors.append(last_layer(x))
|
|
195
|
+
|
|
196
|
+
# === If SCALE method, scale activations from penultimate layer ===
|
|
197
|
+
# === If ASH method, scale and prune activations from penultimate layer ===
|
|
198
|
+
elif (self.scale_percentile is not None) or (self.ash_percentile is not None):
|
|
199
|
+
penultimate_layer = self.find_layer(
|
|
200
|
+
self.model, self.head_layer_id, index_offset=-1
|
|
201
|
+
)
|
|
202
|
+
penult_extractor = tf.keras.models.Model(
|
|
203
|
+
new_input, penultimate_layer.output
|
|
204
|
+
)
|
|
205
|
+
last_layer = self.find_layer(self.model, self.head_layer_id)
|
|
206
|
+
|
|
207
|
+
# apply scaling on penultimate activations
|
|
208
|
+
penultimate = penult_extractor(new_input)
|
|
209
|
+
if self.scale_percentile is not None:
|
|
210
|
+
output_percentile = tfp.stats.percentile(
|
|
211
|
+
penultimate, 100 * self.scale_percentile, axis=1
|
|
212
|
+
)
|
|
213
|
+
else:
|
|
214
|
+
output_percentile = tfp.stats.percentile(
|
|
215
|
+
penultimate, 100 * self.ash_percentile, axis=1
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
mask = penultimate > tf.reshape(output_percentile, (-1, 1))
|
|
219
|
+
filtered_penultimate = tf.where(
|
|
220
|
+
mask, penultimate, tf.zeros_like(penultimate)
|
|
221
|
+
)
|
|
222
|
+
s = tf.math.exp(
|
|
223
|
+
tf.reduce_sum(penultimate, axis=1)
|
|
224
|
+
/ tf.reduce_sum(filtered_penultimate, axis=1)
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
if self.scale_percentile is not None:
|
|
228
|
+
x = penultimate * tf.expand_dims(s, 1)
|
|
229
|
+
else:
|
|
230
|
+
x = filtered_penultimate * tf.expand_dims(s, 1)
|
|
231
|
+
# apply ultimate layer on scaled activations
|
|
232
|
+
output_tensors.append(last_layer(x))
|
|
233
|
+
|
|
234
|
+
else:
|
|
235
|
+
output_tensors.append(
|
|
236
|
+
self.find_layer(self.model, self.head_layer_id).output
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
extractor = tf.keras.models.Model(new_input, output_tensors)
|
|
240
|
+
return extractor
|
|
241
|
+
|
|
242
|
+
@sanitize_input
|
|
243
|
+
def predict_tensor(
|
|
244
|
+
self,
|
|
245
|
+
tensor: TensorType,
|
|
246
|
+
postproc_fns: Optional[List[tf.keras.Model]] = None,
|
|
247
|
+
) -> Tuple[List[tf.Tensor], tf.Tensor]:
|
|
248
|
+
"""Get the projection of tensor in the feature space of self.model
|
|
249
|
+
|
|
250
|
+
Args:
|
|
251
|
+
tensor (TensorType): input tensor (or dataset elem)
|
|
252
|
+
postproc_fns (Optional[List[tf.keras.Model]]): postprocessing function
|
|
253
|
+
to apply to each feature immediately after forward. Default to None.
|
|
254
|
+
|
|
255
|
+
Returns:
|
|
256
|
+
Tuple[List[tf.Tensor], tf.Tensor]: features, logits
|
|
257
|
+
"""
|
|
258
|
+
features = self.forward(tensor)
|
|
259
|
+
|
|
260
|
+
if type(features) is not list:
|
|
261
|
+
features = [features]
|
|
262
|
+
|
|
263
|
+
# split features and logits
|
|
264
|
+
logits = features.pop()
|
|
265
|
+
|
|
266
|
+
if postproc_fns is not None:
|
|
267
|
+
features = [
|
|
268
|
+
postproc_fn(feature)
|
|
269
|
+
for feature, postproc_fn in zip(features, postproc_fns)
|
|
270
|
+
]
|
|
271
|
+
|
|
272
|
+
self._last_logits = logits
|
|
273
|
+
return features, logits
|
|
274
|
+
|
|
275
|
+
@tf.function
|
|
276
|
+
def forward(self, tensor: TensorType) -> List[tf.Tensor]:
|
|
277
|
+
return self.extractor(tensor, training=False)
|
|
278
|
+
|
|
279
|
+
def predict(
|
|
280
|
+
self,
|
|
281
|
+
dataset: Union[ItemType, tf.data.Dataset],
|
|
282
|
+
postproc_fns: Optional[List[tf.keras.Model]] = None,
|
|
283
|
+
verbose: bool = False,
|
|
284
|
+
numpy_concat: bool = False,
|
|
285
|
+
**kwargs,
|
|
286
|
+
) -> Tuple[List[tf.Tensor], dict]:
|
|
287
|
+
"""Get the projection of the dataset in the feature space of self.model
|
|
288
|
+
|
|
289
|
+
Args:
|
|
290
|
+
dataset (Union[ItemType, tf.data.Dataset]): input dataset
|
|
291
|
+
postproc_fns (Optional[tf.keras.Model]): postprocessing function to apply
|
|
292
|
+
to each feature immediately after forward. Default to None.
|
|
293
|
+
verbose (bool): if True, display a progress bar. Defaults to False.
|
|
294
|
+
numpy_concat (bool): if True, each mini-batch is immediately moved
|
|
295
|
+
to CPU and converted to a NumPy array before concatenation.
|
|
296
|
+
That keeps GPU memory constant at one batch, at the cost of a small
|
|
297
|
+
host-device transfer overhead. Defaults to False.
|
|
298
|
+
kwargs (dict): additional arguments not considered for prediction
|
|
299
|
+
|
|
300
|
+
Returns:
|
|
301
|
+
List[tf.Tensor], dict: features and extra information (logits, labels) as a
|
|
302
|
+
dictionary.
|
|
303
|
+
"""
|
|
304
|
+
labels = None
|
|
305
|
+
|
|
306
|
+
if isinstance(dataset, get_args(ItemType)):
|
|
307
|
+
tensor = TFDataHandler.get_input_from_dataset_item(dataset)
|
|
308
|
+
features, logits = self.predict_tensor(tensor, postproc_fns)
|
|
309
|
+
|
|
310
|
+
# Get labels if dataset is a tuple/list
|
|
311
|
+
if isinstance(dataset, (list, tuple)):
|
|
312
|
+
labels = TFDataHandler.get_label_from_dataset_item(dataset)
|
|
313
|
+
else:
|
|
314
|
+
# Determine if dataset yields labels
|
|
315
|
+
contains_labels = TFDataHandler.get_item_length(dataset) > 1
|
|
316
|
+
|
|
317
|
+
# Buffers for accumulation
|
|
318
|
+
features_per_layer = [[] for _ in self.feature_layers_id]
|
|
319
|
+
logits_list = []
|
|
320
|
+
labels_list = [] if contains_labels else None
|
|
321
|
+
|
|
322
|
+
# Iterate through dataset
|
|
323
|
+
for elem in tqdm(dataset, desc="Predicting", disable=not verbose):
|
|
324
|
+
tensor = TFDataHandler.get_input_from_dataset_item(elem)
|
|
325
|
+
feats_batch, logits_batch = self.predict_tensor(tensor, postproc_fns)
|
|
326
|
+
|
|
327
|
+
# To host/NumPy if requested
|
|
328
|
+
if numpy_concat:
|
|
329
|
+
feats_batch = [f.numpy() for f in feats_batch]
|
|
330
|
+
logits_batch = (
|
|
331
|
+
logits_batch.numpy() if logits_batch is not None else None
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
# Accumulate per layer
|
|
335
|
+
for i, f in enumerate(feats_batch):
|
|
336
|
+
features_per_layer[i].append(f)
|
|
337
|
+
if logits_batch is not None:
|
|
338
|
+
logits_list.append(logits_batch)
|
|
339
|
+
if contains_labels:
|
|
340
|
+
lbl = TFDataHandler.get_label_from_dataset_item(elem)
|
|
341
|
+
labels_list.append(lbl)
|
|
342
|
+
|
|
343
|
+
# Concatenate
|
|
344
|
+
labels = tf.concat(labels_list, axis=0) if labels_list is not None else None
|
|
345
|
+
|
|
346
|
+
if numpy_concat:
|
|
347
|
+
features = [np.concatenate(lst, axis=0) for lst in features_per_layer]
|
|
348
|
+
logits = np.concatenate(logits_list, axis=0) if logits_list else None
|
|
349
|
+
labels = labels.numpy() if labels is not None else None
|
|
350
|
+
else:
|
|
351
|
+
features = [tf.concat(lst, axis=0) for lst in features_per_layer]
|
|
352
|
+
logits = tf.concat(logits_list, axis=0) if logits_list else None
|
|
353
|
+
|
|
354
|
+
info = {"labels": labels, "logits": logits}
|
|
355
|
+
return features, info
|
|
356
|
+
|
|
357
|
+
def get_weights(self, layer_id: Union[int, str]) -> List[tf.Tensor]:
|
|
358
|
+
"""Get the weights of a layer
|
|
359
|
+
|
|
360
|
+
Args:
|
|
361
|
+
layer_id (Union[int, str]): layer identifier
|
|
362
|
+
|
|
363
|
+
Returns:
|
|
364
|
+
List[tf.Tensor]: weights and biases matrixes
|
|
365
|
+
"""
|
|
366
|
+
return self.find_layer(self.model, layer_id).get_weights()
|
|
367
|
+
|
|
368
|
+
def _default_postproc_fn(self, feat: tf.Tensor) -> tf.Tensor:
|
|
369
|
+
"""
|
|
370
|
+
Default postprocessing function to apply to each feature immediately after
|
|
371
|
+
forward pass.
|
|
372
|
+
|
|
373
|
+
This function applies global average pooling if the input tensor has rank 4
|
|
374
|
+
(e.g., [batch, height, width, channels]) or rank 3
|
|
375
|
+
(e.g., [batch, sequence_length, features]). If the tensor is already 2D, it is
|
|
376
|
+
returned as is.
|
|
377
|
+
|
|
378
|
+
Args:
|
|
379
|
+
feat (tf.Tensor): Input tensor.
|
|
380
|
+
|
|
381
|
+
Returns:
|
|
382
|
+
tf.Tensor: Postprocessed tensor where spatial (or temporal) dimensions have
|
|
383
|
+
been averaged out.
|
|
384
|
+
|
|
385
|
+
Raises:
|
|
386
|
+
NotImplementedError: If the input tensor has a rank other than 2, 3, or 4.
|
|
387
|
+
"""
|
|
388
|
+
tensor_rank = len(feat.shape)
|
|
389
|
+
|
|
390
|
+
if tensor_rank == 4:
|
|
391
|
+
# Assumes input is in channels_last format: [batch, height, width, channels]
|
|
392
|
+
# Applies global average pooling over height and width.
|
|
393
|
+
pooled = tf.keras.layers.GlobalAveragePooling2D()(feat)
|
|
394
|
+
# The resulting tensor has shape [batch, channels]
|
|
395
|
+
return pooled
|
|
396
|
+
elif tensor_rank == 3:
|
|
397
|
+
# Assumes input is in channels_last format: [batch, seq_length, features]
|
|
398
|
+
# Applies global average pooling over the sequence_length dimension.
|
|
399
|
+
pooled = tf.keras.layers.GlobalAveragePooling1D()(feat)
|
|
400
|
+
# The resulting tensor has shape [batch, features]
|
|
401
|
+
return pooled
|
|
402
|
+
elif tensor_rank == 2:
|
|
403
|
+
# If the tensor is already 2D, no further processing is needed.
|
|
404
|
+
return feat
|
|
405
|
+
else:
|
|
406
|
+
raise NotImplementedError(
|
|
407
|
+
"Postprocessing function not implemented for tensors with"
|
|
408
|
+
+ " rank {}.".format(tensor_rank)
|
|
409
|
+
)
|