dataeval 0.85.0__py3-none-any.whl → 0.86.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. dataeval/__init__.py +1 -1
  2. dataeval/_log.py +1 -1
  3. dataeval/config.py +21 -4
  4. dataeval/data/_embeddings.py +2 -2
  5. dataeval/data/_images.py +2 -3
  6. dataeval/data/_metadata.py +65 -42
  7. dataeval/data/_selection.py +2 -3
  8. dataeval/data/_split.py +2 -3
  9. dataeval/data/_targets.py +17 -13
  10. dataeval/data/selections/_classfilter.py +6 -8
  11. dataeval/data/selections/_prioritize.py +6 -9
  12. dataeval/data/selections/_shuffle.py +3 -1
  13. dataeval/detectors/drift/__init__.py +4 -1
  14. dataeval/detectors/drift/_base.py +4 -5
  15. dataeval/detectors/drift/_mmd.py +3 -6
  16. dataeval/detectors/drift/_mvdc.py +92 -0
  17. dataeval/detectors/drift/_nml/__init__.py +6 -0
  18. dataeval/detectors/drift/_nml/_base.py +70 -0
  19. dataeval/detectors/drift/_nml/_chunk.py +396 -0
  20. dataeval/detectors/drift/_nml/_domainclassifier.py +181 -0
  21. dataeval/detectors/drift/_nml/_result.py +97 -0
  22. dataeval/detectors/drift/_nml/_thresholds.py +269 -0
  23. dataeval/detectors/linters/outliers.py +7 -7
  24. dataeval/metrics/bias/_parity.py +10 -13
  25. dataeval/metrics/estimators/_divergence.py +2 -4
  26. dataeval/metrics/stats/_base.py +103 -42
  27. dataeval/metrics/stats/_boxratiostats.py +21 -19
  28. dataeval/metrics/stats/_dimensionstats.py +14 -10
  29. dataeval/metrics/stats/_hashstats.py +1 -1
  30. dataeval/metrics/stats/_pixelstats.py +6 -6
  31. dataeval/metrics/stats/_visualstats.py +3 -3
  32. dataeval/outputs/__init__.py +2 -1
  33. dataeval/outputs/_base.py +22 -7
  34. dataeval/outputs/_bias.py +27 -31
  35. dataeval/outputs/_drift.py +60 -0
  36. dataeval/outputs/_linters.py +12 -17
  37. dataeval/outputs/_stats.py +83 -29
  38. dataeval/outputs/_workflows.py +2 -2
  39. dataeval/utils/_array.py +6 -9
  40. dataeval/utils/_bin.py +1 -2
  41. dataeval/utils/_clusterer.py +7 -4
  42. dataeval/utils/_fast_mst.py +27 -13
  43. dataeval/utils/_image.py +65 -11
  44. dataeval/utils/_mst.py +1 -3
  45. dataeval/utils/_plot.py +15 -10
  46. dataeval/utils/data/_dataset.py +32 -20
  47. dataeval/utils/data/metadata.py +104 -82
  48. dataeval/utils/datasets/__init__.py +2 -0
  49. dataeval/utils/datasets/_antiuav.py +189 -0
  50. dataeval/utils/datasets/_base.py +11 -8
  51. dataeval/utils/datasets/_cifar10.py +104 -45
  52. dataeval/utils/datasets/_fileio.py +21 -47
  53. dataeval/utils/datasets/_milco.py +19 -11
  54. dataeval/utils/datasets/_mixin.py +2 -4
  55. dataeval/utils/datasets/_mnist.py +3 -4
  56. dataeval/utils/datasets/_ships.py +14 -7
  57. dataeval/utils/datasets/_voc.py +229 -42
  58. dataeval/utils/torch/models.py +5 -10
  59. dataeval/utils/torch/trainer.py +3 -3
  60. dataeval/workflows/sufficiency.py +2 -2
  61. {dataeval-0.85.0.dist-info → dataeval-0.86.1.dist-info}/METADATA +3 -2
  62. dataeval-0.86.1.dist-info/RECORD +114 -0
  63. dataeval/detectors/ood/vae.py +0 -74
  64. dataeval-0.85.0.dist-info/RECORD +0 -107
  65. {dataeval-0.85.0.dist-info → dataeval-0.86.1.dist-info}/LICENSE.txt +0 -0
  66. {dataeval-0.85.0.dist-info → dataeval-0.86.1.dist-info}/WHEEL +0 -0
dataeval/__init__.py CHANGED
@@ -8,7 +8,7 @@ shifts that impact performance of deployed models.
8
8
  from __future__ import annotations
9
9
 
10
10
  __all__ = ["config", "detectors", "log", "metrics", "typing", "utils", "workflows"]
11
- __version__ = "0.85.0"
11
+ __version__ = "0.86.1"
12
12
 
13
13
  import logging
14
14
 
dataeval/_log.py CHANGED
@@ -8,7 +8,7 @@ class LogMessage:
8
8
  Deferred message callback for logging expensive messages.
9
9
  """
10
10
 
11
- def __init__(self, fn: Callable[..., str]):
11
+ def __init__(self, fn: Callable[..., str]) -> None:
12
12
  self._fn = fn
13
13
  self._str = None
14
14
 
dataeval/config.py CHANGED
@@ -4,10 +4,10 @@ Global configuration settings for DataEval.
4
4
 
5
5
  from __future__ import annotations
6
6
 
7
- __all__ = ["get_device", "set_device", "get_max_processes", "set_max_processes", "DeviceLike"]
7
+ __all__ = ["get_device", "set_device", "get_max_processes", "set_max_processes", "use_max_processes", "DeviceLike"]
8
8
 
9
9
  import sys
10
- from typing import Union
10
+ from typing import Any, Union
11
11
 
12
12
  if sys.version_info >= (3, 10):
13
13
  from typing import TypeAlias
@@ -78,8 +78,7 @@ def get_device(override: DeviceLike | None = None) -> torch.device:
78
78
  if override is None:
79
79
  global _device
80
80
  return torch.get_default_device() if _device is None else _device
81
- else:
82
- return _todevice(override)
81
+ return _todevice(override)
83
82
 
84
83
 
85
84
  def set_max_processes(processes: int | None) -> None:
@@ -112,6 +111,24 @@ def get_max_processes() -> int | None:
112
111
  return _processes
113
112
 
114
113
 
114
+ class MaxProcessesContextManager:
115
+ def __init__(self, processes: int) -> None:
116
+ self._processes = processes
117
+
118
+ def __enter__(self) -> None:
119
+ global _processes
120
+ self._old = _processes
121
+ set_max_processes(self._processes)
122
+
123
+ def __exit__(self, *args: tuple[Any, ...]) -> None:
124
+ global _processes
125
+ _processes = self._old
126
+
127
+
128
+ def use_max_processes(processes: int) -> MaxProcessesContextManager:
129
+ return MaxProcessesContextManager(processes)
130
+
131
+
115
132
  def set_seed(seed: int | None, all_generators: bool = False) -> None:
116
133
  """
117
134
  Sets the seed for use by classes that allow for a random state or seed.
@@ -144,8 +144,7 @@ class Embeddings:
144
144
  """
145
145
  if indices is not None:
146
146
  return torch.vstack(list(self._batch(indices))).to(self.device)
147
- else:
148
- return self[:]
147
+ return self[:]
149
148
 
150
149
  def to_numpy(self, indices: Sequence[int] | None = None) -> NDArray[Any]:
151
150
  """
@@ -248,6 +247,7 @@ class Embeddings:
248
247
  _logger.log(logging.DEBUG, f"Saved embeddings cache from {path}")
249
248
  except Exception as e:
250
249
  _logger.log(logging.ERROR, f"Failed to save embeddings cache: {e}")
250
+ raise e
251
251
 
252
252
  @classmethod
253
253
  def load(cls, path: Path | str) -> Embeddings:
dataeval/data/_images.py CHANGED
@@ -73,15 +73,14 @@ class Images(Generic[T]):
73
73
  def __getitem__(self, key: int | slice, /) -> Sequence[T] | T:
74
74
  if isinstance(key, slice):
75
75
  return [self._get_image(k) for k in range(len(self._dataset))[key]]
76
- elif hasattr(key, "__int__"):
76
+ if hasattr(key, "__int__"):
77
77
  return self._get_image(int(key))
78
78
  raise TypeError(f"Key must be integers or slices, not {type(key)}")
79
79
 
80
80
  def _get_image(self, index: int) -> T:
81
81
  if self._is_tuple_datum:
82
82
  return cast(Dataset[tuple[T, Any, Any]], self._dataset)[index][0]
83
- else:
84
- return cast(Dataset[T], self._dataset)[index]
83
+ return cast(Dataset[T], self._dataset)[index]
85
84
 
86
85
  def __iter__(self) -> Iterator[T]:
87
86
  for i in range(len(self._dataset)):
@@ -191,7 +191,12 @@ class Metadata:
191
191
  self._process()
192
192
  return self._image_indices
193
193
 
194
- def _collate(self, force: bool = False):
194
+ @property
195
+ def image_count(self) -> int:
196
+ self._process()
197
+ return int(self._image_indices.max() + 1)
198
+
199
+ def _collate(self, force: bool = False) -> None:
195
200
  if self._collated and not force:
196
201
  return
197
202
 
@@ -238,7 +243,7 @@ class Metadata:
238
243
  self._class_names = [index2label.get(i, str(i)) for i in np.unique(self._class_labels)]
239
244
  self._collated = True
240
245
 
241
- def _merge(self, force: bool = False):
246
+ def _merge(self, force: bool = False) -> None:
242
247
  if self._merged is not None and not force:
243
248
  return
244
249
 
@@ -261,48 +266,26 @@ class Metadata:
261
266
  "Metadata dictionary needs to be a single dictionary whose values "
262
267
  "are arraylike containing the metadata on a per image or per object basis."
263
268
  )
264
- else:
265
- check_length = len(v) if check_length is None else check_length
266
- if check_length != len(v):
267
- raise ValueError(
268
- "The lists/arrays in the metadata dict have varying lengths. "
269
- "Metadata requires them to be uniform in length."
270
- )
269
+ check_length = len(v) if check_length is None else check_length
270
+ if check_length != len(v):
271
+ raise ValueError(
272
+ "The lists/arrays in the metadata dict have varying lengths. "
273
+ "Metadata requires them to be uniform in length."
274
+ )
271
275
  if len(self._class_labels) != check_length:
272
276
  raise ValueError(
273
277
  f"The length of the label array {len(self._class_labels)} is not the same as "
274
278
  f"the length of the metadata arrays {check_length}."
275
279
  )
276
280
 
277
- def _process(self, force: bool = False) -> None:
278
- if self._processed and not force:
279
- return
280
-
281
- # Create image indices from targets
282
- self._image_indices = np.arange(len(self.raw)) if self.targets.source is None else self.targets.source
283
-
284
- # Validate the metadata dimensions
285
- self._validate()
286
-
287
- # Include specified metadata keys
288
- if self.include:
289
- metadata = {i: self.merged[i] for i in self.include if i in self.merged}
290
- continuous_factor_bins = (
291
- {i: self.continuous_factor_bins[i] for i in self.include if i in self.continuous_factor_bins}
292
- if self.continuous_factor_bins
293
- else {}
294
- )
295
- else:
296
- metadata = self.merged
297
- continuous_factor_bins = dict(self.continuous_factor_bins) if self.continuous_factor_bins else {}
298
- for k in self.exclude:
299
- metadata.pop(k, None)
300
- continuous_factor_bins.pop(k, None)
301
-
302
- # Remove generated "_image_index" if present
303
- if "_image_index" in metadata:
304
- metadata.pop("_image_index", None)
281
+ def _filter(self, d: Mapping[str, Any]) -> dict[str, Any]:
282
+ return (
283
+ {k: d[k] for k in self.include if k in d} if self.include else {k: d[k] for k in d if k not in self.exclude}
284
+ )
305
285
 
286
+ def _split_continuous_discrete(
287
+ self, metadata: dict[str, NDArray[Any]], continuous_factor_bins: dict[str, int | Sequence[float]]
288
+ ) -> tuple[dict[str, NDArray[Any]], dict[str, NDArray[np.int64]]]:
306
289
  # Bin according to user supplied bins
307
290
  continuous_metadata = {}
308
291
  discrete_metadata = {}
@@ -341,6 +324,28 @@ class Metadata:
341
324
  else:
342
325
  _, discrete_metadata[key] = np.unique(data, return_inverse=True)
343
326
 
327
+ return continuous_metadata, discrete_metadata
328
+
329
+ def _process(self, force: bool = False) -> None:
330
+ if self._processed and not force:
331
+ return
332
+
333
+ # Create image indices from targets
334
+ self._image_indices = np.arange(len(self.raw)) if self.targets.source is None else self.targets.source
335
+
336
+ # Validate the metadata dimensions
337
+ self._validate()
338
+
339
+ # Filter the merged metadata and continuous factor bins
340
+ metadata = self._filter(self.merged)
341
+ continuous_factor_bins = self._filter(self.continuous_factor_bins)
342
+
343
+ # Remove generated "_image_index" if present
344
+ metadata.pop("_image_index", None)
345
+
346
+ # Split the metadata into continuous and discrete
347
+ continuous_metadata, discrete_metadata = self._split_continuous_discrete(metadata, continuous_factor_bins)
348
+
344
349
  # Split out the dictionaries into the keys and values
345
350
  self._discrete_factor_names = list(discrete_metadata.keys())
346
351
  self._discrete_data = (
@@ -358,13 +363,31 @@ class Metadata:
358
363
  self._processed = True
359
364
 
360
365
  def add_factors(self, factors: Mapping[str, ArrayLike]) -> None:
366
+ """
367
+ Add additional factors to the metadata.
368
+
369
+ The number of measures per factor must match the number of images
370
+ in the dataset or the number of detections in the dataset.
371
+
372
+ Parameters
373
+ ----------
374
+ factors : Mapping[str, ArrayLike]
375
+ Dictionary of factors to add to the metadata.
376
+ """
361
377
  self._merge()
362
- self._processed = False
363
- target_len = len(self.targets.source) if self.targets.source is not None else len(self.targets)
364
- if any(len(v if isinstance(v, Sized) else as_numpy(v)) != target_len for v in factors.values()):
378
+
379
+ targets = len(self.targets.source) if self.targets.source is not None else len(self.targets)
380
+ images = self.image_count
381
+ lengths = {k: len(v if isinstance(v, Sized) else np.atleast_1d(as_numpy(v))) for k, v in factors.items()}
382
+ targets_match = all(f == targets for f in lengths.values())
383
+ images_match = targets_match if images == targets else all(f == images for f in lengths.values())
384
+ if not targets_match and not images_match:
365
385
  raise ValueError(
366
386
  "The lists/arrays in the provided factors have a different length than the current metadata factors."
367
387
  )
368
- merged = cast(tuple[dict[str, ArrayLike], dict[str, list[str]]], self._merged)[0]
388
+ merged = cast(dict[str, ArrayLike], self._merged[0] if self._merged is not None else {})
369
389
  for k, v in factors.items():
370
- merged[k] = v
390
+ v = as_numpy(v)
391
+ merged[k] = v if (self.targets.source is None or lengths[k] == targets) else v[self.targets.source]
392
+
393
+ self._processed = False
@@ -110,8 +110,7 @@ class Select(AnnotatedDataset[_TDatum]):
110
110
  grouped: dict[int, list[Selection[_TDatum]]] = {}
111
111
  for selection in selections_list:
112
112
  grouped.setdefault(selection.stage, []).append(selection)
113
- selection_list = [selection for category in sorted(grouped) for selection in grouped[category]]
114
- return selection_list
113
+ return [selection for category in sorted(grouped) for selection in grouped[category]]
115
114
 
116
115
  def _apply_selections(self) -> None:
117
116
  for selection in self._selections:
@@ -120,7 +119,7 @@ class Select(AnnotatedDataset[_TDatum]):
120
119
 
121
120
  def _apply_subselection(self, datum: _TDatum, index: int) -> _TDatum:
122
121
  for subselection, indices in self._subselections:
123
- datum = subselection(datum) if index in indices else datum
122
+ datum = subselection(datum) if self._selection[index] in indices else datum
124
123
  return datum
125
124
 
126
125
  def __getitem__(self, index: int) -> _TDatum:
dataeval/data/_split.py CHANGED
@@ -23,7 +23,7 @@ _logger = logging.getLogger(__name__)
23
23
  class KFoldSplitter(Protocol):
24
24
  """Protocol covering sklearn KFold variant splitters"""
25
25
 
26
- def __init__(self, n_splits: int): ...
26
+ def __init__(self, n_splits: int) -> None: ...
27
27
  def split(self, X: Any, y: Any, groups: Any) -> Iterator[tuple[NDArray[Any], NDArray[Any]]]: ...
28
28
 
29
29
 
@@ -209,8 +209,7 @@ def get_groups(metadata: Metadata, split_on: Sequence[str] | None) -> NDArray[np
209
209
  split_set = set(split_on)
210
210
  indices = [i for i, name in enumerate(metadata.discrete_factor_names) if name in split_set]
211
211
  binned_features = metadata.discrete_data[:, indices]
212
- group_ids = np.unique(binned_features, axis=0, return_inverse=True)[1]
213
- return group_ids
212
+ return np.unique(binned_features, axis=0, return_inverse=True)[1]
214
213
 
215
214
 
216
215
  def make_splits(
dataeval/data/_targets.py CHANGED
@@ -24,11 +24,13 @@ class Targets:
24
24
  labels : NDArray[np.intp]
25
25
  Labels (N,) for N images or objects
26
26
  scores : NDArray[np.float32]
27
- Probability scores (N,M) for N images of M classes or confidence score (N,) of objects
27
+ Probability scores (N, M) for N images of M classes or confidence score (N,) of objects
28
28
  bboxes : NDArray[np.float32] | None
29
- Bounding boxes (N,4) for N objects in (x0,y0,x1,y1) format
29
+ Bounding boxes (N, 4) for N objects in (x0, y0, x1, y1) format
30
30
  source : NDArray[np.intp] | None
31
31
  Source image index (N,) for N objects
32
+ size : int
33
+ Count of objects
32
34
  """
33
35
 
34
36
  labels: NDArray[np.intp]
@@ -55,13 +57,16 @@ class Targets:
55
57
  )
56
58
 
57
59
  if self.bboxes is not None and len(self.bboxes) > 0 and self.bboxes.shape[-1] != 4:
58
- raise ValueError("Bounding boxes must be in (x0,y0,x1,y1) format.")
60
+ raise ValueError("Bounding boxes must be in (x0, y0, x1, y1) format.")
61
+
62
+ @property
63
+ def size(self) -> int:
64
+ return len(self.labels)
59
65
 
60
66
  def __len__(self) -> int:
61
67
  if self.source is None:
62
68
  return len(self.labels)
63
- else:
64
- return len(np.unique(self.source))
69
+ return len(np.unique(self.source))
65
70
 
66
71
  def __getitem__(self, idx: int, /) -> Targets:
67
72
  if self.source is None or self.bboxes is None:
@@ -71,14 +76,13 @@ class Targets:
71
76
  None,
72
77
  None,
73
78
  )
74
- else:
75
- mask = np.where(self.source == idx, True, False)
76
- return Targets(
77
- np.atleast_1d(self.labels[mask]),
78
- np.atleast_1d(self.scores[mask]),
79
- np.atleast_2d(self.bboxes[mask]),
80
- np.atleast_1d(self.source[mask]),
81
- )
79
+ mask = np.where(self.source == idx, True, False)
80
+ return Targets(
81
+ np.atleast_1d(self.labels[mask]),
82
+ np.atleast_1d(self.scores[mask]),
83
+ np.atleast_2d(self.bboxes[mask]),
84
+ np.atleast_1d(self.source[mask]),
85
+ )
82
86
 
83
87
  def __iter__(self) -> Iterator[Targets]:
84
88
  for i in range(len(self.labels)) if self.source is None else np.unique(self.source):
@@ -10,7 +10,6 @@ from numpy.typing import NDArray
10
10
  from dataeval.data._selection import Select, Selection, SelectionStage, Subselection
11
11
  from dataeval.typing import Array, ObjectDetectionDatum, ObjectDetectionTarget, SegmentationDatum, SegmentationTarget
12
12
  from dataeval.utils._array import as_numpy
13
- from dataeval.utils.data.metadata import flatten
14
13
 
15
14
 
16
15
  class ClassFilter(Selection[Any]):
@@ -69,11 +68,8 @@ _TTarget = TypeVar("_TTarget", ObjectDetectionTarget, SegmentationTarget)
69
68
 
70
69
 
71
70
  def _try_mask_object(obj: _T, mask: NDArray[np.bool_]) -> _T:
72
- if isinstance(obj, Sized) and not isinstance(obj, (str, bytes, bytearray)) and len(obj) == len(mask):
73
- if isinstance(obj, Array):
74
- return obj[mask]
75
- elif isinstance(obj, Sequence):
76
- return cast(_T, [item for i, item in enumerate(obj) if mask[i]])
71
+ if not isinstance(obj, (str, bytes, bytearray)) and isinstance(obj, (Sequence, Array)) and len(obj) == len(mask):
72
+ return obj[mask] if isinstance(obj, Array) else cast(_T, [item for i, item in enumerate(obj) if mask[i]])
77
73
  return obj
78
74
 
79
75
 
@@ -96,13 +92,15 @@ class ClassFilterSubSelection(Subselection[Any]):
96
92
  def __init__(self, classes: Sequence[int]) -> None:
97
93
  self.classes = classes
98
94
 
95
+ def _filter(self, d: dict[str, Any], mask: NDArray[np.bool_]) -> dict[str, Any]:
96
+ return {k: self._filter(v, mask) if isinstance(v, dict) else _try_mask_object(v, mask) for k, v in d.items()}
97
+
99
98
  def __call__(self, datum: _TDatum) -> _TDatum:
100
99
  # build a mask for any arrays
101
100
  image, target, metadata = datum
102
101
 
103
102
  mask = np.isin(as_numpy(target.labels), self.classes)
104
- flattened_metadata = flatten(metadata)[0]
105
- filtered_metadata = {k: _try_mask_object(v, mask) for k, v in flattened_metadata.items()}
103
+ filtered_metadata = self._filter(metadata, mask)
106
104
 
107
105
  # return a masked datum
108
106
  filtered_datum = image, ClassFilterTarget(target, mask), filtered_metadata
@@ -99,8 +99,7 @@ class _KNNSorter(_Sorter):
99
99
  np.fill_diagonal(dists, np.inf)
100
100
  else:
101
101
  dists = pairwise_distances(embeddings, reference)
102
- inds = np.argsort(np.sort(dists, axis=1)[:, self._k])
103
- return inds
102
+ return np.argsort(np.sort(dists, axis=1)[:, self._k])
104
103
 
105
104
 
106
105
  class _KMeansSorter(_Sorter):
@@ -124,15 +123,13 @@ class _KMeansSorter(_Sorter):
124
123
  class _KMeansDistanceSorter(_KMeansSorter):
125
124
  def _sort(self, embeddings: NDArray[Any], reference: NDArray[Any] | None = None) -> NDArray[np.intp]:
126
125
  clst = self._get_clusters(embeddings if reference is None else reference)
127
- inds = np.argsort(clst._dist2center(embeddings))
128
- return inds
126
+ return np.argsort(clst._dist2center(embeddings))
129
127
 
130
128
 
131
129
  class _KMeansComplexitySorter(_KMeansSorter):
132
130
  def _sort(self, embeddings: NDArray[Any], reference: NDArray[Any] | None = None) -> NDArray[np.intp]:
133
131
  clst = self._get_clusters(embeddings if reference is None else reference)
134
- inds = clst._sort_by_weights(embeddings)
135
- return inds
132
+ return clst._sort_by_weights(embeddings)
136
133
 
137
134
 
138
135
  class Prioritize(Selection[Any]):
@@ -266,10 +263,10 @@ class Prioritize(Selection[Any]):
266
263
  def _get_sorter(self, samples: int) -> _Sorter:
267
264
  if self._method == "knn":
268
265
  return _KNNSorter(samples, self._k)
269
- elif self._method == "kmeans_distance":
266
+ if self._method == "kmeans_distance":
270
267
  return _KMeansDistanceSorter(samples, self._c)
271
- else: # self._method == "kmeans_complexity"
272
- return _KMeansComplexitySorter(samples, self._c)
268
+ # self._method == "kmeans_complexity"
269
+ return _KMeansComplexitySorter(samples, self._c)
273
270
 
274
271
  def _to_normalized_ndarray(self, embeddings: Embeddings, selection: list[int] | None = None) -> NDArray[Any]:
275
272
  emb: NDArray[Any] = embeddings.to_numpy(selection)
@@ -30,7 +30,9 @@ class Shuffle(Selection[Any]):
30
30
  seed: int | NDArray[Any] | SeedSequence | BitGenerator | Generator | None
31
31
  stage = SelectionStage.ORDER
32
32
 
33
- def __init__(self, seed: int | Sequence[int] | Array | SeedSequence | BitGenerator | Generator | None = None):
33
+ def __init__(
34
+ self, seed: int | Sequence[int] | Array | SeedSequence | BitGenerator | Generator | None = None
35
+ ) -> None:
34
36
  self.seed = as_numpy(seed) if isinstance(seed, (Sequence, Array)) else seed
35
37
 
36
38
  def __call__(self, dataset: Select[Any]) -> None:
@@ -7,6 +7,8 @@ __all__ = [
7
7
  "DriftKS",
8
8
  "DriftMMD",
9
9
  "DriftMMDOutput",
10
+ "DriftMVDC",
11
+ "DriftMVDCOutput",
10
12
  "DriftOutput",
11
13
  "DriftUncertainty",
12
14
  "UpdateStrategy",
@@ -18,5 +20,6 @@ from dataeval.detectors.drift._base import UpdateStrategy
18
20
  from dataeval.detectors.drift._cvm import DriftCVM
19
21
  from dataeval.detectors.drift._ks import DriftKS
20
22
  from dataeval.detectors.drift._mmd import DriftMMD
23
+ from dataeval.detectors.drift._mvdc import DriftMVDC
21
24
  from dataeval.detectors.drift._uncertainty import DriftUncertainty
22
- from dataeval.outputs._drift import DriftMMDOutput, DriftOutput
25
+ from dataeval.outputs._drift import DriftMMDOutput, DriftMVDCOutput, DriftOutput
@@ -13,7 +13,7 @@ __all__ = []
13
13
  import math
14
14
  from abc import abstractmethod
15
15
  from functools import wraps
16
- from typing import Callable, Literal, Protocol, TypeVar, runtime_checkable
16
+ from typing import Any, Callable, Literal, Protocol, TypeVar, runtime_checkable
17
17
 
18
18
  import numpy as np
19
19
  from numpy.typing import NDArray
@@ -40,7 +40,7 @@ def update_strategy(fn: Callable[..., R]) -> Callable[..., R]:
40
40
  """Decorator to update x_ref with x using selected update methodology"""
41
41
 
42
42
  @wraps(fn)
43
- def _(self: BaseDrift, data: Embeddings | Array, *args, **kwargs) -> R:
43
+ def _(self: BaseDrift, data: Embeddings | Array, *args: tuple[Any, ...], **kwargs: dict[str, Any]) -> R:
44
44
  output = fn(self, data, *args, **kwargs)
45
45
 
46
46
  # update reference dataset
@@ -184,7 +184,7 @@ class BaseDriftUnivariate(BaseDrift):
184
184
  threshold = self.p_val / self.n_features
185
185
  drift_pred = bool((p_vals < threshold).any())
186
186
  return drift_pred, threshold
187
- elif self.correction == "fdr":
187
+ if self.correction == "fdr":
188
188
  n = p_vals.shape[0]
189
189
  i = np.arange(n) + np.int_(1)
190
190
  p_sorted = np.sort(p_vals)
@@ -195,8 +195,7 @@ class BaseDriftUnivariate(BaseDrift):
195
195
  except ValueError: # sorted p-values not below thresholds
196
196
  return bool(below_threshold.any()), q_threshold.min()
197
197
  return bool(below_threshold.any()), q_threshold[idx_threshold]
198
- else:
199
- raise ValueError("`correction` needs to be either `bonferroni` or `fdr`.")
198
+ raise ValueError("`correction` needs to be either `bonferroni` or `fdr`.")
200
199
 
201
200
  @set_metadata
202
201
  @update_strategy
@@ -95,8 +95,7 @@ class DriftMMD(BaseDrift):
95
95
  k_xy = self._kernel(x, y)
96
96
  k_xx = self._k_xx if self._k_xx is not None and self.update_strategy is None else self._kernel(x, x)
97
97
  k_yy = self._kernel(y, y)
98
- kernel_mat = torch.cat([torch.cat([k_xx, k_xy], 1), torch.cat([k_xy.T, k_yy], 1)], 0)
99
- return kernel_mat
98
+ return torch.cat([torch.cat([k_xx, k_xy], 1), torch.cat([k_xy.T, k_yy], 1)], 0)
100
99
 
101
100
  def score(self, data: Embeddings | Array) -> tuple[float, float, float]:
102
101
  """
@@ -205,8 +204,7 @@ def sigma_median(x: torch.Tensor, y: torch.Tensor, dist: torch.Tensor) -> torch.
205
204
  n = min(x.shape[0], y.shape[0])
206
205
  n = n if (x[:n] == y[:n]).all() and x.shape == y.shape else 0
207
206
  n_median = n + (torch.prod(torch.as_tensor(dist.shape)) - n) // 2 - 1
208
- sigma = (0.5 * dist.flatten().sort().values[int(n_median)].unsqueeze(dim=-1)) ** 0.5
209
- return sigma
207
+ return (0.5 * dist.flatten().sort().values[int(n_median)].unsqueeze(dim=-1)) ** 0.5
210
208
 
211
209
 
212
210
  class GaussianRBF(torch.nn.Module):
@@ -310,5 +308,4 @@ def mmd2_from_kernel_matrix(
310
308
  kernel_mat = kernel_mat[idx][:, idx]
311
309
  k_xx, k_yy, k_xy = kernel_mat[:-m, :-m], kernel_mat[-m:, -m:], kernel_mat[-m:, :-m]
312
310
  c_xx, c_yy = 1 / (n * (n - 1)), 1 / (m * (m - 1))
313
- mmd2 = c_xx * k_xx.sum() + c_yy * k_yy.sum() - 2.0 * k_xy.mean()
314
- return mmd2
311
+ return c_xx * k_xx.sum() + c_yy * k_yy.sum() - 2.0 * k_xy.mean()
@@ -0,0 +1,92 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ from numpy.typing import ArrayLike
8
+
9
+ if TYPE_CHECKING:
10
+ from typing import Self
11
+ else:
12
+ from typing_extensions import Self
13
+
14
+ from dataeval.detectors.drift._nml._chunk import CountBasedChunker, SizeBasedChunker
15
+ from dataeval.detectors.drift._nml._domainclassifier import DomainClassifierCalculator
16
+ from dataeval.detectors.drift._nml._thresholds import ConstantThreshold
17
+ from dataeval.outputs._drift import DriftMVDCOutput
18
+ from dataeval.utils._array import flatten
19
+
20
+
21
+ class DriftMVDC:
22
+ """Multivariant Domain Classifier
23
+
24
+ Parameters
25
+ ----------
26
+ n_folds : int, default 5
27
+ Number of cross-validation (CV) folds.
28
+ chunk_size : int or None, default None
29
+ Number of samples in a chunk used in CV, will get one metric & prediction per chunk.
30
+ chunk_count : int or None, default None
31
+ Number of total chunks used in CV, will get one metric & prediction per chunk.
32
+ threshold : Tuple[float, float], default (0.45, 0.65)
33
+ (lower, upper) metric bounds on roc_auc for identifying :term:`drift<Drift>`.
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ n_folds: int = 5,
39
+ chunk_size: int | None = None,
40
+ chunk_count: int | None = None,
41
+ threshold: tuple[float, float] = (0.45, 0.65),
42
+ ) -> None:
43
+ self.threshold: tuple[float, float] = max(0.0, min(threshold)), min(1.0, max(threshold))
44
+ chunker = (
45
+ CountBasedChunker(10 if chunk_count is None else chunk_count)
46
+ if chunk_size is None
47
+ else SizeBasedChunker(chunk_size)
48
+ )
49
+ self._calc = DomainClassifierCalculator(
50
+ cv_folds_num=n_folds,
51
+ chunker=chunker,
52
+ threshold=ConstantThreshold(lower=self.threshold[0], upper=self.threshold[1]),
53
+ )
54
+
55
+ def fit(self, x_ref: ArrayLike) -> Self:
56
+ """
57
+ Fit the domain classifier on the training dataframe
58
+
59
+ Parameters
60
+ ----------
61
+ x_ref : ArrayLike
62
+ Reference data with dim[n_samples, n_features].
63
+
64
+ Returns
65
+ -------
66
+ Self
67
+
68
+ """
69
+ # for 1D input, assume that is 1 sample: dim[1,n_features]
70
+ self.x_ref: pd.DataFrame = pd.DataFrame(flatten(np.atleast_2d(np.asarray(x_ref))))
71
+ self.n_features: int = self.x_ref.shape[-1]
72
+ self._calc.fit(self.x_ref)
73
+ return self
74
+
75
+ def predict(self, x: ArrayLike) -> DriftMVDCOutput:
76
+ """
77
+ Perform :term:`inference<Inference>` on the test dataframe
78
+
79
+ Parameters
80
+ ----------
81
+ x : ArrayLike
82
+ Test (analysis) data with dim[n_samples, n_features].
83
+
84
+ Returns
85
+ -------
86
+ DomainClassifierDriftResult
87
+ """
88
+ self.x_test: pd.DataFrame = pd.DataFrame(flatten(np.atleast_2d(np.asarray(x))))
89
+ if self.x_test.shape[-1] != self.n_features:
90
+ raise ValueError("Reference and test embeddings have different number of features")
91
+
92
+ return self._calc.calculate(self.x_test)
@@ -0,0 +1,6 @@
1
+ """
2
+ Source code derived from NannyML 0.13.0
3
+ https://github.com/NannyML/nannyml/
4
+
5
+ Licensed under Apache Software License (Apache 2.0)
6
+ """