dataeval 0.69.4__py3-none-any.whl → 0.70.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. dataeval/__init__.py +8 -8
  2. dataeval/_internal/datasets.py +235 -131
  3. dataeval/_internal/detectors/clusterer.py +2 -0
  4. dataeval/_internal/detectors/drift/base.py +7 -8
  5. dataeval/_internal/detectors/drift/mmd.py +4 -4
  6. dataeval/_internal/detectors/duplicates.py +64 -45
  7. dataeval/_internal/detectors/merged_stats.py +23 -54
  8. dataeval/_internal/detectors/ood/ae.py +8 -6
  9. dataeval/_internal/detectors/ood/aegmm.py +6 -4
  10. dataeval/_internal/detectors/ood/base.py +12 -7
  11. dataeval/_internal/detectors/ood/llr.py +6 -4
  12. dataeval/_internal/detectors/ood/vae.py +5 -3
  13. dataeval/_internal/detectors/ood/vaegmm.py +6 -4
  14. dataeval/_internal/detectors/outliers.py +137 -63
  15. dataeval/_internal/interop.py +11 -7
  16. dataeval/_internal/metrics/balance.py +13 -11
  17. dataeval/_internal/metrics/ber.py +5 -3
  18. dataeval/_internal/metrics/coverage.py +4 -0
  19. dataeval/_internal/metrics/divergence.py +9 -5
  20. dataeval/_internal/metrics/diversity.py +14 -12
  21. dataeval/_internal/metrics/parity.py +32 -22
  22. dataeval/_internal/metrics/stats/base.py +231 -0
  23. dataeval/_internal/metrics/stats/boxratiostats.py +159 -0
  24. dataeval/_internal/metrics/stats/datasetstats.py +99 -0
  25. dataeval/_internal/metrics/stats/dimensionstats.py +113 -0
  26. dataeval/_internal/metrics/stats/hashstats.py +75 -0
  27. dataeval/_internal/metrics/stats/labelstats.py +125 -0
  28. dataeval/_internal/metrics/stats/pixelstats.py +119 -0
  29. dataeval/_internal/metrics/stats/visualstats.py +124 -0
  30. dataeval/_internal/metrics/uap.py +8 -4
  31. dataeval/_internal/metrics/utils.py +30 -15
  32. dataeval/_internal/models/pytorch/autoencoder.py +5 -5
  33. dataeval/_internal/models/tensorflow/pixelcnn.py +1 -4
  34. dataeval/_internal/output.py +3 -18
  35. dataeval/_internal/utils.py +11 -16
  36. dataeval/_internal/workflows/sufficiency.py +152 -151
  37. dataeval/detectors/__init__.py +4 -0
  38. dataeval/detectors/drift/__init__.py +8 -3
  39. dataeval/detectors/drift/kernels/__init__.py +4 -0
  40. dataeval/detectors/drift/updates/__init__.py +4 -0
  41. dataeval/detectors/linters/__init__.py +15 -4
  42. dataeval/detectors/ood/__init__.py +14 -2
  43. dataeval/metrics/__init__.py +5 -0
  44. dataeval/metrics/bias/__init__.py +13 -4
  45. dataeval/metrics/estimators/__init__.py +8 -8
  46. dataeval/metrics/stats/__init__.py +25 -3
  47. dataeval/utils/__init__.py +16 -3
  48. dataeval/utils/tensorflow/__init__.py +11 -0
  49. dataeval/utils/torch/__init__.py +12 -0
  50. dataeval/utils/torch/datasets/__init__.py +7 -0
  51. dataeval/workflows/__init__.py +6 -2
  52. {dataeval-0.69.4.dist-info → dataeval-0.70.1.dist-info}/METADATA +12 -4
  53. dataeval-0.70.1.dist-info/RECORD +80 -0
  54. {dataeval-0.69.4.dist-info → dataeval-0.70.1.dist-info}/WHEEL +1 -1
  55. dataeval/_internal/flags.py +0 -77
  56. dataeval/_internal/metrics/stats.py +0 -397
  57. dataeval/flags/__init__.py +0 -3
  58. dataeval/tensorflow/__init__.py +0 -3
  59. dataeval/torch/__init__.py +0 -3
  60. dataeval-0.69.4.dist-info/RECORD +0 -74
  61. /dataeval/{tensorflow → utils/tensorflow}/loss/__init__.py +0 -0
  62. /dataeval/{tensorflow → utils/tensorflow}/models/__init__.py +0 -0
  63. /dataeval/{tensorflow → utils/tensorflow}/recon/__init__.py +0 -0
  64. /dataeval/{torch → utils/torch}/models/__init__.py +0 -0
  65. /dataeval/{torch → utils/torch}/trainer/__init__.py +0 -0
  66. {dataeval-0.69.4.dist-info → dataeval-0.70.1.dist-info}/LICENSE.txt +0 -0
dataeval/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = "0.69.4"
1
+ __version__ = "0.70.1"
2
2
 
3
3
  from importlib.util import find_spec
4
4
 
@@ -7,16 +7,16 @@ _IS_TENSORFLOW_AVAILABLE = find_spec("tensorflow") is not None and find_spec("te
7
7
 
8
8
  del find_spec
9
9
 
10
- from . import detectors, flags, metrics # noqa: E402
10
+ from . import detectors, metrics # noqa: E402
11
11
 
12
- __all__ = ["detectors", "flags", "metrics"]
12
+ __all__ = ["detectors", "metrics"]
13
13
 
14
14
  if _IS_TORCH_AVAILABLE: # pragma: no cover
15
- from . import torch, utils, workflows
15
+ from . import workflows
16
16
 
17
- __all__ += ["torch", "utils", "workflows"]
17
+ __all__ += ["workflows"]
18
18
 
19
- if _IS_TENSORFLOW_AVAILABLE: # pragma: no cover
20
- from . import tensorflow
19
+ if _IS_TENSORFLOW_AVAILABLE or _IS_TORCH_AVAILABLE: # pragma: no cover
20
+ from . import utils
21
21
 
22
- __all__ += ["tensorflow"]
22
+ __all__ += ["utils"]
@@ -4,18 +4,39 @@ import hashlib
4
4
  import os
5
5
  import zipfile
6
6
  from pathlib import Path
7
- from typing import Literal
8
- from urllib.error import HTTPError, URLError
9
- from urllib.request import urlretrieve
7
+ from typing import Literal, TypeVar
8
+ from warnings import warn
10
9
 
11
10
  import numpy as np
11
+ import requests
12
12
  from numpy.typing import NDArray
13
13
  from torch.utils.data import Dataset
14
14
  from torchvision.datasets import CIFAR10, VOCDetection # noqa: F401
15
15
 
16
-
17
- def _validate_file(fpath, file_md5, chunk_size=65535):
18
- hasher = hashlib.md5()
16
+ ClassStringMap = Literal["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
17
+ TClassMap = TypeVar("TClassMap", ClassStringMap, int, list[ClassStringMap], list[int])
18
+ CorruptionStringMap = Literal[
19
+ "identity",
20
+ "shot_noise",
21
+ "impulse_noise",
22
+ "glass_blur",
23
+ "motion_blur",
24
+ "shear",
25
+ "scale",
26
+ "rotate",
27
+ "brightness",
28
+ "translate",
29
+ "stripe",
30
+ "fog",
31
+ "spatter",
32
+ "dotted_line",
33
+ "zigzag",
34
+ "canny_edges",
35
+ ]
36
+
37
+
38
+ def _validate_file(fpath, file_md5, md5=False, chunk_size=65535):
39
+ hasher = hashlib.md5() if md5 else hashlib.sha256()
19
40
  with open(fpath, "rb") as fpath_file:
20
41
  while chunk := fpath_file.read(chunk_size):
21
42
  hasher.update(chunk)
@@ -26,44 +47,74 @@ def _get_file(
26
47
  root: str | Path,
27
48
  fname: str,
28
49
  origin: str,
29
- file_md5: str | None = None,
50
+ file_hash: str | None = None,
51
+ verbose: bool = True,
52
+ md5: bool = False,
30
53
  ):
31
- fname = os.fspath(fname) if isinstance(fname, os.PathLike) else fname
32
54
  fpath = os.path.join(root, fname)
33
-
34
- download = False
35
- if os.path.exists(fpath):
36
- if file_md5 is not None and not _validate_file(fpath, file_md5):
37
- download = True
38
- else:
39
- print("Files already downloaded and verified")
40
- else:
41
- download = True
55
+ download = True
56
+ if os.path.exists(fpath) and file_hash is not None and _validate_file(fpath, file_hash, md5):
57
+ download = False
58
+ if verbose:
59
+ print("File already downloaded and verified.")
60
+ if md5:
61
+ print("Extracting zip file...")
42
62
 
43
63
  if download:
44
64
  try:
45
65
  error_msg = "URL fetch failure on {}: {} -- {}"
46
66
  try:
47
- urlretrieve(origin, fpath)
48
- except HTTPError as e:
49
- raise Exception(error_msg.format(origin, e.code, e.msg)) from e
50
- except URLError as e:
51
- raise Exception(error_msg.format(origin, e.errno, e.reason)) from e
67
+ with requests.get(origin, stream=True, timeout=60) as r:
68
+ r.raise_for_status()
69
+ with open(fpath, "wb") as f:
70
+ for chunk in r.iter_content(chunk_size=8192):
71
+ if chunk:
72
+ f.write(chunk)
73
+ except requests.exceptions.HTTPError as e:
74
+ raise Exception(f"{error_msg.format(origin, e.response.status_code)} -- {e.response.reason}") from e
75
+ except requests.exceptions.RequestException as e:
76
+ raise Exception(f"{error_msg.format(origin, 'Unknown error')} -- {str(e)}") from e
52
77
  except (Exception, KeyboardInterrupt):
53
78
  if os.path.exists(fpath):
54
79
  os.remove(fpath)
55
80
  raise
56
81
 
57
- if os.path.exists(fpath) and file_md5 is not None and not _validate_file(fpath, file_md5):
82
+ if os.path.exists(fpath) and file_hash is not None and not _validate_file(fpath, file_hash, md5):
58
83
  raise ValueError(
59
84
  "Incomplete or corrupted file detected. "
60
- f"The md5 file hash does not match the provided value "
61
- f"of {file_md5}.",
85
+ f"The file hash does not match the provided value "
86
+ f"of {file_hash}.",
62
87
  )
88
+
63
89
  return fpath
64
90
 
65
91
 
66
- def download_dataset(url: str, root: str | Path, fname: str, md5: str) -> str:
92
+ def check_exists(
93
+ folder: str | Path,
94
+ url: str,
95
+ root: str | Path,
96
+ fname: str,
97
+ file_hash: str,
98
+ download: bool = True,
99
+ verbose: bool = True,
100
+ md5: bool = False,
101
+ ):
102
+ """Determine if the dataset has already been downloaded."""
103
+ location = str(folder)
104
+ if not os.path.exists(folder):
105
+ if download:
106
+ location = download_dataset(url, root, fname, file_hash, verbose, md5)
107
+ else:
108
+ raise RuntimeError("Dataset not found. You can use download=True to download it")
109
+ else:
110
+ if verbose:
111
+ print("Files already downloaded and verified")
112
+ return location
113
+
114
+
115
+ def download_dataset(
116
+ url: str, root: str | Path, fname: str, file_hash: str, verbose: bool = True, md5: bool = False
117
+ ) -> str:
67
118
  """Code to download mnist and corruptions, originates from tensorflow_datasets (tfds):
68
119
  https://github.com/tensorflow/datasets/blob/master/tensorflow_datasets/image_classification/mnist_corrupted.py
69
120
  """
@@ -71,21 +122,24 @@ def download_dataset(url: str, root: str | Path, fname: str, md5: str) -> str:
71
122
  folder = os.path.join(root, name)
72
123
  os.makedirs(folder, exist_ok=True)
73
124
 
74
- path = _get_file(
75
- root,
125
+ fpath = _get_file(
126
+ folder,
76
127
  fname,
77
128
  origin=url + fname,
78
- file_md5=md5,
129
+ file_hash=file_hash,
130
+ verbose=verbose,
131
+ md5=md5,
79
132
  )
80
- extract_archive(path, remove_finished=True)
81
- return path
133
+ if md5:
134
+ folder = extract_archive(fpath, root, remove_finished=True)
135
+ return folder
82
136
 
83
137
 
84
138
  def extract_archive(
85
139
  from_path: str | Path,
86
140
  to_path: str | Path | None = None,
87
141
  remove_finished: bool = False,
88
- ):
142
+ ) -> str:
89
143
  """Extract an archive.
90
144
 
91
145
  The archive type and a possible compression is automatically detected from the file name.
@@ -94,8 +148,11 @@ def extract_archive(
94
148
  if not from_path.is_absolute():
95
149
  from_path = from_path.resolve()
96
150
 
97
- if to_path is None:
151
+ if to_path is None or not os.path.exists(to_path):
98
152
  to_path = os.path.dirname(from_path)
153
+ to_path = Path(to_path)
154
+ if not to_path.is_absolute():
155
+ to_path = to_path.resolve()
99
156
 
100
157
  # Extracting zip
101
158
  with zipfile.ZipFile(from_path, "r", compression=zipfile.ZIP_STORED) as zzip:
@@ -103,6 +160,13 @@ def extract_archive(
103
160
 
104
161
  if remove_finished:
105
162
  os.remove(from_path)
163
+ return str(to_path)
164
+
165
+
166
+ def subselect(arr: NDArray, count: int, from_back: bool = False):
167
+ if from_back:
168
+ return arr[-count:]
169
+ return arr[:count]
106
170
 
107
171
 
108
172
  class MNIST(Dataset):
@@ -133,40 +197,42 @@ class MNIST(Dataset):
133
197
  'motion_blur' | 'shear' | 'scale' | 'rotate' | 'brightness' | 'translate' | 'stripe' |
134
198
  'fog' | 'spatter' | 'dotted_line' | 'zigzag' | 'canny_edges'] | None, default None
135
199
  The desired corruption style or None.
200
+ classes : Literal["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
201
+ | int | list[int] | list[Literal["zero", "one", "two", "three", "four", "five", "six", "seven",
202
+ "eight", "nine"]] | None, default None
203
+ Option to select specific classes from dataset.
204
+ balance : bool, default True
205
+ If True, returns equal number of samples for each class.
206
+ randomize : bool, default False
207
+ If True, shuffles the data prior to selection - uses a set seed for reproducibility.
208
+ slice_back : bool, default False
209
+ If True and size has a value greater than 0, then grabs selection starting at the last image.
210
+ verbose : bool, default True
211
+ If True, outputs print statements.
136
212
  """
137
213
 
138
- mirror = "https://zenodo.org/record/3239543/files/"
139
-
140
- resources = ("mnist_c.zip", "4b34b33045869ee6d424616cd3a65da3")
141
-
142
- classes = [
143
- "0 - zero",
144
- "1 - one",
145
- "2 - two",
146
- "3 - three",
147
- "4 - four",
148
- "5 - five",
149
- "6 - six",
150
- "7 - seven",
151
- "8 - eight",
152
- "9 - nine",
214
+ mirror = [
215
+ "https://storage.googleapis.com/tensorflow/tf-keras-datasets/",
216
+ "https://zenodo.org/record/3239543/files/",
153
217
  ]
154
218
 
155
- @property
156
- def train_labels(self):
157
- return self.targets
158
-
159
- @property
160
- def test_labels(self):
161
- return self.targets
162
-
163
- @property
164
- def train_data(self):
165
- return self.data
219
+ resources = [
220
+ ("mnist.npz", "731c5ac602752760c8e48fbffcf8c3b850d9dc2a2aedcf2cc48468fc17b673d1"),
221
+ ("mnist_c.zip", "4b34b33045869ee6d424616cd3a65da3"),
222
+ ]
166
223
 
167
- @property
168
- def test_data(self):
169
- return self.data
224
+ class_dict = {
225
+ "zero": 0,
226
+ "one": 1,
227
+ "two": 2,
228
+ "three": 3,
229
+ "four": 4,
230
+ "five": 5,
231
+ "six": 6,
232
+ "seven": 7,
233
+ "eight": 8,
234
+ "nine": 9,
235
+ }
170
236
 
171
237
  def __init__(
172
238
  self,
@@ -179,25 +245,12 @@ class MNIST(Dataset):
179
245
  channels: Literal["channels_first", "channels_last"] | None = None,
180
246
  flatten: bool = False,
181
247
  normalize: tuple[float, float] | None = None,
182
- corruption: Literal[
183
- "identity",
184
- "shot_noise",
185
- "impulse_noise",
186
- "glass_blur",
187
- "motion_blur",
188
- "shear",
189
- "scale",
190
- "rotate",
191
- "brightness",
192
- "translate",
193
- "stripe",
194
- "fog",
195
- "spatter",
196
- "dotted_line",
197
- "zigzag",
198
- "canny_edges",
199
- ]
200
- | None = None,
248
+ corruption: CorruptionStringMap | None = None,
249
+ classes: TClassMap | None = None,
250
+ balance: bool = True,
251
+ randomize: bool = False,
252
+ slice_back: bool = False,
253
+ verbose: bool = True,
201
254
  ) -> None:
202
255
  if isinstance(root, str):
203
256
  root = os.path.expanduser(root)
@@ -209,64 +262,113 @@ class MNIST(Dataset):
209
262
  self.channels = channels
210
263
  self.flatten = flatten
211
264
  self.normalize = normalize
212
-
213
- if corruption is None:
214
- corruption = "identity"
215
- elif corruption == "identity":
216
- print("Identity is not a corrupted dataset but the original MNIST dataset")
217
265
  self.corruption = corruption
218
-
219
- if os.path.exists(self.mnist_folder):
220
- print("Files already downloaded and verified")
221
- elif download:
222
- download_dataset(self.mirror, self.root, self.resources[0], self.resources[1])
266
+ self.balance = balance
267
+ self.randomize = randomize
268
+ self.from_back = slice_back
269
+ self.verbose = verbose
270
+
271
+ self.class_set = []
272
+ if classes is not None:
273
+ if not isinstance(classes, list):
274
+ classes = [classes] # type: ignore
275
+
276
+ for val in classes: # type: ignore
277
+ if isinstance(val, int) and 0 <= val < 10:
278
+ self.class_set.append(val)
279
+ elif isinstance(val, str):
280
+ self.class_set.append(self.class_dict[val])
281
+ self.class_set = set(self.class_set)
282
+
283
+ if not self.class_set:
284
+ self.class_set = set(self.class_dict.values())
285
+
286
+ self.num_classes = len(self.class_set)
287
+
288
+ if self.corruption is None:
289
+ file_resource = self.resources[0]
290
+ mirror = self.mirror[0]
291
+ md5 = False
223
292
  else:
224
- raise RuntimeError("Dataset not found. You can use download=True to download it")
293
+ if self.corruption == "identity" and verbose:
294
+ print("Identity is not a corrupted dataset but the original MNIST dataset.")
295
+ file_resource = self.resources[1]
296
+ mirror = self.mirror[1]
297
+ md5 = True
298
+ check_exists(self.mnist_folder, mirror, self.root, file_resource[0], file_resource[1], download, verbose, md5)
225
299
 
226
300
  self.data, self.targets = self._load_data()
227
301
 
302
+ self._augmentations()
303
+
228
304
  def _load_data(self):
229
- image_file = f"{'train' if self.train else 'test'}_images.npy"
230
- data = self._read_image_file(os.path.join(self.mnist_folder, image_file))
231
-
232
- label_file = f"{'train' if self.train else 'test'}_labels.npy"
233
- targets = self._read_label_file(os.path.join(self.mnist_folder, label_file))
234
-
235
- if self.size >= 1 and self.size >= len(self.classes):
236
- final_data = []
237
- final_targets = []
238
- for label in range(len(self.classes)):
239
- indices = np.where(targets == label)[0]
240
- selected_indices = indices[: int(self.size / len(self.classes))]
241
- final_data.append(data[selected_indices])
242
- final_targets.append(targets[selected_indices])
243
- data = np.concatenate(final_data)
244
- targets = np.concatenate(final_targets)
245
- shuffled_indices = np.random.permutation(data.shape[0])
246
- data = data[shuffled_indices]
247
- targets = targets[shuffled_indices]
248
- elif self.size >= 1:
249
- data = data[: self.size]
250
- targets = targets[: self.size]
305
+ if self.corruption is None:
306
+ image_file = self.resources[0][0]
307
+ data, targets = self._read_normal_file(os.path.join(self.mnist_folder, image_file))
308
+ else:
309
+ image_file = f"{'train' if self.train else 'test'}_images.npy"
310
+ data = self._read_corrupt_file(os.path.join(self.mnist_folder, image_file))
311
+ data = data.squeeze()
312
+
313
+ label_file = f"{'train' if self.train else 'test'}_labels.npy"
314
+ targets = self._read_corrupt_file(os.path.join(self.mnist_folder, label_file))
315
+
316
+ return data, targets
317
+
318
+ def _augmentations(self):
319
+ if self.size > self.targets.shape[0] and self.verbose:
320
+ warn(
321
+ f"Asked for more samples, {self.size}, than the raw dataset contains, {self.targets.shape[0]}. "
322
+ "Adjusting down to raw dataset size."
323
+ )
324
+ self.size = -1
325
+
326
+ if self.randomize:
327
+ rdm_seed = np.random.default_rng(2023)
328
+ shuffled_indices = rdm_seed.permutation(self.data.shape[0])
329
+ self.data = self.data[shuffled_indices]
330
+ self.targets = self.targets[shuffled_indices]
331
+
332
+ if not self.balance and self.num_classes > self.size:
333
+ if self.size > 0:
334
+ self.data = subselect(self.data, self.size, self.from_back)
335
+ self.targets = subselect(self.targets, self.size, self.from_back)
336
+ else:
337
+ label_dict = {label: np.where(self.targets == label)[0] for label in self.class_set}
338
+ min_label_count = min(len(indices) for indices in label_dict.values())
339
+
340
+ self.per_class_count = int(np.ceil(self.size / self.num_classes)) if self.size > 0 else min_label_count
341
+
342
+ if self.per_class_count > min_label_count:
343
+ self.per_class_count = min_label_count
344
+ if not self.balance and self.verbose:
345
+ warn(
346
+ f"Because of dataset limitations, only {min_label_count*self.num_classes} samples "
347
+ f"will be returned, instead of the desired {self.size}."
348
+ )
349
+
350
+ all_indices = np.empty(shape=(self.num_classes, self.per_class_count), dtype=int)
351
+ for i, label in enumerate(self.class_set):
352
+ all_indices[i] = subselect(label_dict[label], self.per_class_count, self.from_back)
353
+ self.data = np.vstack(self.data[all_indices.T]) # type: ignore
354
+ self.targets = np.hstack(self.targets[all_indices.T]) # type: ignore
251
355
 
252
356
  if self.unit_interval:
253
- data = data / 255
357
+ self.data = self.data / 255
254
358
 
255
359
  if self.normalize:
256
- data = (data - self.normalize[0]) / self.normalize[1]
360
+ self.data = (self.data - self.normalize[0]) / self.normalize[1]
257
361
 
258
362
  if self.dtype:
259
- data = data.astype(self.dtype)
363
+ self.data = self.data.astype(self.dtype)
260
364
 
261
365
  if self.channels == "channels_first":
262
- data = np.moveaxis(data, -1, 1)
263
- elif self.channels is None:
264
- data = data[:, :, :, 0]
366
+ self.data = self.data[:, np.newaxis, :, :]
367
+ elif self.channels == "channels_last":
368
+ self.data = self.data[:, :, :, np.newaxis]
265
369
 
266
370
  if self.flatten and self.channels is None:
267
- data = data.reshape(data.shape[0], -1)
268
-
269
- return data, targets
371
+ self.data = self.data.reshape(self.data.shape[0], -1)
270
372
 
271
373
  def __getitem__(self, index: int) -> tuple[NDArray, int]:
272
374
  """
@@ -285,16 +387,18 @@ class MNIST(Dataset):
285
387
 
286
388
  @property
287
389
  def mnist_folder(self) -> str:
390
+ if self.corruption is None:
391
+ return os.path.join(self.root, "mnist")
288
392
  return os.path.join(self.root, "mnist_c", self.corruption)
289
393
 
290
- @property
291
- def class_to_idx(self) -> dict[str, int]:
292
- return {_class: i for i, _class in enumerate(self.classes)}
293
-
294
- def _read_label_file(self, path: str) -> NDArray:
295
- x = np.load(path, allow_pickle=False)
296
- return x
394
+ def _read_normal_file(self, path: str) -> tuple[NDArray, NDArray]:
395
+ with np.load(path, allow_pickle=True) as f:
396
+ if self.train:
397
+ x, y = f["x_train"], f["y_train"]
398
+ else:
399
+ x, y = f["x_test"], f["y_test"]
400
+ return x, y
297
401
 
298
- def _read_image_file(self, path: str) -> NDArray:
402
+ def _read_corrupt_file(self, path: str) -> NDArray:
299
403
  x = np.load(path, allow_pickle=False)
300
404
  return x
@@ -16,6 +16,8 @@ from dataeval._internal.output import OutputMetadata, set_metadata
16
16
  @dataclass(frozen=True)
17
17
  class ClustererOutput(OutputMetadata):
18
18
  """
19
+ Output class for :class:`Clusterer` lint detector
20
+
19
21
  Attributes
20
22
  ----------
21
23
  outliers : List[int]
@@ -16,14 +16,14 @@ from typing import Callable, Literal
16
16
  import numpy as np
17
17
  from numpy.typing import ArrayLike, NDArray
18
18
 
19
- from dataeval._internal.interop import to_numpy
19
+ from dataeval._internal.interop import as_numpy, to_numpy
20
20
  from dataeval._internal.output import OutputMetadata, set_metadata
21
21
 
22
22
 
23
23
  @dataclass(frozen=True)
24
24
  class DriftBaseOutput(OutputMetadata):
25
25
  """
26
- Output class for Drift
26
+ Base output class for Drift detector classes
27
27
 
28
28
  Attributes
29
29
  ----------
@@ -42,7 +42,7 @@ class DriftBaseOutput(OutputMetadata):
42
42
  @dataclass(frozen=True)
43
43
  class DriftOutput(DriftBaseOutput):
44
44
  """
45
- Output class for DriftCVM and DriftKS
45
+ Output class for :class:`DriftCVM`, :class:`DriftKS`, and :class:`DriftUncertainty` drift detectors
46
46
 
47
47
  Attributes
48
48
  ----------
@@ -234,7 +234,7 @@ class BaseDrift:
234
234
  if correction not in ["bonferroni", "fdr"]:
235
235
  raise ValueError("`correction` must be `bonferroni` or `fdr`.")
236
236
 
237
- self._x_ref = x_ref
237
+ self._x_ref = to_numpy(x_ref)
238
238
  self.x_ref_preprocessed = x_ref_preprocessed
239
239
 
240
240
  # Other attributes
@@ -242,7 +242,7 @@ class BaseDrift:
242
242
  self.update_x_ref = update_x_ref
243
243
  self.preprocess_fn = preprocess_fn
244
244
  self.correction = correction
245
- self.n = len(self._x_ref) # type: ignore
245
+ self.n = len(self._x_ref)
246
246
 
247
247
  # Ref counter for preprocessed x
248
248
  self._x_refcount = 0
@@ -260,9 +260,8 @@ class BaseDrift:
260
260
  if not self.x_ref_preprocessed:
261
261
  self.x_ref_preprocessed = True
262
262
  if self.preprocess_fn is not None:
263
- self._x_ref = self.preprocess_fn(self._x_ref)
263
+ self._x_ref = as_numpy(self.preprocess_fn(self._x_ref))
264
264
 
265
- self._x_ref = to_numpy(self._x_ref)
266
265
  return self._x_ref
267
266
 
268
267
  def _preprocess(self, x: ArrayLike) -> ArrayLike:
@@ -380,7 +379,7 @@ class BaseDriftUnivariate(BaseDrift):
380
379
  self._n_features = self.x_ref.reshape(self.x_ref.shape[0], -1).shape[-1]
381
380
  else:
382
381
  # infer number of features after applying preprocessing step
383
- x = to_numpy(self.preprocess_fn(self._x_ref[0:1])) # type: ignore
382
+ x = as_numpy(self.preprocess_fn(self._x_ref[0:1])) # type: ignore
384
383
  self._n_features = x.reshape(x.shape[0], -1).shape[-1]
385
384
 
386
385
  return self._n_features
@@ -14,7 +14,7 @@ from typing import Callable
14
14
  import torch
15
15
  from numpy.typing import ArrayLike
16
16
 
17
- from dataeval._internal.interop import to_numpy
17
+ from dataeval._internal.interop import as_numpy
18
18
  from dataeval._internal.output import set_metadata
19
19
 
20
20
  from .base import BaseDrift, DriftBaseOutput, UpdateStrategy, preprocess_x, update_x_ref
@@ -24,7 +24,7 @@ from .torch import GaussianRBF, get_device, mmd2_from_kernel_matrix
24
24
  @dataclass(frozen=True)
25
25
  class DriftMMDOutput(DriftBaseOutput):
26
26
  """
27
- Output class for DriftMMD
27
+ Output class for :class:`DriftMMD` drift detector
28
28
 
29
29
  Attributes
30
30
  ----------
@@ -110,7 +110,7 @@ class DriftMMD(BaseDrift):
110
110
  self.device = get_device(device)
111
111
 
112
112
  # initialize kernel
113
- sigma_tensor = torch.from_numpy(to_numpy(sigma)).to(self.device) if sigma is not None else None
113
+ sigma_tensor = torch.from_numpy(as_numpy(sigma)).to(self.device) if sigma is not None else None
114
114
  self.kernel = kernel(sigma_tensor).to(self.device) if kernel == GaussianRBF else kernel
115
115
 
116
116
  # compute kernel matrix for the reference data
@@ -147,7 +147,7 @@ class DriftMMD(BaseDrift):
147
147
  p-value obtained from the permutation test, MMD^2 between the reference and test set,
148
148
  and MMD^2 threshold above which drift is flagged
149
149
  """
150
- x = to_numpy(x)
150
+ x = as_numpy(x)
151
151
  x_ref = torch.from_numpy(self.x_ref).to(self.device)
152
152
  n = x.shape[0]
153
153
  kernel_mat = self._kernel_matrix(x_ref, torch.from_numpy(x).to(self.device))