dataeval 0.69.4__tar.gz → 0.70.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dataeval-0.69.4 → dataeval-0.70.1}/PKG-INFO +12 -4
- {dataeval-0.69.4 → dataeval-0.70.1}/README.md +9 -0
- {dataeval-0.69.4 → dataeval-0.70.1}/pyproject.toml +19 -6
- dataeval-0.70.1/src/dataeval/__init__.py +22 -0
- dataeval-0.70.1/src/dataeval/_internal/datasets.py +404 -0
- {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/detectors/clusterer.py +2 -0
- {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/detectors/drift/base.py +7 -8
- {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/detectors/drift/mmd.py +4 -4
- {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/detectors/duplicates.py +64 -45
- dataeval-0.70.1/src/dataeval/_internal/detectors/merged_stats.py +47 -0
- {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/detectors/ood/ae.py +8 -6
- {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/detectors/ood/aegmm.py +6 -4
- {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/detectors/ood/base.py +12 -7
- {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/detectors/ood/llr.py +6 -4
- {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/detectors/ood/vae.py +5 -3
- {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/detectors/ood/vaegmm.py +6 -4
- dataeval-0.70.1/src/dataeval/_internal/detectors/outliers.py +271 -0
- {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/interop.py +11 -7
- {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/metrics/balance.py +13 -11
- {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/metrics/ber.py +5 -3
- {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/metrics/coverage.py +4 -0
- {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/metrics/divergence.py +9 -5
- {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/metrics/diversity.py +14 -12
- {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/metrics/parity.py +32 -22
- dataeval-0.70.1/src/dataeval/_internal/metrics/stats/base.py +231 -0
- dataeval-0.70.1/src/dataeval/_internal/metrics/stats/boxratiostats.py +159 -0
- dataeval-0.70.1/src/dataeval/_internal/metrics/stats/datasetstats.py +99 -0
- dataeval-0.70.1/src/dataeval/_internal/metrics/stats/dimensionstats.py +113 -0
- dataeval-0.70.1/src/dataeval/_internal/metrics/stats/hashstats.py +75 -0
- dataeval-0.70.1/src/dataeval/_internal/metrics/stats/labelstats.py +125 -0
- dataeval-0.70.1/src/dataeval/_internal/metrics/stats/pixelstats.py +119 -0
- dataeval-0.70.1/src/dataeval/_internal/metrics/stats/visualstats.py +124 -0
- {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/metrics/uap.py +8 -4
- {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/metrics/utils.py +30 -15
- {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/models/pytorch/autoencoder.py +5 -5
- {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/models/tensorflow/pixelcnn.py +1 -4
- {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/output.py +3 -18
- {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/utils.py +11 -16
- {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/workflows/sufficiency.py +152 -151
- {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/detectors/__init__.py +4 -0
- {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/detectors/drift/__init__.py +8 -3
- {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/detectors/drift/kernels/__init__.py +4 -0
- {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/detectors/drift/updates/__init__.py +4 -0
- dataeval-0.70.1/src/dataeval/detectors/linters/__init__.py +16 -0
- {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/detectors/ood/__init__.py +14 -2
- dataeval-0.70.1/src/dataeval/metrics/__init__.py +8 -0
- dataeval-0.70.1/src/dataeval/metrics/bias/__init__.py +21 -0
- dataeval-0.70.1/src/dataeval/metrics/estimators/__init__.py +9 -0
- dataeval-0.70.1/src/dataeval/metrics/stats/__init__.py +28 -0
- dataeval-0.70.1/src/dataeval/utils/__init__.py +19 -0
- dataeval-0.70.1/src/dataeval/utils/tensorflow/__init__.py +11 -0
- dataeval-0.70.1/src/dataeval/utils/torch/__init__.py +12 -0
- dataeval-0.70.1/src/dataeval/utils/torch/datasets/__init__.py +7 -0
- dataeval-0.70.1/src/dataeval/workflows/__init__.py +10 -0
- dataeval-0.69.4/src/dataeval/__init__.py +0 -22
- dataeval-0.69.4/src/dataeval/_internal/datasets.py +0 -300
- dataeval-0.69.4/src/dataeval/_internal/detectors/merged_stats.py +0 -78
- dataeval-0.69.4/src/dataeval/_internal/detectors/outliers.py +0 -197
- dataeval-0.69.4/src/dataeval/_internal/flags.py +0 -77
- dataeval-0.69.4/src/dataeval/_internal/metrics/stats.py +0 -397
- dataeval-0.69.4/src/dataeval/detectors/linters/__init__.py +0 -5
- dataeval-0.69.4/src/dataeval/flags/__init__.py +0 -3
- dataeval-0.69.4/src/dataeval/metrics/__init__.py +0 -3
- dataeval-0.69.4/src/dataeval/metrics/bias/__init__.py +0 -12
- dataeval-0.69.4/src/dataeval/metrics/estimators/__init__.py +0 -9
- dataeval-0.69.4/src/dataeval/metrics/stats/__init__.py +0 -6
- dataeval-0.69.4/src/dataeval/tensorflow/__init__.py +0 -3
- dataeval-0.69.4/src/dataeval/torch/__init__.py +0 -3
- dataeval-0.69.4/src/dataeval/utils/__init__.py +0 -6
- dataeval-0.69.4/src/dataeval/workflows/__init__.py +0 -6
- {dataeval-0.69.4 → dataeval-0.70.1}/LICENSE.txt +0 -0
- {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/detectors/__init__.py +0 -0
- {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/detectors/drift/__init__.py +0 -0
- {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/detectors/drift/cvm.py +0 -0
- {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/detectors/drift/ks.py +0 -0
- {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/detectors/drift/torch.py +0 -0
- {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/detectors/drift/uncertainty.py +0 -0
- {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/detectors/ood/__init__.py +0 -0
- {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/metrics/__init__.py +0 -0
- {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/models/__init__.py +0 -0
- {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/models/pytorch/__init__.py +0 -0
- {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/models/pytorch/blocks.py +0 -0
- {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/models/pytorch/utils.py +0 -0
- {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/models/tensorflow/__init__.py +0 -0
- {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/models/tensorflow/autoencoder.py +0 -0
- {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/models/tensorflow/gmm.py +0 -0
- {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/models/tensorflow/losses.py +0 -0
- {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/models/tensorflow/trainer.py +0 -0
- {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/models/tensorflow/utils.py +0 -0
- {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/_internal/workflows/__init__.py +0 -0
- {dataeval-0.69.4 → dataeval-0.70.1}/src/dataeval/py.typed +0 -0
- {dataeval-0.69.4/src/dataeval → dataeval-0.70.1/src/dataeval/utils}/tensorflow/loss/__init__.py +0 -0
- {dataeval-0.69.4/src/dataeval → dataeval-0.70.1/src/dataeval/utils}/tensorflow/models/__init__.py +0 -0
- {dataeval-0.69.4/src/dataeval → dataeval-0.70.1/src/dataeval/utils}/tensorflow/recon/__init__.py +0 -0
- {dataeval-0.69.4/src/dataeval → dataeval-0.70.1/src/dataeval/utils}/torch/models/__init__.py +0 -0
- {dataeval-0.69.4/src/dataeval → dataeval-0.70.1/src/dataeval/utils}/torch/trainer/__init__.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: dataeval
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.70.1
|
4
4
|
Summary: DataEval provides a simple interface to characterize image data and its impact on model performance across classification and object-detection tasks
|
5
5
|
Home-page: https://dataeval.ai/
|
6
6
|
License: MIT
|
@@ -30,10 +30,9 @@ Requires-Dist: pillow (>=10.3.0)
|
|
30
30
|
Requires-Dist: scikit-learn (>=1.5.0)
|
31
31
|
Requires-Dist: scipy (>=1.10)
|
32
32
|
Requires-Dist: tensorflow (>=2.14.1,<2.16) ; extra == "tensorflow" or extra == "all"
|
33
|
-
Requires-Dist: tensorflow-io-gcs-filesystem (>=0.35.0,<0.37) ; extra == "tensorflow" or extra == "all"
|
34
33
|
Requires-Dist: tensorflow_probability (>=0.22.1,<0.24) ; extra == "tensorflow" or extra == "all"
|
35
|
-
Requires-Dist: torch (>=2.
|
36
|
-
Requires-Dist: torchvision (>=0.
|
34
|
+
Requires-Dist: torch (>=2.2.0) ; extra == "torch" or extra == "all"
|
35
|
+
Requires-Dist: torchvision (>=0.17.0) ; extra == "torch" or extra == "all"
|
37
36
|
Requires-Dist: xxhash (>=3.3)
|
38
37
|
Project-URL: Documentation, https://dataeval.readthedocs.io/
|
39
38
|
Project-URL: Repository, https://github.com/aria-ml/dataeval/
|
@@ -75,6 +74,15 @@ You can install DataEval directly from pypi.org using the following command. Th
|
|
75
74
|
pip install dataeval[all]
|
76
75
|
```
|
77
76
|
|
77
|
+
### Installing DataEval in Conda/Mamba
|
78
|
+
|
79
|
+
DataEval can be installed in a Conda/Mamba environment using the provided `environment.yaml` file. As some dependencies
|
80
|
+
are installed from the `pytorch` channel, the channel is specified in the below example.
|
81
|
+
|
82
|
+
```
|
83
|
+
micromamba create -f environment\environment.yaml -c pytorch
|
84
|
+
```
|
85
|
+
|
78
86
|
### Installing DataEval from GitHub
|
79
87
|
|
80
88
|
To install DataEval from source locally on Ubuntu, you will need `git-lfs` to download larger, binary source files and `poetry` for project dependency management.
|
@@ -34,6 +34,15 @@ You can install DataEval directly from pypi.org using the following command. Th
|
|
34
34
|
pip install dataeval[all]
|
35
35
|
```
|
36
36
|
|
37
|
+
### Installing DataEval in Conda/Mamba
|
38
|
+
|
39
|
+
DataEval can be installed in a Conda/Mamba environment using the provided `environment.yaml` file. As some dependencies
|
40
|
+
are installed from the `pytorch` channel, the channel is specified in the below example.
|
41
|
+
|
42
|
+
```
|
43
|
+
micromamba create -f environment\environment.yaml -c pytorch
|
44
|
+
```
|
45
|
+
|
37
46
|
### Installing DataEval from GitHub
|
38
47
|
|
39
48
|
To install DataEval from source locally on Ubuntu, you will need `git-lfs` to download larger, binary source files and `poetry` for project dependency management.
|
@@ -1,6 +1,6 @@
|
|
1
1
|
[tool.poetry]
|
2
2
|
name = "dataeval"
|
3
|
-
version = "0.
|
3
|
+
version = "0.70.1" # dynamic
|
4
4
|
description = "DataEval provides a simple interface to characterize image data and its impact on model performance across classification and object-detection tasks"
|
5
5
|
license = "MIT"
|
6
6
|
readme = "README.md"
|
@@ -52,15 +52,14 @@ xxhash = {version = ">=3.3"}
|
|
52
52
|
matplotlib = {version = "*", optional = true}
|
53
53
|
nvidia-cudnn-cu11 = {version = ">=8.6.0.163", optional = true}
|
54
54
|
tensorflow = {version = ">=2.14.1, <2.16", optional = true}
|
55
|
-
tensorflow-io-gcs-filesystem = {version = ">=0.35.0, <0.37", optional = true}
|
56
55
|
tensorflow_probability = {version = ">=0.22.1, <0.24", optional = true}
|
57
|
-
torch = {version = ">=2.
|
58
|
-
torchvision = {version = ">=0.
|
56
|
+
torch = {version = ">=2.2.0", source = "pytorch", optional = true}
|
57
|
+
torchvision = {version = ">=0.17.0", source = "pytorch", optional = true}
|
59
58
|
|
60
59
|
[tool.poetry.extras]
|
61
|
-
tensorflow = ["tensorflow", "
|
60
|
+
tensorflow = ["tensorflow", "tensorflow_probability", "nvidia-cudnn-cu11"]
|
62
61
|
torch = ["torch", "torchvision", "matplotlib", "nvidia-cudnn-cu11"]
|
63
|
-
all = ["matplotlib", "nvidia-cudnn-cu11", "tensorflow", "
|
62
|
+
all = ["matplotlib", "nvidia-cudnn-cu11", "tensorflow", "tensorflow_probability", "torch", "torchvision"]
|
64
63
|
|
65
64
|
[tool.poetry.group.dev]
|
66
65
|
optional = true
|
@@ -71,6 +70,7 @@ tox-uv = {version = "*"}
|
|
71
70
|
uv = {version = "*"}
|
72
71
|
poetry = {version = "*"}
|
73
72
|
poetry-lock-groups-plugin = {version = "*"}
|
73
|
+
poetry2conda = {version = "*"}
|
74
74
|
# lint
|
75
75
|
ruff = {version = "*"}
|
76
76
|
codespell = {version = "*", extras = ["toml"]}
|
@@ -113,6 +113,18 @@ pattern = "v(?P<base>\\d+\\.\\d+\\.\\d+)$"
|
|
113
113
|
[tool.poetry-dynamic-versioning.substitution]
|
114
114
|
files = ["src/dataeval/__init__.py"]
|
115
115
|
|
116
|
+
[tool.poetry2conda]
|
117
|
+
name = "dataeval"
|
118
|
+
|
119
|
+
[tool.poetry2conda.dependencies]
|
120
|
+
nvidia-cudnn-cu11 = { name = "cudnn" }
|
121
|
+
pillow = { channel = "pip" }
|
122
|
+
tensorflow = { channel = "pip" }
|
123
|
+
tensorflow_probability = { channel = "pip" }
|
124
|
+
torch = { name = "pytorch", channel = "pytorch" }
|
125
|
+
torchvision = { channel = "pytorch" }
|
126
|
+
xxhash = { name = "python-xxhash" }
|
127
|
+
|
116
128
|
[tool.pyright]
|
117
129
|
reportMissingImports = false
|
118
130
|
|
@@ -173,6 +185,7 @@ docstring-code-line-length = "dynamic"
|
|
173
185
|
|
174
186
|
[tool.codespell]
|
175
187
|
skip = './*env*,./prototype,./output,./docs/_build,./docs/.jupyter_cache,CHANGELOG.md,poetry.lock,*.html'
|
188
|
+
ignore-words-list = ["Hart"]
|
176
189
|
|
177
190
|
[build-system]
|
178
191
|
requires = ["poetry-core>=1.0.0", "poetry-dynamic-versioning"]
|
@@ -0,0 +1,22 @@
|
|
1
|
+
__version__ = "0.70.1"
|
2
|
+
|
3
|
+
from importlib.util import find_spec
|
4
|
+
|
5
|
+
_IS_TORCH_AVAILABLE = find_spec("torch") is not None
|
6
|
+
_IS_TENSORFLOW_AVAILABLE = find_spec("tensorflow") is not None and find_spec("tensorflow_probability") is not None
|
7
|
+
|
8
|
+
del find_spec
|
9
|
+
|
10
|
+
from . import detectors, metrics # noqa: E402
|
11
|
+
|
12
|
+
__all__ = ["detectors", "metrics"]
|
13
|
+
|
14
|
+
if _IS_TORCH_AVAILABLE: # pragma: no cover
|
15
|
+
from . import workflows
|
16
|
+
|
17
|
+
__all__ += ["workflows"]
|
18
|
+
|
19
|
+
if _IS_TENSORFLOW_AVAILABLE or _IS_TORCH_AVAILABLE: # pragma: no cover
|
20
|
+
from . import utils
|
21
|
+
|
22
|
+
__all__ += ["utils"]
|
@@ -0,0 +1,404 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import hashlib
|
4
|
+
import os
|
5
|
+
import zipfile
|
6
|
+
from pathlib import Path
|
7
|
+
from typing import Literal, TypeVar
|
8
|
+
from warnings import warn
|
9
|
+
|
10
|
+
import numpy as np
|
11
|
+
import requests
|
12
|
+
from numpy.typing import NDArray
|
13
|
+
from torch.utils.data import Dataset
|
14
|
+
from torchvision.datasets import CIFAR10, VOCDetection # noqa: F401
|
15
|
+
|
16
|
+
ClassStringMap = Literal["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
|
17
|
+
TClassMap = TypeVar("TClassMap", ClassStringMap, int, list[ClassStringMap], list[int])
|
18
|
+
CorruptionStringMap = Literal[
|
19
|
+
"identity",
|
20
|
+
"shot_noise",
|
21
|
+
"impulse_noise",
|
22
|
+
"glass_blur",
|
23
|
+
"motion_blur",
|
24
|
+
"shear",
|
25
|
+
"scale",
|
26
|
+
"rotate",
|
27
|
+
"brightness",
|
28
|
+
"translate",
|
29
|
+
"stripe",
|
30
|
+
"fog",
|
31
|
+
"spatter",
|
32
|
+
"dotted_line",
|
33
|
+
"zigzag",
|
34
|
+
"canny_edges",
|
35
|
+
]
|
36
|
+
|
37
|
+
|
38
|
+
def _validate_file(fpath, file_md5, md5=False, chunk_size=65535):
|
39
|
+
hasher = hashlib.md5() if md5 else hashlib.sha256()
|
40
|
+
with open(fpath, "rb") as fpath_file:
|
41
|
+
while chunk := fpath_file.read(chunk_size):
|
42
|
+
hasher.update(chunk)
|
43
|
+
return hasher.hexdigest() == file_md5
|
44
|
+
|
45
|
+
|
46
|
+
def _get_file(
|
47
|
+
root: str | Path,
|
48
|
+
fname: str,
|
49
|
+
origin: str,
|
50
|
+
file_hash: str | None = None,
|
51
|
+
verbose: bool = True,
|
52
|
+
md5: bool = False,
|
53
|
+
):
|
54
|
+
fpath = os.path.join(root, fname)
|
55
|
+
download = True
|
56
|
+
if os.path.exists(fpath) and file_hash is not None and _validate_file(fpath, file_hash, md5):
|
57
|
+
download = False
|
58
|
+
if verbose:
|
59
|
+
print("File already downloaded and verified.")
|
60
|
+
if md5:
|
61
|
+
print("Extracting zip file...")
|
62
|
+
|
63
|
+
if download:
|
64
|
+
try:
|
65
|
+
error_msg = "URL fetch failure on {}: {} -- {}"
|
66
|
+
try:
|
67
|
+
with requests.get(origin, stream=True, timeout=60) as r:
|
68
|
+
r.raise_for_status()
|
69
|
+
with open(fpath, "wb") as f:
|
70
|
+
for chunk in r.iter_content(chunk_size=8192):
|
71
|
+
if chunk:
|
72
|
+
f.write(chunk)
|
73
|
+
except requests.exceptions.HTTPError as e:
|
74
|
+
raise Exception(f"{error_msg.format(origin, e.response.status_code)} -- {e.response.reason}") from e
|
75
|
+
except requests.exceptions.RequestException as e:
|
76
|
+
raise Exception(f"{error_msg.format(origin, 'Unknown error')} -- {str(e)}") from e
|
77
|
+
except (Exception, KeyboardInterrupt):
|
78
|
+
if os.path.exists(fpath):
|
79
|
+
os.remove(fpath)
|
80
|
+
raise
|
81
|
+
|
82
|
+
if os.path.exists(fpath) and file_hash is not None and not _validate_file(fpath, file_hash, md5):
|
83
|
+
raise ValueError(
|
84
|
+
"Incomplete or corrupted file detected. "
|
85
|
+
f"The file hash does not match the provided value "
|
86
|
+
f"of {file_hash}.",
|
87
|
+
)
|
88
|
+
|
89
|
+
return fpath
|
90
|
+
|
91
|
+
|
92
|
+
def check_exists(
|
93
|
+
folder: str | Path,
|
94
|
+
url: str,
|
95
|
+
root: str | Path,
|
96
|
+
fname: str,
|
97
|
+
file_hash: str,
|
98
|
+
download: bool = True,
|
99
|
+
verbose: bool = True,
|
100
|
+
md5: bool = False,
|
101
|
+
):
|
102
|
+
"""Determine if the dataset has already been downloaded."""
|
103
|
+
location = str(folder)
|
104
|
+
if not os.path.exists(folder):
|
105
|
+
if download:
|
106
|
+
location = download_dataset(url, root, fname, file_hash, verbose, md5)
|
107
|
+
else:
|
108
|
+
raise RuntimeError("Dataset not found. You can use download=True to download it")
|
109
|
+
else:
|
110
|
+
if verbose:
|
111
|
+
print("Files already downloaded and verified")
|
112
|
+
return location
|
113
|
+
|
114
|
+
|
115
|
+
def download_dataset(
|
116
|
+
url: str, root: str | Path, fname: str, file_hash: str, verbose: bool = True, md5: bool = False
|
117
|
+
) -> str:
|
118
|
+
"""Code to download mnist and corruptions, originates from tensorflow_datasets (tfds):
|
119
|
+
https://github.com/tensorflow/datasets/blob/master/tensorflow_datasets/image_classification/mnist_corrupted.py
|
120
|
+
"""
|
121
|
+
name, _ = os.path.splitext(fname)
|
122
|
+
folder = os.path.join(root, name)
|
123
|
+
os.makedirs(folder, exist_ok=True)
|
124
|
+
|
125
|
+
fpath = _get_file(
|
126
|
+
folder,
|
127
|
+
fname,
|
128
|
+
origin=url + fname,
|
129
|
+
file_hash=file_hash,
|
130
|
+
verbose=verbose,
|
131
|
+
md5=md5,
|
132
|
+
)
|
133
|
+
if md5:
|
134
|
+
folder = extract_archive(fpath, root, remove_finished=True)
|
135
|
+
return folder
|
136
|
+
|
137
|
+
|
138
|
+
def extract_archive(
|
139
|
+
from_path: str | Path,
|
140
|
+
to_path: str | Path | None = None,
|
141
|
+
remove_finished: bool = False,
|
142
|
+
) -> str:
|
143
|
+
"""Extract an archive.
|
144
|
+
|
145
|
+
The archive type and a possible compression is automatically detected from the file name.
|
146
|
+
"""
|
147
|
+
from_path = Path(from_path)
|
148
|
+
if not from_path.is_absolute():
|
149
|
+
from_path = from_path.resolve()
|
150
|
+
|
151
|
+
if to_path is None or not os.path.exists(to_path):
|
152
|
+
to_path = os.path.dirname(from_path)
|
153
|
+
to_path = Path(to_path)
|
154
|
+
if not to_path.is_absolute():
|
155
|
+
to_path = to_path.resolve()
|
156
|
+
|
157
|
+
# Extracting zip
|
158
|
+
with zipfile.ZipFile(from_path, "r", compression=zipfile.ZIP_STORED) as zzip:
|
159
|
+
zzip.extractall(to_path)
|
160
|
+
|
161
|
+
if remove_finished:
|
162
|
+
os.remove(from_path)
|
163
|
+
return str(to_path)
|
164
|
+
|
165
|
+
|
166
|
+
def subselect(arr: NDArray, count: int, from_back: bool = False):
|
167
|
+
if from_back:
|
168
|
+
return arr[-count:]
|
169
|
+
return arr[:count]
|
170
|
+
|
171
|
+
|
172
|
+
class MNIST(Dataset):
|
173
|
+
"""MNIST Dataset and Corruptions.
|
174
|
+
|
175
|
+
Args:
|
176
|
+
root : str | ``pathlib.Path``
|
177
|
+
Root directory of dataset where the ``mnist_c/`` folder exists.
|
178
|
+
train : bool, default True
|
179
|
+
If True, creates dataset from ``train_images.npy`` and ``train_labels.npy``.
|
180
|
+
download : bool, default False
|
181
|
+
If True, downloads the dataset from the internet and puts it in root
|
182
|
+
directory. If dataset is already downloaded, it is not downloaded again.
|
183
|
+
size : int, default -1
|
184
|
+
Limit the dataset size, must be a value greater than 0.
|
185
|
+
unit_interval : bool, default False
|
186
|
+
Shift the data values to the unit interval [0-1].
|
187
|
+
dtype : type | None, default None
|
188
|
+
Change the numpy dtype - data is loaded as np.uint8
|
189
|
+
channels : Literal['channels_first' | 'channels_last'] | None, default None
|
190
|
+
Location of channel axis if desired, default has no channels (N, 28, 28)
|
191
|
+
flatten : bool, default False
|
192
|
+
Flatten data into single dimension (N, 784) - cannot use both channels and flatten,
|
193
|
+
channels takes priority over flatten.
|
194
|
+
normalize : tuple[mean, std] | None, default None
|
195
|
+
Normalize images acorrding to provided mean and standard deviation
|
196
|
+
corruption : Literal['identity' | 'shot_noise' | 'impulse_noise' | 'glass_blur' |
|
197
|
+
'motion_blur' | 'shear' | 'scale' | 'rotate' | 'brightness' | 'translate' | 'stripe' |
|
198
|
+
'fog' | 'spatter' | 'dotted_line' | 'zigzag' | 'canny_edges'] | None, default None
|
199
|
+
The desired corruption style or None.
|
200
|
+
classes : Literal["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
|
201
|
+
| int | list[int] | list[Literal["zero", "one", "two", "three", "four", "five", "six", "seven",
|
202
|
+
"eight", "nine"]] | None, default None
|
203
|
+
Option to select specific classes from dataset.
|
204
|
+
balance : bool, default True
|
205
|
+
If True, returns equal number of samples for each class.
|
206
|
+
randomize : bool, default False
|
207
|
+
If True, shuffles the data prior to selection - uses a set seed for reproducibility.
|
208
|
+
slice_back : bool, default False
|
209
|
+
If True and size has a value greater than 0, then grabs selection starting at the last image.
|
210
|
+
verbose : bool, default True
|
211
|
+
If True, outputs print statements.
|
212
|
+
"""
|
213
|
+
|
214
|
+
mirror = [
|
215
|
+
"https://storage.googleapis.com/tensorflow/tf-keras-datasets/",
|
216
|
+
"https://zenodo.org/record/3239543/files/",
|
217
|
+
]
|
218
|
+
|
219
|
+
resources = [
|
220
|
+
("mnist.npz", "731c5ac602752760c8e48fbffcf8c3b850d9dc2a2aedcf2cc48468fc17b673d1"),
|
221
|
+
("mnist_c.zip", "4b34b33045869ee6d424616cd3a65da3"),
|
222
|
+
]
|
223
|
+
|
224
|
+
class_dict = {
|
225
|
+
"zero": 0,
|
226
|
+
"one": 1,
|
227
|
+
"two": 2,
|
228
|
+
"three": 3,
|
229
|
+
"four": 4,
|
230
|
+
"five": 5,
|
231
|
+
"six": 6,
|
232
|
+
"seven": 7,
|
233
|
+
"eight": 8,
|
234
|
+
"nine": 9,
|
235
|
+
}
|
236
|
+
|
237
|
+
def __init__(
|
238
|
+
self,
|
239
|
+
root: str | Path,
|
240
|
+
train: bool = True,
|
241
|
+
download: bool = False,
|
242
|
+
size: int = -1,
|
243
|
+
unit_interval: bool = False,
|
244
|
+
dtype: type | None = None,
|
245
|
+
channels: Literal["channels_first", "channels_last"] | None = None,
|
246
|
+
flatten: bool = False,
|
247
|
+
normalize: tuple[float, float] | None = None,
|
248
|
+
corruption: CorruptionStringMap | None = None,
|
249
|
+
classes: TClassMap | None = None,
|
250
|
+
balance: bool = True,
|
251
|
+
randomize: bool = False,
|
252
|
+
slice_back: bool = False,
|
253
|
+
verbose: bool = True,
|
254
|
+
) -> None:
|
255
|
+
if isinstance(root, str):
|
256
|
+
root = os.path.expanduser(root)
|
257
|
+
self.root = root # location of stored dataset
|
258
|
+
self.train = train # training set or test set
|
259
|
+
self.size = size
|
260
|
+
self.unit_interval = unit_interval
|
261
|
+
self.dtype = dtype
|
262
|
+
self.channels = channels
|
263
|
+
self.flatten = flatten
|
264
|
+
self.normalize = normalize
|
265
|
+
self.corruption = corruption
|
266
|
+
self.balance = balance
|
267
|
+
self.randomize = randomize
|
268
|
+
self.from_back = slice_back
|
269
|
+
self.verbose = verbose
|
270
|
+
|
271
|
+
self.class_set = []
|
272
|
+
if classes is not None:
|
273
|
+
if not isinstance(classes, list):
|
274
|
+
classes = [classes] # type: ignore
|
275
|
+
|
276
|
+
for val in classes: # type: ignore
|
277
|
+
if isinstance(val, int) and 0 <= val < 10:
|
278
|
+
self.class_set.append(val)
|
279
|
+
elif isinstance(val, str):
|
280
|
+
self.class_set.append(self.class_dict[val])
|
281
|
+
self.class_set = set(self.class_set)
|
282
|
+
|
283
|
+
if not self.class_set:
|
284
|
+
self.class_set = set(self.class_dict.values())
|
285
|
+
|
286
|
+
self.num_classes = len(self.class_set)
|
287
|
+
|
288
|
+
if self.corruption is None:
|
289
|
+
file_resource = self.resources[0]
|
290
|
+
mirror = self.mirror[0]
|
291
|
+
md5 = False
|
292
|
+
else:
|
293
|
+
if self.corruption == "identity" and verbose:
|
294
|
+
print("Identity is not a corrupted dataset but the original MNIST dataset.")
|
295
|
+
file_resource = self.resources[1]
|
296
|
+
mirror = self.mirror[1]
|
297
|
+
md5 = True
|
298
|
+
check_exists(self.mnist_folder, mirror, self.root, file_resource[0], file_resource[1], download, verbose, md5)
|
299
|
+
|
300
|
+
self.data, self.targets = self._load_data()
|
301
|
+
|
302
|
+
self._augmentations()
|
303
|
+
|
304
|
+
def _load_data(self):
|
305
|
+
if self.corruption is None:
|
306
|
+
image_file = self.resources[0][0]
|
307
|
+
data, targets = self._read_normal_file(os.path.join(self.mnist_folder, image_file))
|
308
|
+
else:
|
309
|
+
image_file = f"{'train' if self.train else 'test'}_images.npy"
|
310
|
+
data = self._read_corrupt_file(os.path.join(self.mnist_folder, image_file))
|
311
|
+
data = data.squeeze()
|
312
|
+
|
313
|
+
label_file = f"{'train' if self.train else 'test'}_labels.npy"
|
314
|
+
targets = self._read_corrupt_file(os.path.join(self.mnist_folder, label_file))
|
315
|
+
|
316
|
+
return data, targets
|
317
|
+
|
318
|
+
def _augmentations(self):
|
319
|
+
if self.size > self.targets.shape[0] and self.verbose:
|
320
|
+
warn(
|
321
|
+
f"Asked for more samples, {self.size}, than the raw dataset contains, {self.targets.shape[0]}. "
|
322
|
+
"Adjusting down to raw dataset size."
|
323
|
+
)
|
324
|
+
self.size = -1
|
325
|
+
|
326
|
+
if self.randomize:
|
327
|
+
rdm_seed = np.random.default_rng(2023)
|
328
|
+
shuffled_indices = rdm_seed.permutation(self.data.shape[0])
|
329
|
+
self.data = self.data[shuffled_indices]
|
330
|
+
self.targets = self.targets[shuffled_indices]
|
331
|
+
|
332
|
+
if not self.balance and self.num_classes > self.size:
|
333
|
+
if self.size > 0:
|
334
|
+
self.data = subselect(self.data, self.size, self.from_back)
|
335
|
+
self.targets = subselect(self.targets, self.size, self.from_back)
|
336
|
+
else:
|
337
|
+
label_dict = {label: np.where(self.targets == label)[0] for label in self.class_set}
|
338
|
+
min_label_count = min(len(indices) for indices in label_dict.values())
|
339
|
+
|
340
|
+
self.per_class_count = int(np.ceil(self.size / self.num_classes)) if self.size > 0 else min_label_count
|
341
|
+
|
342
|
+
if self.per_class_count > min_label_count:
|
343
|
+
self.per_class_count = min_label_count
|
344
|
+
if not self.balance and self.verbose:
|
345
|
+
warn(
|
346
|
+
f"Because of dataset limitations, only {min_label_count*self.num_classes} samples "
|
347
|
+
f"will be returned, instead of the desired {self.size}."
|
348
|
+
)
|
349
|
+
|
350
|
+
all_indices = np.empty(shape=(self.num_classes, self.per_class_count), dtype=int)
|
351
|
+
for i, label in enumerate(self.class_set):
|
352
|
+
all_indices[i] = subselect(label_dict[label], self.per_class_count, self.from_back)
|
353
|
+
self.data = np.vstack(self.data[all_indices.T]) # type: ignore
|
354
|
+
self.targets = np.hstack(self.targets[all_indices.T]) # type: ignore
|
355
|
+
|
356
|
+
if self.unit_interval:
|
357
|
+
self.data = self.data / 255
|
358
|
+
|
359
|
+
if self.normalize:
|
360
|
+
self.data = (self.data - self.normalize[0]) / self.normalize[1]
|
361
|
+
|
362
|
+
if self.dtype:
|
363
|
+
self.data = self.data.astype(self.dtype)
|
364
|
+
|
365
|
+
if self.channels == "channels_first":
|
366
|
+
self.data = self.data[:, np.newaxis, :, :]
|
367
|
+
elif self.channels == "channels_last":
|
368
|
+
self.data = self.data[:, :, :, np.newaxis]
|
369
|
+
|
370
|
+
if self.flatten and self.channels is None:
|
371
|
+
self.data = self.data.reshape(self.data.shape[0], -1)
|
372
|
+
|
373
|
+
def __getitem__(self, index: int) -> tuple[NDArray, int]:
|
374
|
+
"""
|
375
|
+
Args:
|
376
|
+
index (int): Index
|
377
|
+
|
378
|
+
Returns:
|
379
|
+
tuple: (image, target) where target is index of the target class.
|
380
|
+
"""
|
381
|
+
img, target = self.data[index], int(self.targets[index])
|
382
|
+
|
383
|
+
return img, target
|
384
|
+
|
385
|
+
def __len__(self) -> int:
|
386
|
+
return len(self.data)
|
387
|
+
|
388
|
+
@property
|
389
|
+
def mnist_folder(self) -> str:
|
390
|
+
if self.corruption is None:
|
391
|
+
return os.path.join(self.root, "mnist")
|
392
|
+
return os.path.join(self.root, "mnist_c", self.corruption)
|
393
|
+
|
394
|
+
def _read_normal_file(self, path: str) -> tuple[NDArray, NDArray]:
|
395
|
+
with np.load(path, allow_pickle=True) as f:
|
396
|
+
if self.train:
|
397
|
+
x, y = f["x_train"], f["y_train"]
|
398
|
+
else:
|
399
|
+
x, y = f["x_test"], f["y_test"]
|
400
|
+
return x, y
|
401
|
+
|
402
|
+
def _read_corrupt_file(self, path: str) -> NDArray:
|
403
|
+
x = np.load(path, allow_pickle=False)
|
404
|
+
return x
|
@@ -16,14 +16,14 @@ from typing import Callable, Literal
|
|
16
16
|
import numpy as np
|
17
17
|
from numpy.typing import ArrayLike, NDArray
|
18
18
|
|
19
|
-
from dataeval._internal.interop import to_numpy
|
19
|
+
from dataeval._internal.interop import as_numpy, to_numpy
|
20
20
|
from dataeval._internal.output import OutputMetadata, set_metadata
|
21
21
|
|
22
22
|
|
23
23
|
@dataclass(frozen=True)
|
24
24
|
class DriftBaseOutput(OutputMetadata):
|
25
25
|
"""
|
26
|
-
|
26
|
+
Base output class for Drift detector classes
|
27
27
|
|
28
28
|
Attributes
|
29
29
|
----------
|
@@ -42,7 +42,7 @@ class DriftBaseOutput(OutputMetadata):
|
|
42
42
|
@dataclass(frozen=True)
|
43
43
|
class DriftOutput(DriftBaseOutput):
|
44
44
|
"""
|
45
|
-
Output class for DriftCVM and
|
45
|
+
Output class for :class:`DriftCVM`, :class:`DriftKS`, and :class:`DriftUncertainty` drift detectors
|
46
46
|
|
47
47
|
Attributes
|
48
48
|
----------
|
@@ -234,7 +234,7 @@ class BaseDrift:
|
|
234
234
|
if correction not in ["bonferroni", "fdr"]:
|
235
235
|
raise ValueError("`correction` must be `bonferroni` or `fdr`.")
|
236
236
|
|
237
|
-
self._x_ref = x_ref
|
237
|
+
self._x_ref = to_numpy(x_ref)
|
238
238
|
self.x_ref_preprocessed = x_ref_preprocessed
|
239
239
|
|
240
240
|
# Other attributes
|
@@ -242,7 +242,7 @@ class BaseDrift:
|
|
242
242
|
self.update_x_ref = update_x_ref
|
243
243
|
self.preprocess_fn = preprocess_fn
|
244
244
|
self.correction = correction
|
245
|
-
self.n = len(self._x_ref)
|
245
|
+
self.n = len(self._x_ref)
|
246
246
|
|
247
247
|
# Ref counter for preprocessed x
|
248
248
|
self._x_refcount = 0
|
@@ -260,9 +260,8 @@ class BaseDrift:
|
|
260
260
|
if not self.x_ref_preprocessed:
|
261
261
|
self.x_ref_preprocessed = True
|
262
262
|
if self.preprocess_fn is not None:
|
263
|
-
self._x_ref = self.preprocess_fn(self._x_ref)
|
263
|
+
self._x_ref = as_numpy(self.preprocess_fn(self._x_ref))
|
264
264
|
|
265
|
-
self._x_ref = to_numpy(self._x_ref)
|
266
265
|
return self._x_ref
|
267
266
|
|
268
267
|
def _preprocess(self, x: ArrayLike) -> ArrayLike:
|
@@ -380,7 +379,7 @@ class BaseDriftUnivariate(BaseDrift):
|
|
380
379
|
self._n_features = self.x_ref.reshape(self.x_ref.shape[0], -1).shape[-1]
|
381
380
|
else:
|
382
381
|
# infer number of features after applying preprocessing step
|
383
|
-
x =
|
382
|
+
x = as_numpy(self.preprocess_fn(self._x_ref[0:1])) # type: ignore
|
384
383
|
self._n_features = x.reshape(x.shape[0], -1).shape[-1]
|
385
384
|
|
386
385
|
return self._n_features
|
@@ -14,7 +14,7 @@ from typing import Callable
|
|
14
14
|
import torch
|
15
15
|
from numpy.typing import ArrayLike
|
16
16
|
|
17
|
-
from dataeval._internal.interop import
|
17
|
+
from dataeval._internal.interop import as_numpy
|
18
18
|
from dataeval._internal.output import set_metadata
|
19
19
|
|
20
20
|
from .base import BaseDrift, DriftBaseOutput, UpdateStrategy, preprocess_x, update_x_ref
|
@@ -24,7 +24,7 @@ from .torch import GaussianRBF, get_device, mmd2_from_kernel_matrix
|
|
24
24
|
@dataclass(frozen=True)
|
25
25
|
class DriftMMDOutput(DriftBaseOutput):
|
26
26
|
"""
|
27
|
-
Output class for DriftMMD
|
27
|
+
Output class for :class:`DriftMMD` drift detector
|
28
28
|
|
29
29
|
Attributes
|
30
30
|
----------
|
@@ -110,7 +110,7 @@ class DriftMMD(BaseDrift):
|
|
110
110
|
self.device = get_device(device)
|
111
111
|
|
112
112
|
# initialize kernel
|
113
|
-
sigma_tensor = torch.from_numpy(
|
113
|
+
sigma_tensor = torch.from_numpy(as_numpy(sigma)).to(self.device) if sigma is not None else None
|
114
114
|
self.kernel = kernel(sigma_tensor).to(self.device) if kernel == GaussianRBF else kernel
|
115
115
|
|
116
116
|
# compute kernel matrix for the reference data
|
@@ -147,7 +147,7 @@ class DriftMMD(BaseDrift):
|
|
147
147
|
p-value obtained from the permutation test, MMD^2 between the reference and test set,
|
148
148
|
and MMD^2 threshold above which drift is flagged
|
149
149
|
"""
|
150
|
-
x =
|
150
|
+
x = as_numpy(x)
|
151
151
|
x_ref = torch.from_numpy(self.x_ref).to(self.device)
|
152
152
|
n = x.shape[0]
|
153
153
|
kernel_mat = self._kernel_matrix(x_ref, torch.from_numpy(x).to(self.device))
|