dataeval 0.69.4__tar.gz → 0.70.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. {dataeval-0.69.4 → dataeval-0.70.1}/PKG-INFO +12 -4
  2. {dataeval-0.69.4 → dataeval-0.70.1}/README.md +9 -0
  3. {dataeval-0.69.4 → dataeval-0.70.1}/pyproject.toml +19 -6
  4. dataeval-0.70.1/src/dataeval/__init__.py +22 -0
  5. dataeval-0.70.1/src/dataeval/_internal/datasets.py +404 -0
  6. {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/detectors/clusterer.py +2 -0
  7. {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/detectors/drift/base.py +7 -8
  8. {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/detectors/drift/mmd.py +4 -4
  9. {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/detectors/duplicates.py +64 -45
  10. dataeval-0.70.1/src/dataeval/_internal/detectors/merged_stats.py +47 -0
  11. {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/detectors/ood/ae.py +8 -6
  12. {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/detectors/ood/aegmm.py +6 -4
  13. {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/detectors/ood/base.py +12 -7
  14. {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/detectors/ood/llr.py +6 -4
  15. {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/detectors/ood/vae.py +5 -3
  16. {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/detectors/ood/vaegmm.py +6 -4
  17. dataeval-0.70.1/src/dataeval/_internal/detectors/outliers.py +271 -0
  18. {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/interop.py +11 -7
  19. {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/metrics/balance.py +13 -11
  20. {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/metrics/ber.py +5 -3
  21. {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/metrics/coverage.py +4 -0
  22. {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/metrics/divergence.py +9 -5
  23. {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/metrics/diversity.py +14 -12
  24. {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/metrics/parity.py +32 -22
  25. dataeval-0.70.1/src/dataeval/_internal/metrics/stats/base.py +231 -0
  26. dataeval-0.70.1/src/dataeval/_internal/metrics/stats/boxratiostats.py +159 -0
  27. dataeval-0.70.1/src/dataeval/_internal/metrics/stats/datasetstats.py +99 -0
  28. dataeval-0.70.1/src/dataeval/_internal/metrics/stats/dimensionstats.py +113 -0
  29. dataeval-0.70.1/src/dataeval/_internal/metrics/stats/hashstats.py +75 -0
  30. dataeval-0.70.1/src/dataeval/_internal/metrics/stats/labelstats.py +125 -0
  31. dataeval-0.70.1/src/dataeval/_internal/metrics/stats/pixelstats.py +119 -0
  32. dataeval-0.70.1/src/dataeval/_internal/metrics/stats/visualstats.py +124 -0
  33. {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/metrics/uap.py +8 -4
  34. {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/metrics/utils.py +30 -15
  35. {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/models/pytorch/autoencoder.py +5 -5
  36. {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/models/tensorflow/pixelcnn.py +1 -4
  37. {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/output.py +3 -18
  38. {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/utils.py +11 -16
  39. {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/workflows/sufficiency.py +152 -151
  40. {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/detectors/__init__.py +4 -0
  41. {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/detectors/drift/__init__.py +8 -3
  42. {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/detectors/drift/kernels/__init__.py +4 -0
  43. {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/detectors/drift/updates/__init__.py +4 -0
  44. dataeval-0.70.1/src/dataeval/detectors/linters/__init__.py +16 -0
  45. {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/detectors/ood/__init__.py +14 -2
  46. dataeval-0.70.1/src/dataeval/metrics/__init__.py +8 -0
  47. dataeval-0.70.1/src/dataeval/metrics/bias/__init__.py +21 -0
  48. dataeval-0.70.1/src/dataeval/metrics/estimators/__init__.py +9 -0
  49. dataeval-0.70.1/src/dataeval/metrics/stats/__init__.py +28 -0
  50. dataeval-0.70.1/src/dataeval/utils/__init__.py +19 -0
  51. dataeval-0.70.1/src/dataeval/utils/tensorflow/__init__.py +11 -0
  52. dataeval-0.70.1/src/dataeval/utils/torch/__init__.py +12 -0
  53. dataeval-0.70.1/src/dataeval/utils/torch/datasets/__init__.py +7 -0
  54. dataeval-0.70.1/src/dataeval/workflows/__init__.py +10 -0
  55. dataeval-0.69.4/src/dataeval/__init__.py +0 -22
  56. dataeval-0.69.4/src/dataeval/_internal/datasets.py +0 -300
  57. dataeval-0.69.4/src/dataeval/_internal/detectors/merged_stats.py +0 -78
  58. dataeval-0.69.4/src/dataeval/_internal/detectors/outliers.py +0 -197
  59. dataeval-0.69.4/src/dataeval/_internal/flags.py +0 -77
  60. dataeval-0.69.4/src/dataeval/_internal/metrics/stats.py +0 -397
  61. dataeval-0.69.4/src/dataeval/detectors/linters/__init__.py +0 -5
  62. dataeval-0.69.4/src/dataeval/flags/__init__.py +0 -3
  63. dataeval-0.69.4/src/dataeval/metrics/__init__.py +0 -3
  64. dataeval-0.69.4/src/dataeval/metrics/bias/__init__.py +0 -12
  65. dataeval-0.69.4/src/dataeval/metrics/estimators/__init__.py +0 -9
  66. dataeval-0.69.4/src/dataeval/metrics/stats/__init__.py +0 -6
  67. dataeval-0.69.4/src/dataeval/tensorflow/__init__.py +0 -3
  68. dataeval-0.69.4/src/dataeval/torch/__init__.py +0 -3
  69. dataeval-0.69.4/src/dataeval/utils/__init__.py +0 -6
  70. dataeval-0.69.4/src/dataeval/workflows/__init__.py +0 -6
  71. {dataeval-0.69.4 → dataeval-0.70.1}/LICENSE.txt +0 -0
  72. {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/detectors/__init__.py +0 -0
  73. {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/detectors/drift/__init__.py +0 -0
  74. {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/detectors/drift/cvm.py +0 -0
  75. {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/detectors/drift/ks.py +0 -0
  76. {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/detectors/drift/torch.py +0 -0
  77. {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/detectors/drift/uncertainty.py +0 -0
  78. {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/detectors/ood/__init__.py +0 -0
  79. {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/metrics/__init__.py +0 -0
  80. {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/models/__init__.py +0 -0
  81. {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/models/pytorch/__init__.py +0 -0
  82. {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/models/pytorch/blocks.py +0 -0
  83. {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/models/pytorch/utils.py +0 -0
  84. {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/models/tensorflow/__init__.py +0 -0
  85. {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/models/tensorflow/autoencoder.py +0 -0
  86. {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/models/tensorflow/gmm.py +0 -0
  87. {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/models/tensorflow/losses.py +0 -0
  88. {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/models/tensorflow/trainer.py +0 -0
  89. {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/models/tensorflow/utils.py +0 -0
  90. {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/workflows/__init__.py +0 -0
  91. {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/py.typed +0 -0
  92. {dataeval-0.69.4/src/dataeval → dataeval-0.70.1/src/dataeval/utils}/tensorflow/loss/__init__.py +0 -0
  93. {dataeval-0.69.4/src/dataeval → dataeval-0.70.1/src/dataeval/utils}/tensorflow/models/__init__.py +0 -0
  94. {dataeval-0.69.4/src/dataeval → dataeval-0.70.1/src/dataeval/utils}/tensorflow/recon/__init__.py +0 -0
  95. {dataeval-0.69.4/src/dataeval → dataeval-0.70.1/src/dataeval/utils}/torch/models/__init__.py +0 -0
  96. {dataeval-0.69.4/src/dataeval → dataeval-0.70.1/src/dataeval/utils}/torch/trainer/__init__.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: dataeval
3
- Version: 0.69.4
3
+ Version: 0.70.1
4
4
  Summary: DataEval provides a simple interface to characterize image data and its impact on model performance across classification and object-detection tasks
5
5
  Home-page: https://dataeval.ai/
6
6
  License: MIT
@@ -30,10 +30,9 @@ Requires-Dist: pillow (>=10.3.0)
30
30
  Requires-Dist: scikit-learn (>=1.5.0)
31
31
  Requires-Dist: scipy (>=1.10)
32
32
  Requires-Dist: tensorflow (>=2.14.1,<2.16) ; extra == "tensorflow" or extra == "all"
33
- Requires-Dist: tensorflow-io-gcs-filesystem (>=0.35.0,<0.37) ; extra == "tensorflow" or extra == "all"
34
33
  Requires-Dist: tensorflow_probability (>=0.22.1,<0.24) ; extra == "tensorflow" or extra == "all"
35
- Requires-Dist: torch (>=2.0.1,!=2.2.0) ; extra == "torch" or extra == "all"
36
- Requires-Dist: torchvision (>=0.16.0) ; extra == "torch" or extra == "all"
34
+ Requires-Dist: torch (>=2.2.0) ; extra == "torch" or extra == "all"
35
+ Requires-Dist: torchvision (>=0.17.0) ; extra == "torch" or extra == "all"
37
36
  Requires-Dist: xxhash (>=3.3)
38
37
  Project-URL: Documentation, https://dataeval.readthedocs.io/
39
38
  Project-URL: Repository, https://github.com/aria-ml/dataeval/
@@ -75,6 +74,15 @@ You can install DataEval directly from pypi.org using the following command. Th
75
74
  pip install dataeval[all]
76
75
  ```
77
76
 
77
+ ### Installing DataEval in Conda/Mamba
78
+
79
+ DataEval can be installed in a Conda/Mamba environment using the provided `environment.yaml` file. As some dependencies
80
+ are installed from the `pytorch` channel, the channel is specified in the below example.
81
+
82
+ ```
83
+ micromamba create -f environment\environment.yaml -c pytorch
84
+ ```
85
+
78
86
  ### Installing DataEval from GitHub
79
87
 
80
88
  To install DataEval from source locally on Ubuntu, you will need `git-lfs` to download larger, binary source files and `poetry` for project dependency management.
@@ -34,6 +34,15 @@ You can install DataEval directly from pypi.org using the following command. Th
34
34
  pip install dataeval[all]
35
35
  ```
36
36
 
37
+ ### Installing DataEval in Conda/Mamba
38
+
39
+ DataEval can be installed in a Conda/Mamba environment using the provided `environment.yaml` file. As some dependencies
40
+ are installed from the `pytorch` channel, the channel is specified in the below example.
41
+
42
+ ```
43
+ micromamba create -f environment\environment.yaml -c pytorch
44
+ ```
45
+
37
46
  ### Installing DataEval from GitHub
38
47
 
39
48
  To install DataEval from source locally on Ubuntu, you will need `git-lfs` to download larger, binary source files and `poetry` for project dependency management.
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "dataeval"
3
- version = "0.69.4" # dynamic
3
+ version = "0.70.1" # dynamic
4
4
  description = "DataEval provides a simple interface to characterize image data and its impact on model performance across classification and object-detection tasks"
5
5
  license = "MIT"
6
6
  readme = "README.md"
@@ -52,15 +52,14 @@ xxhash = {version = ">=3.3"}
52
52
  matplotlib = {version = "*", optional = true}
53
53
  nvidia-cudnn-cu11 = {version = ">=8.6.0.163", optional = true}
54
54
  tensorflow = {version = ">=2.14.1, <2.16", optional = true}
55
- tensorflow-io-gcs-filesystem = {version = ">=0.35.0, <0.37", optional = true}
56
55
  tensorflow_probability = {version = ">=0.22.1, <0.24", optional = true}
57
- torch = {version = ">=2.0.1, !=2.2.0", source = "pytorch", optional = true}
58
- torchvision = {version = ">=0.16.0", source = "pytorch", optional = true}
56
+ torch = {version = ">=2.2.0", source = "pytorch", optional = true}
57
+ torchvision = {version = ">=0.17.0", source = "pytorch", optional = true}
59
58
 
60
59
  [tool.poetry.extras]
61
- tensorflow = ["tensorflow", "tensorflow-io-gcs-filesystem", "tensorflow_probability", "nvidia-cudnn-cu11"]
60
+ tensorflow = ["tensorflow", "tensorflow_probability", "nvidia-cudnn-cu11"]
62
61
  torch = ["torch", "torchvision", "matplotlib", "nvidia-cudnn-cu11"]
63
- all = ["matplotlib", "nvidia-cudnn-cu11", "tensorflow", "tensorflow-io-gcs-filesystem", "tensorflow_probability", "torch", "torchvision"]
62
+ all = ["matplotlib", "nvidia-cudnn-cu11", "tensorflow", "tensorflow_probability", "torch", "torchvision"]
64
63
 
65
64
  [tool.poetry.group.dev]
66
65
  optional = true
@@ -71,6 +70,7 @@ tox-uv = {version = "*"}
71
70
  uv = {version = "*"}
72
71
  poetry = {version = "*"}
73
72
  poetry-lock-groups-plugin = {version = "*"}
73
+ poetry2conda = {version = "*"}
74
74
  # lint
75
75
  ruff = {version = "*"}
76
76
  codespell = {version = "*", extras = ["toml"]}
@@ -113,6 +113,18 @@ pattern = "v(?P<base>\\d+\\.\\d+\\.\\d+)$"
113
113
  [tool.poetry-dynamic-versioning.substitution]
114
114
  files = ["src/dataeval/__init__.py"]
115
115
 
116
+ [tool.poetry2conda]
117
+ name = "dataeval"
118
+
119
+ [tool.poetry2conda.dependencies]
120
+ nvidia-cudnn-cu11 = { name = "cudnn" }
121
+ pillow = { channel = "pip" }
122
+ tensorflow = { channel = "pip" }
123
+ tensorflow_probability = { channel = "pip" }
124
+ torch = { name = "pytorch", channel = "pytorch" }
125
+ torchvision = { channel = "pytorch" }
126
+ xxhash = { name = "python-xxhash" }
127
+
116
128
  [tool.pyright]
117
129
  reportMissingImports = false
118
130
 
@@ -173,6 +185,7 @@ docstring-code-line-length = "dynamic"
173
185
 
174
186
  [tool.codespell]
175
187
  skip = './*env*,./prototype,./output,./docs/_build,./docs/.jupyter_cache,CHANGELOG.md,poetry.lock,*.html'
188
+ ignore-words-list = ["Hart"]
176
189
 
177
190
  [build-system]
178
191
  requires = ["poetry-core>=1.0.0", "poetry-dynamic-versioning"]
@@ -0,0 +1,22 @@
1
+ __version__ = "0.70.1"
2
+
3
+ from importlib.util import find_spec
4
+
5
+ _IS_TORCH_AVAILABLE = find_spec("torch") is not None
6
+ _IS_TENSORFLOW_AVAILABLE = find_spec("tensorflow") is not None and find_spec("tensorflow_probability") is not None
7
+
8
+ del find_spec
9
+
10
+ from . import detectors, metrics # noqa: E402
11
+
12
+ __all__ = ["detectors", "metrics"]
13
+
14
+ if _IS_TORCH_AVAILABLE: # pragma: no cover
15
+ from . import workflows
16
+
17
+ __all__ += ["workflows"]
18
+
19
+ if _IS_TENSORFLOW_AVAILABLE or _IS_TORCH_AVAILABLE: # pragma: no cover
20
+ from . import utils
21
+
22
+ __all__ += ["utils"]
@@ -0,0 +1,404 @@
1
+ from __future__ import annotations
2
+
3
+ import hashlib
4
+ import os
5
+ import zipfile
6
+ from pathlib import Path
7
+ from typing import Literal, TypeVar
8
+ from warnings import warn
9
+
10
+ import numpy as np
11
+ import requests
12
+ from numpy.typing import NDArray
13
+ from torch.utils.data import Dataset
14
+ from torchvision.datasets import CIFAR10, VOCDetection # noqa: F401
15
+
16
+ ClassStringMap = Literal["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
17
+ TClassMap = TypeVar("TClassMap", ClassStringMap, int, list[ClassStringMap], list[int])
18
+ CorruptionStringMap = Literal[
19
+ "identity",
20
+ "shot_noise",
21
+ "impulse_noise",
22
+ "glass_blur",
23
+ "motion_blur",
24
+ "shear",
25
+ "scale",
26
+ "rotate",
27
+ "brightness",
28
+ "translate",
29
+ "stripe",
30
+ "fog",
31
+ "spatter",
32
+ "dotted_line",
33
+ "zigzag",
34
+ "canny_edges",
35
+ ]
36
+
37
+
38
+ def _validate_file(fpath, file_md5, md5=False, chunk_size=65535):
39
+ hasher = hashlib.md5() if md5 else hashlib.sha256()
40
+ with open(fpath, "rb") as fpath_file:
41
+ while chunk := fpath_file.read(chunk_size):
42
+ hasher.update(chunk)
43
+ return hasher.hexdigest() == file_md5
44
+
45
+
46
+ def _get_file(
47
+ root: str | Path,
48
+ fname: str,
49
+ origin: str,
50
+ file_hash: str | None = None,
51
+ verbose: bool = True,
52
+ md5: bool = False,
53
+ ):
54
+ fpath = os.path.join(root, fname)
55
+ download = True
56
+ if os.path.exists(fpath) and file_hash is not None and _validate_file(fpath, file_hash, md5):
57
+ download = False
58
+ if verbose:
59
+ print("File already downloaded and verified.")
60
+ if md5:
61
+ print("Extracting zip file...")
62
+
63
+ if download:
64
+ try:
65
+ error_msg = "URL fetch failure on {}: {} -- {}"
66
+ try:
67
+ with requests.get(origin, stream=True, timeout=60) as r:
68
+ r.raise_for_status()
69
+ with open(fpath, "wb") as f:
70
+ for chunk in r.iter_content(chunk_size=8192):
71
+ if chunk:
72
+ f.write(chunk)
73
+ except requests.exceptions.HTTPError as e:
74
+ raise Exception(f"{error_msg.format(origin, e.response.status_code)} -- {e.response.reason}") from e
75
+ except requests.exceptions.RequestException as e:
76
+ raise Exception(f"{error_msg.format(origin, 'Unknown error')} -- {str(e)}") from e
77
+ except (Exception, KeyboardInterrupt):
78
+ if os.path.exists(fpath):
79
+ os.remove(fpath)
80
+ raise
81
+
82
+ if os.path.exists(fpath) and file_hash is not None and not _validate_file(fpath, file_hash, md5):
83
+ raise ValueError(
84
+ "Incomplete or corrupted file detected. "
85
+ f"The file hash does not match the provided value "
86
+ f"of {file_hash}.",
87
+ )
88
+
89
+ return fpath
90
+
91
+
92
+ def check_exists(
93
+ folder: str | Path,
94
+ url: str,
95
+ root: str | Path,
96
+ fname: str,
97
+ file_hash: str,
98
+ download: bool = True,
99
+ verbose: bool = True,
100
+ md5: bool = False,
101
+ ):
102
+ """Determine if the dataset has already been downloaded."""
103
+ location = str(folder)
104
+ if not os.path.exists(folder):
105
+ if download:
106
+ location = download_dataset(url, root, fname, file_hash, verbose, md5)
107
+ else:
108
+ raise RuntimeError("Dataset not found. You can use download=True to download it")
109
+ else:
110
+ if verbose:
111
+ print("Files already downloaded and verified")
112
+ return location
113
+
114
+
115
+ def download_dataset(
116
+ url: str, root: str | Path, fname: str, file_hash: str, verbose: bool = True, md5: bool = False
117
+ ) -> str:
118
+ """Code to download mnist and corruptions, originates from tensorflow_datasets (tfds):
119
+ https://github.com/tensorflow/datasets/blob/master/tensorflow_datasets/image_classification/mnist_corrupted.py
120
+ """
121
+ name, _ = os.path.splitext(fname)
122
+ folder = os.path.join(root, name)
123
+ os.makedirs(folder, exist_ok=True)
124
+
125
+ fpath = _get_file(
126
+ folder,
127
+ fname,
128
+ origin=url + fname,
129
+ file_hash=file_hash,
130
+ verbose=verbose,
131
+ md5=md5,
132
+ )
133
+ if md5:
134
+ folder = extract_archive(fpath, root, remove_finished=True)
135
+ return folder
136
+
137
+
138
+ def extract_archive(
139
+ from_path: str | Path,
140
+ to_path: str | Path | None = None,
141
+ remove_finished: bool = False,
142
+ ) -> str:
143
+ """Extract an archive.
144
+
145
+ The archive type and a possible compression is automatically detected from the file name.
146
+ """
147
+ from_path = Path(from_path)
148
+ if not from_path.is_absolute():
149
+ from_path = from_path.resolve()
150
+
151
+ if to_path is None or not os.path.exists(to_path):
152
+ to_path = os.path.dirname(from_path)
153
+ to_path = Path(to_path)
154
+ if not to_path.is_absolute():
155
+ to_path = to_path.resolve()
156
+
157
+ # Extracting zip
158
+ with zipfile.ZipFile(from_path, "r", compression=zipfile.ZIP_STORED) as zzip:
159
+ zzip.extractall(to_path)
160
+
161
+ if remove_finished:
162
+ os.remove(from_path)
163
+ return str(to_path)
164
+
165
+
166
+ def subselect(arr: NDArray, count: int, from_back: bool = False):
167
+ if from_back:
168
+ return arr[-count:]
169
+ return arr[:count]
170
+
171
+
172
+ class MNIST(Dataset):
173
+ """MNIST Dataset and Corruptions.
174
+
175
+ Args:
176
+ root : str | ``pathlib.Path``
177
+ Root directory of dataset where the ``mnist_c/`` folder exists.
178
+ train : bool, default True
179
+ If True, creates dataset from ``train_images.npy`` and ``train_labels.npy``.
180
+ download : bool, default False
181
+ If True, downloads the dataset from the internet and puts it in root
182
+ directory. If dataset is already downloaded, it is not downloaded again.
183
+ size : int, default -1
184
+ Limit the dataset size, must be a value greater than 0.
185
+ unit_interval : bool, default False
186
+ Shift the data values to the unit interval [0-1].
187
+ dtype : type | None, default None
188
+ Change the numpy dtype - data is loaded as np.uint8
189
+ channels : Literal['channels_first' | 'channels_last'] | None, default None
190
+ Location of channel axis if desired, default has no channels (N, 28, 28)
191
+ flatten : bool, default False
192
+ Flatten data into single dimension (N, 784) - cannot use both channels and flatten,
193
+ channels takes priority over flatten.
194
+ normalize : tuple[mean, std] | None, default None
195
+ Normalize images acorrding to provided mean and standard deviation
196
+ corruption : Literal['identity' | 'shot_noise' | 'impulse_noise' | 'glass_blur' |
197
+ 'motion_blur' | 'shear' | 'scale' | 'rotate' | 'brightness' | 'translate' | 'stripe' |
198
+ 'fog' | 'spatter' | 'dotted_line' | 'zigzag' | 'canny_edges'] | None, default None
199
+ The desired corruption style or None.
200
+ classes : Literal["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
201
+ | int | list[int] | list[Literal["zero", "one", "two", "three", "four", "five", "six", "seven",
202
+ "eight", "nine"]] | None, default None
203
+ Option to select specific classes from dataset.
204
+ balance : bool, default True
205
+ If True, returns equal number of samples for each class.
206
+ randomize : bool, default False
207
+ If True, shuffles the data prior to selection - uses a set seed for reproducibility.
208
+ slice_back : bool, default False
209
+ If True and size has a value greater than 0, then grabs selection starting at the last image.
210
+ verbose : bool, default True
211
+ If True, outputs print statements.
212
+ """
213
+
214
+ mirror = [
215
+ "https://storage.googleapis.com/tensorflow/tf-keras-datasets/",
216
+ "https://zenodo.org/record/3239543/files/",
217
+ ]
218
+
219
+ resources = [
220
+ ("mnist.npz", "731c5ac602752760c8e48fbffcf8c3b850d9dc2a2aedcf2cc48468fc17b673d1"),
221
+ ("mnist_c.zip", "4b34b33045869ee6d424616cd3a65da3"),
222
+ ]
223
+
224
+ class_dict = {
225
+ "zero": 0,
226
+ "one": 1,
227
+ "two": 2,
228
+ "three": 3,
229
+ "four": 4,
230
+ "five": 5,
231
+ "six": 6,
232
+ "seven": 7,
233
+ "eight": 8,
234
+ "nine": 9,
235
+ }
236
+
237
+ def __init__(
238
+ self,
239
+ root: str | Path,
240
+ train: bool = True,
241
+ download: bool = False,
242
+ size: int = -1,
243
+ unit_interval: bool = False,
244
+ dtype: type | None = None,
245
+ channels: Literal["channels_first", "channels_last"] | None = None,
246
+ flatten: bool = False,
247
+ normalize: tuple[float, float] | None = None,
248
+ corruption: CorruptionStringMap | None = None,
249
+ classes: TClassMap | None = None,
250
+ balance: bool = True,
251
+ randomize: bool = False,
252
+ slice_back: bool = False,
253
+ verbose: bool = True,
254
+ ) -> None:
255
+ if isinstance(root, str):
256
+ root = os.path.expanduser(root)
257
+ self.root = root # location of stored dataset
258
+ self.train = train # training set or test set
259
+ self.size = size
260
+ self.unit_interval = unit_interval
261
+ self.dtype = dtype
262
+ self.channels = channels
263
+ self.flatten = flatten
264
+ self.normalize = normalize
265
+ self.corruption = corruption
266
+ self.balance = balance
267
+ self.randomize = randomize
268
+ self.from_back = slice_back
269
+ self.verbose = verbose
270
+
271
+ self.class_set = []
272
+ if classes is not None:
273
+ if not isinstance(classes, list):
274
+ classes = [classes] # type: ignore
275
+
276
+ for val in classes: # type: ignore
277
+ if isinstance(val, int) and 0 <= val < 10:
278
+ self.class_set.append(val)
279
+ elif isinstance(val, str):
280
+ self.class_set.append(self.class_dict[val])
281
+ self.class_set = set(self.class_set)
282
+
283
+ if not self.class_set:
284
+ self.class_set = set(self.class_dict.values())
285
+
286
+ self.num_classes = len(self.class_set)
287
+
288
+ if self.corruption is None:
289
+ file_resource = self.resources[0]
290
+ mirror = self.mirror[0]
291
+ md5 = False
292
+ else:
293
+ if self.corruption == "identity" and verbose:
294
+ print("Identity is not a corrupted dataset but the original MNIST dataset.")
295
+ file_resource = self.resources[1]
296
+ mirror = self.mirror[1]
297
+ md5 = True
298
+ check_exists(self.mnist_folder, mirror, self.root, file_resource[0], file_resource[1], download, verbose, md5)
299
+
300
+ self.data, self.targets = self._load_data()
301
+
302
+ self._augmentations()
303
+
304
+ def _load_data(self):
305
+ if self.corruption is None:
306
+ image_file = self.resources[0][0]
307
+ data, targets = self._read_normal_file(os.path.join(self.mnist_folder, image_file))
308
+ else:
309
+ image_file = f"{'train' if self.train else 'test'}_images.npy"
310
+ data = self._read_corrupt_file(os.path.join(self.mnist_folder, image_file))
311
+ data = data.squeeze()
312
+
313
+ label_file = f"{'train' if self.train else 'test'}_labels.npy"
314
+ targets = self._read_corrupt_file(os.path.join(self.mnist_folder, label_file))
315
+
316
+ return data, targets
317
+
318
+ def _augmentations(self):
319
+ if self.size > self.targets.shape[0] and self.verbose:
320
+ warn(
321
+ f"Asked for more samples, {self.size}, than the raw dataset contains, {self.targets.shape[0]}. "
322
+ "Adjusting down to raw dataset size."
323
+ )
324
+ self.size = -1
325
+
326
+ if self.randomize:
327
+ rdm_seed = np.random.default_rng(2023)
328
+ shuffled_indices = rdm_seed.permutation(self.data.shape[0])
329
+ self.data = self.data[shuffled_indices]
330
+ self.targets = self.targets[shuffled_indices]
331
+
332
+ if not self.balance and self.num_classes > self.size:
333
+ if self.size > 0:
334
+ self.data = subselect(self.data, self.size, self.from_back)
335
+ self.targets = subselect(self.targets, self.size, self.from_back)
336
+ else:
337
+ label_dict = {label: np.where(self.targets == label)[0] for label in self.class_set}
338
+ min_label_count = min(len(indices) for indices in label_dict.values())
339
+
340
+ self.per_class_count = int(np.ceil(self.size / self.num_classes)) if self.size > 0 else min_label_count
341
+
342
+ if self.per_class_count > min_label_count:
343
+ self.per_class_count = min_label_count
344
+ if not self.balance and self.verbose:
345
+ warn(
346
+ f"Because of dataset limitations, only {min_label_count*self.num_classes} samples "
347
+ f"will be returned, instead of the desired {self.size}."
348
+ )
349
+
350
+ all_indices = np.empty(shape=(self.num_classes, self.per_class_count), dtype=int)
351
+ for i, label in enumerate(self.class_set):
352
+ all_indices[i] = subselect(label_dict[label], self.per_class_count, self.from_back)
353
+ self.data = np.vstack(self.data[all_indices.T]) # type: ignore
354
+ self.targets = np.hstack(self.targets[all_indices.T]) # type: ignore
355
+
356
+ if self.unit_interval:
357
+ self.data = self.data / 255
358
+
359
+ if self.normalize:
360
+ self.data = (self.data - self.normalize[0]) / self.normalize[1]
361
+
362
+ if self.dtype:
363
+ self.data = self.data.astype(self.dtype)
364
+
365
+ if self.channels == "channels_first":
366
+ self.data = self.data[:, np.newaxis, :, :]
367
+ elif self.channels == "channels_last":
368
+ self.data = self.data[:, :, :, np.newaxis]
369
+
370
+ if self.flatten and self.channels is None:
371
+ self.data = self.data.reshape(self.data.shape[0], -1)
372
+
373
+ def __getitem__(self, index: int) -> tuple[NDArray, int]:
374
+ """
375
+ Args:
376
+ index (int): Index
377
+
378
+ Returns:
379
+ tuple: (image, target) where target is index of the target class.
380
+ """
381
+ img, target = self.data[index], int(self.targets[index])
382
+
383
+ return img, target
384
+
385
+ def __len__(self) -> int:
386
+ return len(self.data)
387
+
388
+ @property
389
+ def mnist_folder(self) -> str:
390
+ if self.corruption is None:
391
+ return os.path.join(self.root, "mnist")
392
+ return os.path.join(self.root, "mnist_c", self.corruption)
393
+
394
+ def _read_normal_file(self, path: str) -> tuple[NDArray, NDArray]:
395
+ with np.load(path, allow_pickle=True) as f:
396
+ if self.train:
397
+ x, y = f["x_train"], f["y_train"]
398
+ else:
399
+ x, y = f["x_test"], f["y_test"]
400
+ return x, y
401
+
402
+ def _read_corrupt_file(self, path: str) -> NDArray:
403
+ x = np.load(path, allow_pickle=False)
404
+ return x
@@ -16,6 +16,8 @@ from dataeval._internal.output import OutputMetadata, set_metadata
16
16
  @dataclass(frozen=True)
17
17
  class ClustererOutput(OutputMetadata):
18
18
  """
19
+ Output class for :class:`Clusterer` lint detector
20
+
19
21
  Attributes
20
22
  ----------
21
23
  outliers : List[int]
@@ -16,14 +16,14 @@ from typing import Callable, Literal
16
16
  import numpy as np
17
17
  from numpy.typing import ArrayLike, NDArray
18
18
 
19
- from dataeval._internal.interop import to_numpy
19
+ from dataeval._internal.interop import as_numpy, to_numpy
20
20
  from dataeval._internal.output import OutputMetadata, set_metadata
21
21
 
22
22
 
23
23
  @dataclass(frozen=True)
24
24
  class DriftBaseOutput(OutputMetadata):
25
25
  """
26
- Output class for Drift
26
+ Base output class for Drift detector classes
27
27
 
28
28
  Attributes
29
29
  ----------
@@ -42,7 +42,7 @@ class DriftBaseOutput(OutputMetadata):
42
42
  @dataclass(frozen=True)
43
43
  class DriftOutput(DriftBaseOutput):
44
44
  """
45
- Output class for DriftCVM and DriftKS
45
+ Output class for :class:`DriftCVM`, :class:`DriftKS`, and :class:`DriftUncertainty` drift detectors
46
46
 
47
47
  Attributes
48
48
  ----------
@@ -234,7 +234,7 @@ class BaseDrift:
234
234
  if correction not in ["bonferroni", "fdr"]:
235
235
  raise ValueError("`correction` must be `bonferroni` or `fdr`.")
236
236
 
237
- self._x_ref = x_ref
237
+ self._x_ref = to_numpy(x_ref)
238
238
  self.x_ref_preprocessed = x_ref_preprocessed
239
239
 
240
240
  # Other attributes
@@ -242,7 +242,7 @@ class BaseDrift:
242
242
  self.update_x_ref = update_x_ref
243
243
  self.preprocess_fn = preprocess_fn
244
244
  self.correction = correction
245
- self.n = len(self._x_ref) # type: ignore
245
+ self.n = len(self._x_ref)
246
246
 
247
247
  # Ref counter for preprocessed x
248
248
  self._x_refcount = 0
@@ -260,9 +260,8 @@ class BaseDrift:
260
260
  if not self.x_ref_preprocessed:
261
261
  self.x_ref_preprocessed = True
262
262
  if self.preprocess_fn is not None:
263
- self._x_ref = self.preprocess_fn(self._x_ref)
263
+ self._x_ref = as_numpy(self.preprocess_fn(self._x_ref))
264
264
 
265
- self._x_ref = to_numpy(self._x_ref)
266
265
  return self._x_ref
267
266
 
268
267
  def _preprocess(self, x: ArrayLike) -> ArrayLike:
@@ -380,7 +379,7 @@ class BaseDriftUnivariate(BaseDrift):
380
379
  self._n_features = self.x_ref.reshape(self.x_ref.shape[0], -1).shape[-1]
381
380
  else:
382
381
  # infer number of features after applying preprocessing step
383
- x = to_numpy(self.preprocess_fn(self._x_ref[0:1])) # type: ignore
382
+ x = as_numpy(self.preprocess_fn(self._x_ref[0:1])) # type: ignore
384
383
  self._n_features = x.reshape(x.shape[0], -1).shape[-1]
385
384
 
386
385
  return self._n_features
@@ -14,7 +14,7 @@ from typing import Callable
14
14
  import torch
15
15
  from numpy.typing import ArrayLike
16
16
 
17
- from dataeval._internal.interop import to_numpy
17
+ from dataeval._internal.interop import as_numpy
18
18
  from dataeval._internal.output import set_metadata
19
19
 
20
20
  from .base import BaseDrift, DriftBaseOutput, UpdateStrategy, preprocess_x, update_x_ref
@@ -24,7 +24,7 @@ from .torch import GaussianRBF, get_device, mmd2_from_kernel_matrix
24
24
  @dataclass(frozen=True)
25
25
  class DriftMMDOutput(DriftBaseOutput):
26
26
  """
27
- Output class for DriftMMD
27
+ Output class for :class:`DriftMMD` drift detector
28
28
 
29
29
  Attributes
30
30
  ----------
@@ -110,7 +110,7 @@ class DriftMMD(BaseDrift):
110
110
  self.device = get_device(device)
111
111
 
112
112
  # initialize kernel
113
- sigma_tensor = torch.from_numpy(to_numpy(sigma)).to(self.device) if sigma is not None else None
113
+ sigma_tensor = torch.from_numpy(as_numpy(sigma)).to(self.device) if sigma is not None else None
114
114
  self.kernel = kernel(sigma_tensor).to(self.device) if kernel == GaussianRBF else kernel
115
115
 
116
116
  # compute kernel matrix for the reference data
@@ -147,7 +147,7 @@ class DriftMMD(BaseDrift):
147
147
  p-value obtained from the permutation test, MMD^2 between the reference and test set,
148
148
  and MMD^2 threshold above which drift is flagged
149
149
  """
150
- x = to_numpy(x)
150
+ x = as_numpy(x)
151
151
  x_ref = torch.from_numpy(self.x_ref).to(self.device)
152
152
  n = x.shape[0]
153
153
  kernel_mat = self._kernel_matrix(x_ref, torch.from_numpy(x).to(self.device))