oodeel 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. oodeel/__init__.py +28 -0
  2. oodeel/aggregator/__init__.py +26 -0
  3. oodeel/aggregator/base.py +70 -0
  4. oodeel/aggregator/fisher.py +259 -0
  5. oodeel/aggregator/mean.py +72 -0
  6. oodeel/aggregator/std.py +86 -0
  7. oodeel/datasets/__init__.py +24 -0
  8. oodeel/datasets/data_handler.py +334 -0
  9. oodeel/datasets/deprecated/DEPRECATED_data_handler.py +236 -0
  10. oodeel/datasets/deprecated/DEPRECATED_ooddataset.py +330 -0
  11. oodeel/datasets/deprecated/DEPRECATED_tf_data_handler.py +671 -0
  12. oodeel/datasets/deprecated/DEPRECATED_torch_data_handler.py +769 -0
  13. oodeel/datasets/deprecated/__init__.py +31 -0
  14. oodeel/datasets/tf_data_handler.py +600 -0
  15. oodeel/datasets/torch_data_handler.py +672 -0
  16. oodeel/eval/__init__.py +22 -0
  17. oodeel/eval/metrics.py +218 -0
  18. oodeel/eval/plots/__init__.py +27 -0
  19. oodeel/eval/plots/features.py +345 -0
  20. oodeel/eval/plots/metrics.py +118 -0
  21. oodeel/eval/plots/plotly.py +162 -0
  22. oodeel/extractor/__init__.py +35 -0
  23. oodeel/extractor/feature_extractor.py +187 -0
  24. oodeel/extractor/hf_torch_feature_extractor.py +184 -0
  25. oodeel/extractor/keras_feature_extractor.py +409 -0
  26. oodeel/extractor/torch_feature_extractor.py +506 -0
  27. oodeel/methods/__init__.py +47 -0
  28. oodeel/methods/base.py +570 -0
  29. oodeel/methods/dknn.py +185 -0
  30. oodeel/methods/energy.py +119 -0
  31. oodeel/methods/entropy.py +113 -0
  32. oodeel/methods/gen.py +113 -0
  33. oodeel/methods/gram.py +274 -0
  34. oodeel/methods/mahalanobis.py +209 -0
  35. oodeel/methods/mls.py +113 -0
  36. oodeel/methods/odin.py +109 -0
  37. oodeel/methods/rmds.py +172 -0
  38. oodeel/methods/she.py +159 -0
  39. oodeel/methods/vim.py +273 -0
  40. oodeel/preprocess/__init__.py +31 -0
  41. oodeel/preprocess/tf_preprocess.py +95 -0
  42. oodeel/preprocess/torch_preprocess.py +97 -0
  43. oodeel/types/__init__.py +75 -0
  44. oodeel/utils/__init__.py +38 -0
  45. oodeel/utils/general_utils.py +97 -0
  46. oodeel/utils/operator.py +253 -0
  47. oodeel/utils/tf_operator.py +269 -0
  48. oodeel/utils/tf_training_tools.py +219 -0
  49. oodeel/utils/torch_operator.py +292 -0
  50. oodeel/utils/torch_training_tools.py +303 -0
  51. oodeel-0.4.0.dist-info/METADATA +409 -0
  52. oodeel-0.4.0.dist-info/RECORD +63 -0
  53. oodeel-0.4.0.dist-info/WHEEL +5 -0
  54. oodeel-0.4.0.dist-info/licenses/LICENSE +21 -0
  55. oodeel-0.4.0.dist-info/top_level.txt +2 -0
  56. tests/__init__.py +22 -0
  57. tests/tests_tensorflow/__init__.py +37 -0
  58. tests/tests_tensorflow/tf_methods_utils.py +140 -0
  59. tests/tests_tensorflow/tools_tf.py +86 -0
  60. tests/tests_torch/__init__.py +38 -0
  61. tests/tests_torch/tools_torch.py +151 -0
  62. tests/tests_torch/torch_methods_utils.py +148 -0
  63. tests/tools_operator.py +153 -0
oodeel/methods/base.py ADDED
@@ -0,0 +1,570 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright IRT Antoine de Saint Exupéry et Université Paul Sabatier Toulouse III - All
3
+ # rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
4
+ # CRIAQ and ANITI - https://www.deel.ai/
5
+ #
6
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
7
+ # of this software and associated documentation files (the "Software"), to deal
8
+ # in the Software without restriction, including without limitation the rights
9
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
+ # copies of the Software, and to permit persons to whom the Software is
11
+ # furnished to do so, subject to the following conditions:
12
+ #
13
+ # The above copyright notice and this permission notice shall be included in all
14
+ # copies or substantial portions of the Software.
15
+ #
16
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
+ # SOFTWARE.
23
+ import inspect
24
+ from abc import ABC
25
+ from typing import Dict
26
+ from typing import get_args
27
+
28
+ import numpy as np
29
+ from tqdm import tqdm
30
+
31
+ from ..aggregator import BaseAggregator
32
+ from ..aggregator import StdNormalizedAggregator
33
+ from ..extractor.feature_extractor import FeatureExtractor
34
+ from ..types import Callable
35
+ from ..types import DatasetType
36
+ from ..types import ItemType
37
+ from ..types import List
38
+ from ..types import Optional
39
+ from ..types import TensorType
40
+ from ..types import Union
41
+ from ..utils import import_backend_specific_stuff
42
+
43
+
44
+ class OODBaseDetector(ABC):
45
+ """OODBaseDetector is an abstract base class for Out-of-Distribution (OOD)
46
+ detection.
47
+
48
+ Attributes:
49
+ feature_extractor (FeatureExtractor): The feature extractor instance.
50
+ use_react (bool): Flag to indicate if ReAct method is used.
51
+ use_scale (bool): Flag to indicate if scaling method is used.
52
+ use_ash (bool): Flag to indicate if ASH method is used.
53
+ react_quantile (float): Quantile value for ReAct threshold.
54
+ scale_percentile (float): Percentile value for scaling.
55
+ ash_percentile (float): Percentile value for ASH.
56
+ eps (float): Perturbation noise for input perturbation.
57
+ temperature (float): Temperature parameter for input pertubation.
58
+ react_threshold (Optional[float]): Threshold for ReAct clipping.
59
+
60
+ Public Methods:
61
+ - fit(): Prepare the detector by setting up feature extraction and calibrating.
62
+ - score(): Compute OOD scores on input data (batched or single item).
63
+ - __call__(): Shorthand for score().
64
+
65
+ Internal Methods (for subclassing or advanced usage):
66
+ - _fit_to_dataset(): Optional calibration routine on a dataset.
67
+ - _score_tensor(): Compute scores for a single batch of input data.
68
+ - _load_feature_extractor(): Initialize feature extraction pipeline.
69
+ - _sanitize_posproc_fns(): Normalize post-processing function list.
70
+ - _compute_react_threshold(): Calibrate ReAct clipping threshold.
71
+ - _input_perturbation(): Apply perturbation to input data.
72
+
73
+ Abstract Properties:
74
+ - requires_to_fit_dataset: Whether fit_dataset is mandatory for calibration.
75
+ - requires_internal_features: Whether the detector uses internal features.
76
+ """
77
+
78
+ def __init__(
79
+ self,
80
+ use_react: Optional[bool] = False,
81
+ use_scale: Optional[bool] = False,
82
+ use_ash: Optional[bool] = False,
83
+ react_quantile: Optional[float] = None,
84
+ scale_percentile: Optional[float] = None,
85
+ ash_percentile: Optional[float] = None,
86
+ eps: float = 0.0,
87
+ temperature: float = 1000.0,
88
+ ):
89
+ self.feature_extractor: FeatureExtractor = None
90
+ self.use_react = use_react
91
+ self.use_scale = use_scale
92
+ self.use_ash = use_ash
93
+ self.react_quantile = react_quantile
94
+ self.scale_percentile = scale_percentile
95
+ self.ash_percentile = ash_percentile
96
+ self.eps = eps
97
+ self.temperature = temperature
98
+ self.react_threshold = None
99
+
100
+ if use_scale and use_react:
101
+ raise ValueError("Cannot use both ReAct and scale at the same time")
102
+ if use_scale and use_ash:
103
+ raise ValueError("Cannot use both ASH and scale at the same time")
104
+ if use_ash and use_react:
105
+ raise ValueError("Cannot use both ReAct and ASH at the same time")
106
+
107
+ # === Public API ===
108
+ def fit(
109
+ self,
110
+ model: Callable,
111
+ fit_dataset: Optional[Union[ItemType, DatasetType]] = None,
112
+ feature_layers_id: List[Union[int, str]] = [],
113
+ postproc_fns: Optional[List[Callable]] = None,
114
+ head_layer_id: Optional[Union[int, str]] = -1,
115
+ input_layer_id: Optional[Union[int, str]] = None,
116
+ verbose: bool = False,
117
+ **kwargs,
118
+ ) -> None:
119
+ """Prepare the detector for scoring:
120
+ * Constructs the feature extractor based on the model
121
+ * Calibrates the detector on ID data "fit_dataset" if needed,
122
+ using self._fit_to_dataset
123
+
124
+ Args:
125
+ model: model to extract the features from
126
+ fit_dataset: dataset to fit the detector on
127
+ feature_layers_id (List[int]): list of str or int that identify
128
+ features to output.
129
+ If int, the rank of the layer in the layer list
130
+ If str, the name of the layer. Defaults to [-1]
131
+ postproc_fns (Optional[List[Callable]]): list of postproc functions,
132
+ one per output layer. Defaults to None.
133
+ If None, identity function is used.
134
+ head_layer_id (int, str): identifier of the head layer.
135
+ If int, the rank of the layer in the layer list
136
+ If str, the name of the layer.
137
+ Defaults to -1
138
+ input_layer_id (List[int]): = list of str or int that identify the input
139
+ layer of the feature extractor.
140
+ If int, the rank of the layer in the layer list
141
+ If str, the name of the layer. Defaults to None.
142
+ verbose (bool): if True, display a progress bar. Defaults to False.
143
+ """
144
+ (
145
+ self.backend,
146
+ self.data_handler,
147
+ self.op,
148
+ self.FeatureExtractorClass,
149
+ ) = import_backend_specific_stuff(model)
150
+
151
+ # if required by the method, check that fit_dataset is not None
152
+ if self.requires_to_fit_dataset and fit_dataset is None:
153
+ raise ValueError(
154
+ "`fit_dataset` argument must be provided for this OOD detector"
155
+ )
156
+
157
+ self.postproc_fns = self._sanitize_posproc_fns(postproc_fns)
158
+
159
+ # react: compute threshold (activation percentiles)
160
+ if self.use_react:
161
+ if fit_dataset is None:
162
+ raise ValueError(
163
+ "if react quantile is not None, fit_dataset must be"
164
+ " provided to compute react activation threshold"
165
+ )
166
+ else:
167
+ self._compute_react_threshold(
168
+ model, fit_dataset, verbose=verbose, head_layer_id=head_layer_id
169
+ )
170
+
171
+ if (feature_layers_id == []) and isinstance(self, FeatureBasedDetector):
172
+ raise ValueError(
173
+ "Explicitly specify feature_layers_id=[layer0, layer1,...], "
174
+ + "where layer0, layer1,... are the names of the desired output "
175
+ + "layers of your model. These can be int or str (even though str"
176
+ + " is safer). To know what to put, have a look at model.summary() "
177
+ + "with keras or model.named_modules() with pytorch"
178
+ )
179
+
180
+ self.feature_extractor = self._load_feature_extractor(
181
+ model, feature_layers_id, head_layer_id, input_layer_id
182
+ )
183
+
184
+ if fit_dataset is not None:
185
+ if "verbose" in inspect.signature(self._fit_to_dataset).parameters.keys():
186
+ kwargs.update({"verbose": verbose})
187
+ self._fit_to_dataset(fit_dataset, **kwargs)
188
+
189
+ def score(
190
+ self,
191
+ dataset: Union[ItemType, DatasetType],
192
+ verbose: bool = False,
193
+ ) -> np.ndarray:
194
+ """
195
+ Computes an OOD score for input samples "inputs".
196
+
197
+ Args:
198
+ dataset (Union[ItemType, DatasetType]): dataset or tensors to score
199
+ verbose (bool): if True, display a progress bar. Defaults to False.
200
+
201
+ Returns:
202
+ tuple: scores or list of scores (depending on the input) and a dictionary
203
+ containing logits and labels.
204
+ """
205
+ assert self.feature_extractor is not None, "Call .fit() before .score()"
206
+ labels = None
207
+ # Case 1: dataset is neither a tf.data.Dataset nor a torch.DataLoader
208
+ if isinstance(dataset, get_args(ItemType)):
209
+ tensor = self.data_handler.get_input_from_dataset_item(dataset)
210
+ scores = self._score_tensor(tensor)
211
+ logits = self.op.convert_to_numpy(self.feature_extractor._last_logits)
212
+
213
+ # Get labels if dataset is a tuple/list
214
+ if isinstance(dataset, (list, tuple)):
215
+ labels = self.data_handler.get_label_from_dataset_item(dataset)
216
+ labels = self.op.convert_to_numpy(labels)
217
+
218
+ # Case 2: dataset is a tf.data.Dataset or a torch.DataLoader
219
+ elif isinstance(dataset, get_args(DatasetType)):
220
+ scores = np.array([])
221
+ logits = None
222
+
223
+ for item in tqdm(dataset, desc="Scoring", disable=not verbose):
224
+ tensor = self.data_handler.get_input_from_dataset_item(item)
225
+ score_batch = self._score_tensor(tensor)
226
+ logits_batch = self.op.convert_to_numpy(
227
+ self.feature_extractor._last_logits
228
+ )
229
+
230
+ # get the label if available
231
+ if len(item) > 1:
232
+ labels_batch = self.data_handler.get_label_from_dataset_item(item)
233
+ labels = (
234
+ labels_batch
235
+ if labels is None
236
+ else np.append(labels, self.op.convert_to_numpy(labels_batch))
237
+ )
238
+
239
+ scores = np.append(scores, score_batch)
240
+ logits = (
241
+ logits_batch
242
+ if logits is None
243
+ else np.concatenate([logits, logits_batch], axis=0)
244
+ )
245
+
246
+ else:
247
+ raise NotImplementedError(
248
+ f"OODBaseDetector.score() not implemented for {type(dataset)}"
249
+ )
250
+
251
+ info = dict(labels=labels, logits=logits)
252
+ return scores, info
253
+
254
+ def __call__(self, inputs: Union[ItemType, DatasetType]) -> np.ndarray:
255
+ """
256
+ Convenience wrapper for score
257
+
258
+ Args:
259
+ inputs (Union[ItemType, DatasetType]): dataset or tensors to score.
260
+ threshold (float): threshold to use for distinguishing between OOD and ID
261
+
262
+ Returns:
263
+ np.ndarray: array of 0 for ID samples and 1 for OOD samples
264
+ """
265
+ return self.score(inputs)
266
+
267
+ # === Internal: Feature Extractor ===
268
+ def _load_feature_extractor(
269
+ self,
270
+ model: Callable,
271
+ feature_layers_id: List[Union[int, str]] = [],
272
+ head_layer_id: Optional[Union[int, str]] = -1,
273
+ input_layer_id: Optional[Union[int, str]] = None,
274
+ return_penultimate: bool = False,
275
+ ) -> Callable:
276
+ """
277
+ Loads feature extractor
278
+
279
+ Args:
280
+ model: a model (Keras or PyTorch) to load.
281
+ feature_layers_id (List[int]): list of str or int that identify
282
+ features to output.
283
+ If int, the rank of the layer in the layer list
284
+ If str, the name of the layer. Defaults to [-1]
285
+ head_layer_id (int): identifier of the head layer.
286
+ -1 when the last layer is the head
287
+ -2 when the last layer is a softmax activation layer
288
+ ...
289
+ If int, the rank of the layer in the layer list
290
+ If str, the name of the layer. Defaults to -1
291
+ input_layer_id (List[int]): = list of str or int that identify the input
292
+ layer of the feature extractor.
293
+ If int, the rank of the layer in the layer list
294
+ If str, the name of the layer. Defaults to None.
295
+ return_penultimate (bool): if True, the penultimate values are returned,
296
+ i.e. the input to the head_layer.
297
+
298
+ Returns:
299
+ FeatureExtractor: a feature extractor instance
300
+ """
301
+ if not self.use_ash:
302
+ self.ash_percentile = None
303
+ if not self.use_scale:
304
+ self.scale_percentile = None
305
+
306
+ feature_extractor = self.FeatureExtractorClass(
307
+ model,
308
+ feature_layers_id=feature_layers_id,
309
+ input_layer_id=input_layer_id,
310
+ head_layer_id=head_layer_id,
311
+ react_threshold=self.react_threshold,
312
+ scale_percentile=self.scale_percentile,
313
+ ash_percentile=self.ash_percentile,
314
+ return_penultimate=return_penultimate,
315
+ )
316
+ return feature_extractor
317
+
318
+ def _sanitize_posproc_fns(
319
+ self,
320
+ postproc_fns: Union[List[Callable], None],
321
+ ) -> List[Callable]:
322
+ """Sanitize postproc fns used at each layer output of the feature extractor.
323
+
324
+ Args:
325
+ postproc_fns (Optional[List[Callable]], optional): List of postproc
326
+ functions, one per output layer. Defaults to None.
327
+
328
+ Returns:
329
+ List[Callable]: Sanitized postproc_fns list
330
+ """
331
+ if postproc_fns is not None:
332
+ assert len(postproc_fns) == len(
333
+ self.feature_extractor.feature_layers_id
334
+ ), "len of postproc_fns and output_layers_id must match"
335
+
336
+ def identity(x):
337
+ return x
338
+
339
+ postproc_fns = [identity if fn is None else fn for fn in postproc_fns]
340
+
341
+ return postproc_fns
342
+
343
+ # === Internal: ODIN input perturbation ===
344
+ def _input_perturbation(
345
+ self, inputs: TensorType, eps: float, temperature: float = 1000
346
+ ) -> TensorType:
347
+ """Apply the ODIN gradient-based perturbation to the inputs.
348
+
349
+ Args:
350
+ inputs: Batch of samples to perturb.
351
+ eps: Magnitude of the perturbation. If zero, `inputs` are
352
+ returned unchanged.
353
+ temperature: Temperature used for the softmax in the loss.
354
+
355
+ Returns:
356
+ Perturbed input tensor of the same shape as `inputs`.
357
+ """
358
+
359
+ if eps == 0:
360
+ return inputs
361
+
362
+ if self.feature_extractor.backend == "torch":
363
+ inputs = inputs.to(self.feature_extractor._device)
364
+
365
+ preds = self.feature_extractor.model(inputs)
366
+ labels = self.op.argmax(preds, dim=1)
367
+ gradients = self.op.gradient(
368
+ self._temperature_loss, inputs, labels, temperature
369
+ )
370
+ return inputs - eps * self.op.sign(gradients)
371
+
372
+ def _temperature_loss(
373
+ self, inputs: TensorType, labels: TensorType, temperature: float
374
+ ) -> TensorType:
375
+ """Cross-entropy loss used for ODIN input perturbation."""
376
+
377
+ preds = self.feature_extractor.model(inputs) / temperature
378
+ loss = self.op.CrossEntropyLoss(reduction="sum")(inputs=preds, targets=labels)
379
+ return loss
380
+
381
+ # === Internal: Fitting logic ===
382
+ def _fit_to_dataset(
383
+ self,
384
+ fit_dataset: DatasetType,
385
+ verbose: bool = False,
386
+ **kwargs,
387
+ ) -> None:
388
+ """Optional fitting routine for detectors using a calibration dataset."""
389
+
390
+ return None
391
+
392
+ # === Internal: Scoring logic ===
393
+ def _score_tensor(self, inputs: TensorType) -> np.ndarray:
394
+ """Compute an OOD score for a batch of inputs.
395
+
396
+ Child classes must implement this method. It should return one score per
397
+ sample of `inputs`.
398
+ """
399
+
400
+ raise NotImplementedError("_score_tensor must be implemented in subclasses")
401
+
402
+ # === Internal calibration methods ===
403
+ def _compute_react_threshold(
404
+ self,
405
+ model: Callable,
406
+ fit_dataset: DatasetType,
407
+ verbose: bool = False,
408
+ head_layer_id: int = -1,
409
+ ):
410
+ penult_feat_extractor = self._load_feature_extractor(
411
+ model, head_layer_id=head_layer_id, return_penultimate=True
412
+ )
413
+ unclipped_features, _ = penult_feat_extractor.predict(
414
+ fit_dataset,
415
+ verbose=verbose,
416
+ postproc_fns=self.postproc_fns,
417
+ numpy_concat=True,
418
+ )
419
+ self.react_threshold = np.quantile(unclipped_features[0], self.react_quantile)
420
+
421
+ # === Properties ===
422
+ @property
423
+ def requires_to_fit_dataset(self) -> bool:
424
+ """
425
+ Whether an OOD detector needs a `fit_dataset` argument in the fit function.
426
+
427
+ Returns:
428
+ bool: True if `fit_dataset` is required else False.
429
+ """
430
+ raise NotImplementedError(
431
+ "Property `requires_to_fit_dataset` is not implemented. It should return"
432
+ + " a True or False boolean."
433
+ )
434
+
435
+ @property
436
+ def requires_internal_features(self) -> bool:
437
+ """
438
+ Whether an OOD detector acts on internal model features.
439
+
440
+ Returns:
441
+ bool: True if the detector perform computations on an intermediate layer
442
+ else False.
443
+ """
444
+ raise NotImplementedError(
445
+ "Property `requires_internal_dataset` is not implemented. It should return"
446
+ + " a True or False boolean."
447
+ )
448
+
449
+
450
+ class FeatureBasedDetector(OODBaseDetector):
451
+ """Base class for detectors operating on internal feature representations."""
452
+
453
+ def __init__(self, aggregator: BaseAggregator = None, *args, **kwargs):
454
+ """Initialize the feature-based OOD detector.
455
+
456
+ Args:
457
+ aggregator (BaseAggregator, optional): Aggregator to normalize scores
458
+ across multiple feature layers. Defaults to None.
459
+ *args: Additional positional arguments.
460
+ **kwargs: Additional keyword arguments.
461
+ """
462
+ super().__init__(*args, **kwargs)
463
+ self.aggregator = aggregator
464
+ self.postproc_fns = None
465
+
466
+ # === Internal: Fitting logic ===
467
+ def _fit_to_dataset(
468
+ self,
469
+ fit_dataset: DatasetType,
470
+ verbose: bool = False,
471
+ **kwargs,
472
+ ) -> None:
473
+ """Extract features from `fit_dataset` and compute layer statistitics which will
474
+ be used for scoring in _score_layer.
475
+
476
+ Child classes must implement :func:`_fit_layer` to compute the
477
+ statistics required by the detector on each feature layer. If an
478
+ `aggregator` attribute is present and more than one layer is used, the
479
+ returned scores are fed to it for normalization.
480
+ """
481
+
482
+ n_layers = len(self.feature_extractor.feature_layers_id)
483
+
484
+ if self.postproc_fns is None:
485
+ self.postproc_fns = [self.feature_extractor._default_postproc_fn] * n_layers
486
+
487
+ feats, info = self.feature_extractor.predict(
488
+ fit_dataset,
489
+ postproc_fns=self.postproc_fns,
490
+ verbose=verbose,
491
+ return_labels=True,
492
+ numpy_concat=True,
493
+ )
494
+
495
+ aggregator = getattr(self, "aggregator", None)
496
+ if aggregator is None and n_layers > 1:
497
+ aggregator = StdNormalizedAggregator()
498
+ setattr(self, "aggregator", aggregator)
499
+
500
+ per_layer_scores = []
501
+ for idx in range(n_layers):
502
+ self._fit_layer(idx, feats[idx], info, **kwargs)
503
+
504
+ # If the aggregator is not None, compute per-layer scores
505
+ if aggregator is not None:
506
+ # batch the scoring to avoid memory issues
507
+ batch_size = kwargs.get("batch_size", 128)
508
+ n_samples = feats[idx].shape[0]
509
+ scores = []
510
+ for start in range(0, n_samples, batch_size):
511
+ end = min(start + batch_size, n_samples)
512
+ batch_feats = feats[idx][start:end]
513
+ batch_info = {k: v[start:end] for k, v in info.items()}
514
+ batch_feats = self.op.from_numpy(batch_feats)
515
+ scores.append(
516
+ self._score_layer(
517
+ idx, batch_feats, batch_info, fit=True, **kwargs
518
+ )
519
+ )
520
+ scores = np.concatenate(scores, axis=0)
521
+ per_layer_scores.append(scores)
522
+
523
+ if aggregator is not None and per_layer_scores:
524
+ aggregator.fit(per_layer_scores)
525
+
526
+ def _fit_layer(
527
+ self,
528
+ layer_id: int,
529
+ layer_features: np.ndarray,
530
+ info: Dict[str, TensorType],
531
+ **kwargs,
532
+ ) -> None:
533
+ """Compute statistics for a single feature layer."""
534
+
535
+ raise NotImplementedError
536
+
537
+ # === Internal: Scoring logic ===
538
+ def _score_tensor(self, inputs: TensorType) -> np.ndarray:
539
+ """Compute the OOD score for a batch using internal features."""
540
+
541
+ if getattr(self, "eps", 0) > 0:
542
+ inputs = self._input_perturbation(inputs, self.eps)
543
+
544
+ feats, logits = self.feature_extractor.predict_tensor(
545
+ inputs, postproc_fns=self.postproc_fns
546
+ )
547
+
548
+ info: Dict[str, TensorType] = {"logits": logits}
549
+ per_layer_scores = [
550
+ self._score_layer(idx, feats[idx], info) for idx in range(len(feats))
551
+ ]
552
+
553
+ aggregator = getattr(self, "aggregator", None)
554
+ if aggregator is not None and len(per_layer_scores) > 1:
555
+ return aggregator.aggregate(per_layer_scores)
556
+ if len(per_layer_scores) > 1:
557
+ return np.mean(np.stack(per_layer_scores, axis=1), axis=1)
558
+ return per_layer_scores[0]
559
+
560
+ def _score_layer(
561
+ self,
562
+ layer_id: int,
563
+ layer_features: TensorType,
564
+ info: Dict[str, TensorType],
565
+ fit: bool = False,
566
+ **kwargs,
567
+ ) -> np.ndarray:
568
+ """Score samples for a single feature layer."""
569
+
570
+ raise NotImplementedError