dataeval 0.76.1__py3-none-any.whl → 0.81.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +3 -3
- dataeval/{output.py → _output.py} +14 -0
- dataeval/config.py +77 -0
- dataeval/detectors/__init__.py +1 -1
- dataeval/detectors/drift/__init__.py +6 -6
- dataeval/detectors/drift/{base.py → _base.py} +41 -30
- dataeval/detectors/drift/{cvm.py → _cvm.py} +21 -28
- dataeval/detectors/drift/{ks.py → _ks.py} +20 -26
- dataeval/detectors/drift/{mmd.py → _mmd.py} +33 -19
- dataeval/detectors/drift/{torch.py → _torch.py} +2 -1
- dataeval/detectors/drift/{uncertainty.py → _uncertainty.py} +23 -7
- dataeval/detectors/drift/updates.py +1 -1
- dataeval/detectors/linters/__init__.py +0 -3
- dataeval/detectors/linters/duplicates.py +17 -8
- dataeval/detectors/linters/outliers.py +23 -14
- dataeval/detectors/ood/ae.py +29 -8
- dataeval/detectors/ood/base.py +5 -4
- dataeval/detectors/ood/metadata_ks_compare.py +1 -1
- dataeval/detectors/ood/mixin.py +20 -5
- dataeval/detectors/ood/output.py +1 -1
- dataeval/detectors/ood/vae.py +73 -0
- dataeval/metadata/__init__.py +5 -0
- dataeval/metadata/_ood.py +238 -0
- dataeval/metrics/__init__.py +1 -1
- dataeval/metrics/bias/__init__.py +5 -4
- dataeval/metrics/bias/{balance.py → _balance.py} +67 -17
- dataeval/metrics/bias/{coverage.py → _coverage.py} +41 -35
- dataeval/metrics/bias/{diversity.py → _diversity.py} +17 -12
- dataeval/metrics/bias/{parity.py → _parity.py} +89 -61
- dataeval/metrics/estimators/__init__.py +14 -4
- dataeval/metrics/estimators/{ber.py → _ber.py} +42 -11
- dataeval/metrics/estimators/_clusterer.py +104 -0
- dataeval/metrics/estimators/{divergence.py → _divergence.py} +18 -13
- dataeval/metrics/estimators/{uap.py → _uap.py} +4 -4
- dataeval/metrics/stats/__init__.py +7 -7
- dataeval/metrics/stats/{base.py → _base.py} +52 -16
- dataeval/metrics/stats/{boxratiostats.py → _boxratiostats.py} +6 -9
- dataeval/metrics/stats/{datasetstats.py → _datasetstats.py} +10 -14
- dataeval/metrics/stats/{dimensionstats.py → _dimensionstats.py} +6 -5
- dataeval/metrics/stats/{hashstats.py → _hashstats.py} +6 -6
- dataeval/metrics/stats/{labelstats.py → _labelstats.py} +4 -4
- dataeval/metrics/stats/{pixelstats.py → _pixelstats.py} +5 -4
- dataeval/metrics/stats/{visualstats.py → _visualstats.py} +9 -8
- dataeval/typing.py +54 -0
- dataeval/utils/__init__.py +2 -2
- dataeval/utils/_array.py +169 -0
- dataeval/utils/_bin.py +199 -0
- dataeval/utils/_clusterer.py +144 -0
- dataeval/utils/_fast_mst.py +189 -0
- dataeval/utils/{image.py → _image.py} +6 -4
- dataeval/utils/_method.py +18 -0
- dataeval/utils/{shared.py → _mst.py} +3 -65
- dataeval/utils/{plot.py → _plot.py} +4 -4
- dataeval/utils/data/__init__.py +22 -0
- dataeval/utils/data/_embeddings.py +105 -0
- dataeval/utils/data/_images.py +65 -0
- dataeval/utils/data/_metadata.py +352 -0
- dataeval/utils/data/_selection.py +119 -0
- dataeval/utils/{dataset/split.py → data/_split.py} +13 -14
- dataeval/utils/data/_targets.py +73 -0
- dataeval/utils/data/_types.py +58 -0
- dataeval/utils/data/collate.py +103 -0
- dataeval/utils/data/datasets/__init__.py +17 -0
- dataeval/utils/data/datasets/_base.py +254 -0
- dataeval/utils/data/datasets/_cifar10.py +134 -0
- dataeval/utils/data/datasets/_fileio.py +168 -0
- dataeval/utils/data/datasets/_milco.py +153 -0
- dataeval/utils/data/datasets/_mixin.py +56 -0
- dataeval/utils/data/datasets/_mnist.py +183 -0
- dataeval/utils/data/datasets/_ships.py +123 -0
- dataeval/utils/data/datasets/_voc.py +352 -0
- dataeval/utils/data/selections/__init__.py +15 -0
- dataeval/utils/data/selections/_classfilter.py +60 -0
- dataeval/utils/data/selections/_indices.py +26 -0
- dataeval/utils/data/selections/_limit.py +26 -0
- dataeval/utils/data/selections/_reverse.py +18 -0
- dataeval/utils/data/selections/_shuffle.py +29 -0
- dataeval/utils/metadata.py +51 -376
- dataeval/utils/torch/{gmm.py → _gmm.py} +4 -2
- dataeval/utils/torch/{internal.py → _internal.py} +21 -51
- dataeval/utils/torch/models.py +43 -2
- dataeval/workflows/sufficiency.py +10 -9
- {dataeval-0.76.1.dist-info → dataeval-0.81.0.dist-info}/METADATA +4 -1
- dataeval-0.81.0.dist-info/RECORD +94 -0
- dataeval/detectors/linters/clusterer.py +0 -512
- dataeval/detectors/linters/merged_stats.py +0 -49
- dataeval/detectors/ood/metadata_least_likely.py +0 -119
- dataeval/interop.py +0 -69
- dataeval/utils/dataset/__init__.py +0 -7
- dataeval/utils/dataset/datasets.py +0 -412
- dataeval/utils/dataset/read.py +0 -63
- dataeval-0.76.1.dist-info/RECORD +0 -67
- /dataeval/{log.py → _log.py} +0 -0
- /dataeval/utils/torch/{blocks.py → _blocks.py} +0 -0
- {dataeval-0.76.1.dist-info → dataeval-0.81.0.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.76.1.dist-info → dataeval-0.81.0.dist-info}/WHEEL +0 -0
dataeval/interop.py
DELETED
@@ -1,69 +0,0 @@
|
|
1
|
-
"""Utility functions for interoperability with different array types."""
|
2
|
-
|
3
|
-
from __future__ import annotations
|
4
|
-
|
5
|
-
__all__ = []
|
6
|
-
|
7
|
-
import logging
|
8
|
-
from importlib import import_module
|
9
|
-
from types import ModuleType
|
10
|
-
from typing import Any, Iterable, Iterator
|
11
|
-
|
12
|
-
import numpy as np
|
13
|
-
from numpy.typing import ArrayLike, NDArray
|
14
|
-
|
15
|
-
from dataeval.log import LogMessage
|
16
|
-
|
17
|
-
_logger = logging.getLogger(__name__)
|
18
|
-
|
19
|
-
_MODULE_CACHE = {}
|
20
|
-
|
21
|
-
|
22
|
-
def _try_import(module_name) -> ModuleType | None:
|
23
|
-
if module_name in _MODULE_CACHE:
|
24
|
-
return _MODULE_CACHE[module_name]
|
25
|
-
|
26
|
-
try:
|
27
|
-
module = import_module(module_name)
|
28
|
-
except ImportError: # pragma: no cover - covered by test_mindeps.py
|
29
|
-
_logger.log(logging.INFO, f"Unable to import {module_name}.")
|
30
|
-
module = None
|
31
|
-
|
32
|
-
_MODULE_CACHE[module_name] = module
|
33
|
-
return module
|
34
|
-
|
35
|
-
|
36
|
-
def as_numpy(array: ArrayLike | None) -> NDArray[Any]:
|
37
|
-
"""Converts an ArrayLike to Numpy array without copying (if possible)"""
|
38
|
-
return to_numpy(array, copy=False)
|
39
|
-
|
40
|
-
|
41
|
-
def to_numpy(array: ArrayLike | None, copy: bool = True) -> NDArray[Any]:
|
42
|
-
"""Converts an ArrayLike to new Numpy array"""
|
43
|
-
if array is None:
|
44
|
-
return np.ndarray([])
|
45
|
-
|
46
|
-
if isinstance(array, np.ndarray):
|
47
|
-
return array.copy() if copy else array
|
48
|
-
|
49
|
-
if array.__class__.__module__.startswith("tensorflow"): # pragma: no cover - removed tf from deps
|
50
|
-
tf = _try_import("tensorflow")
|
51
|
-
if tf and tf.is_tensor(array):
|
52
|
-
_logger.log(logging.INFO, "Converting Tensorflow array to NumPy array.")
|
53
|
-
return array.numpy().copy() if copy else array.numpy() # type: ignore
|
54
|
-
|
55
|
-
if array.__class__.__module__.startswith("torch"):
|
56
|
-
torch = _try_import("torch")
|
57
|
-
if torch and isinstance(array, torch.Tensor):
|
58
|
-
_logger.log(logging.INFO, "Converting PyTorch array to NumPy array.")
|
59
|
-
numpy = array.detach().cpu().numpy().copy() if copy else array.detach().cpu().numpy() # type: ignore
|
60
|
-
_logger.log(logging.DEBUG, LogMessage(lambda: f"{str(array)} -> {str(numpy)}"))
|
61
|
-
return numpy
|
62
|
-
|
63
|
-
return np.array(array) if copy else np.asarray(array)
|
64
|
-
|
65
|
-
|
66
|
-
def to_numpy_iter(iterable: Iterable[ArrayLike]) -> Iterator[NDArray[Any]]:
|
67
|
-
"""Yields an iterator of numpy arrays from an ArrayLike"""
|
68
|
-
for array in iterable:
|
69
|
-
yield to_numpy(array)
|
@@ -1,7 +0,0 @@
|
|
1
|
-
"""Provides utility functions for interacting with Computer Vision datasets."""
|
2
|
-
|
3
|
-
__all__ = ["datasets", "read_dataset", "SplitDatasetOutput", "split_dataset"]
|
4
|
-
|
5
|
-
from dataeval.utils.dataset import datasets
|
6
|
-
from dataeval.utils.dataset.read import read_dataset
|
7
|
-
from dataeval.utils.dataset.split import SplitDatasetOutput, split_dataset
|
@@ -1,412 +0,0 @@
|
|
1
|
-
"""Provides access to common Computer Vision datasets."""
|
2
|
-
|
3
|
-
from __future__ import annotations
|
4
|
-
|
5
|
-
__all__ = ["MNIST", "CIFAR10", "VOCDetection"]
|
6
|
-
|
7
|
-
import hashlib
|
8
|
-
import os
|
9
|
-
import zipfile
|
10
|
-
from pathlib import Path
|
11
|
-
from typing import Literal, TypeVar
|
12
|
-
from warnings import warn
|
13
|
-
|
14
|
-
import numpy as np
|
15
|
-
import requests
|
16
|
-
from numpy.typing import NDArray
|
17
|
-
from torch.utils.data import Dataset
|
18
|
-
from torchvision.datasets import CIFAR10, VOCDetection
|
19
|
-
|
20
|
-
ClassStringMap = Literal["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
|
21
|
-
TClassMap = TypeVar("TClassMap", ClassStringMap, int, list[ClassStringMap], list[int])
|
22
|
-
CorruptionStringMap = Literal[
|
23
|
-
"identity",
|
24
|
-
"shot_noise",
|
25
|
-
"impulse_noise",
|
26
|
-
"glass_blur",
|
27
|
-
"motion_blur",
|
28
|
-
"shear",
|
29
|
-
"scale",
|
30
|
-
"rotate",
|
31
|
-
"brightness",
|
32
|
-
"translate",
|
33
|
-
"stripe",
|
34
|
-
"fog",
|
35
|
-
"spatter",
|
36
|
-
"dotted_line",
|
37
|
-
"zigzag",
|
38
|
-
"canny_edges",
|
39
|
-
]
|
40
|
-
|
41
|
-
|
42
|
-
def _validate_file(fpath, file_md5, md5=False, chunk_size=65535):
|
43
|
-
hasher = hashlib.md5() if md5 else hashlib.sha256()
|
44
|
-
with open(fpath, "rb") as fpath_file:
|
45
|
-
while chunk := fpath_file.read(chunk_size):
|
46
|
-
hasher.update(chunk)
|
47
|
-
return hasher.hexdigest() == file_md5
|
48
|
-
|
49
|
-
|
50
|
-
def _get_file(
|
51
|
-
root: str | Path,
|
52
|
-
fname: str,
|
53
|
-
origin: str,
|
54
|
-
file_hash: str | None = None,
|
55
|
-
verbose: bool = True,
|
56
|
-
md5: bool = False,
|
57
|
-
timeout: int = 60,
|
58
|
-
):
|
59
|
-
fpath = os.path.join(root, fname)
|
60
|
-
download = True
|
61
|
-
if os.path.exists(fpath) and file_hash is not None and _validate_file(fpath, file_hash, md5):
|
62
|
-
download = False
|
63
|
-
if verbose:
|
64
|
-
print("File already downloaded and verified.")
|
65
|
-
if md5:
|
66
|
-
print("Extracting zip file...")
|
67
|
-
|
68
|
-
if download:
|
69
|
-
try:
|
70
|
-
error_msg = "URL fetch failure on {}: {} -- {}"
|
71
|
-
try:
|
72
|
-
with requests.get(origin, stream=True, timeout=timeout) as r:
|
73
|
-
r.raise_for_status()
|
74
|
-
with open(fpath, "wb") as f:
|
75
|
-
for chunk in r.iter_content(chunk_size=8192):
|
76
|
-
if chunk:
|
77
|
-
f.write(chunk)
|
78
|
-
except requests.exceptions.HTTPError as e:
|
79
|
-
raise RuntimeError(f"{error_msg.format(origin, e.response.status_code, e.response.reason)}") from e
|
80
|
-
except requests.exceptions.RequestException as e:
|
81
|
-
raise ValueError(f"{error_msg.format(origin, 'Unknown error', str(e))}") from e
|
82
|
-
except (Exception, KeyboardInterrupt):
|
83
|
-
if os.path.exists(fpath):
|
84
|
-
os.remove(fpath)
|
85
|
-
raise
|
86
|
-
|
87
|
-
if os.path.exists(fpath) and file_hash is not None and not _validate_file(fpath, file_hash, md5):
|
88
|
-
raise ValueError(
|
89
|
-
"Incomplete or corrupted file detected. "
|
90
|
-
f"The file hash does not match the provided value "
|
91
|
-
f"of {file_hash}.",
|
92
|
-
)
|
93
|
-
|
94
|
-
return fpath
|
95
|
-
|
96
|
-
|
97
|
-
def _check_exists(
|
98
|
-
folder: str | Path,
|
99
|
-
url: str,
|
100
|
-
root: str | Path,
|
101
|
-
fname: str,
|
102
|
-
file_hash: str,
|
103
|
-
download: bool = True,
|
104
|
-
verbose: bool = True,
|
105
|
-
md5: bool = False,
|
106
|
-
):
|
107
|
-
"""Determine if the dataset has already been downloaded."""
|
108
|
-
location = str(folder)
|
109
|
-
if not os.path.exists(folder):
|
110
|
-
if download:
|
111
|
-
location = _download_dataset(url, root, fname, file_hash, verbose, md5)
|
112
|
-
else:
|
113
|
-
raise RuntimeError("Dataset not found. You can use download=True to download it")
|
114
|
-
else:
|
115
|
-
if verbose:
|
116
|
-
print("Files already downloaded and verified")
|
117
|
-
return location
|
118
|
-
|
119
|
-
|
120
|
-
def _download_dataset(
|
121
|
-
url: str, root: str | Path, fname: str, file_hash: str, verbose: bool = True, md5: bool = False
|
122
|
-
) -> str:
|
123
|
-
"""Code to download mnist and corruptions, originates from tensorflow_datasets (tfds):
|
124
|
-
https://github.com/tensorflow/datasets/blob/master/tensorflow_datasets/image_classification/mnist_corrupted.py
|
125
|
-
"""
|
126
|
-
name, _ = os.path.splitext(fname)
|
127
|
-
folder = os.path.join(root, name)
|
128
|
-
os.makedirs(folder, exist_ok=True)
|
129
|
-
|
130
|
-
fpath = _get_file(
|
131
|
-
folder,
|
132
|
-
fname,
|
133
|
-
origin=url + fname,
|
134
|
-
file_hash=file_hash,
|
135
|
-
verbose=verbose,
|
136
|
-
md5=md5,
|
137
|
-
)
|
138
|
-
if md5:
|
139
|
-
folder = _extract_archive(fpath, root, remove_finished=True)
|
140
|
-
return folder
|
141
|
-
|
142
|
-
|
143
|
-
def _extract_archive(
|
144
|
-
from_path: str | Path,
|
145
|
-
to_path: str | Path | None = None,
|
146
|
-
remove_finished: bool = False,
|
147
|
-
) -> str:
|
148
|
-
"""Extract an archive.
|
149
|
-
|
150
|
-
The archive type and a possible compression is automatically detected from the file name.
|
151
|
-
"""
|
152
|
-
from_path = Path(from_path)
|
153
|
-
if not from_path.is_absolute():
|
154
|
-
from_path = from_path.resolve()
|
155
|
-
|
156
|
-
if to_path is None or not os.path.exists(to_path):
|
157
|
-
to_path = os.path.dirname(from_path)
|
158
|
-
to_path = Path(to_path)
|
159
|
-
if not to_path.is_absolute():
|
160
|
-
to_path = to_path.resolve()
|
161
|
-
|
162
|
-
# Extracting zip
|
163
|
-
with zipfile.ZipFile(from_path, "r", compression=zipfile.ZIP_STORED) as zzip:
|
164
|
-
zzip.extractall(to_path)
|
165
|
-
|
166
|
-
if remove_finished:
|
167
|
-
os.remove(from_path)
|
168
|
-
return str(to_path)
|
169
|
-
|
170
|
-
|
171
|
-
def _subselect(arr: NDArray, count: int, from_back: bool = False):
|
172
|
-
if from_back:
|
173
|
-
return arr[-count:]
|
174
|
-
return arr[:count]
|
175
|
-
|
176
|
-
|
177
|
-
class MNIST(Dataset[tuple[NDArray[np.float64], int]]):
|
178
|
-
"""MNIST Dataset and Corruptions.
|
179
|
-
|
180
|
-
Args:
|
181
|
-
root : str | ``pathlib.Path``
|
182
|
-
Root directory of dataset where the ``mnist_c/`` folder exists.
|
183
|
-
train : bool, default True
|
184
|
-
If True, creates dataset from ``train_images.npy`` and ``train_labels.npy``.
|
185
|
-
download : bool, default False
|
186
|
-
If True, downloads the dataset from the internet and puts it in root
|
187
|
-
directory. If dataset is already downloaded, it is not downloaded again.
|
188
|
-
size : int, default -1
|
189
|
-
Limit the dataset size, must be a value greater than 0.
|
190
|
-
unit_interval : bool, default False
|
191
|
-
Shift the data values to the unit interval [0-1].
|
192
|
-
dtype : type | None, default None
|
193
|
-
Change the :term:`NumPy` dtype - data is loaded as np.uint8
|
194
|
-
channels : Literal['channels_first' | 'channels_last'] | None, default None
|
195
|
-
Location of channel axis if desired, default has no channels (N, 28, 28)
|
196
|
-
flatten : bool, default False
|
197
|
-
Flatten data into single dimension (N, 784) - cannot use both channels and flatten,
|
198
|
-
channels takes priority over flatten.
|
199
|
-
normalize : tuple[mean, std] | None, default None
|
200
|
-
Normalize images acorrding to provided mean and standard deviation
|
201
|
-
corruption : Literal['identity' | 'shot_noise' | 'impulse_noise' | 'glass_blur' |
|
202
|
-
'motion_blur' | 'shear' | 'scale' | 'rotate' | 'brightness' | 'translate' | 'stripe' |
|
203
|
-
'fog' | 'spatter' | 'dotted_line' | 'zigzag' | 'canny_edges'] | None, default None
|
204
|
-
The desired corruption style or None.
|
205
|
-
classes : Literal["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
|
206
|
-
| int | list[int] | list[Literal["zero", "one", "two", "three", "four", "five", "six", "seven",
|
207
|
-
"eight", "nine"]] | None, default None
|
208
|
-
Option to select specific classes from dataset.
|
209
|
-
balance : bool, default True
|
210
|
-
If True, returns equal number of samples for each class.
|
211
|
-
randomize : bool, default True
|
212
|
-
If True, shuffles the data prior to selection - uses a set seed for reproducibility.
|
213
|
-
slice_back : bool, default False
|
214
|
-
If True and size has a value greater than 0, then grabs selection starting at the last image.
|
215
|
-
verbose : bool, default True
|
216
|
-
If True, outputs print statements.
|
217
|
-
"""
|
218
|
-
|
219
|
-
_mirrors: tuple[str, ...] = (
|
220
|
-
"https://storage.googleapis.com/tensorflow/tf-keras-datasets/",
|
221
|
-
"https://zenodo.org/record/3239543/files/",
|
222
|
-
)
|
223
|
-
|
224
|
-
_resources: tuple[tuple[str, str], ...] = (
|
225
|
-
("mnist.npz", "731c5ac602752760c8e48fbffcf8c3b850d9dc2a2aedcf2cc48468fc17b673d1"),
|
226
|
-
("mnist_c.zip", "4b34b33045869ee6d424616cd3a65da3"),
|
227
|
-
)
|
228
|
-
|
229
|
-
class_dict: dict[str, int] = {
|
230
|
-
"zero": 0,
|
231
|
-
"one": 1,
|
232
|
-
"two": 2,
|
233
|
-
"three": 3,
|
234
|
-
"four": 4,
|
235
|
-
"five": 5,
|
236
|
-
"six": 6,
|
237
|
-
"seven": 7,
|
238
|
-
"eight": 8,
|
239
|
-
"nine": 9,
|
240
|
-
}
|
241
|
-
|
242
|
-
def __init__(
|
243
|
-
self,
|
244
|
-
root: str | Path,
|
245
|
-
train: bool = True,
|
246
|
-
download: bool = False,
|
247
|
-
size: int = -1,
|
248
|
-
unit_interval: bool = False,
|
249
|
-
dtype: type | None = None,
|
250
|
-
channels: Literal["channels_first", "channels_last"] | None = None,
|
251
|
-
flatten: bool = False,
|
252
|
-
normalize: tuple[float, float] | None = None,
|
253
|
-
corruption: CorruptionStringMap | None = None,
|
254
|
-
classes: TClassMap | None = None,
|
255
|
-
balance: bool = True,
|
256
|
-
randomize: bool = True,
|
257
|
-
slice_back: bool = False,
|
258
|
-
verbose: bool = True,
|
259
|
-
) -> None:
|
260
|
-
if isinstance(root, str):
|
261
|
-
root = os.path.expanduser(root)
|
262
|
-
self.root = root # location of stored dataset
|
263
|
-
self.train = train # training set or test set
|
264
|
-
self.size = size
|
265
|
-
self.unit_interval = unit_interval
|
266
|
-
self.dtype = dtype
|
267
|
-
self.channels = channels
|
268
|
-
self.flatten = flatten
|
269
|
-
self.normalize = normalize
|
270
|
-
self.corruption = corruption
|
271
|
-
self.balance = balance
|
272
|
-
self.randomize = randomize
|
273
|
-
self.from_back = slice_back
|
274
|
-
self.verbose = verbose
|
275
|
-
self.data: NDArray[np.float64]
|
276
|
-
self.targets: NDArray[np.int_]
|
277
|
-
self.size: int
|
278
|
-
|
279
|
-
self._class_set = []
|
280
|
-
if classes is not None:
|
281
|
-
if not isinstance(classes, list):
|
282
|
-
classes = [classes] # type: ignore
|
283
|
-
|
284
|
-
for val in classes: # type: ignore
|
285
|
-
if isinstance(val, int) and 0 <= val < 10:
|
286
|
-
self._class_set.append(val)
|
287
|
-
elif isinstance(val, str):
|
288
|
-
self._class_set.append(self.class_dict[val])
|
289
|
-
self._class_set = set(self._class_set)
|
290
|
-
|
291
|
-
if not self._class_set:
|
292
|
-
self._class_set = set(self.class_dict.values())
|
293
|
-
|
294
|
-
self._num_classes = len(self._class_set)
|
295
|
-
|
296
|
-
if self.corruption is None:
|
297
|
-
file_resource = self._resources[0]
|
298
|
-
mirror = self._mirrors[0]
|
299
|
-
md5 = False
|
300
|
-
else:
|
301
|
-
if self.corruption == "identity" and verbose:
|
302
|
-
print("Identity is not a corrupted dataset but the original MNIST dataset.")
|
303
|
-
file_resource = self._resources[1]
|
304
|
-
mirror = self._mirrors[1]
|
305
|
-
md5 = True
|
306
|
-
_check_exists(self.mnist_folder, mirror, self.root, file_resource[0], file_resource[1], download, verbose, md5)
|
307
|
-
|
308
|
-
self.data, self.targets = self._load_data()
|
309
|
-
|
310
|
-
self._augmentations()
|
311
|
-
|
312
|
-
def _load_data(self) -> tuple[NDArray[np.float64], NDArray[np.int64]]:
|
313
|
-
if self.corruption is None:
|
314
|
-
image_file = self._resources[0][0]
|
315
|
-
data, targets = self._read_normal_file(os.path.join(self.mnist_folder, image_file))
|
316
|
-
else:
|
317
|
-
image_file = f"{'train' if self.train else 'test'}_images.npy"
|
318
|
-
data = self._read_corrupt_file(os.path.join(self.mnist_folder, image_file))
|
319
|
-
data = data.squeeze()
|
320
|
-
|
321
|
-
label_file = f"{'train' if self.train else 'test'}_labels.npy"
|
322
|
-
targets = self._read_corrupt_file(os.path.join(self.mnist_folder, label_file))
|
323
|
-
|
324
|
-
return data, targets
|
325
|
-
|
326
|
-
def _augmentations(self):
|
327
|
-
if self.size > self.targets.shape[0] and self.verbose:
|
328
|
-
warn(
|
329
|
-
f"Asked for more samples, {self.size}, than the raw dataset contains, {self.targets.shape[0]}. "
|
330
|
-
"Adjusting down to raw dataset size."
|
331
|
-
)
|
332
|
-
self.size = -1
|
333
|
-
|
334
|
-
if self.randomize:
|
335
|
-
rdm_seed = np.random.default_rng(2023)
|
336
|
-
shuffled_indices = rdm_seed.permutation(self.data.shape[0])
|
337
|
-
self.data = self.data[shuffled_indices]
|
338
|
-
self.targets = self.targets[shuffled_indices]
|
339
|
-
|
340
|
-
if not self.balance and self._num_classes > self.size:
|
341
|
-
if self.size > 0:
|
342
|
-
self.data = _subselect(self.data, self.size, self.from_back)
|
343
|
-
self.targets = _subselect(self.targets, self.size, self.from_back)
|
344
|
-
else:
|
345
|
-
label_dict = {label: np.where(self.targets == label)[0] for label in self._class_set}
|
346
|
-
min_label_count = min(len(indices) for indices in label_dict.values())
|
347
|
-
|
348
|
-
self._per_class_count = int(np.ceil(self.size / self._num_classes)) if self.size > 0 else min_label_count
|
349
|
-
|
350
|
-
if self._per_class_count > min_label_count:
|
351
|
-
self._per_class_count = min_label_count
|
352
|
-
if not self.balance and self.verbose:
|
353
|
-
warn(
|
354
|
-
f"Because of dataset limitations, only {min_label_count*self._num_classes} samples "
|
355
|
-
f"will be returned, instead of the desired {self.size}."
|
356
|
-
)
|
357
|
-
|
358
|
-
all_indices: NDArray[np.int_] = np.empty(shape=(self._num_classes, self._per_class_count), dtype=np.int_)
|
359
|
-
for i, label in enumerate(self._class_set):
|
360
|
-
all_indices[i] = _subselect(label_dict[label], self._per_class_count, self.from_back)
|
361
|
-
self.data = np.vstack(self.data[all_indices.T]) # type: ignore
|
362
|
-
self.targets = np.hstack(self.targets[all_indices.T]) # type: ignore
|
363
|
-
|
364
|
-
if self.unit_interval:
|
365
|
-
self.data = self.data / 255
|
366
|
-
|
367
|
-
if self.normalize:
|
368
|
-
self.data = (self.data - self.normalize[0]) / self.normalize[1]
|
369
|
-
|
370
|
-
if self.dtype:
|
371
|
-
self.data = self.data.astype(self.dtype)
|
372
|
-
|
373
|
-
if self.channels == "channels_first":
|
374
|
-
self.data = self.data[:, np.newaxis, :, :]
|
375
|
-
elif self.channels == "channels_last":
|
376
|
-
self.data = self.data[:, :, :, np.newaxis]
|
377
|
-
|
378
|
-
if self.flatten and self.channels is None:
|
379
|
-
self.data = self.data.reshape(self.data.shape[0], -1)
|
380
|
-
|
381
|
-
def __getitem__(self, index: int) -> tuple[NDArray[np.float64], int]:
|
382
|
-
"""
|
383
|
-
Args:
|
384
|
-
index (int): Index
|
385
|
-
|
386
|
-
Returns:
|
387
|
-
tuple: (image, target) where target is index of the target class.
|
388
|
-
"""
|
389
|
-
img, target = self.data[index], int(self.targets[index])
|
390
|
-
|
391
|
-
return img, target
|
392
|
-
|
393
|
-
def __len__(self) -> int:
|
394
|
-
return len(self.data)
|
395
|
-
|
396
|
-
@property
|
397
|
-
def mnist_folder(self) -> str:
|
398
|
-
if self.corruption is None:
|
399
|
-
return os.path.join(self.root, "mnist")
|
400
|
-
return os.path.join(self.root, "mnist_c", self.corruption)
|
401
|
-
|
402
|
-
def _read_normal_file(self, path: str) -> tuple[NDArray, NDArray]:
|
403
|
-
with np.load(path, allow_pickle=True) as f:
|
404
|
-
if self.train:
|
405
|
-
x, y = f["x_train"], f["y_train"]
|
406
|
-
else:
|
407
|
-
x, y = f["x_test"], f["y_test"]
|
408
|
-
return x, y
|
409
|
-
|
410
|
-
def _read_corrupt_file(self, path: str) -> NDArray:
|
411
|
-
x = np.load(path, allow_pickle=False)
|
412
|
-
return x
|
dataeval/utils/dataset/read.py
DELETED
@@ -1,63 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
__all__ = []
|
4
|
-
|
5
|
-
from collections import defaultdict
|
6
|
-
from typing import Any
|
7
|
-
|
8
|
-
from torch.utils.data import Dataset
|
9
|
-
|
10
|
-
|
11
|
-
def read_dataset(dataset: Dataset[Any]) -> list[list[Any]]:
|
12
|
-
"""
|
13
|
-
Extract information from a dataset at each index into individual lists of each information position.
|
14
|
-
|
15
|
-
Parameters
|
16
|
-
----------
|
17
|
-
dataset : torch.utils.data.Dataset
|
18
|
-
Input dataset
|
19
|
-
|
20
|
-
Returns
|
21
|
-
-------
|
22
|
-
List[List[Any]]
|
23
|
-
All objects in individual lists based on return position from dataset
|
24
|
-
|
25
|
-
Warning
|
26
|
-
-------
|
27
|
-
No type checking is done between lists or data inside lists
|
28
|
-
|
29
|
-
See Also
|
30
|
-
--------
|
31
|
-
torch.utils.data.Dataset
|
32
|
-
|
33
|
-
Examples
|
34
|
-
--------
|
35
|
-
>>> import numpy as np
|
36
|
-
>>> data = np.ones((10, 1, 3, 3))
|
37
|
-
>>> labels = np.ones((10,))
|
38
|
-
>>> class ICDataset:
|
39
|
-
... def __init__(self, data, labels):
|
40
|
-
... self.data = data
|
41
|
-
... self.labels = labels
|
42
|
-
...
|
43
|
-
... def __getitem__(self, idx):
|
44
|
-
... return self.data[idx], self.labels[idx]
|
45
|
-
|
46
|
-
>>> ds = ICDataset(data, labels)
|
47
|
-
|
48
|
-
>>> result = read_dataset(ds)
|
49
|
-
>>> len(result) # images and labels
|
50
|
-
2
|
51
|
-
>>> np.asarray(result[0]).shape # images
|
52
|
-
(10, 1, 3, 3)
|
53
|
-
>>> np.asarray(result[1]).shape # labels
|
54
|
-
(10,)
|
55
|
-
"""
|
56
|
-
|
57
|
-
ddict: dict[int, list[Any]] = defaultdict(list[Any])
|
58
|
-
|
59
|
-
for data in dataset:
|
60
|
-
for i, d in enumerate(data if isinstance(data, tuple) else (data,)):
|
61
|
-
ddict[i].append(d)
|
62
|
-
|
63
|
-
return list(ddict.values())
|
dataeval-0.76.1.dist-info/RECORD
DELETED
@@ -1,67 +0,0 @@
|
|
1
|
-
dataeval/__init__.py,sha256=vqyenyxYGE0OXW3C8PC1YDZRak1uLFIYd45-vh9qafQ,1474
|
2
|
-
dataeval/detectors/__init__.py,sha256=iifG-Z08mH5B4QhkKtAieDGJBKldKvmCXpDQJD9qVY8,206
|
3
|
-
dataeval/detectors/drift/__init__.py,sha256=wO294Oz--l0GuZTAkBpyGwZphbQsot57HoiEX6kjNOc,652
|
4
|
-
dataeval/detectors/drift/base.py,sha256=8zHUnUpmgpWMzDv5C-tUX61lbpDjhJ-eAIiNxaNvWP8,14469
|
5
|
-
dataeval/detectors/drift/cvm.py,sha256=TATS6IOE0INO1pkyRkesgrhDawD_kITsRsOOGVRs420,4132
|
6
|
-
dataeval/detectors/drift/ks.py,sha256=SAd2T9CdytXD7DegCzAX1pWYJdPuttyL97KAQYF4j7Y,4265
|
7
|
-
dataeval/detectors/drift/mmd.py,sha256=z7JPFbW4fmHJhR-Qe1OQ4mM8kW6dNxnd3uHD9oXMETE,7599
|
8
|
-
dataeval/detectors/drift/torch.py,sha256=ykD-Nggys5T9FTGXXbYYOi2WRKwEzEjXhL8ZueVmTxU,7659
|
9
|
-
dataeval/detectors/drift/uncertainty.py,sha256=zkrqz5euJJtYFKsDiRqFfTnDjVOTbqpZWgZiGMrYxvI,5351
|
10
|
-
dataeval/detectors/drift/updates.py,sha256=nKsF4xrMFZd2X84GJ5XnGylUuketX_RcH7UpcdlonIo,1781
|
11
|
-
dataeval/detectors/linters/__init__.py,sha256=CZV5naeYQYL3sHXO_CXB26AXkyTeKHI-TMaewtEs8Ag,483
|
12
|
-
dataeval/detectors/linters/clusterer.py,sha256=V-bNs4ut2E6SIqU4MR1Y96WBZcs4cavQhvXBB0vFZPw,20937
|
13
|
-
dataeval/detectors/linters/duplicates.py,sha256=Ba-Nmbjqg_HDMlEBqlWW1aFO_BA-HSc-uWHc3cmI394,5620
|
14
|
-
dataeval/detectors/linters/merged_stats.py,sha256=X-bDTwjyR8RuVmzxLaHZmQ5nI3oOWvsqVlitdSncapk,1355
|
15
|
-
dataeval/detectors/linters/outliers.py,sha256=o0LtAHdazLfj5GM2HcVDjVY_AfSU5GpBUjxHPC9VfIc,13728
|
16
|
-
dataeval/detectors/ood/__init__.py,sha256=Ws6_un4pFWNknki7Bp7qjrslZVB9pYNE-K72u2lF65k,291
|
17
|
-
dataeval/detectors/ood/ae.py,sha256=SL8oKTERhMwaZTQWwDhQQ6H07UKj8ozXqEWO3TaOAos,2151
|
18
|
-
dataeval/detectors/ood/base.py,sha256=-ApcC9lyZJAgk-joMpLXF20sJqtvlAugg-W18TcAsEw,3010
|
19
|
-
dataeval/detectors/ood/metadata_ks_compare.py,sha256=-hEhDNXFC7X8wmFeoigO7A7Qn90vRLroN_nKDwNgjnE,5204
|
20
|
-
dataeval/detectors/ood/metadata_least_likely.py,sha256=rb8GOgsrlrEzc6fxccdmyZQ5PC7HtTsTY8U97D-h5OU,5088
|
21
|
-
dataeval/detectors/ood/metadata_ood_mi.py,sha256=7_Sdzf7-x1TlrIQvSyOIB98C8_UQhUwmwFQmZ9_q1Uc,4042
|
22
|
-
dataeval/detectors/ood/mixin.py,sha256=Ia-rJF6rtGhE8uavijdbzOha3ueFk2CFfA0Ah_mnF40,4976
|
23
|
-
dataeval/detectors/ood/output.py,sha256=yygnsjaIQB6v6sXh7glqX2aoqWdf3_YLINqx7BGKMtk,1710
|
24
|
-
dataeval/interop.py,sha256=P9Kwe-vOVgbn1ng60y4giCnJYmHjIOpyGpccuIA7P1g,2322
|
25
|
-
dataeval/log.py,sha256=Mn5bRWO0cgtAYd5VGYSFiPgu57ta3zoktrtHAZ1m3dU,357
|
26
|
-
dataeval/metrics/__init__.py,sha256=OMntcHmmrsOfIlRsJTZQQaF5qXEuP61Li-ElKy7Ysbk,240
|
27
|
-
dataeval/metrics/bias/__init__.py,sha256=SIg4Qxza9BqXyKNQLIY0bpqoFvZfK5-GaejpTH6efVc,601
|
28
|
-
dataeval/metrics/bias/balance.py,sha256=B1sPackyodiBct9Hs88BR4nJde_R61JyjwSBIG_CFug,9171
|
29
|
-
dataeval/metrics/bias/coverage.py,sha256=igVDWJSrO2MvaTEiDUhVzVWPGNB1QOZvngCi8UF0RwA,5746
|
30
|
-
dataeval/metrics/bias/diversity.py,sha256=nF1y2FaQIU0yHQtckoddjqoty2hsVVMqwaXWHRdGfqA,8521
|
31
|
-
dataeval/metrics/bias/parity.py,sha256=2gSpXkg6ASnkywRTqqx3b3k1T5Qg1Jm-ihMKNZgEwys,12732
|
32
|
-
dataeval/metrics/estimators/__init__.py,sha256=oY_9jX7V-Kg7-4KpvMNB4rUhsk8QTA0DIoM8d2VtVIg,380
|
33
|
-
dataeval/metrics/estimators/ber.py,sha256=vcndXr0PNLRlYz7u7K74f-B5g3DnUkaTO_WigGdj0cg,5012
|
34
|
-
dataeval/metrics/estimators/divergence.py,sha256=joqqlH0AQFibJkHCCb7i7dMJIGF28fmZIR-tGupQQJQ,4247
|
35
|
-
dataeval/metrics/estimators/uap.py,sha256=ZAQUjJCbdulftWk6yjILCbnXGOE8RuDqEINZRtMW3tc,2143
|
36
|
-
dataeval/metrics/stats/__init__.py,sha256=pUT84sOxDiCHW6xz6Ml1Mf1bFszQrtd3qPG0Ja3boxA,1088
|
37
|
-
dataeval/metrics/stats/base.py,sha256=1ejjwlA0FmllcAw7J9Yv1r7GMmBYKvuGPzmDk9ktASM,12613
|
38
|
-
dataeval/metrics/stats/boxratiostats.py,sha256=PS1wvWwhTCMJX56erfPW-BZymXrevvXnKl2PkE0qmLk,6315
|
39
|
-
dataeval/metrics/stats/datasetstats.py,sha256=mt5t5WhlVI7mo56dmhqgnk1eH8oBV7dahgmqkFDcKo0,7387
|
40
|
-
dataeval/metrics/stats/dimensionstats.py,sha256=AlPor23dUH718jFNiVNedHQVaQzN-6OKQEVDQbnGE50,4027
|
41
|
-
dataeval/metrics/stats/hashstats.py,sha256=5nNSJ3Tl8gPqpYlWpxl7EHfW6pJd1BtbXYUiuGgH4Eo,5070
|
42
|
-
dataeval/metrics/stats/labelstats.py,sha256=MW6kB7V8pdIc7yHdXzRwlD6xSl6SYZonNsLUPKAVILI,6992
|
43
|
-
dataeval/metrics/stats/pixelstats.py,sha256=tfvu0tYPgDS0jCCSY2sZ2Ice5r1nNuKx-LYXxZQCw7s,4220
|
44
|
-
dataeval/metrics/stats/visualstats.py,sha256=pEQnAPFg-zQ1U5orwF0-U7kfHuZGjMJDsdEMAoDZd4I,4634
|
45
|
-
dataeval/output.py,sha256=Dyfv1xlrwSbCe7HdDyq8t-kiIRJbBeaMEmMROr1FrVQ,4034
|
46
|
-
dataeval/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
47
|
-
dataeval/utils/__init__.py,sha256=WW9e_1RbtkvLDRqu1NpDw3-V4su4mA8yJ_P3bgd_7Ho,283
|
48
|
-
dataeval/utils/dataset/__init__.py,sha256=IvRauQaa0CzJ5nZrfTSjGoaaKelyJcQDe3OPRw0-NXs,332
|
49
|
-
dataeval/utils/dataset/datasets.py,sha256=7tSqN3d8UncqmXh4eiEwarXgVxc4sMuIKPTqBCE0pN8,15080
|
50
|
-
dataeval/utils/dataset/read.py,sha256=Q_RaNTFXhkMsx3PrgJEIySdHAA-QxGuih6eq6mnJv-4,1524
|
51
|
-
dataeval/utils/dataset/split.py,sha256=1vNy5I1zZx-LIf8B0y57dUaO_UdVd1hyJggUANkwNtM,18958
|
52
|
-
dataeval/utils/image.py,sha256=AQljELyMFkYsf2AoNOH5dZG8DYE4hPw0MCk85eIXqAw,1926
|
53
|
-
dataeval/utils/metadata.py,sha256=tRcXgJsM1l7vt_naNJj8g8_EHD_AB5MGi1uWxqZsA6M,27431
|
54
|
-
dataeval/utils/plot.py,sha256=YyFL1KoJgnl2Bip7m73WVBJa6zbsBnn5c1b3skFfUrA,7068
|
55
|
-
dataeval/utils/shared.py,sha256=xvF3VLfyheVwJtdtDrneOobkKf7t-JTmf_w91FWXmqo,3616
|
56
|
-
dataeval/utils/torch/__init__.py,sha256=dn5mjCrFp0b1aL_UEURhONU0Ag0cmXoTOBSGagpkTiA,325
|
57
|
-
dataeval/utils/torch/blocks.py,sha256=HVhBTMMD5NA4qheMUgyol1KWiKZDIuc8k5j4RcMKmhk,1466
|
58
|
-
dataeval/utils/torch/gmm.py,sha256=fQ8CBO4Bf6i9N1CZdeJ8VJP25fsPjgMextQkondwgvo,3693
|
59
|
-
dataeval/utils/torch/internal.py,sha256=qAzQTwTI9Qy6f01Olw3d1TIJ4HoWGf0gQzgWVcdD2x4,6653
|
60
|
-
dataeval/utils/torch/models.py,sha256=Df3B_9x5uu-Y5ZOyhRZYpKJnDvxt0hgMeJLy1E4oxpU,8519
|
61
|
-
dataeval/utils/torch/trainer.py,sha256=Qay0LK63RuyoGYiJ5zI2C5BVym309ORvp6shhpcrIU4,5589
|
62
|
-
dataeval/workflows/__init__.py,sha256=L9yfBipNFGnYuN2JbMknIHDvziwfa2XAGFnOwifZbls,216
|
63
|
-
dataeval/workflows/sufficiency.py,sha256=jf53J1PAlfRHSjGpMCWRJzImitLtCQvTMCaMm28ZuPM,18675
|
64
|
-
dataeval-0.76.1.dist-info/LICENSE.txt,sha256=uAooygKWvX6NbU9Ran9oG2msttoG8aeTeHSTe5JeCnY,1061
|
65
|
-
dataeval-0.76.1.dist-info/METADATA,sha256=w02IzEy_S5kgRZFRGbWayMg98uFdn3jJT4Gl6MOQzek,5196
|
66
|
-
dataeval-0.76.1.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
|
67
|
-
dataeval-0.76.1.dist-info/RECORD,,
|
/dataeval/{log.py → _log.py}
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|