dataeval 0.86.8__py3-none-any.whl → 0.87.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. dataeval/__init__.py +1 -1
  2. dataeval/_version.py +2 -2
  3. dataeval/config.py +4 -19
  4. dataeval/data/_metadata.py +56 -27
  5. dataeval/data/_split.py +1 -1
  6. dataeval/data/selections/_classbalance.py +4 -3
  7. dataeval/data/selections/_classfilter.py +5 -5
  8. dataeval/data/selections/_indices.py +2 -2
  9. dataeval/data/selections/_prioritize.py +249 -29
  10. dataeval/data/selections/_reverse.py +1 -1
  11. dataeval/data/selections/_shuffle.py +2 -2
  12. dataeval/detectors/ood/__init__.py +2 -1
  13. dataeval/detectors/ood/base.py +38 -1
  14. dataeval/detectors/ood/knn.py +95 -0
  15. dataeval/metrics/bias/_balance.py +28 -21
  16. dataeval/metrics/bias/_diversity.py +4 -4
  17. dataeval/metrics/bias/_parity.py +2 -2
  18. dataeval/metrics/stats/_hashstats.py +19 -2
  19. dataeval/outputs/_workflows.py +20 -7
  20. dataeval/typing.py +14 -2
  21. dataeval/utils/__init__.py +2 -2
  22. dataeval/utils/_bin.py +7 -6
  23. dataeval/utils/data/__init__.py +2 -0
  24. dataeval/utils/data/_dataset.py +13 -6
  25. dataeval/utils/data/_validate.py +169 -0
  26. dataeval/workflows/sufficiency.py +53 -10
  27. {dataeval-0.86.8.dist-info → dataeval-0.87.0.dist-info}/METADATA +5 -17
  28. {dataeval-0.86.8.dist-info → dataeval-0.87.0.dist-info}/RECORD +30 -39
  29. dataeval/utils/datasets/__init__.py +0 -19
  30. dataeval/utils/datasets/_antiuav.py +0 -189
  31. dataeval/utils/datasets/_base.py +0 -262
  32. dataeval/utils/datasets/_cifar10.py +0 -201
  33. dataeval/utils/datasets/_fileio.py +0 -142
  34. dataeval/utils/datasets/_milco.py +0 -197
  35. dataeval/utils/datasets/_mixin.py +0 -54
  36. dataeval/utils/datasets/_mnist.py +0 -202
  37. dataeval/utils/datasets/_ships.py +0 -144
  38. dataeval/utils/datasets/_types.py +0 -48
  39. dataeval/utils/datasets/_voc.py +0 -583
  40. {dataeval-0.86.8.dist-info → dataeval-0.87.0.dist-info}/WHEEL +0 -0
  41. /dataeval-0.86.8.dist-info/licenses/LICENSE.txt → /dataeval-0.87.0.dist-info/licenses/LICENSE +0 -0
@@ -32,8 +32,8 @@ class _Clusters:
32
32
  self.cluster_centers = cluster_centers
33
33
  self.unique_labels = np.unique(labels)
34
34
 
35
- def _dist2center(self, X: NDArray[np.float64]) -> NDArray[np.float64]:
36
- dist = np.zeros(self.labels.shape)
35
+ def _dist2center(self, X: NDArray[np.floating[Any]]) -> NDArray[np.float32]:
36
+ dist = np.zeros(self.labels.shape, dtype=np.float32)
37
37
  for lab in self.unique_labels:
38
38
  dist[self.labels == lab] = np.linalg.norm(X[self.labels == lab, :] - self.cluster_centers[lab, :], axis=1)
39
39
  return dist
@@ -75,6 +75,8 @@ class _Clusters:
75
75
 
76
76
 
77
77
  class _Sorter(ABC):
78
+ scores: NDArray[np.float32] | None = None
79
+
78
80
  @abstractmethod
79
81
  def _sort(self, embeddings: NDArray[Any], reference: NDArray[Any] | None = None) -> NDArray[np.intp]: ...
80
82
 
@@ -95,11 +97,12 @@ class _KNNSorter(_Sorter):
95
97
 
96
98
  def _sort(self, embeddings: NDArray[Any], reference: NDArray[Any] | None = None) -> NDArray[np.intp]:
97
99
  if reference is None:
98
- dists = pairwise_distances(embeddings, embeddings)
100
+ dists = pairwise_distances(embeddings, embeddings).astype(np.float32)
99
101
  np.fill_diagonal(dists, np.inf)
100
102
  else:
101
- dists = pairwise_distances(embeddings, reference)
102
- return np.argsort(np.sort(dists, axis=1)[:, self._k])
103
+ dists = pairwise_distances(embeddings, reference).astype(np.float32)
104
+ self.scores = np.sort(dists, axis=1)[:, self._k]
105
+ return np.argsort(self.scores)
103
106
 
104
107
 
105
108
  class _KMeansSorter(_Sorter):
@@ -123,7 +126,8 @@ class _KMeansSorter(_Sorter):
123
126
  class _KMeansDistanceSorter(_KMeansSorter):
124
127
  def _sort(self, embeddings: NDArray[Any], reference: NDArray[Any] | None = None) -> NDArray[np.intp]:
125
128
  clst = self._get_clusters(embeddings if reference is None else reference)
126
- return np.argsort(clst._dist2center(embeddings))
129
+ self.scores = clst._dist2center(embeddings)
130
+ return np.argsort(self.scores)
127
131
 
128
132
 
129
133
  class _KMeansComplexitySorter(_KMeansSorter):
@@ -134,11 +138,11 @@ class _KMeansComplexitySorter(_KMeansSorter):
134
138
 
135
139
  class Prioritize(Selection[Any]):
136
140
  """
137
- Prioritizes the dataset by sort order in the embedding space.
141
+ Sort the dataset indices in order of highest priority data in the embedding space.
138
142
 
139
143
  Parameters
140
144
  ----------
141
- model : torch.nn.Module
145
+ model : torch.nn.Module | None
142
146
  Model to use for encoding images
143
147
  batch_size : int
144
148
  Batch size to use when encoding images
@@ -146,10 +150,23 @@ class Prioritize(Selection[Any]):
146
150
  Device to use for encoding images
147
151
  method : Literal["knn", "kmeans_distance", "kmeans_complexity"]
148
152
  Method to use for prioritization
149
- k : int | None, default None
150
- Number of nearest neighbors to use for prioritization (knn only)
151
- c : int | None, default None
152
- Number of clusters to use for prioritization (kmeans only)
153
+ k : int or None, default None
154
+ Number of nearest neighbors to use for prioritization.
155
+ If None, uses the square_root of the number of samples. Only used for method="knn", ignored otherwise.
156
+ c : int or None, default None
157
+ Number of clusters to use for prioritization. If None, uses the square_root of the number of samples.
158
+ Only used for method="kmeans_*", ignored otherwise.
159
+
160
+ Notes
161
+ -----
162
+ 1. `k` is only used for method ["knn"].
163
+ 2. `c` is only used for methods ["kmeans_distance", "kmeans_complexity"].
164
+
165
+ Raises
166
+ ------
167
+ ValueError
168
+ If method not in supported methods
169
+
153
170
  """
154
171
 
155
172
  stage = SelectionStage.ORDER
@@ -157,55 +174,95 @@ class Prioritize(Selection[Any]):
157
174
  @overload
158
175
  def __init__(
159
176
  self,
160
- model: torch.nn.Module,
177
+ model: torch.nn.Module | None,
161
178
  batch_size: int,
162
179
  device: DeviceLike | None,
163
180
  method: Literal["knn"],
181
+ policy: Literal["hard_first", "easy_first", "stratified", "class_balance"],
164
182
  *,
165
183
  k: int | None = None,
184
+ class_label: NDArray[np.integer[Any]] | None = None,
166
185
  ) -> None: ...
167
186
 
168
187
  @overload
169
188
  def __init__(
170
189
  self,
171
- model: torch.nn.Module,
190
+ model: torch.nn.Module | None,
172
191
  batch_size: int,
173
192
  device: DeviceLike | None,
174
193
  method: Literal["kmeans_distance", "kmeans_complexity"],
194
+ policy: Literal["hard_first", "easy_first", "stratified", "class_balance"],
195
+ *,
196
+ c: int | None = None,
197
+ class_label: NDArray[np.integer[Any]] | None = None,
198
+ ) -> None: ...
199
+
200
+ @overload
201
+ def __init__(
202
+ self,
203
+ model: torch.nn.Module | None,
204
+ batch_size: int,
205
+ device: DeviceLike | None,
206
+ method: Literal["knn", "kmeans_distance", "kmeans_complexity"],
207
+ policy: Literal["class_balance"],
208
+ *,
209
+ k: int | None = None,
210
+ c: int | None = None,
211
+ class_label: NDArray[np.integer[Any]] | None,
212
+ ) -> None: ...
213
+
214
+ @overload
215
+ def __init__(
216
+ self,
217
+ model: torch.nn.Module | None,
218
+ batch_size: int,
219
+ device: DeviceLike | None,
220
+ method: Literal["knn", "kmeans_distance", "kmeans_complexity"],
221
+ policy: Literal["hard_first", "easy_first", "stratified"],
175
222
  *,
223
+ k: int | None = None,
176
224
  c: int | None = None,
225
+ class_label: NDArray[np.integer[Any]] | None = None,
177
226
  ) -> None: ...
178
227
 
179
228
  def __init__(
180
229
  self,
181
- model: torch.nn.Module,
230
+ model: torch.nn.Module | None,
182
231
  batch_size: int,
183
232
  device: DeviceLike | None,
184
233
  method: Literal["knn", "kmeans_distance", "kmeans_complexity"],
234
+ policy: Literal["hard_first", "easy_first", "stratified", "class_balance"],
185
235
  *,
186
236
  k: int | None = None,
187
237
  c: int | None = None,
238
+ class_label: NDArray[np.integer[Any]] | None = None,
188
239
  ) -> None:
189
- if method not in ("knn", "kmeans_distance", "kmeans_complexity"):
240
+ if method not in {"knn", "kmeans_distance", "kmeans_complexity"}:
190
241
  raise ValueError(f"Invalid prioritization method: {method}")
242
+ if policy not in ("hard_first", "easy_first", "stratified", "class_balance"):
243
+ raise ValueError(f"Invalid selection policy: {policy}")
191
244
  self._model = model
192
245
  self._batch_size = batch_size
193
246
  self._device = device
194
247
  self._method = method
248
+ self._policy = policy
195
249
  self._embeddings: Embeddings | None = None
196
250
  self._reference: Embeddings | None = None
197
251
  self._k = k
198
252
  self._c = c
253
+ self.class_label = class_label
199
254
 
200
255
  @overload
201
256
  @classmethod
202
257
  def using(
203
258
  cls,
204
259
  method: Literal["knn"],
260
+ policy: Literal["hard_first", "easy_first", "stratified", "class_balance"],
205
261
  *,
206
262
  k: int | None = None,
207
263
  embeddings: Embeddings | None = None,
208
264
  reference: Embeddings | None = None,
265
+ class_label: NDArray[np.integer[Any]] | None = None,
209
266
  ) -> Prioritize: ...
210
267
 
211
268
  @overload
@@ -213,49 +270,72 @@ class Prioritize(Selection[Any]):
213
270
  def using(
214
271
  cls,
215
272
  method: Literal["kmeans_distance", "kmeans_complexity"],
273
+ policy: Literal["hard_first", "easy_first", "stratified", "class_balance"],
216
274
  *,
217
275
  c: int | None = None,
218
276
  embeddings: Embeddings | None = None,
219
277
  reference: Embeddings | None = None,
278
+ class_label: NDArray[np.integer[Any]] | None = None,
220
279
  ) -> Prioritize: ...
221
280
 
222
281
  @classmethod
223
282
  def using(
224
283
  cls,
225
284
  method: Literal["knn", "kmeans_distance", "kmeans_complexity"],
285
+ policy: Literal["hard_first", "easy_first", "stratified", "class_balance"],
226
286
  *,
227
287
  k: int | None = None,
228
288
  c: int | None = None,
229
289
  embeddings: Embeddings | None = None,
230
290
  reference: Embeddings | None = None,
291
+ class_label: NDArray[np.integer[Any]] | None = None,
231
292
  ) -> Prioritize:
232
293
  """
233
- Prioritizes the dataset by sort order in the embedding space using existing
234
- embeddings and/or reference dataset embeddings.
294
+ Use precalculated embeddings to sort the dataset indices in order of
295
+ highest priority data in the embedding space.
235
296
 
236
297
  Parameters
237
298
  ----------
238
299
  method : Literal["knn", "kmeans_distance", "kmeans_complexity"]
239
- Method to use for prioritization
300
+ Method to use for sample scoring during prioritization.
301
+ policy : Literal["hard_first","easy_first","stratified","class_balance"]
302
+ Selection policy for prioritizing scored samples.
240
303
  embeddings : Embeddings or None, default None
241
- Embeddings to use for prioritization
304
+ Embeddings to use during prioritization. If None, `reference` must be set.
242
305
  reference : Embeddings or None, default None
243
- Reference embeddings to prioritize relative to
306
+ Reference embeddings used to prioritize the calculated dataset embeddings relative to them.
307
+ If `embeddings` is None, this will be used instead.
244
308
  k : int or None, default None
245
- Number of nearest neighbors to use for prioritization (knn only)
309
+ Number of nearest neighbors to use for prioritization.
310
+ If None, uses the square_root of the number of samples. Only used for method="knn", ignored otherwise.
246
311
  c : int or None, default None
247
- Number of clusters to use for prioritization (kmeans, cluster only)
312
+ Number of clusters to use for prioritization. If None, uses the square_root of the number of samples.
313
+ Only used for method="kmeans_*", ignored otherwise.
248
314
 
249
315
  Notes
250
316
  -----
251
- At least one of `embeddings` or `reference` must be provided.
317
+ 1. `k` is only used for method ["knn"].
318
+ 2. `c` is only used for methods ["kmeans_distance", "kmeans_complexity"].
319
+
320
+ Raises
321
+ ------
322
+ ValueError
323
+ If both `embeddings` and `reference` are None
324
+
252
325
  """
253
326
  emb_params: Embeddings | None = embeddings if embeddings is not None else reference
254
327
  if emb_params is None:
255
328
  raise ValueError("Must provide at least embeddings or reference embeddings.")
256
- prioritize = Prioritize(emb_params._model, emb_params.batch_size, emb_params.device, method)
257
- prioritize._k = k
258
- prioritize._c = c
329
+ prioritize = Prioritize(
330
+ emb_params._model,
331
+ emb_params.batch_size,
332
+ emb_params.device,
333
+ method,
334
+ policy,
335
+ k=k,
336
+ c=c,
337
+ class_label=class_label,
338
+ )
259
339
  prioritize._embeddings = embeddings
260
340
  prioritize._reference = reference
261
341
  return prioritize
@@ -265,9 +345,148 @@ class Prioritize(Selection[Any]):
265
345
  return _KNNSorter(samples, self._k)
266
346
  if self._method == "kmeans_distance":
267
347
  return _KMeansDistanceSorter(samples, self._c)
268
- # self._method == "kmeans_complexity"
269
348
  return _KMeansComplexitySorter(samples, self._c)
270
349
 
350
+ def _compute_bin_extents(self, scores: NDArray[np.floating[Any]]) -> tuple[np.float64, np.float64]:
351
+ """
352
+ Compute min/max bin extents for `scores`, padding outward by epsilon
353
+
354
+ Parameters
355
+ ----------
356
+ scores: NDArray[np.float64])
357
+ Array of floats to bin
358
+
359
+ Returns
360
+ -------
361
+ tuple[np.float64, np.float64]
362
+ (min,max) scores padded outward by epsilon = 1e-6*range(scores).
363
+ """
364
+ # ensure binning captures all samples in range
365
+ scores = scores.astype(np.float64)
366
+ min_score = np.min(scores)
367
+ max_score = np.max(scores)
368
+ rng = max_score - min_score
369
+ eps = rng * 1e-6
370
+ return min_score - eps, max_score + eps
371
+
372
+ def _select_ordered_by_label(self, labels: NDArray[np.integer[Any]]) -> NDArray[np.intp]:
373
+ """
374
+ Given labels (class, group, bin, etc) sorted with decreasing priority,
375
+ rerank so that we have approximate class/group balance. This function
376
+ is used for both stratified and class-balance rerank methods.
377
+
378
+ We could require and return prioritization scores and re-sorted class
379
+ labels, but it is more compact to return indices. This allows us to
380
+ resort other quantities, as well, outside the function.
381
+
382
+ Parameters
383
+ ---------
384
+ labels: NDArray[np.integer[Any]]
385
+ Class label or group ID per instance in order of decreasing priority
386
+
387
+ Returns
388
+ -------
389
+ NDArray[np.intp]
390
+ Indices that sort samples according to uniform class balance or
391
+ group membership while respecting priority of the initial ordering.
392
+ """
393
+ labels = np.array(labels)
394
+ num_samp = labels.shape[0]
395
+ selected = np.zeros(num_samp, dtype=bool)
396
+ # preserve ordering
397
+ _, index = np.unique(labels, return_index=True)
398
+ u_lab = labels[np.sort(index)]
399
+ n_cls = len(u_lab)
400
+
401
+ resort_inds = []
402
+ cls_idx = 0
403
+ n = 0
404
+ while len(resort_inds) < num_samp:
405
+ c0 = u_lab[cls_idx % n_cls]
406
+ samples_available = (~selected) * (labels == c0)
407
+ if any(samples_available):
408
+ i0 = np.argmax(samples_available) # selects first occurrence
409
+ resort_inds.append(i0)
410
+ selected[i0] = True
411
+ cls_idx += 1
412
+ n += 1
413
+ return np.array(resort_inds).astype(np.intp)
414
+
415
+ def _stratified_rerank(
416
+ self,
417
+ scores: NDArray[np.floating[Any]],
418
+ indices: NDArray[np.integer[Any]],
419
+ num_bins: int = 50,
420
+ ) -> NDArray[np.intp]:
421
+ """
422
+ Re-rank samples by sampling uniformly over binned scores. This
423
+ de-weights selection of samples with similar scores and encourages both
424
+ prototypical and challenging samples near the decision boundary.
425
+
426
+ Inputs
427
+ ------
428
+ scores: NDArray[float]
429
+ prioritization scores sorted in order of decreasing priority
430
+ indices: NDArray[int]
431
+ Indices to be re-sorted according to stratified sampling of scores.
432
+ Indices are ordered by decreasing priority.
433
+ num_bins: int
434
+
435
+
436
+ Returns
437
+ -------
438
+ NDArray[int]
439
+ re-ranked indices
440
+
441
+ """
442
+ mn, mx = self._compute_bin_extents(scores)
443
+ bin_edges = np.linspace(mn, mx, num=num_bins + 1, endpoint=True)
444
+ bin_label = np.digitize(scores, bin_edges)
445
+ srt_inds = self._select_ordered_by_label(bin_label)
446
+ return indices[srt_inds].astype(np.intp)
447
+
448
+ def _rerank(
449
+ self,
450
+ indices: NDArray[np.integer[Any]],
451
+ ) -> NDArray[np.intp]:
452
+ """
453
+ Re-rank samples according to the re-rank policy, self._policy. Values
454
+ from the 'indices' and optional 'scores' and 'class_label' variables are
455
+ assumed to correspond by index---i.e. indices[i], scores[i], and
456
+ class_label[i] should all refer to the same instance in the dataset.
457
+
458
+ Note: indices are assumed to be sorted with easy/prototypical samples
459
+ first--increasing order by most prioritization scoring methods.
460
+
461
+ Parameters
462
+ ----------
463
+ indices: NDArray[np.intp]
464
+ Indices that sort samples by increasing prioritization score, where
465
+ low scores indicate high prototypicality ('easy') and high scores
466
+ indicate challenging samples near the decision boundary ('hard').
467
+ """
468
+
469
+ if self._policy == "easy_first":
470
+ return indices.astype(np.intp)
471
+ if self._policy == "stratified":
472
+ if self._sorter.scores is None:
473
+ raise (
474
+ ValueError(
475
+ "Prioritization scores are necessary in order to use "
476
+ "stratified re-rank. Use 'knn' or 'kmeans_distance' "
477
+ "methods to populate scores."
478
+ )
479
+ )
480
+ return self._stratified_rerank(self._sorter.scores[::-1], indices[::-1])
481
+ if self._policy == "class_balance":
482
+ if self.class_label is None:
483
+ raise (ValueError("Class labels are necessary in order to use class_balance re-rank"))
484
+ indices_reversed = self._select_ordered_by_label(self.class_label[indices[::-1]]).astype(np.int32)
485
+ n = len(indices_reversed)
486
+ return (n - 1 - indices_reversed).astype(np.intp)
487
+ # elif self._policy == "hard_first" (default)
488
+ return indices[::-1].astype(np.intp)
489
+
271
490
  def _to_normalized_ndarray(self, embeddings: Embeddings, selection: list[int] | None = None) -> NDArray[Any]:
272
491
  emb: NDArray[Any] = embeddings.to_numpy(selection)
273
492
  emb /= max(np.max(np.linalg.norm(emb, axis=1)), EPSILON)
@@ -290,4 +509,5 @@ class Prioritize(Selection[Any]):
290
509
  emb = self._to_normalized_ndarray(embeddings, dataset._selection)
291
510
  ref = None if self._reference is None else self._to_normalized_ndarray(self._reference)
292
511
  # Sort indices
293
- dataset._selection = self._sorter._sort(emb, ref).tolist()
512
+ indices = self._sorter._sort(emb, ref)
513
+ dataset._selection = indices[self._rerank(indices)].astype(int).tolist()
@@ -9,7 +9,7 @@ from dataeval.data._selection import Select, Selection, SelectionStage
9
9
 
10
10
  class Reverse(Selection[Any]):
11
11
  """
12
- Reverse the selection order of the dataset.
12
+ Select dataset indices in reverse order.
13
13
  """
14
14
 
15
15
  stage = SelectionStage.ORDER
@@ -15,12 +15,12 @@ from dataeval.utils._array import as_numpy
15
15
 
16
16
  class Shuffle(Selection[Any]):
17
17
  """
18
- Shuffle the dataset using a seed.
18
+ Select dataset indices in a random order.
19
19
 
20
20
  Parameters
21
21
  ----------
22
22
  seed : int, ArrayLike, SeedSequence, BitGenerator, Generator or None, default None
23
- Seed for the random number generator.
23
+ Seed for the random number generator. If None, results are not reproducible.
24
24
 
25
25
  See Also
26
26
  --------
@@ -2,7 +2,8 @@
2
2
  Out-of-distribution (OOD) detectors identify data that is different from the data used to train a particular model.
3
3
  """
4
4
 
5
- __all__ = ["OODOutput", "OODScoreOutput", "OOD_AE"]
5
+ __all__ = ["OODOutput", "OODScoreOutput", "OOD_AE", "OOD_KNN"]
6
6
 
7
7
  from dataeval.detectors.ood.ae import OOD_AE
8
+ from dataeval.detectors.ood.knn import OOD_KNN
8
9
  from dataeval.outputs._ood import OODOutput, OODScoreOutput
@@ -10,11 +10,15 @@ from __future__ import annotations
10
10
 
11
11
  __all__ = []
12
12
 
13
- from typing import Callable, cast
13
+ from abc import ABC, abstractmethod
14
+ from typing import Any, Callable, cast
14
15
 
16
+ import numpy as np
15
17
  import torch
18
+ from numpy.typing import NDArray
16
19
 
17
20
  from dataeval.config import DeviceLike, get_device
21
+ from dataeval.data import Embeddings
18
22
  from dataeval.detectors.ood.mixin import OODBaseMixin, OODFitMixin, OODGMMMixin
19
23
  from dataeval.typing import ArrayLike
20
24
  from dataeval.utils._array import to_numpy
@@ -93,3 +97,36 @@ class OODBaseGMM(OODBase, OODGMMMixin[GaussianMixtureModelParams]):
93
97
  # Calculate the GMM parameters
94
98
  _, z, gamma = cast(tuple[torch.Tensor, torch.Tensor, torch.Tensor], self.model(x_ref))
95
99
  self._gmm_params = gmm_params(z, gamma)
100
+
101
+
102
+ class EmbeddingBasedOODBase(OODBaseMixin[Callable[[Any], Any]], ABC):
103
+ """
104
+ Base class for embedding-based OOD detection methods.
105
+
106
+ These methods work directly on embedding representations,
107
+ using distance metrics or density estimation in embedding space.
108
+ Inherits from OODBaseMixin to get automatic thresholding.
109
+ """
110
+
111
+ def __init__(self) -> None:
112
+ """Initialize embedding-based OOD detector."""
113
+ # Pass a dummy callable as model since we don't use it
114
+ super().__init__(lambda x: x)
115
+
116
+ def _get_data_info(self, X: NDArray) -> tuple[tuple, type]:
117
+ """Override to skip [0-1] validation for embeddings."""
118
+ if not isinstance(X, np.ndarray):
119
+ raise TypeError("Dataset should of type: `NDArray`.")
120
+ # Skip the [0-1] range check for embeddings
121
+ return X.shape[1:], X.dtype.type
122
+
123
+ @abstractmethod
124
+ def fit_embeddings(self, embeddings: Embeddings, threshold_perc: float = 95.0) -> None:
125
+ """
126
+ Fit using reference embeddings.
127
+
128
+ Args:
129
+ embeddings: Reference (in-distribution) embeddings
130
+ threshold_perc: Percentage of reference data considered normal
131
+ """
132
+ pass
@@ -0,0 +1,95 @@
1
+ from typing import Literal
2
+
3
+ import numpy as np
4
+ from sklearn.neighbors import NearestNeighbors
5
+
6
+ from dataeval.data import Embeddings
7
+ from dataeval.detectors.ood.base import EmbeddingBasedOODBase
8
+ from dataeval.outputs._ood import OODScoreOutput
9
+ from dataeval.typing import ArrayLike
10
+
11
+
12
+ class OOD_KNN(EmbeddingBasedOODBase):
13
+ """
14
+ K-Nearest Neighbors Out-of-Distribution detector.
15
+
16
+ Uses average cosine distance to k nearest neighbors in embedding space to detect OOD samples.
17
+ Samples with larger average distances to their k nearest neighbors in the
18
+ reference (in-distribution) set are considered more likely to be OOD.
19
+
20
+ Based on the methodology from:
21
+ "Back to the Basics: Revisiting Out-of-Distribution Detection Baselines"
22
+ (Kuan & Mueller, 2022)
23
+
24
+ As referenced in:
25
+ "Safe AI for coral reefs: Benchmarking out-of-distribution detection
26
+ algorithms for coral reef image surveys"
27
+ """
28
+
29
+ def __init__(self, k: int = 10, distance_metric: Literal["cosine", "euclidean"] = "cosine") -> None:
30
+ """
31
+ Initialize KNN OOD detector.
32
+
33
+ Args:
34
+ k: Number of nearest neighbors to consider (default: 10)
35
+ distance_metric: Distance metric to use ('cosine' or 'euclidean')
36
+ """
37
+ super().__init__()
38
+ self.k = k
39
+ self.distance_metric = distance_metric
40
+ self._nn_model: NearestNeighbors
41
+ self.reference_embeddings: ArrayLike
42
+
43
+ def fit_embeddings(self, embeddings: Embeddings, threshold_perc: float = 95.0) -> None:
44
+ """
45
+ Fit the detector using reference (in-distribution) embeddings.
46
+
47
+ Builds a k-NN index for efficient nearest neighbor search and
48
+ computes reference scores for automatic thresholding.
49
+
50
+ Args:
51
+ embeddings: Reference embeddings from in-distribution data
52
+ threshold_perc: Percentage of reference data considered normal
53
+ """
54
+ self.reference_embeddings = embeddings.to_numpy()
55
+
56
+ if self.k >= len(self.reference_embeddings):
57
+ raise ValueError(
58
+ f"k ({self.k}) must be less than number of reference embeddings ({len(self.reference_embeddings)})"
59
+ )
60
+
61
+ # Build k-NN index using sklearn
62
+ self._nn_model = NearestNeighbors(
63
+ n_neighbors=self.k,
64
+ metric=self.distance_metric,
65
+ algorithm="auto", # Let sklearn choose the best algorithm
66
+ )
67
+ self._nn_model.fit(self.reference_embeddings)
68
+
69
+ # efficiently compute reference scores for automatic thresholding
70
+ ref_scores = self._compute_reference_scores()
71
+ self._ref_score = OODScoreOutput(instance_score=ref_scores)
72
+ self._threshold_perc = threshold_perc
73
+ self._data_info = self._get_data_info(self.reference_embeddings)
74
+
75
+ def _compute_reference_scores(self) -> np.ndarray:
76
+ """Efficiently compute reference scores by excluding self-matches."""
77
+ # Find k+1 neighbors (including self) for reference points
78
+ distances, _ = self._nn_model.kneighbors(self.reference_embeddings, n_neighbors=self.k + 1)
79
+ # Skip first neighbor (self with distance 0) and average the rest
80
+ return np.mean(distances[:, 1:], axis=1)
81
+
82
+ def _score(self, X: np.ndarray, batch_size: int = int(1e10)) -> OODScoreOutput:
83
+ """
84
+ Compute OOD scores for input embeddings.
85
+
86
+ Args:
87
+ X: Input embeddings to score
88
+ batch_size: Batch size (not used, kept for interface compatibility)
89
+
90
+ Returns:
91
+ OODScoreOutput containing instance-level scores
92
+ """
93
+ # Compute OOD scores using sklearn's efficient k-NN search
94
+ distances, _ = self._nn_model.kneighbors(X)
95
+ return OODScoreOutput(instance_score=np.mean(distances, axis=1))