dataeval 0.86.9__py3-none-any.whl → 0.88.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. dataeval/__init__.py +1 -1
  2. dataeval/_log.py +1 -1
  3. dataeval/_version.py +2 -2
  4. dataeval/config.py +4 -19
  5. dataeval/data/_embeddings.py +78 -35
  6. dataeval/data/_images.py +41 -8
  7. dataeval/data/_metadata.py +348 -66
  8. dataeval/data/_selection.py +22 -7
  9. dataeval/data/_split.py +3 -2
  10. dataeval/data/selections/_classbalance.py +4 -3
  11. dataeval/data/selections/_classfilter.py +9 -8
  12. dataeval/data/selections/_indices.py +4 -3
  13. dataeval/data/selections/_prioritize.py +249 -29
  14. dataeval/data/selections/_reverse.py +1 -1
  15. dataeval/data/selections/_shuffle.py +5 -4
  16. dataeval/detectors/drift/_base.py +2 -1
  17. dataeval/detectors/drift/_mmd.py +2 -1
  18. dataeval/detectors/drift/_nml/_base.py +1 -1
  19. dataeval/detectors/drift/_nml/_chunk.py +2 -1
  20. dataeval/detectors/drift/_nml/_result.py +3 -2
  21. dataeval/detectors/drift/_nml/_thresholds.py +6 -5
  22. dataeval/detectors/drift/_uncertainty.py +2 -1
  23. dataeval/detectors/linters/duplicates.py +2 -1
  24. dataeval/detectors/linters/outliers.py +4 -3
  25. dataeval/detectors/ood/__init__.py +2 -1
  26. dataeval/detectors/ood/ae.py +1 -1
  27. dataeval/detectors/ood/base.py +39 -1
  28. dataeval/detectors/ood/knn.py +95 -0
  29. dataeval/detectors/ood/mixin.py +2 -1
  30. dataeval/metadata/_utils.py +1 -1
  31. dataeval/metrics/bias/_balance.py +29 -22
  32. dataeval/metrics/bias/_diversity.py +4 -4
  33. dataeval/metrics/bias/_parity.py +2 -2
  34. dataeval/metrics/stats/_base.py +3 -29
  35. dataeval/metrics/stats/_boxratiostats.py +2 -1
  36. dataeval/metrics/stats/_dimensionstats.py +2 -1
  37. dataeval/metrics/stats/_hashstats.py +21 -3
  38. dataeval/metrics/stats/_pixelstats.py +2 -1
  39. dataeval/metrics/stats/_visualstats.py +2 -1
  40. dataeval/outputs/_base.py +2 -3
  41. dataeval/outputs/_bias.py +2 -1
  42. dataeval/outputs/_estimators.py +1 -1
  43. dataeval/outputs/_linters.py +3 -3
  44. dataeval/outputs/_stats.py +3 -3
  45. dataeval/outputs/_utils.py +1 -1
  46. dataeval/outputs/_workflows.py +49 -31
  47. dataeval/typing.py +23 -9
  48. dataeval/utils/__init__.py +2 -2
  49. dataeval/utils/_array.py +3 -2
  50. dataeval/utils/_bin.py +9 -7
  51. dataeval/utils/_method.py +2 -3
  52. dataeval/utils/_multiprocessing.py +34 -0
  53. dataeval/utils/_plot.py +2 -1
  54. dataeval/utils/data/__init__.py +6 -5
  55. dataeval/utils/data/{metadata.py → _merge.py} +3 -2
  56. dataeval/utils/data/_validate.py +170 -0
  57. dataeval/utils/data/collate.py +2 -1
  58. dataeval/utils/torch/_internal.py +2 -1
  59. dataeval/utils/torch/trainer.py +1 -1
  60. dataeval/workflows/sufficiency.py +13 -9
  61. {dataeval-0.86.9.dist-info → dataeval-0.88.0.dist-info}/METADATA +8 -21
  62. dataeval-0.88.0.dist-info/RECORD +105 -0
  63. dataeval/utils/data/_dataset.py +0 -246
  64. dataeval/utils/datasets/__init__.py +0 -21
  65. dataeval/utils/datasets/_antiuav.py +0 -189
  66. dataeval/utils/datasets/_base.py +0 -266
  67. dataeval/utils/datasets/_cifar10.py +0 -201
  68. dataeval/utils/datasets/_fileio.py +0 -142
  69. dataeval/utils/datasets/_milco.py +0 -197
  70. dataeval/utils/datasets/_mixin.py +0 -54
  71. dataeval/utils/datasets/_mnist.py +0 -202
  72. dataeval/utils/datasets/_seadrone.py +0 -512
  73. dataeval/utils/datasets/_ships.py +0 -144
  74. dataeval/utils/datasets/_types.py +0 -48
  75. dataeval/utils/datasets/_voc.py +0 -583
  76. dataeval-0.86.9.dist-info/RECORD +0 -115
  77. {dataeval-0.86.9.dist-info → dataeval-0.88.0.dist-info}/WHEEL +0 -0
  78. /dataeval-0.86.9.dist-info/licenses/LICENSE.txt → /dataeval-0.88.0.dist-info/licenses/LICENSE +0 -0
@@ -2,8 +2,9 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
+ from collections.abc import Iterator, Sequence
5
6
  from enum import IntEnum
6
- from typing import Generic, Iterator, Sequence, TypeVar
7
+ from typing import Generic, TypeVar
7
8
 
8
9
  from dataeval.typing import AnnotatedDataset, DatasetMetadata
9
10
 
@@ -31,14 +32,21 @@ class Subselection(Generic[_TDatum]):
31
32
 
32
33
  class Select(AnnotatedDataset[_TDatum]):
33
34
  """
34
- Wraps a dataset and applies selection criteria to it.
35
+ Dataset wrapper that applies selection criteria for filtering.
36
+
37
+ Wraps an existing dataset and applies one or more selection filters to
38
+ create a subset view without modifying the original dataset. Supports
39
+ chaining multiple selection criteria for complex filtering operations.
35
40
 
36
41
  Parameters
37
42
  ----------
38
- dataset : Dataset
39
- The dataset to wrap.
40
- selections : Selection or list[Selection], optional
41
- The selection criteria to apply to the dataset.
43
+ dataset : AnnotatedDataset[_TDatum]
44
+ Source dataset to wrap and filter. Must implement AnnotatedDataset
45
+ interface with indexed access to data tuples.
46
+ selections : Selection or Sequence[Selection] or None, default None
47
+ Selection criteria to apply for filtering the dataset. When None,
48
+ returns all items from the source dataset. Default None creates
49
+ unfiltered view for consistent interface.
42
50
 
43
51
  Examples
44
52
  --------
@@ -49,7 +57,7 @@ class Select(AnnotatedDataset[_TDatum]):
49
57
  >>> # - f"data_{idx}", one_hot_encoded(idx % class_count), {"id": idx}
50
58
  >>> dataset = SampleDataset(size=100, class_count=10)
51
59
 
52
- >>> # Apply a selection criteria to the dataset
60
+ >>> # Apply selection criteria to the dataset
53
61
  >>> selections = [Limit(size=5), ClassFilter(classes=[0, 2])]
54
62
  >>> selected_dataset = Select(dataset, selections=selections)
55
63
 
@@ -61,6 +69,12 @@ class Select(AnnotatedDataset[_TDatum]):
61
69
  (data_10, 0, {'id': 10})
62
70
  (data_12, 2, {'id': 12})
63
71
  (data_20, 0, {'id': 20})
72
+
73
+ Notes
74
+ -----
75
+ Selection criteria are applied in the order provided, allowing for
76
+ efficient sequential filtering. The wrapper maintains all metadata
77
+ and interface compatibility with the original dataset.
64
78
  """
65
79
 
66
80
  _dataset: AnnotatedDataset[_TDatum]
@@ -91,6 +105,7 @@ class Select(AnnotatedDataset[_TDatum]):
91
105
 
92
106
  @property
93
107
  def metadata(self) -> DatasetMetadata:
108
+ """Dataset metadata information including identifier and configuration."""
94
109
  return self._metadata
95
110
 
96
111
  def __str__(self) -> str:
dataeval/data/_split.py CHANGED
@@ -4,7 +4,8 @@ __all__ = []
4
4
 
5
5
  import logging
6
6
  import warnings
7
- from typing import Any, Iterator, Protocol, Sequence
7
+ from collections.abc import Iterator, Sequence
8
+ from typing import Any, Protocol
8
9
 
9
10
  import numpy as np
10
11
  from numpy.typing import NDArray
@@ -208,7 +209,7 @@ def get_groups(metadata: Metadata, split_on: Sequence[str] | None) -> NDArray[np
208
209
 
209
210
  split_set = set(split_on)
210
211
  indices = [i for i, name in enumerate(metadata.factor_names) if name in split_set]
211
- binned_features = metadata.discretized_data[:, indices]
212
+ binned_features = metadata.binned_data[:, indices]
212
213
  return np.unique(binned_features, axis=0, return_inverse=True)[1]
213
214
 
214
215
 
@@ -11,12 +11,13 @@ from dataeval.utils._array import as_numpy
11
11
 
12
12
  class ClassBalance(Selection[ImageClassificationDatum]):
13
13
  """
14
- Balance the dataset by class.
14
+ Select indices of a dataset that will equalize the occurrences of all classes.
15
15
 
16
16
  Note
17
17
  ----
18
- The total number of instances of each class will be equalized which may result
18
+ 1. The total number of instances of each class will be equalized which may result
19
19
  in a lower total number of instances than specified by the selection limit.
20
+ 2. This selection currently only supports classification tasks
20
21
  """
21
22
 
22
23
  stage = SelectionStage.FILTER
@@ -29,7 +30,7 @@ class ClassBalance(Selection[ImageClassificationDatum]):
29
30
  label = int(np.argmax(as_numpy(target)))
30
31
  else:
31
32
  # ObjectDetectionTarget and SegmentationTarget not supported yet
32
- raise TypeError("ClassFilter only supports classification targets as an array of confidence scores.")
33
+ raise TypeError("ClassBalance only supports classification targets as an array of class probabilities.")
33
34
  class_indices.setdefault(label, []).append(i)
34
35
 
35
36
  per_class_limit = min(min(len(c) for c in class_indices.values()), dataset._size_limit // len(class_indices))
@@ -2,7 +2,8 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
- from typing import Any, Generic, Iterable, Mapping, Sequence, Sized, TypeVar, cast
5
+ from collections.abc import Iterable, Mapping, Sequence, Sized
6
+ from typing import Any, Generic, TypeVar, cast
6
7
 
7
8
  import numpy as np
8
9
  from numpy.typing import NDArray
@@ -14,12 +15,12 @@ from dataeval.utils._array import as_numpy
14
15
 
15
16
  class ClassFilter(Selection[Any]):
16
17
  """
17
- Filter the dataset by class.
18
+ Select dataset indices based on class labels, keeping only those present in `classes`.
18
19
 
19
20
  Parameters
20
21
  ----------
21
22
  classes : Sequence[int]
22
- The classes to filter by.
23
+ The sequence of classes to keep.
23
24
  filter_detections : bool, default True
24
25
  Whether to filter detections from targets for object detection and segmentation datasets.
25
26
  """
@@ -41,16 +42,16 @@ class ClassFilter(Selection[Any]):
41
42
  if isinstance(target, Array):
42
43
  # Get the label for the image
43
44
  label = int(np.argmax(as_numpy(target)))
44
- # Check to see if the label is in the classes to filter for
45
+ # Check to see if the label is in the classes to keep
45
46
  if label in self.classes:
46
- # Include the image
47
+ # Include the image index
47
48
  selection.append(idx)
48
- elif isinstance(target, (ObjectDetectionTarget, SegmentationTarget)):
49
+ elif isinstance(target, ObjectDetectionTarget | SegmentationTarget):
49
50
  # Get the set of labels from the target
50
51
  labels = set(target.labels if isinstance(target.labels, Iterable) else [target.labels])
51
52
  # Check to see if any labels are in the classes to filter for
52
53
  if labels.intersection(self.classes):
53
- # Include the image
54
+ # Include the image index
54
55
  selection.append(idx)
55
56
  # If we are filtering out other labels and there are other labels, add a subselection filter
56
57
  if self.filter_detections and labels.difference(self.classes):
@@ -68,7 +69,7 @@ _TTarget = TypeVar("_TTarget", ObjectDetectionTarget, SegmentationTarget)
68
69
 
69
70
 
70
71
  def _try_mask_object(obj: _T, mask: NDArray[np.bool_]) -> _T:
71
- if not isinstance(obj, (str, bytes, bytearray)) and isinstance(obj, (Sequence, Array)) and len(obj) == len(mask):
72
+ if not isinstance(obj, str | bytes | bytearray) and isinstance(obj, Sequence | Array) and len(obj) == len(mask):
72
73
  return obj[mask] if isinstance(obj, Array) else cast(_T, [item for i, item in enumerate(obj) if mask[i]])
73
74
  return obj
74
75
 
@@ -2,19 +2,20 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
- from typing import Any, Sequence
5
+ from collections.abc import Sequence
6
+ from typing import Any
6
7
 
7
8
  from dataeval.data._selection import Select, Selection, SelectionStage
8
9
 
9
10
 
10
11
  class Indices(Selection[Any]):
11
12
  """
12
- Selects specific indices from the dataset.
13
+ Selects only the given indices from the dataset.
13
14
 
14
15
  Parameters
15
16
  ----------
16
17
  indices : Sequence[int]
17
- The indices to select from the dataset.
18
+ The specific indices to select.
18
19
  """
19
20
 
20
21
  stage = SelectionStage.FILTER
@@ -32,8 +32,8 @@ class _Clusters:
32
32
  self.cluster_centers = cluster_centers
33
33
  self.unique_labels = np.unique(labels)
34
34
 
35
- def _dist2center(self, X: NDArray[np.float64]) -> NDArray[np.float64]:
36
- dist = np.zeros(self.labels.shape)
35
+ def _dist2center(self, X: NDArray[np.floating[Any]]) -> NDArray[np.float32]:
36
+ dist = np.zeros(self.labels.shape, dtype=np.float32)
37
37
  for lab in self.unique_labels:
38
38
  dist[self.labels == lab] = np.linalg.norm(X[self.labels == lab, :] - self.cluster_centers[lab, :], axis=1)
39
39
  return dist
@@ -75,6 +75,8 @@ class _Clusters:
75
75
 
76
76
 
77
77
  class _Sorter(ABC):
78
+ scores: NDArray[np.float32] | None = None
79
+
78
80
  @abstractmethod
79
81
  def _sort(self, embeddings: NDArray[Any], reference: NDArray[Any] | None = None) -> NDArray[np.intp]: ...
80
82
 
@@ -95,11 +97,12 @@ class _KNNSorter(_Sorter):
95
97
 
96
98
  def _sort(self, embeddings: NDArray[Any], reference: NDArray[Any] | None = None) -> NDArray[np.intp]:
97
99
  if reference is None:
98
- dists = pairwise_distances(embeddings, embeddings)
100
+ dists = pairwise_distances(embeddings, embeddings).astype(np.float32)
99
101
  np.fill_diagonal(dists, np.inf)
100
102
  else:
101
- dists = pairwise_distances(embeddings, reference)
102
- return np.argsort(np.sort(dists, axis=1)[:, self._k])
103
+ dists = pairwise_distances(embeddings, reference).astype(np.float32)
104
+ self.scores = np.sort(dists, axis=1)[:, self._k]
105
+ return np.argsort(self.scores)
103
106
 
104
107
 
105
108
  class _KMeansSorter(_Sorter):
@@ -123,7 +126,8 @@ class _KMeansSorter(_Sorter):
123
126
  class _KMeansDistanceSorter(_KMeansSorter):
124
127
  def _sort(self, embeddings: NDArray[Any], reference: NDArray[Any] | None = None) -> NDArray[np.intp]:
125
128
  clst = self._get_clusters(embeddings if reference is None else reference)
126
- return np.argsort(clst._dist2center(embeddings))
129
+ self.scores = clst._dist2center(embeddings)
130
+ return np.argsort(self.scores)
127
131
 
128
132
 
129
133
  class _KMeansComplexitySorter(_KMeansSorter):
@@ -134,11 +138,11 @@ class _KMeansComplexitySorter(_KMeansSorter):
134
138
 
135
139
  class Prioritize(Selection[Any]):
136
140
  """
137
- Prioritizes the dataset by sort order in the embedding space.
141
+ Sort the dataset indices in order of highest priority data in the embedding space.
138
142
 
139
143
  Parameters
140
144
  ----------
141
- model : torch.nn.Module
145
+ model : torch.nn.Module | None
142
146
  Model to use for encoding images
143
147
  batch_size : int
144
148
  Batch size to use when encoding images
@@ -146,10 +150,23 @@ class Prioritize(Selection[Any]):
146
150
  Device to use for encoding images
147
151
  method : Literal["knn", "kmeans_distance", "kmeans_complexity"]
148
152
  Method to use for prioritization
149
- k : int | None, default None
150
- Number of nearest neighbors to use for prioritization (knn only)
151
- c : int | None, default None
152
- Number of clusters to use for prioritization (kmeans only)
153
+ k : int or None, default None
154
+ Number of nearest neighbors to use for prioritization.
155
+ If None, uses the square_root of the number of samples. Only used for method="knn", ignored otherwise.
156
+ c : int or None, default None
157
+ Number of clusters to use for prioritization. If None, uses the square_root of the number of samples.
158
+ Only used for method="kmeans_*", ignored otherwise.
159
+
160
+ Notes
161
+ -----
162
+ 1. `k` is only used for method ["knn"].
163
+ 2. `c` is only used for methods ["kmeans_distance", "kmeans_complexity"].
164
+
165
+ Raises
166
+ ------
167
+ ValueError
168
+ If method not in supported methods
169
+
153
170
  """
154
171
 
155
172
  stage = SelectionStage.ORDER
@@ -157,55 +174,95 @@ class Prioritize(Selection[Any]):
157
174
  @overload
158
175
  def __init__(
159
176
  self,
160
- model: torch.nn.Module,
177
+ model: torch.nn.Module | None,
161
178
  batch_size: int,
162
179
  device: DeviceLike | None,
163
180
  method: Literal["knn"],
181
+ policy: Literal["hard_first", "easy_first", "stratified", "class_balance"],
164
182
  *,
165
183
  k: int | None = None,
184
+ class_label: NDArray[np.integer[Any]] | None = None,
166
185
  ) -> None: ...
167
186
 
168
187
  @overload
169
188
  def __init__(
170
189
  self,
171
- model: torch.nn.Module,
190
+ model: torch.nn.Module | None,
172
191
  batch_size: int,
173
192
  device: DeviceLike | None,
174
193
  method: Literal["kmeans_distance", "kmeans_complexity"],
194
+ policy: Literal["hard_first", "easy_first", "stratified", "class_balance"],
195
+ *,
196
+ c: int | None = None,
197
+ class_label: NDArray[np.integer[Any]] | None = None,
198
+ ) -> None: ...
199
+
200
+ @overload
201
+ def __init__(
202
+ self,
203
+ model: torch.nn.Module | None,
204
+ batch_size: int,
205
+ device: DeviceLike | None,
206
+ method: Literal["knn", "kmeans_distance", "kmeans_complexity"],
207
+ policy: Literal["class_balance"],
208
+ *,
209
+ k: int | None = None,
210
+ c: int | None = None,
211
+ class_label: NDArray[np.integer[Any]] | None,
212
+ ) -> None: ...
213
+
214
+ @overload
215
+ def __init__(
216
+ self,
217
+ model: torch.nn.Module | None,
218
+ batch_size: int,
219
+ device: DeviceLike | None,
220
+ method: Literal["knn", "kmeans_distance", "kmeans_complexity"],
221
+ policy: Literal["hard_first", "easy_first", "stratified"],
175
222
  *,
223
+ k: int | None = None,
176
224
  c: int | None = None,
225
+ class_label: NDArray[np.integer[Any]] | None = None,
177
226
  ) -> None: ...
178
227
 
179
228
  def __init__(
180
229
  self,
181
- model: torch.nn.Module,
230
+ model: torch.nn.Module | None,
182
231
  batch_size: int,
183
232
  device: DeviceLike | None,
184
233
  method: Literal["knn", "kmeans_distance", "kmeans_complexity"],
234
+ policy: Literal["hard_first", "easy_first", "stratified", "class_balance"],
185
235
  *,
186
236
  k: int | None = None,
187
237
  c: int | None = None,
238
+ class_label: NDArray[np.integer[Any]] | None = None,
188
239
  ) -> None:
189
- if method not in ("knn", "kmeans_distance", "kmeans_complexity"):
240
+ if method not in {"knn", "kmeans_distance", "kmeans_complexity"}:
190
241
  raise ValueError(f"Invalid prioritization method: {method}")
242
+ if policy not in ("hard_first", "easy_first", "stratified", "class_balance"):
243
+ raise ValueError(f"Invalid selection policy: {policy}")
191
244
  self._model = model
192
245
  self._batch_size = batch_size
193
246
  self._device = device
194
247
  self._method = method
248
+ self._policy = policy
195
249
  self._embeddings: Embeddings | None = None
196
250
  self._reference: Embeddings | None = None
197
251
  self._k = k
198
252
  self._c = c
253
+ self.class_label = class_label
199
254
 
200
255
  @overload
201
256
  @classmethod
202
257
  def using(
203
258
  cls,
204
259
  method: Literal["knn"],
260
+ policy: Literal["hard_first", "easy_first", "stratified", "class_balance"],
205
261
  *,
206
262
  k: int | None = None,
207
263
  embeddings: Embeddings | None = None,
208
264
  reference: Embeddings | None = None,
265
+ class_label: NDArray[np.integer[Any]] | None = None,
209
266
  ) -> Prioritize: ...
210
267
 
211
268
  @overload
@@ -213,49 +270,72 @@ class Prioritize(Selection[Any]):
213
270
  def using(
214
271
  cls,
215
272
  method: Literal["kmeans_distance", "kmeans_complexity"],
273
+ policy: Literal["hard_first", "easy_first", "stratified", "class_balance"],
216
274
  *,
217
275
  c: int | None = None,
218
276
  embeddings: Embeddings | None = None,
219
277
  reference: Embeddings | None = None,
278
+ class_label: NDArray[np.integer[Any]] | None = None,
220
279
  ) -> Prioritize: ...
221
280
 
222
281
  @classmethod
223
282
  def using(
224
283
  cls,
225
284
  method: Literal["knn", "kmeans_distance", "kmeans_complexity"],
285
+ policy: Literal["hard_first", "easy_first", "stratified", "class_balance"],
226
286
  *,
227
287
  k: int | None = None,
228
288
  c: int | None = None,
229
289
  embeddings: Embeddings | None = None,
230
290
  reference: Embeddings | None = None,
291
+ class_label: NDArray[np.integer[Any]] | None = None,
231
292
  ) -> Prioritize:
232
293
  """
233
- Prioritizes the dataset by sort order in the embedding space using existing
234
- embeddings and/or reference dataset embeddings.
294
+ Use precalculated embeddings to sort the dataset indices in order of
295
+ highest priority data in the embedding space.
235
296
 
236
297
  Parameters
237
298
  ----------
238
299
  method : Literal["knn", "kmeans_distance", "kmeans_complexity"]
239
- Method to use for prioritization
300
+ Method to use for sample scoring during prioritization.
301
+ policy : Literal["hard_first","easy_first","stratified","class_balance"]
302
+ Selection policy for prioritizing scored samples.
240
303
  embeddings : Embeddings or None, default None
241
- Embeddings to use for prioritization
304
+ Embeddings to use during prioritization. If None, `reference` must be set.
242
305
  reference : Embeddings or None, default None
243
- Reference embeddings to prioritize relative to
306
+ Reference embeddings used to prioritize the calculated dataset embeddings relative to them.
307
+ If `embeddings` is None, this will be used instead.
244
308
  k : int or None, default None
245
- Number of nearest neighbors to use for prioritization (knn only)
309
+ Number of nearest neighbors to use for prioritization.
310
+ If None, uses the square_root of the number of samples. Only used for method="knn", ignored otherwise.
246
311
  c : int or None, default None
247
- Number of clusters to use for prioritization (kmeans, cluster only)
312
+ Number of clusters to use for prioritization. If None, uses the square_root of the number of samples.
313
+ Only used for method="kmeans_*", ignored otherwise.
248
314
 
249
315
  Notes
250
316
  -----
251
- At least one of `embeddings` or `reference` must be provided.
317
+ 1. `k` is only used for method ["knn"].
318
+ 2. `c` is only used for methods ["kmeans_distance", "kmeans_complexity"].
319
+
320
+ Raises
321
+ ------
322
+ ValueError
323
+ If both `embeddings` and `reference` are None
324
+
252
325
  """
253
326
  emb_params: Embeddings | None = embeddings if embeddings is not None else reference
254
327
  if emb_params is None:
255
328
  raise ValueError("Must provide at least embeddings or reference embeddings.")
256
- prioritize = Prioritize(emb_params._model, emb_params.batch_size, emb_params.device, method)
257
- prioritize._k = k
258
- prioritize._c = c
329
+ prioritize = Prioritize(
330
+ emb_params._model,
331
+ emb_params.batch_size,
332
+ emb_params.device,
333
+ method,
334
+ policy,
335
+ k=k,
336
+ c=c,
337
+ class_label=class_label,
338
+ )
259
339
  prioritize._embeddings = embeddings
260
340
  prioritize._reference = reference
261
341
  return prioritize
@@ -265,9 +345,148 @@ class Prioritize(Selection[Any]):
265
345
  return _KNNSorter(samples, self._k)
266
346
  if self._method == "kmeans_distance":
267
347
  return _KMeansDistanceSorter(samples, self._c)
268
- # self._method == "kmeans_complexity"
269
348
  return _KMeansComplexitySorter(samples, self._c)
270
349
 
350
+ def _compute_bin_extents(self, scores: NDArray[np.floating[Any]]) -> tuple[np.float64, np.float64]:
351
+ """
352
+ Compute min/max bin extents for `scores`, padding outward by epsilon
353
+
354
+ Parameters
355
+ ----------
356
+ scores: NDArray[np.float64])
357
+ Array of floats to bin
358
+
359
+ Returns
360
+ -------
361
+ tuple[np.float64, np.float64]
362
+ (min,max) scores padded outward by epsilon = 1e-6*range(scores).
363
+ """
364
+ # ensure binning captures all samples in range
365
+ scores = scores.astype(np.float64)
366
+ min_score = np.min(scores)
367
+ max_score = np.max(scores)
368
+ rng = max_score - min_score
369
+ eps = rng * 1e-6
370
+ return min_score - eps, max_score + eps
371
+
372
+ def _select_ordered_by_label(self, labels: NDArray[np.integer[Any]]) -> NDArray[np.intp]:
373
+ """
374
+ Given labels (class, group, bin, etc) sorted with decreasing priority,
375
+ rerank so that we have approximate class/group balance. This function
376
+ is used for both stratified and class-balance rerank methods.
377
+
378
+ We could require and return prioritization scores and re-sorted class
379
+ labels, but it is more compact to return indices. This allows us to
380
+ resort other quantities, as well, outside the function.
381
+
382
+ Parameters
383
+ ---------
384
+ labels: NDArray[np.integer[Any]]
385
+ Class label or group ID per instance in order of decreasing priority
386
+
387
+ Returns
388
+ -------
389
+ NDArray[np.intp]
390
+ Indices that sort samples according to uniform class balance or
391
+ group membership while respecting priority of the initial ordering.
392
+ """
393
+ labels = np.array(labels)
394
+ num_samp = labels.shape[0]
395
+ selected = np.zeros(num_samp, dtype=bool)
396
+ # preserve ordering
397
+ _, index = np.unique(labels, return_index=True)
398
+ u_lab = labels[np.sort(index)]
399
+ n_cls = len(u_lab)
400
+
401
+ resort_inds = []
402
+ cls_idx = 0
403
+ n = 0
404
+ while len(resort_inds) < num_samp:
405
+ c0 = u_lab[cls_idx % n_cls]
406
+ samples_available = (~selected) * (labels == c0)
407
+ if any(samples_available):
408
+ i0 = np.argmax(samples_available) # selects first occurrence
409
+ resort_inds.append(i0)
410
+ selected[i0] = True
411
+ cls_idx += 1
412
+ n += 1
413
+ return np.array(resort_inds).astype(np.intp)
414
+
415
+ def _stratified_rerank(
416
+ self,
417
+ scores: NDArray[np.floating[Any]],
418
+ indices: NDArray[np.integer[Any]],
419
+ num_bins: int = 50,
420
+ ) -> NDArray[np.intp]:
421
+ """
422
+ Re-rank samples by sampling uniformly over binned scores. This
423
+ de-weights selection of samples with similar scores and encourages both
424
+ prototypical and challenging samples near the decision boundary.
425
+
426
+ Inputs
427
+ ------
428
+ scores: NDArray[float]
429
+ prioritization scores sorted in order of decreasing priority
430
+ indices: NDArray[int]
431
+ Indices to be re-sorted according to stratified sampling of scores.
432
+ Indices are ordered by decreasing priority.
433
+ num_bins: int
434
+
435
+
436
+ Returns
437
+ -------
438
+ NDArray[int]
439
+ re-ranked indices
440
+
441
+ """
442
+ mn, mx = self._compute_bin_extents(scores)
443
+ bin_edges = np.linspace(mn, mx, num=num_bins + 1, endpoint=True)
444
+ bin_label = np.digitize(scores, bin_edges)
445
+ srt_inds = self._select_ordered_by_label(bin_label)
446
+ return indices[srt_inds].astype(np.intp)
447
+
448
+ def _rerank(
449
+ self,
450
+ indices: NDArray[np.integer[Any]],
451
+ ) -> NDArray[np.intp]:
452
+ """
453
+ Re-rank samples according to the re-rank policy, self._policy. Values
454
+ from the 'indices' and optional 'scores' and 'class_label' variables are
455
+ assumed to correspond by index---i.e. indices[i], scores[i], and
456
+ class_label[i] should all refer to the same instance in the dataset.
457
+
458
+ Note: indices are assumed to be sorted with easy/prototypical samples
459
+ first--increasing order by most prioritization scoring methods.
460
+
461
+ Parameters
462
+ ----------
463
+ indices: NDArray[np.intp]
464
+ Indices that sort samples by increasing prioritization score, where
465
+ low scores indicate high prototypicality ('easy') and high scores
466
+ indicate challenging samples near the decision boundary ('hard').
467
+ """
468
+
469
+ if self._policy == "easy_first":
470
+ return indices.astype(np.intp)
471
+ if self._policy == "stratified":
472
+ if self._sorter.scores is None:
473
+ raise (
474
+ ValueError(
475
+ "Prioritization scores are necessary in order to use "
476
+ "stratified re-rank. Use 'knn' or 'kmeans_distance' "
477
+ "methods to populate scores."
478
+ )
479
+ )
480
+ return self._stratified_rerank(self._sorter.scores[::-1], indices[::-1])
481
+ if self._policy == "class_balance":
482
+ if self.class_label is None:
483
+ raise (ValueError("Class labels are necessary in order to use class_balance re-rank"))
484
+ indices_reversed = self._select_ordered_by_label(self.class_label[indices[::-1]]).astype(np.int32)
485
+ n = len(indices_reversed)
486
+ return (n - 1 - indices_reversed).astype(np.intp)
487
+ # elif self._policy == "hard_first" (default)
488
+ return indices[::-1].astype(np.intp)
489
+
271
490
  def _to_normalized_ndarray(self, embeddings: Embeddings, selection: list[int] | None = None) -> NDArray[Any]:
272
491
  emb: NDArray[Any] = embeddings.to_numpy(selection)
273
492
  emb /= max(np.max(np.linalg.norm(emb, axis=1)), EPSILON)
@@ -290,4 +509,5 @@ class Prioritize(Selection[Any]):
290
509
  emb = self._to_normalized_ndarray(embeddings, dataset._selection)
291
510
  ref = None if self._reference is None else self._to_normalized_ndarray(self._reference)
292
511
  # Sort indices
293
- dataset._selection = self._sorter._sort(emb, ref).tolist()
512
+ indices = self._sorter._sort(emb, ref)
513
+ dataset._selection = indices[self._rerank(indices)].astype(int).tolist()
@@ -9,7 +9,7 @@ from dataeval.data._selection import Select, Selection, SelectionStage
9
9
 
10
10
  class Reverse(Selection[Any]):
11
11
  """
12
- Reverse the selection order of the dataset.
12
+ Select dataset indices in reverse order.
13
13
  """
14
14
 
15
15
  stage = SelectionStage.ORDER
@@ -2,7 +2,8 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
- from typing import Any, Sequence
5
+ from collections.abc import Sequence
6
+ from typing import Any
6
7
 
7
8
  import numpy as np
8
9
  from numpy.random import BitGenerator, Generator, SeedSequence
@@ -15,12 +16,12 @@ from dataeval.utils._array import as_numpy
15
16
 
16
17
  class Shuffle(Selection[Any]):
17
18
  """
18
- Shuffle the dataset using a seed.
19
+ Select dataset indices in a random order.
19
20
 
20
21
  Parameters
21
22
  ----------
22
23
  seed : int, ArrayLike, SeedSequence, BitGenerator, Generator or None, default None
23
- Seed for the random number generator.
24
+ Seed for the random number generator. If None, results are not reproducible.
24
25
 
25
26
  See Also
26
27
  --------
@@ -33,7 +34,7 @@ class Shuffle(Selection[Any]):
33
34
  def __init__(
34
35
  self, seed: int | Sequence[int] | Array | SeedSequence | BitGenerator | Generator | None = None
35
36
  ) -> None:
36
- self.seed = as_numpy(seed) if isinstance(seed, (Sequence, Array)) else seed
37
+ self.seed = as_numpy(seed) if isinstance(seed, Sequence | Array) else seed
37
38
 
38
39
  def __call__(self, dataset: Select[Any]) -> None:
39
40
  rng = np.random.default_rng(self.seed)
@@ -12,8 +12,9 @@ __all__ = []
12
12
 
13
13
  import math
14
14
  from abc import abstractmethod
15
+ from collections.abc import Callable
15
16
  from functools import wraps
16
- from typing import Any, Callable, Literal, Protocol, TypeVar, runtime_checkable
17
+ from typing import Any, Literal, Protocol, TypeVar, runtime_checkable
17
18
 
18
19
  import numpy as np
19
20
  from numpy.typing import NDArray