dataeval 0.70.0__tar.gz → 0.70.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. {dataeval-0.70.0 → dataeval-0.70.1}/PKG-INFO +10 -2
  2. {dataeval-0.70.0 → dataeval-0.70.1}/README.md +9 -0
  3. {dataeval-0.70.0 → dataeval-0.70.1}/pyproject.toml +16 -5
  4. {dataeval-0.70.0 → dataeval-0.70.1}/src/dataeval/__init__.py +6 -6
  5. dataeval-0.70.1/src/dataeval/_internal/datasets.py +404 -0
  6. {dataeval-0.70.0 → dataeval-0.70.1}/src/dataeval/_internal/detectors/clusterer.py +2 -0
  7. {dataeval-0.70.0 → dataeval-0.70.1}/src/dataeval/_internal/detectors/drift/base.py +2 -2
  8. {dataeval-0.70.0 → dataeval-0.70.1}/src/dataeval/_internal/detectors/drift/mmd.py +1 -1
  9. {dataeval-0.70.0 → dataeval-0.70.1}/src/dataeval/_internal/detectors/duplicates.py +2 -0
  10. {dataeval-0.70.0 → dataeval-0.70.1}/src/dataeval/_internal/detectors/ood/ae.py +5 -3
  11. {dataeval-0.70.0 → dataeval-0.70.1}/src/dataeval/_internal/detectors/ood/aegmm.py +6 -4
  12. {dataeval-0.70.0 → dataeval-0.70.1}/src/dataeval/_internal/detectors/ood/base.py +12 -7
  13. {dataeval-0.70.0 → dataeval-0.70.1}/src/dataeval/_internal/detectors/ood/llr.py +6 -4
  14. {dataeval-0.70.0 → dataeval-0.70.1}/src/dataeval/_internal/detectors/ood/vae.py +5 -3
  15. {dataeval-0.70.0 → dataeval-0.70.1}/src/dataeval/_internal/detectors/ood/vaegmm.py +6 -4
  16. {dataeval-0.70.0 → dataeval-0.70.1}/src/dataeval/_internal/detectors/outliers.py +4 -2
  17. {dataeval-0.70.0 → dataeval-0.70.1}/src/dataeval/_internal/metrics/balance.py +4 -2
  18. {dataeval-0.70.0 → dataeval-0.70.1}/src/dataeval/_internal/metrics/ber.py +2 -0
  19. {dataeval-0.70.0 → dataeval-0.70.1}/src/dataeval/_internal/metrics/coverage.py +4 -0
  20. {dataeval-0.70.0 → dataeval-0.70.1}/src/dataeval/_internal/metrics/divergence.py +6 -2
  21. {dataeval-0.70.0 → dataeval-0.70.1}/src/dataeval/_internal/metrics/diversity.py +8 -6
  22. {dataeval-0.70.0 → dataeval-0.70.1}/src/dataeval/_internal/metrics/parity.py +8 -6
  23. {dataeval-0.70.0 → dataeval-0.70.1}/src/dataeval/_internal/metrics/stats/base.py +2 -2
  24. {dataeval-0.70.0 → dataeval-0.70.1}/src/dataeval/_internal/metrics/stats/datasetstats.py +2 -0
  25. {dataeval-0.70.0 → dataeval-0.70.1}/src/dataeval/_internal/metrics/stats/dimensionstats.py +2 -0
  26. {dataeval-0.70.0 → dataeval-0.70.1}/src/dataeval/_internal/metrics/stats/hashstats.py +2 -0
  27. {dataeval-0.70.0 → dataeval-0.70.1}/src/dataeval/_internal/metrics/stats/labelstats.py +1 -1
  28. {dataeval-0.70.0 → dataeval-0.70.1}/src/dataeval/_internal/metrics/stats/pixelstats.py +4 -2
  29. {dataeval-0.70.0 → dataeval-0.70.1}/src/dataeval/_internal/metrics/stats/visualstats.py +4 -2
  30. {dataeval-0.70.0 → dataeval-0.70.1}/src/dataeval/_internal/metrics/uap.py +6 -2
  31. {dataeval-0.70.0 → dataeval-0.70.1}/src/dataeval/_internal/metrics/utils.py +2 -2
  32. {dataeval-0.70.0 → dataeval-0.70.1}/src/dataeval/_internal/models/pytorch/autoencoder.py +5 -5
  33. {dataeval-0.70.0 → dataeval-0.70.1}/src/dataeval/_internal/models/tensorflow/pixelcnn.py +1 -4
  34. {dataeval-0.70.0 → dataeval-0.70.1}/src/dataeval/_internal/utils.py +11 -16
  35. {dataeval-0.70.0 → dataeval-0.70.1}/src/dataeval/_internal/workflows/sufficiency.py +44 -33
  36. {dataeval-0.70.0 → dataeval-0.70.1}/src/dataeval/detectors/__init__.py +4 -0
  37. {dataeval-0.70.0 → dataeval-0.70.1}/src/dataeval/detectors/drift/__init__.py +8 -3
  38. {dataeval-0.70.0 → dataeval-0.70.1}/src/dataeval/detectors/drift/kernels/__init__.py +4 -0
  39. {dataeval-0.70.0 → dataeval-0.70.1}/src/dataeval/detectors/drift/updates/__init__.py +4 -0
  40. dataeval-0.70.1/src/dataeval/detectors/linters/__init__.py +16 -0
  41. {dataeval-0.70.0 → dataeval-0.70.1}/src/dataeval/detectors/ood/__init__.py +14 -2
  42. dataeval-0.70.1/src/dataeval/metrics/__init__.py +8 -0
  43. dataeval-0.70.1/src/dataeval/metrics/bias/__init__.py +21 -0
  44. dataeval-0.70.1/src/dataeval/metrics/estimators/__init__.py +9 -0
  45. dataeval-0.70.1/src/dataeval/metrics/stats/__init__.py +28 -0
  46. dataeval-0.70.1/src/dataeval/utils/__init__.py +19 -0
  47. dataeval-0.70.1/src/dataeval/utils/tensorflow/__init__.py +11 -0
  48. dataeval-0.70.1/src/dataeval/utils/torch/__init__.py +12 -0
  49. dataeval-0.70.1/src/dataeval/utils/torch/datasets/__init__.py +7 -0
  50. {dataeval-0.70.0 → dataeval-0.70.1}/src/dataeval/workflows/__init__.py +4 -0
  51. dataeval-0.70.0/src/dataeval/_internal/datasets.py +0 -300
  52. dataeval-0.70.0/src/dataeval/detectors/linters/__init__.py +0 -5
  53. dataeval-0.70.0/src/dataeval/metrics/__init__.py +0 -3
  54. dataeval-0.70.0/src/dataeval/metrics/bias/__init__.py +0 -12
  55. dataeval-0.70.0/src/dataeval/metrics/estimators/__init__.py +0 -9
  56. dataeval-0.70.0/src/dataeval/metrics/stats/__init__.py +0 -17
  57. dataeval-0.70.0/src/dataeval/tensorflow/__init__.py +0 -3
  58. dataeval-0.70.0/src/dataeval/torch/__init__.py +0 -3
  59. dataeval-0.70.0/src/dataeval/utils/__init__.py +0 -6
  60. {dataeval-0.70.0 → dataeval-0.70.1}/LICENSE.txt +0 -0
  61. {dataeval-0.70.0 → dataeval-0.70.1}/src/dataeval/_internal/detectors/__init__.py +0 -0
  62. {dataeval-0.70.0 → dataeval-0.70.1}/src/dataeval/_internal/detectors/drift/__init__.py +0 -0
  63. {dataeval-0.70.0 → dataeval-0.70.1}/src/dataeval/_internal/detectors/drift/cvm.py +0 -0
  64. {dataeval-0.70.0 → dataeval-0.70.1}/src/dataeval/_internal/detectors/drift/ks.py +0 -0
  65. {dataeval-0.70.0 → dataeval-0.70.1}/src/dataeval/_internal/detectors/drift/torch.py +0 -0
  66. {dataeval-0.70.0 → dataeval-0.70.1}/src/dataeval/_internal/detectors/drift/uncertainty.py +0 -0
  67. {dataeval-0.70.0 → dataeval-0.70.1}/src/dataeval/_internal/detectors/merged_stats.py +0 -0
  68. {dataeval-0.70.0 → dataeval-0.70.1}/src/dataeval/_internal/detectors/ood/__init__.py +0 -0
  69. {dataeval-0.70.0 → dataeval-0.70.1}/src/dataeval/_internal/interop.py +0 -0
  70. {dataeval-0.70.0 → dataeval-0.70.1}/src/dataeval/_internal/metrics/__init__.py +0 -0
  71. {dataeval-0.70.0 → dataeval-0.70.1}/src/dataeval/_internal/metrics/stats/boxratiostats.py +0 -0
  72. {dataeval-0.70.0 → dataeval-0.70.1}/src/dataeval/_internal/models/__init__.py +0 -0
  73. {dataeval-0.70.0 → dataeval-0.70.1}/src/dataeval/_internal/models/pytorch/__init__.py +0 -0
  74. {dataeval-0.70.0 → dataeval-0.70.1}/src/dataeval/_internal/models/pytorch/blocks.py +0 -0
  75. {dataeval-0.70.0 → dataeval-0.70.1}/src/dataeval/_internal/models/pytorch/utils.py +0 -0
  76. {dataeval-0.70.0 → dataeval-0.70.1}/src/dataeval/_internal/models/tensorflow/__init__.py +0 -0
  77. {dataeval-0.70.0 → dataeval-0.70.1}/src/dataeval/_internal/models/tensorflow/autoencoder.py +0 -0
  78. {dataeval-0.70.0 → dataeval-0.70.1}/src/dataeval/_internal/models/tensorflow/gmm.py +0 -0
  79. {dataeval-0.70.0 → dataeval-0.70.1}/src/dataeval/_internal/models/tensorflow/losses.py +0 -0
  80. {dataeval-0.70.0 → dataeval-0.70.1}/src/dataeval/_internal/models/tensorflow/trainer.py +0 -0
  81. {dataeval-0.70.0 → dataeval-0.70.1}/src/dataeval/_internal/models/tensorflow/utils.py +0 -0
  82. {dataeval-0.70.0 → dataeval-0.70.1}/src/dataeval/_internal/output.py +0 -0
  83. {dataeval-0.70.0 → dataeval-0.70.1}/src/dataeval/_internal/workflows/__init__.py +0 -0
  84. {dataeval-0.70.0 → dataeval-0.70.1}/src/dataeval/py.typed +0 -0
  85. {dataeval-0.70.0/src/dataeval → dataeval-0.70.1/src/dataeval/utils}/tensorflow/loss/__init__.py +0 -0
  86. {dataeval-0.70.0/src/dataeval → dataeval-0.70.1/src/dataeval/utils}/tensorflow/models/__init__.py +0 -0
  87. {dataeval-0.70.0/src/dataeval → dataeval-0.70.1/src/dataeval/utils}/tensorflow/recon/__init__.py +0 -0
  88. {dataeval-0.70.0/src/dataeval → dataeval-0.70.1/src/dataeval/utils}/torch/models/__init__.py +0 -0
  89. {dataeval-0.70.0/src/dataeval → dataeval-0.70.1/src/dataeval/utils}/torch/trainer/__init__.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: dataeval
3
- Version: 0.70.0
3
+ Version: 0.70.1
4
4
  Summary: DataEval provides a simple interface to characterize image data and its impact on model performance across classification and object-detection tasks
5
5
  Home-page: https://dataeval.ai/
6
6
  License: MIT
@@ -30,7 +30,6 @@ Requires-Dist: pillow (>=10.3.0)
30
30
  Requires-Dist: scikit-learn (>=1.5.0)
31
31
  Requires-Dist: scipy (>=1.10)
32
32
  Requires-Dist: tensorflow (>=2.14.1,<2.16) ; extra == "tensorflow" or extra == "all"
33
- Requires-Dist: tensorflow-io-gcs-filesystem (>=0.35.0,<0.37) ; extra == "tensorflow" or extra == "all"
34
33
  Requires-Dist: tensorflow_probability (>=0.22.1,<0.24) ; extra == "tensorflow" or extra == "all"
35
34
  Requires-Dist: torch (>=2.2.0) ; extra == "torch" or extra == "all"
36
35
  Requires-Dist: torchvision (>=0.17.0) ; extra == "torch" or extra == "all"
@@ -75,6 +74,15 @@ You can install DataEval directly from pypi.org using the following command. Th
75
74
  pip install dataeval[all]
76
75
  ```
77
76
 
77
+ ### Installing DataEval in Conda/Mamba
78
+
79
+ DataEval can be installed in a Conda/Mamba environment using the provided `environment.yaml` file. As some dependencies
80
+ are installed from the `pytorch` channel, the channel is specified in the below example.
81
+
82
+ ```
83
+ micromamba create -f environment\environment.yaml -c pytorch
84
+ ```
85
+
78
86
  ### Installing DataEval from GitHub
79
87
 
80
88
  To install DataEval from source locally on Ubuntu, you will need `git-lfs` to download larger, binary source files and `poetry` for project dependency management.
@@ -34,6 +34,15 @@ You can install DataEval directly from pypi.org using the following command. Th
34
34
  pip install dataeval[all]
35
35
  ```
36
36
 
37
+ ### Installing DataEval in Conda/Mamba
38
+
39
+ DataEval can be installed in a Conda/Mamba environment using the provided `environment.yaml` file. As some dependencies
40
+ are installed from the `pytorch` channel, the channel is specified in the below example.
41
+
42
+ ```
43
+ micromamba create -f environment\environment.yaml -c pytorch
44
+ ```
45
+
37
46
  ### Installing DataEval from GitHub
38
47
 
39
48
  To install DataEval from source locally on Ubuntu, you will need `git-lfs` to download larger, binary source files and `poetry` for project dependency management.
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "dataeval"
3
- version = "0.70.0" # dynamic
3
+ version = "0.70.1" # dynamic
4
4
  description = "DataEval provides a simple interface to characterize image data and its impact on model performance across classification and object-detection tasks"
5
5
  license = "MIT"
6
6
  readme = "README.md"
@@ -52,15 +52,14 @@ xxhash = {version = ">=3.3"}
52
52
  matplotlib = {version = "*", optional = true}
53
53
  nvidia-cudnn-cu11 = {version = ">=8.6.0.163", optional = true}
54
54
  tensorflow = {version = ">=2.14.1, <2.16", optional = true}
55
- tensorflow-io-gcs-filesystem = {version = ">=0.35.0, <0.37", optional = true}
56
55
  tensorflow_probability = {version = ">=0.22.1, <0.24", optional = true}
57
56
  torch = {version = ">=2.2.0", source = "pytorch", optional = true}
58
57
  torchvision = {version = ">=0.17.0", source = "pytorch", optional = true}
59
58
 
60
59
  [tool.poetry.extras]
61
- tensorflow = ["tensorflow", "tensorflow-io-gcs-filesystem", "tensorflow_probability", "nvidia-cudnn-cu11"]
60
+ tensorflow = ["tensorflow", "tensorflow_probability", "nvidia-cudnn-cu11"]
62
61
  torch = ["torch", "torchvision", "matplotlib", "nvidia-cudnn-cu11"]
63
- all = ["matplotlib", "nvidia-cudnn-cu11", "tensorflow", "tensorflow-io-gcs-filesystem", "tensorflow_probability", "torch", "torchvision"]
62
+ all = ["matplotlib", "nvidia-cudnn-cu11", "tensorflow", "tensorflow_probability", "torch", "torchvision"]
64
63
 
65
64
  [tool.poetry.group.dev]
66
65
  optional = true
@@ -71,6 +70,7 @@ tox-uv = {version = "*"}
71
70
  uv = {version = "*"}
72
71
  poetry = {version = "*"}
73
72
  poetry-lock-groups-plugin = {version = "*"}
73
+ poetry2conda = {version = "*"}
74
74
  # lint
75
75
  ruff = {version = "*"}
76
76
  codespell = {version = "*", extras = ["toml"]}
@@ -113,6 +113,18 @@ pattern = "v(?P<base>\\d+\\.\\d+\\.\\d+)$"
113
113
  [tool.poetry-dynamic-versioning.substitution]
114
114
  files = ["src/dataeval/__init__.py"]
115
115
 
116
+ [tool.poetry2conda]
117
+ name = "dataeval"
118
+
119
+ [tool.poetry2conda.dependencies]
120
+ nvidia-cudnn-cu11 = { name = "cudnn" }
121
+ pillow = { channel = "pip" }
122
+ tensorflow = { channel = "pip" }
123
+ tensorflow_probability = { channel = "pip" }
124
+ torch = { name = "pytorch", channel = "pytorch" }
125
+ torchvision = { channel = "pytorch" }
126
+ xxhash = { name = "python-xxhash" }
127
+
116
128
  [tool.pyright]
117
129
  reportMissingImports = false
118
130
 
@@ -131,7 +143,6 @@ omit = [
131
143
  "*/_internal/models/pytorch/blocks.py",
132
144
  "*/_internal/models/pytorch/utils.py",
133
145
  "*/_internal/models/tensorflow/pixelcnn.py",
134
- "*/_internal/datasets.py",
135
146
  ]
136
147
  fail_under = 90
137
148
 
@@ -1,4 +1,4 @@
1
- __version__ = "0.70.0"
1
+ __version__ = "0.70.1"
2
2
 
3
3
  from importlib.util import find_spec
4
4
 
@@ -12,11 +12,11 @@ from . import detectors, metrics # noqa: E402
12
12
  __all__ = ["detectors", "metrics"]
13
13
 
14
14
  if _IS_TORCH_AVAILABLE: # pragma: no cover
15
- from . import torch, utils, workflows
15
+ from . import workflows
16
16
 
17
- __all__ += ["torch", "utils", "workflows"]
17
+ __all__ += ["workflows"]
18
18
 
19
- if _IS_TENSORFLOW_AVAILABLE: # pragma: no cover
20
- from . import tensorflow
19
+ if _IS_TENSORFLOW_AVAILABLE or _IS_TORCH_AVAILABLE: # pragma: no cover
20
+ from . import utils
21
21
 
22
- __all__ += ["tensorflow"]
22
+ __all__ += ["utils"]
@@ -0,0 +1,404 @@
1
+ from __future__ import annotations
2
+
3
+ import hashlib
4
+ import os
5
+ import zipfile
6
+ from pathlib import Path
7
+ from typing import Literal, TypeVar
8
+ from warnings import warn
9
+
10
+ import numpy as np
11
+ import requests
12
+ from numpy.typing import NDArray
13
+ from torch.utils.data import Dataset
14
+ from torchvision.datasets import CIFAR10, VOCDetection # noqa: F401
15
+
16
+ ClassStringMap = Literal["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
17
+ TClassMap = TypeVar("TClassMap", ClassStringMap, int, list[ClassStringMap], list[int])
18
+ CorruptionStringMap = Literal[
19
+ "identity",
20
+ "shot_noise",
21
+ "impulse_noise",
22
+ "glass_blur",
23
+ "motion_blur",
24
+ "shear",
25
+ "scale",
26
+ "rotate",
27
+ "brightness",
28
+ "translate",
29
+ "stripe",
30
+ "fog",
31
+ "spatter",
32
+ "dotted_line",
33
+ "zigzag",
34
+ "canny_edges",
35
+ ]
36
+
37
+
38
+ def _validate_file(fpath, file_md5, md5=False, chunk_size=65535):
39
+ hasher = hashlib.md5() if md5 else hashlib.sha256()
40
+ with open(fpath, "rb") as fpath_file:
41
+ while chunk := fpath_file.read(chunk_size):
42
+ hasher.update(chunk)
43
+ return hasher.hexdigest() == file_md5
44
+
45
+
46
+ def _get_file(
47
+ root: str | Path,
48
+ fname: str,
49
+ origin: str,
50
+ file_hash: str | None = None,
51
+ verbose: bool = True,
52
+ md5: bool = False,
53
+ ):
54
+ fpath = os.path.join(root, fname)
55
+ download = True
56
+ if os.path.exists(fpath) and file_hash is not None and _validate_file(fpath, file_hash, md5):
57
+ download = False
58
+ if verbose:
59
+ print("File already downloaded and verified.")
60
+ if md5:
61
+ print("Extracting zip file...")
62
+
63
+ if download:
64
+ try:
65
+ error_msg = "URL fetch failure on {}: {} -- {}"
66
+ try:
67
+ with requests.get(origin, stream=True, timeout=60) as r:
68
+ r.raise_for_status()
69
+ with open(fpath, "wb") as f:
70
+ for chunk in r.iter_content(chunk_size=8192):
71
+ if chunk:
72
+ f.write(chunk)
73
+ except requests.exceptions.HTTPError as e:
74
+ raise Exception(f"{error_msg.format(origin, e.response.status_code)} -- {e.response.reason}") from e
75
+ except requests.exceptions.RequestException as e:
76
+ raise Exception(f"{error_msg.format(origin, 'Unknown error')} -- {str(e)}") from e
77
+ except (Exception, KeyboardInterrupt):
78
+ if os.path.exists(fpath):
79
+ os.remove(fpath)
80
+ raise
81
+
82
+ if os.path.exists(fpath) and file_hash is not None and not _validate_file(fpath, file_hash, md5):
83
+ raise ValueError(
84
+ "Incomplete or corrupted file detected. "
85
+ f"The file hash does not match the provided value "
86
+ f"of {file_hash}.",
87
+ )
88
+
89
+ return fpath
90
+
91
+
92
+ def check_exists(
93
+ folder: str | Path,
94
+ url: str,
95
+ root: str | Path,
96
+ fname: str,
97
+ file_hash: str,
98
+ download: bool = True,
99
+ verbose: bool = True,
100
+ md5: bool = False,
101
+ ):
102
+ """Determine if the dataset has already been downloaded."""
103
+ location = str(folder)
104
+ if not os.path.exists(folder):
105
+ if download:
106
+ location = download_dataset(url, root, fname, file_hash, verbose, md5)
107
+ else:
108
+ raise RuntimeError("Dataset not found. You can use download=True to download it")
109
+ else:
110
+ if verbose:
111
+ print("Files already downloaded and verified")
112
+ return location
113
+
114
+
115
+ def download_dataset(
116
+ url: str, root: str | Path, fname: str, file_hash: str, verbose: bool = True, md5: bool = False
117
+ ) -> str:
118
+ """Code to download mnist and corruptions, originates from tensorflow_datasets (tfds):
119
+ https://github.com/tensorflow/datasets/blob/master/tensorflow_datasets/image_classification/mnist_corrupted.py
120
+ """
121
+ name, _ = os.path.splitext(fname)
122
+ folder = os.path.join(root, name)
123
+ os.makedirs(folder, exist_ok=True)
124
+
125
+ fpath = _get_file(
126
+ folder,
127
+ fname,
128
+ origin=url + fname,
129
+ file_hash=file_hash,
130
+ verbose=verbose,
131
+ md5=md5,
132
+ )
133
+ if md5:
134
+ folder = extract_archive(fpath, root, remove_finished=True)
135
+ return folder
136
+
137
+
138
+ def extract_archive(
139
+ from_path: str | Path,
140
+ to_path: str | Path | None = None,
141
+ remove_finished: bool = False,
142
+ ) -> str:
143
+ """Extract an archive.
144
+
145
+ The archive type and a possible compression is automatically detected from the file name.
146
+ """
147
+ from_path = Path(from_path)
148
+ if not from_path.is_absolute():
149
+ from_path = from_path.resolve()
150
+
151
+ if to_path is None or not os.path.exists(to_path):
152
+ to_path = os.path.dirname(from_path)
153
+ to_path = Path(to_path)
154
+ if not to_path.is_absolute():
155
+ to_path = to_path.resolve()
156
+
157
+ # Extracting zip
158
+ with zipfile.ZipFile(from_path, "r", compression=zipfile.ZIP_STORED) as zzip:
159
+ zzip.extractall(to_path)
160
+
161
+ if remove_finished:
162
+ os.remove(from_path)
163
+ return str(to_path)
164
+
165
+
166
+ def subselect(arr: NDArray, count: int, from_back: bool = False):
167
+ if from_back:
168
+ return arr[-count:]
169
+ return arr[:count]
170
+
171
+
172
+ class MNIST(Dataset):
173
+ """MNIST Dataset and Corruptions.
174
+
175
+ Args:
176
+ root : str | ``pathlib.Path``
177
+ Root directory of dataset where the ``mnist_c/`` folder exists.
178
+ train : bool, default True
179
+ If True, creates dataset from ``train_images.npy`` and ``train_labels.npy``.
180
+ download : bool, default False
181
+ If True, downloads the dataset from the internet and puts it in root
182
+ directory. If dataset is already downloaded, it is not downloaded again.
183
+ size : int, default -1
184
+ Limit the dataset size, must be a value greater than 0.
185
+ unit_interval : bool, default False
186
+ Shift the data values to the unit interval [0-1].
187
+ dtype : type | None, default None
188
+ Change the numpy dtype - data is loaded as np.uint8
189
+ channels : Literal['channels_first' | 'channels_last'] | None, default None
190
+ Location of channel axis if desired, default has no channels (N, 28, 28)
191
+ flatten : bool, default False
192
+ Flatten data into single dimension (N, 784) - cannot use both channels and flatten,
193
+ channels takes priority over flatten.
194
+ normalize : tuple[mean, std] | None, default None
195
+ Normalize images acorrding to provided mean and standard deviation
196
+ corruption : Literal['identity' | 'shot_noise' | 'impulse_noise' | 'glass_blur' |
197
+ 'motion_blur' | 'shear' | 'scale' | 'rotate' | 'brightness' | 'translate' | 'stripe' |
198
+ 'fog' | 'spatter' | 'dotted_line' | 'zigzag' | 'canny_edges'] | None, default None
199
+ The desired corruption style or None.
200
+ classes : Literal["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
201
+ | int | list[int] | list[Literal["zero", "one", "two", "three", "four", "five", "six", "seven",
202
+ "eight", "nine"]] | None, default None
203
+ Option to select specific classes from dataset.
204
+ balance : bool, default True
205
+ If True, returns equal number of samples for each class.
206
+ randomize : bool, default False
207
+ If True, shuffles the data prior to selection - uses a set seed for reproducibility.
208
+ slice_back : bool, default False
209
+ If True and size has a value greater than 0, then grabs selection starting at the last image.
210
+ verbose : bool, default True
211
+ If True, outputs print statements.
212
+ """
213
+
214
+ mirror = [
215
+ "https://storage.googleapis.com/tensorflow/tf-keras-datasets/",
216
+ "https://zenodo.org/record/3239543/files/",
217
+ ]
218
+
219
+ resources = [
220
+ ("mnist.npz", "731c5ac602752760c8e48fbffcf8c3b850d9dc2a2aedcf2cc48468fc17b673d1"),
221
+ ("mnist_c.zip", "4b34b33045869ee6d424616cd3a65da3"),
222
+ ]
223
+
224
+ class_dict = {
225
+ "zero": 0,
226
+ "one": 1,
227
+ "two": 2,
228
+ "three": 3,
229
+ "four": 4,
230
+ "five": 5,
231
+ "six": 6,
232
+ "seven": 7,
233
+ "eight": 8,
234
+ "nine": 9,
235
+ }
236
+
237
+ def __init__(
238
+ self,
239
+ root: str | Path,
240
+ train: bool = True,
241
+ download: bool = False,
242
+ size: int = -1,
243
+ unit_interval: bool = False,
244
+ dtype: type | None = None,
245
+ channels: Literal["channels_first", "channels_last"] | None = None,
246
+ flatten: bool = False,
247
+ normalize: tuple[float, float] | None = None,
248
+ corruption: CorruptionStringMap | None = None,
249
+ classes: TClassMap | None = None,
250
+ balance: bool = True,
251
+ randomize: bool = False,
252
+ slice_back: bool = False,
253
+ verbose: bool = True,
254
+ ) -> None:
255
+ if isinstance(root, str):
256
+ root = os.path.expanduser(root)
257
+ self.root = root # location of stored dataset
258
+ self.train = train # training set or test set
259
+ self.size = size
260
+ self.unit_interval = unit_interval
261
+ self.dtype = dtype
262
+ self.channels = channels
263
+ self.flatten = flatten
264
+ self.normalize = normalize
265
+ self.corruption = corruption
266
+ self.balance = balance
267
+ self.randomize = randomize
268
+ self.from_back = slice_back
269
+ self.verbose = verbose
270
+
271
+ self.class_set = []
272
+ if classes is not None:
273
+ if not isinstance(classes, list):
274
+ classes = [classes] # type: ignore
275
+
276
+ for val in classes: # type: ignore
277
+ if isinstance(val, int) and 0 <= val < 10:
278
+ self.class_set.append(val)
279
+ elif isinstance(val, str):
280
+ self.class_set.append(self.class_dict[val])
281
+ self.class_set = set(self.class_set)
282
+
283
+ if not self.class_set:
284
+ self.class_set = set(self.class_dict.values())
285
+
286
+ self.num_classes = len(self.class_set)
287
+
288
+ if self.corruption is None:
289
+ file_resource = self.resources[0]
290
+ mirror = self.mirror[0]
291
+ md5 = False
292
+ else:
293
+ if self.corruption == "identity" and verbose:
294
+ print("Identity is not a corrupted dataset but the original MNIST dataset.")
295
+ file_resource = self.resources[1]
296
+ mirror = self.mirror[1]
297
+ md5 = True
298
+ check_exists(self.mnist_folder, mirror, self.root, file_resource[0], file_resource[1], download, verbose, md5)
299
+
300
+ self.data, self.targets = self._load_data()
301
+
302
+ self._augmentations()
303
+
304
+ def _load_data(self):
305
+ if self.corruption is None:
306
+ image_file = self.resources[0][0]
307
+ data, targets = self._read_normal_file(os.path.join(self.mnist_folder, image_file))
308
+ else:
309
+ image_file = f"{'train' if self.train else 'test'}_images.npy"
310
+ data = self._read_corrupt_file(os.path.join(self.mnist_folder, image_file))
311
+ data = data.squeeze()
312
+
313
+ label_file = f"{'train' if self.train else 'test'}_labels.npy"
314
+ targets = self._read_corrupt_file(os.path.join(self.mnist_folder, label_file))
315
+
316
+ return data, targets
317
+
318
+ def _augmentations(self):
319
+ if self.size > self.targets.shape[0] and self.verbose:
320
+ warn(
321
+ f"Asked for more samples, {self.size}, than the raw dataset contains, {self.targets.shape[0]}. "
322
+ "Adjusting down to raw dataset size."
323
+ )
324
+ self.size = -1
325
+
326
+ if self.randomize:
327
+ rdm_seed = np.random.default_rng(2023)
328
+ shuffled_indices = rdm_seed.permutation(self.data.shape[0])
329
+ self.data = self.data[shuffled_indices]
330
+ self.targets = self.targets[shuffled_indices]
331
+
332
+ if not self.balance and self.num_classes > self.size:
333
+ if self.size > 0:
334
+ self.data = subselect(self.data, self.size, self.from_back)
335
+ self.targets = subselect(self.targets, self.size, self.from_back)
336
+ else:
337
+ label_dict = {label: np.where(self.targets == label)[0] for label in self.class_set}
338
+ min_label_count = min(len(indices) for indices in label_dict.values())
339
+
340
+ self.per_class_count = int(np.ceil(self.size / self.num_classes)) if self.size > 0 else min_label_count
341
+
342
+ if self.per_class_count > min_label_count:
343
+ self.per_class_count = min_label_count
344
+ if not self.balance and self.verbose:
345
+ warn(
346
+ f"Because of dataset limitations, only {min_label_count*self.num_classes} samples "
347
+ f"will be returned, instead of the desired {self.size}."
348
+ )
349
+
350
+ all_indices = np.empty(shape=(self.num_classes, self.per_class_count), dtype=int)
351
+ for i, label in enumerate(self.class_set):
352
+ all_indices[i] = subselect(label_dict[label], self.per_class_count, self.from_back)
353
+ self.data = np.vstack(self.data[all_indices.T]) # type: ignore
354
+ self.targets = np.hstack(self.targets[all_indices.T]) # type: ignore
355
+
356
+ if self.unit_interval:
357
+ self.data = self.data / 255
358
+
359
+ if self.normalize:
360
+ self.data = (self.data - self.normalize[0]) / self.normalize[1]
361
+
362
+ if self.dtype:
363
+ self.data = self.data.astype(self.dtype)
364
+
365
+ if self.channels == "channels_first":
366
+ self.data = self.data[:, np.newaxis, :, :]
367
+ elif self.channels == "channels_last":
368
+ self.data = self.data[:, :, :, np.newaxis]
369
+
370
+ if self.flatten and self.channels is None:
371
+ self.data = self.data.reshape(self.data.shape[0], -1)
372
+
373
+ def __getitem__(self, index: int) -> tuple[NDArray, int]:
374
+ """
375
+ Args:
376
+ index (int): Index
377
+
378
+ Returns:
379
+ tuple: (image, target) where target is index of the target class.
380
+ """
381
+ img, target = self.data[index], int(self.targets[index])
382
+
383
+ return img, target
384
+
385
+ def __len__(self) -> int:
386
+ return len(self.data)
387
+
388
+ @property
389
+ def mnist_folder(self) -> str:
390
+ if self.corruption is None:
391
+ return os.path.join(self.root, "mnist")
392
+ return os.path.join(self.root, "mnist_c", self.corruption)
393
+
394
+ def _read_normal_file(self, path: str) -> tuple[NDArray, NDArray]:
395
+ with np.load(path, allow_pickle=True) as f:
396
+ if self.train:
397
+ x, y = f["x_train"], f["y_train"]
398
+ else:
399
+ x, y = f["x_test"], f["y_test"]
400
+ return x, y
401
+
402
+ def _read_corrupt_file(self, path: str) -> NDArray:
403
+ x = np.load(path, allow_pickle=False)
404
+ return x
@@ -16,6 +16,8 @@ from dataeval._internal.output import OutputMetadata, set_metadata
16
16
  @dataclass(frozen=True)
17
17
  class ClustererOutput(OutputMetadata):
18
18
  """
19
+ Output class for :class:`Clusterer` lint detector
20
+
19
21
  Attributes
20
22
  ----------
21
23
  outliers : List[int]
@@ -23,7 +23,7 @@ from dataeval._internal.output import OutputMetadata, set_metadata
23
23
  @dataclass(frozen=True)
24
24
  class DriftBaseOutput(OutputMetadata):
25
25
  """
26
- Output class for Drift
26
+ Base output class for Drift detector classes
27
27
 
28
28
  Attributes
29
29
  ----------
@@ -42,7 +42,7 @@ class DriftBaseOutput(OutputMetadata):
42
42
  @dataclass(frozen=True)
43
43
  class DriftOutput(DriftBaseOutput):
44
44
  """
45
- Output class for DriftCVM and DriftKS
45
+ Output class for :class:`DriftCVM`, :class:`DriftKS`, and :class:`DriftUncertainty` drift detectors
46
46
 
47
47
  Attributes
48
48
  ----------
@@ -24,7 +24,7 @@ from .torch import GaussianRBF, get_device, mmd2_from_kernel_matrix
24
24
  @dataclass(frozen=True)
25
25
  class DriftMMDOutput(DriftBaseOutput):
26
26
  """
27
- Output class for DriftMMD
27
+ Output class for :class:`DriftMMD` drift detector
28
28
 
29
29
  Attributes
30
30
  ----------
@@ -17,6 +17,8 @@ TIndexCollection = TypeVar("TIndexCollection", DuplicateGroup, DatasetDuplicateG
17
17
  @dataclass(frozen=True)
18
18
  class DuplicatesOutput(Generic[TIndexCollection], OutputMetadata):
19
19
  """
20
+ Output class for :class:`Duplicates` lint detector
21
+
20
22
  Attributes
21
23
  ----------
22
24
  exact : list[list[int] | dict[int, list[int]]]
@@ -15,10 +15,11 @@ import numpy as np
15
15
  import tensorflow as tf
16
16
  from numpy.typing import ArrayLike
17
17
 
18
- from dataeval._internal.detectors.ood.base import OODBase, OODScore
18
+ from dataeval._internal.detectors.ood.base import OODBase, OODScoreOutput
19
19
  from dataeval._internal.interop import as_numpy
20
20
  from dataeval._internal.models.tensorflow.autoencoder import AE
21
21
  from dataeval._internal.models.tensorflow.utils import predict_batch
22
+ from dataeval._internal.output import set_metadata
22
23
 
23
24
 
24
25
  class OOD_AE(OODBase):
@@ -48,7 +49,8 @@ class OOD_AE(OODBase):
48
49
  loss_fn = keras.losses.MeanSquaredError()
49
50
  super().fit(as_numpy(x_ref), threshold_perc, loss_fn, optimizer, epochs, batch_size, verbose)
50
51
 
51
- def score(self, X: ArrayLike, batch_size: int = int(1e10)) -> OODScore:
52
+ @set_metadata("dataeval.detectors")
53
+ def score(self, X: ArrayLike, batch_size: int = int(1e10)) -> OODScoreOutput:
52
54
  self._validate(X := as_numpy(X))
53
55
 
54
56
  # reconstruct instances
@@ -62,4 +64,4 @@ class OOD_AE(OODBase):
62
64
  sorted_fscore_perc = sorted_fscore[:, -n_score_features:]
63
65
  iscore = np.mean(sorted_fscore_perc, axis=1)
64
66
 
65
- return OODScore(iscore, fscore)
67
+ return OODScoreOutput(iscore, fscore)
@@ -14,12 +14,13 @@ import keras
14
14
  import tensorflow as tf
15
15
  from numpy.typing import ArrayLike
16
16
 
17
- from dataeval._internal.detectors.ood.base import OODGMMBase, OODScore
17
+ from dataeval._internal.detectors.ood.base import OODGMMBase, OODScoreOutput
18
18
  from dataeval._internal.interop import to_numpy
19
19
  from dataeval._internal.models.tensorflow.autoencoder import AEGMM
20
20
  from dataeval._internal.models.tensorflow.gmm import gmm_energy
21
21
  from dataeval._internal.models.tensorflow.losses import LossGMM
22
22
  from dataeval._internal.models.tensorflow.utils import predict_batch
23
+ from dataeval._internal.output import set_metadata
23
24
 
24
25
 
25
26
  class OOD_AEGMM(OODGMMBase):
@@ -49,7 +50,8 @@ class OOD_AEGMM(OODGMMBase):
49
50
  loss_fn = LossGMM()
50
51
  super().fit(x_ref, threshold_perc, loss_fn, optimizer, epochs, batch_size, verbose)
51
52
 
52
- def score(self, X: ArrayLike, batch_size: int = int(1e10)) -> OODScore:
53
+ @set_metadata("dataeval.detectors")
54
+ def score(self, X: ArrayLike, batch_size: int = int(1e10)) -> OODScoreOutput:
53
55
  """
54
56
  Compute the out-of-distribution (OOD) score for a given dataset.
55
57
 
@@ -63,7 +65,7 @@ class OOD_AEGMM(OODGMMBase):
63
65
 
64
66
  Returns
65
67
  -------
66
- OODScore
68
+ OODScoreOutput
67
69
  An object containing the instance-level OOD score.
68
70
 
69
71
  Note
@@ -73,4 +75,4 @@ class OOD_AEGMM(OODGMMBase):
73
75
  self._validate(X := to_numpy(X))
74
76
  _, z, _ = predict_batch(X, self.model, batch_size=batch_size)
75
77
  energy, _ = gmm_energy(z, self.gmm_params, return_mean=False)
76
- return OODScore(energy.numpy()) # type: ignore
78
+ return OODScoreOutput(energy.numpy()) # type: ignore