dataeval 0.70.0__py3-none-any.whl → 0.71.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. dataeval/__init__.py +6 -6
  2. dataeval/_internal/datasets.py +235 -131
  3. dataeval/_internal/detectors/clusterer.py +2 -0
  4. dataeval/_internal/detectors/drift/base.py +2 -2
  5. dataeval/_internal/detectors/drift/mmd.py +1 -1
  6. dataeval/_internal/detectors/duplicates.py +2 -0
  7. dataeval/_internal/detectors/ood/ae.py +5 -3
  8. dataeval/_internal/detectors/ood/aegmm.py +6 -4
  9. dataeval/_internal/detectors/ood/base.py +12 -7
  10. dataeval/_internal/detectors/ood/llr.py +6 -4
  11. dataeval/_internal/detectors/ood/vae.py +5 -3
  12. dataeval/_internal/detectors/ood/vaegmm.py +6 -4
  13. dataeval/_internal/detectors/outliers.py +6 -9
  14. dataeval/_internal/metrics/balance.py +4 -2
  15. dataeval/_internal/metrics/ber.py +2 -0
  16. dataeval/_internal/metrics/coverage.py +4 -0
  17. dataeval/_internal/metrics/divergence.py +6 -2
  18. dataeval/_internal/metrics/diversity.py +8 -6
  19. dataeval/_internal/metrics/parity.py +8 -6
  20. dataeval/_internal/metrics/stats/base.py +105 -46
  21. dataeval/_internal/metrics/stats/datasetstats.py +96 -22
  22. dataeval/_internal/metrics/stats/dimensionstats.py +22 -20
  23. dataeval/_internal/metrics/stats/hashstats.py +11 -9
  24. dataeval/_internal/metrics/stats/labelstats.py +1 -1
  25. dataeval/_internal/metrics/stats/pixelstats.py +28 -26
  26. dataeval/_internal/metrics/stats/visualstats.py +37 -35
  27. dataeval/_internal/metrics/uap.py +6 -2
  28. dataeval/_internal/metrics/utils.py +2 -2
  29. dataeval/_internal/models/pytorch/autoencoder.py +5 -5
  30. dataeval/_internal/models/tensorflow/pixelcnn.py +1 -4
  31. dataeval/_internal/utils.py +11 -16
  32. dataeval/_internal/workflows/sufficiency.py +44 -33
  33. dataeval/detectors/__init__.py +4 -0
  34. dataeval/detectors/drift/__init__.py +8 -3
  35. dataeval/detectors/drift/kernels/__init__.py +4 -0
  36. dataeval/detectors/drift/updates/__init__.py +4 -0
  37. dataeval/detectors/linters/__init__.py +15 -4
  38. dataeval/detectors/ood/__init__.py +14 -2
  39. dataeval/metrics/__init__.py +5 -0
  40. dataeval/metrics/bias/__init__.py +13 -4
  41. dataeval/metrics/estimators/__init__.py +8 -8
  42. dataeval/metrics/stats/__init__.py +24 -6
  43. dataeval/utils/__init__.py +16 -3
  44. dataeval/utils/tensorflow/__init__.py +11 -0
  45. dataeval/utils/torch/__init__.py +12 -0
  46. dataeval/utils/torch/datasets/__init__.py +7 -0
  47. dataeval/workflows/__init__.py +4 -0
  48. {dataeval-0.70.0.dist-info → dataeval-0.71.0.dist-info}/METADATA +11 -2
  49. dataeval-0.71.0.dist-info/RECORD +80 -0
  50. dataeval/tensorflow/__init__.py +0 -3
  51. dataeval/torch/__init__.py +0 -3
  52. dataeval-0.70.0.dist-info/RECORD +0 -79
  53. /dataeval/{tensorflow → utils/tensorflow}/loss/__init__.py +0 -0
  54. /dataeval/{tensorflow → utils/tensorflow}/models/__init__.py +0 -0
  55. /dataeval/{tensorflow → utils/tensorflow}/recon/__init__.py +0 -0
  56. /dataeval/{torch → utils/torch}/models/__init__.py +0 -0
  57. /dataeval/{torch → utils/torch}/trainer/__init__.py +0 -0
  58. {dataeval-0.70.0.dist-info → dataeval-0.71.0.dist-info}/LICENSE.txt +0 -0
  59. {dataeval-0.70.0.dist-info → dataeval-0.71.0.dist-info}/WHEEL +0 -0
dataeval/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = "0.70.0"
1
+ __version__ = "0.71.0"
2
2
 
3
3
  from importlib.util import find_spec
4
4
 
@@ -12,11 +12,11 @@ from . import detectors, metrics # noqa: E402
12
12
  __all__ = ["detectors", "metrics"]
13
13
 
14
14
  if _IS_TORCH_AVAILABLE: # pragma: no cover
15
- from . import torch, utils, workflows
15
+ from . import workflows
16
16
 
17
- __all__ += ["torch", "utils", "workflows"]
17
+ __all__ += ["workflows"]
18
18
 
19
- if _IS_TENSORFLOW_AVAILABLE: # pragma: no cover
20
- from . import tensorflow
19
+ if _IS_TENSORFLOW_AVAILABLE or _IS_TORCH_AVAILABLE: # pragma: no cover
20
+ from . import utils
21
21
 
22
- __all__ += ["tensorflow"]
22
+ __all__ += ["utils"]
@@ -4,18 +4,39 @@ import hashlib
4
4
  import os
5
5
  import zipfile
6
6
  from pathlib import Path
7
- from typing import Literal
8
- from urllib.error import HTTPError, URLError
9
- from urllib.request import urlretrieve
7
+ from typing import Literal, TypeVar
8
+ from warnings import warn
10
9
 
11
10
  import numpy as np
11
+ import requests
12
12
  from numpy.typing import NDArray
13
13
  from torch.utils.data import Dataset
14
14
  from torchvision.datasets import CIFAR10, VOCDetection # noqa: F401
15
15
 
16
-
17
- def _validate_file(fpath, file_md5, chunk_size=65535):
18
- hasher = hashlib.md5()
16
+ ClassStringMap = Literal["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
17
+ TClassMap = TypeVar("TClassMap", ClassStringMap, int, list[ClassStringMap], list[int])
18
+ CorruptionStringMap = Literal[
19
+ "identity",
20
+ "shot_noise",
21
+ "impulse_noise",
22
+ "glass_blur",
23
+ "motion_blur",
24
+ "shear",
25
+ "scale",
26
+ "rotate",
27
+ "brightness",
28
+ "translate",
29
+ "stripe",
30
+ "fog",
31
+ "spatter",
32
+ "dotted_line",
33
+ "zigzag",
34
+ "canny_edges",
35
+ ]
36
+
37
+
38
+ def _validate_file(fpath, file_md5, md5=False, chunk_size=65535):
39
+ hasher = hashlib.md5() if md5 else hashlib.sha256()
19
40
  with open(fpath, "rb") as fpath_file:
20
41
  while chunk := fpath_file.read(chunk_size):
21
42
  hasher.update(chunk)
@@ -26,44 +47,74 @@ def _get_file(
26
47
  root: str | Path,
27
48
  fname: str,
28
49
  origin: str,
29
- file_md5: str | None = None,
50
+ file_hash: str | None = None,
51
+ verbose: bool = True,
52
+ md5: bool = False,
30
53
  ):
31
- fname = os.fspath(fname) if isinstance(fname, os.PathLike) else fname
32
54
  fpath = os.path.join(root, fname)
33
-
34
- download = False
35
- if os.path.exists(fpath):
36
- if file_md5 is not None and not _validate_file(fpath, file_md5):
37
- download = True
38
- else:
39
- print("Files already downloaded and verified")
40
- else:
41
- download = True
55
+ download = True
56
+ if os.path.exists(fpath) and file_hash is not None and _validate_file(fpath, file_hash, md5):
57
+ download = False
58
+ if verbose:
59
+ print("File already downloaded and verified.")
60
+ if md5:
61
+ print("Extracting zip file...")
42
62
 
43
63
  if download:
44
64
  try:
45
65
  error_msg = "URL fetch failure on {}: {} -- {}"
46
66
  try:
47
- urlretrieve(origin, fpath)
48
- except HTTPError as e:
49
- raise Exception(error_msg.format(origin, e.code, e.msg)) from e
50
- except URLError as e:
51
- raise Exception(error_msg.format(origin, e.errno, e.reason)) from e
67
+ with requests.get(origin, stream=True, timeout=60) as r:
68
+ r.raise_for_status()
69
+ with open(fpath, "wb") as f:
70
+ for chunk in r.iter_content(chunk_size=8192):
71
+ if chunk:
72
+ f.write(chunk)
73
+ except requests.exceptions.HTTPError as e:
74
+ raise Exception(f"{error_msg.format(origin, e.response.status_code)} -- {e.response.reason}") from e
75
+ except requests.exceptions.RequestException as e:
76
+ raise Exception(f"{error_msg.format(origin, 'Unknown error')} -- {str(e)}") from e
52
77
  except (Exception, KeyboardInterrupt):
53
78
  if os.path.exists(fpath):
54
79
  os.remove(fpath)
55
80
  raise
56
81
 
57
- if os.path.exists(fpath) and file_md5 is not None and not _validate_file(fpath, file_md5):
82
+ if os.path.exists(fpath) and file_hash is not None and not _validate_file(fpath, file_hash, md5):
58
83
  raise ValueError(
59
84
  "Incomplete or corrupted file detected. "
60
- f"The md5 file hash does not match the provided value "
61
- f"of {file_md5}.",
85
+ f"The file hash does not match the provided value "
86
+ f"of {file_hash}.",
62
87
  )
88
+
63
89
  return fpath
64
90
 
65
91
 
66
- def download_dataset(url: str, root: str | Path, fname: str, md5: str) -> str:
92
+ def check_exists(
93
+ folder: str | Path,
94
+ url: str,
95
+ root: str | Path,
96
+ fname: str,
97
+ file_hash: str,
98
+ download: bool = True,
99
+ verbose: bool = True,
100
+ md5: bool = False,
101
+ ):
102
+ """Determine if the dataset has already been downloaded."""
103
+ location = str(folder)
104
+ if not os.path.exists(folder):
105
+ if download:
106
+ location = download_dataset(url, root, fname, file_hash, verbose, md5)
107
+ else:
108
+ raise RuntimeError("Dataset not found. You can use download=True to download it")
109
+ else:
110
+ if verbose:
111
+ print("Files already downloaded and verified")
112
+ return location
113
+
114
+
115
+ def download_dataset(
116
+ url: str, root: str | Path, fname: str, file_hash: str, verbose: bool = True, md5: bool = False
117
+ ) -> str:
67
118
  """Code to download mnist and corruptions, originates from tensorflow_datasets (tfds):
68
119
  https://github.com/tensorflow/datasets/blob/master/tensorflow_datasets/image_classification/mnist_corrupted.py
69
120
  """
@@ -71,21 +122,24 @@ def download_dataset(url: str, root: str | Path, fname: str, md5: str) -> str:
71
122
  folder = os.path.join(root, name)
72
123
  os.makedirs(folder, exist_ok=True)
73
124
 
74
- path = _get_file(
75
- root,
125
+ fpath = _get_file(
126
+ folder,
76
127
  fname,
77
128
  origin=url + fname,
78
- file_md5=md5,
129
+ file_hash=file_hash,
130
+ verbose=verbose,
131
+ md5=md5,
79
132
  )
80
- extract_archive(path, remove_finished=True)
81
- return path
133
+ if md5:
134
+ folder = extract_archive(fpath, root, remove_finished=True)
135
+ return folder
82
136
 
83
137
 
84
138
  def extract_archive(
85
139
  from_path: str | Path,
86
140
  to_path: str | Path | None = None,
87
141
  remove_finished: bool = False,
88
- ):
142
+ ) -> str:
89
143
  """Extract an archive.
90
144
 
91
145
  The archive type and a possible compression is automatically detected from the file name.
@@ -94,8 +148,11 @@ def extract_archive(
94
148
  if not from_path.is_absolute():
95
149
  from_path = from_path.resolve()
96
150
 
97
- if to_path is None:
151
+ if to_path is None or not os.path.exists(to_path):
98
152
  to_path = os.path.dirname(from_path)
153
+ to_path = Path(to_path)
154
+ if not to_path.is_absolute():
155
+ to_path = to_path.resolve()
99
156
 
100
157
  # Extracting zip
101
158
  with zipfile.ZipFile(from_path, "r", compression=zipfile.ZIP_STORED) as zzip:
@@ -103,6 +160,13 @@ def extract_archive(
103
160
 
104
161
  if remove_finished:
105
162
  os.remove(from_path)
163
+ return str(to_path)
164
+
165
+
166
+ def subselect(arr: NDArray, count: int, from_back: bool = False):
167
+ if from_back:
168
+ return arr[-count:]
169
+ return arr[:count]
106
170
 
107
171
 
108
172
  class MNIST(Dataset):
@@ -133,40 +197,42 @@ class MNIST(Dataset):
133
197
  'motion_blur' | 'shear' | 'scale' | 'rotate' | 'brightness' | 'translate' | 'stripe' |
134
198
  'fog' | 'spatter' | 'dotted_line' | 'zigzag' | 'canny_edges'] | None, default None
135
199
  The desired corruption style or None.
200
+ classes : Literal["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
201
+ | int | list[int] | list[Literal["zero", "one", "two", "three", "four", "five", "six", "seven",
202
+ "eight", "nine"]] | None, default None
203
+ Option to select specific classes from dataset.
204
+ balance : bool, default True
205
+ If True, returns equal number of samples for each class.
206
+ randomize : bool, default False
207
+ If True, shuffles the data prior to selection - uses a set seed for reproducibility.
208
+ slice_back : bool, default False
209
+ If True and size has a value greater than 0, then grabs selection starting at the last image.
210
+ verbose : bool, default True
211
+ If True, outputs print statements.
136
212
  """
137
213
 
138
- mirror = "https://zenodo.org/record/3239543/files/"
139
-
140
- resources = ("mnist_c.zip", "4b34b33045869ee6d424616cd3a65da3")
141
-
142
- classes = [
143
- "0 - zero",
144
- "1 - one",
145
- "2 - two",
146
- "3 - three",
147
- "4 - four",
148
- "5 - five",
149
- "6 - six",
150
- "7 - seven",
151
- "8 - eight",
152
- "9 - nine",
214
+ mirror = [
215
+ "https://storage.googleapis.com/tensorflow/tf-keras-datasets/",
216
+ "https://zenodo.org/record/3239543/files/",
153
217
  ]
154
218
 
155
- @property
156
- def train_labels(self):
157
- return self.targets
158
-
159
- @property
160
- def test_labels(self):
161
- return self.targets
162
-
163
- @property
164
- def train_data(self):
165
- return self.data
219
+ resources = [
220
+ ("mnist.npz", "731c5ac602752760c8e48fbffcf8c3b850d9dc2a2aedcf2cc48468fc17b673d1"),
221
+ ("mnist_c.zip", "4b34b33045869ee6d424616cd3a65da3"),
222
+ ]
166
223
 
167
- @property
168
- def test_data(self):
169
- return self.data
224
+ class_dict = {
225
+ "zero": 0,
226
+ "one": 1,
227
+ "two": 2,
228
+ "three": 3,
229
+ "four": 4,
230
+ "five": 5,
231
+ "six": 6,
232
+ "seven": 7,
233
+ "eight": 8,
234
+ "nine": 9,
235
+ }
170
236
 
171
237
  def __init__(
172
238
  self,
@@ -179,25 +245,12 @@ class MNIST(Dataset):
179
245
  channels: Literal["channels_first", "channels_last"] | None = None,
180
246
  flatten: bool = False,
181
247
  normalize: tuple[float, float] | None = None,
182
- corruption: Literal[
183
- "identity",
184
- "shot_noise",
185
- "impulse_noise",
186
- "glass_blur",
187
- "motion_blur",
188
- "shear",
189
- "scale",
190
- "rotate",
191
- "brightness",
192
- "translate",
193
- "stripe",
194
- "fog",
195
- "spatter",
196
- "dotted_line",
197
- "zigzag",
198
- "canny_edges",
199
- ]
200
- | None = None,
248
+ corruption: CorruptionStringMap | None = None,
249
+ classes: TClassMap | None = None,
250
+ balance: bool = True,
251
+ randomize: bool = False,
252
+ slice_back: bool = False,
253
+ verbose: bool = True,
201
254
  ) -> None:
202
255
  if isinstance(root, str):
203
256
  root = os.path.expanduser(root)
@@ -209,64 +262,113 @@ class MNIST(Dataset):
209
262
  self.channels = channels
210
263
  self.flatten = flatten
211
264
  self.normalize = normalize
212
-
213
- if corruption is None:
214
- corruption = "identity"
215
- elif corruption == "identity":
216
- print("Identity is not a corrupted dataset but the original MNIST dataset")
217
265
  self.corruption = corruption
218
-
219
- if os.path.exists(self.mnist_folder):
220
- print("Files already downloaded and verified")
221
- elif download:
222
- download_dataset(self.mirror, self.root, self.resources[0], self.resources[1])
266
+ self.balance = balance
267
+ self.randomize = randomize
268
+ self.from_back = slice_back
269
+ self.verbose = verbose
270
+
271
+ self.class_set = []
272
+ if classes is not None:
273
+ if not isinstance(classes, list):
274
+ classes = [classes] # type: ignore
275
+
276
+ for val in classes: # type: ignore
277
+ if isinstance(val, int) and 0 <= val < 10:
278
+ self.class_set.append(val)
279
+ elif isinstance(val, str):
280
+ self.class_set.append(self.class_dict[val])
281
+ self.class_set = set(self.class_set)
282
+
283
+ if not self.class_set:
284
+ self.class_set = set(self.class_dict.values())
285
+
286
+ self.num_classes = len(self.class_set)
287
+
288
+ if self.corruption is None:
289
+ file_resource = self.resources[0]
290
+ mirror = self.mirror[0]
291
+ md5 = False
223
292
  else:
224
- raise RuntimeError("Dataset not found. You can use download=True to download it")
293
+ if self.corruption == "identity" and verbose:
294
+ print("Identity is not a corrupted dataset but the original MNIST dataset.")
295
+ file_resource = self.resources[1]
296
+ mirror = self.mirror[1]
297
+ md5 = True
298
+ check_exists(self.mnist_folder, mirror, self.root, file_resource[0], file_resource[1], download, verbose, md5)
225
299
 
226
300
  self.data, self.targets = self._load_data()
227
301
 
302
+ self._augmentations()
303
+
228
304
  def _load_data(self):
229
- image_file = f"{'train' if self.train else 'test'}_images.npy"
230
- data = self._read_image_file(os.path.join(self.mnist_folder, image_file))
231
-
232
- label_file = f"{'train' if self.train else 'test'}_labels.npy"
233
- targets = self._read_label_file(os.path.join(self.mnist_folder, label_file))
234
-
235
- if self.size >= 1 and self.size >= len(self.classes):
236
- final_data = []
237
- final_targets = []
238
- for label in range(len(self.classes)):
239
- indices = np.where(targets == label)[0]
240
- selected_indices = indices[: int(self.size / len(self.classes))]
241
- final_data.append(data[selected_indices])
242
- final_targets.append(targets[selected_indices])
243
- data = np.concatenate(final_data)
244
- targets = np.concatenate(final_targets)
245
- shuffled_indices = np.random.permutation(data.shape[0])
246
- data = data[shuffled_indices]
247
- targets = targets[shuffled_indices]
248
- elif self.size >= 1:
249
- data = data[: self.size]
250
- targets = targets[: self.size]
305
+ if self.corruption is None:
306
+ image_file = self.resources[0][0]
307
+ data, targets = self._read_normal_file(os.path.join(self.mnist_folder, image_file))
308
+ else:
309
+ image_file = f"{'train' if self.train else 'test'}_images.npy"
310
+ data = self._read_corrupt_file(os.path.join(self.mnist_folder, image_file))
311
+ data = data.squeeze()
312
+
313
+ label_file = f"{'train' if self.train else 'test'}_labels.npy"
314
+ targets = self._read_corrupt_file(os.path.join(self.mnist_folder, label_file))
315
+
316
+ return data, targets
317
+
318
+ def _augmentations(self):
319
+ if self.size > self.targets.shape[0] and self.verbose:
320
+ warn(
321
+ f"Asked for more samples, {self.size}, than the raw dataset contains, {self.targets.shape[0]}. "
322
+ "Adjusting down to raw dataset size."
323
+ )
324
+ self.size = -1
325
+
326
+ if self.randomize:
327
+ rdm_seed = np.random.default_rng(2023)
328
+ shuffled_indices = rdm_seed.permutation(self.data.shape[0])
329
+ self.data = self.data[shuffled_indices]
330
+ self.targets = self.targets[shuffled_indices]
331
+
332
+ if not self.balance and self.num_classes > self.size:
333
+ if self.size > 0:
334
+ self.data = subselect(self.data, self.size, self.from_back)
335
+ self.targets = subselect(self.targets, self.size, self.from_back)
336
+ else:
337
+ label_dict = {label: np.where(self.targets == label)[0] for label in self.class_set}
338
+ min_label_count = min(len(indices) for indices in label_dict.values())
339
+
340
+ self.per_class_count = int(np.ceil(self.size / self.num_classes)) if self.size > 0 else min_label_count
341
+
342
+ if self.per_class_count > min_label_count:
343
+ self.per_class_count = min_label_count
344
+ if not self.balance and self.verbose:
345
+ warn(
346
+ f"Because of dataset limitations, only {min_label_count*self.num_classes} samples "
347
+ f"will be returned, instead of the desired {self.size}."
348
+ )
349
+
350
+ all_indices = np.empty(shape=(self.num_classes, self.per_class_count), dtype=int)
351
+ for i, label in enumerate(self.class_set):
352
+ all_indices[i] = subselect(label_dict[label], self.per_class_count, self.from_back)
353
+ self.data = np.vstack(self.data[all_indices.T]) # type: ignore
354
+ self.targets = np.hstack(self.targets[all_indices.T]) # type: ignore
251
355
 
252
356
  if self.unit_interval:
253
- data = data / 255
357
+ self.data = self.data / 255
254
358
 
255
359
  if self.normalize:
256
- data = (data - self.normalize[0]) / self.normalize[1]
360
+ self.data = (self.data - self.normalize[0]) / self.normalize[1]
257
361
 
258
362
  if self.dtype:
259
- data = data.astype(self.dtype)
363
+ self.data = self.data.astype(self.dtype)
260
364
 
261
365
  if self.channels == "channels_first":
262
- data = np.moveaxis(data, -1, 1)
263
- elif self.channels is None:
264
- data = data[:, :, :, 0]
366
+ self.data = self.data[:, np.newaxis, :, :]
367
+ elif self.channels == "channels_last":
368
+ self.data = self.data[:, :, :, np.newaxis]
265
369
 
266
370
  if self.flatten and self.channels is None:
267
- data = data.reshape(data.shape[0], -1)
268
-
269
- return data, targets
371
+ self.data = self.data.reshape(self.data.shape[0], -1)
270
372
 
271
373
  def __getitem__(self, index: int) -> tuple[NDArray, int]:
272
374
  """
@@ -285,16 +387,18 @@ class MNIST(Dataset):
285
387
 
286
388
  @property
287
389
  def mnist_folder(self) -> str:
390
+ if self.corruption is None:
391
+ return os.path.join(self.root, "mnist")
288
392
  return os.path.join(self.root, "mnist_c", self.corruption)
289
393
 
290
- @property
291
- def class_to_idx(self) -> dict[str, int]:
292
- return {_class: i for i, _class in enumerate(self.classes)}
293
-
294
- def _read_label_file(self, path: str) -> NDArray:
295
- x = np.load(path, allow_pickle=False)
296
- return x
394
+ def _read_normal_file(self, path: str) -> tuple[NDArray, NDArray]:
395
+ with np.load(path, allow_pickle=True) as f:
396
+ if self.train:
397
+ x, y = f["x_train"], f["y_train"]
398
+ else:
399
+ x, y = f["x_test"], f["y_test"]
400
+ return x, y
297
401
 
298
- def _read_image_file(self, path: str) -> NDArray:
402
+ def _read_corrupt_file(self, path: str) -> NDArray:
299
403
  x = np.load(path, allow_pickle=False)
300
404
  return x
@@ -16,6 +16,8 @@ from dataeval._internal.output import OutputMetadata, set_metadata
16
16
  @dataclass(frozen=True)
17
17
  class ClustererOutput(OutputMetadata):
18
18
  """
19
+ Output class for :class:`Clusterer` lint detector
20
+
19
21
  Attributes
20
22
  ----------
21
23
  outliers : List[int]
@@ -23,7 +23,7 @@ from dataeval._internal.output import OutputMetadata, set_metadata
23
23
  @dataclass(frozen=True)
24
24
  class DriftBaseOutput(OutputMetadata):
25
25
  """
26
- Output class for Drift
26
+ Base output class for Drift detector classes
27
27
 
28
28
  Attributes
29
29
  ----------
@@ -42,7 +42,7 @@ class DriftBaseOutput(OutputMetadata):
42
42
  @dataclass(frozen=True)
43
43
  class DriftOutput(DriftBaseOutput):
44
44
  """
45
- Output class for DriftCVM and DriftKS
45
+ Output class for :class:`DriftCVM`, :class:`DriftKS`, and :class:`DriftUncertainty` drift detectors
46
46
 
47
47
  Attributes
48
48
  ----------
@@ -24,7 +24,7 @@ from .torch import GaussianRBF, get_device, mmd2_from_kernel_matrix
24
24
  @dataclass(frozen=True)
25
25
  class DriftMMDOutput(DriftBaseOutput):
26
26
  """
27
- Output class for DriftMMD
27
+ Output class for :class:`DriftMMD` drift detector
28
28
 
29
29
  Attributes
30
30
  ----------
@@ -17,6 +17,8 @@ TIndexCollection = TypeVar("TIndexCollection", DuplicateGroup, DatasetDuplicateG
17
17
  @dataclass(frozen=True)
18
18
  class DuplicatesOutput(Generic[TIndexCollection], OutputMetadata):
19
19
  """
20
+ Output class for :class:`Duplicates` lint detector
21
+
20
22
  Attributes
21
23
  ----------
22
24
  exact : list[list[int] | dict[int, list[int]]]
@@ -15,10 +15,11 @@ import numpy as np
15
15
  import tensorflow as tf
16
16
  from numpy.typing import ArrayLike
17
17
 
18
- from dataeval._internal.detectors.ood.base import OODBase, OODScore
18
+ from dataeval._internal.detectors.ood.base import OODBase, OODScoreOutput
19
19
  from dataeval._internal.interop import as_numpy
20
20
  from dataeval._internal.models.tensorflow.autoencoder import AE
21
21
  from dataeval._internal.models.tensorflow.utils import predict_batch
22
+ from dataeval._internal.output import set_metadata
22
23
 
23
24
 
24
25
  class OOD_AE(OODBase):
@@ -48,7 +49,8 @@ class OOD_AE(OODBase):
48
49
  loss_fn = keras.losses.MeanSquaredError()
49
50
  super().fit(as_numpy(x_ref), threshold_perc, loss_fn, optimizer, epochs, batch_size, verbose)
50
51
 
51
- def score(self, X: ArrayLike, batch_size: int = int(1e10)) -> OODScore:
52
+ @set_metadata("dataeval.detectors")
53
+ def score(self, X: ArrayLike, batch_size: int = int(1e10)) -> OODScoreOutput:
52
54
  self._validate(X := as_numpy(X))
53
55
 
54
56
  # reconstruct instances
@@ -62,4 +64,4 @@ class OOD_AE(OODBase):
62
64
  sorted_fscore_perc = sorted_fscore[:, -n_score_features:]
63
65
  iscore = np.mean(sorted_fscore_perc, axis=1)
64
66
 
65
- return OODScore(iscore, fscore)
67
+ return OODScoreOutput(iscore, fscore)
@@ -14,12 +14,13 @@ import keras
14
14
  import tensorflow as tf
15
15
  from numpy.typing import ArrayLike
16
16
 
17
- from dataeval._internal.detectors.ood.base import OODGMMBase, OODScore
17
+ from dataeval._internal.detectors.ood.base import OODGMMBase, OODScoreOutput
18
18
  from dataeval._internal.interop import to_numpy
19
19
  from dataeval._internal.models.tensorflow.autoencoder import AEGMM
20
20
  from dataeval._internal.models.tensorflow.gmm import gmm_energy
21
21
  from dataeval._internal.models.tensorflow.losses import LossGMM
22
22
  from dataeval._internal.models.tensorflow.utils import predict_batch
23
+ from dataeval._internal.output import set_metadata
23
24
 
24
25
 
25
26
  class OOD_AEGMM(OODGMMBase):
@@ -49,7 +50,8 @@ class OOD_AEGMM(OODGMMBase):
49
50
  loss_fn = LossGMM()
50
51
  super().fit(x_ref, threshold_perc, loss_fn, optimizer, epochs, batch_size, verbose)
51
52
 
52
- def score(self, X: ArrayLike, batch_size: int = int(1e10)) -> OODScore:
53
+ @set_metadata("dataeval.detectors")
54
+ def score(self, X: ArrayLike, batch_size: int = int(1e10)) -> OODScoreOutput:
53
55
  """
54
56
  Compute the out-of-distribution (OOD) score for a given dataset.
55
57
 
@@ -63,7 +65,7 @@ class OOD_AEGMM(OODGMMBase):
63
65
 
64
66
  Returns
65
67
  -------
66
- OODScore
68
+ OODScoreOutput
67
69
  An object containing the instance-level OOD score.
68
70
 
69
71
  Note
@@ -73,4 +75,4 @@ class OOD_AEGMM(OODGMMBase):
73
75
  self._validate(X := to_numpy(X))
74
76
  _, z, _ = predict_batch(X, self.model, batch_size=batch_size)
75
77
  energy, _ = gmm_energy(z, self.gmm_params, return_mean=False)
76
- return OODScore(energy.numpy()) # type: ignore
78
+ return OODScoreOutput(energy.numpy()) # type: ignore