dataeval 0.72.1__py3-none-any.whl → 0.73.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +4 -4
- dataeval/detectors/__init__.py +4 -3
- dataeval/detectors/drift/__init__.py +9 -10
- dataeval/{_internal/detectors → detectors}/drift/base.py +39 -91
- dataeval/{_internal/detectors → detectors}/drift/cvm.py +4 -3
- dataeval/{_internal/detectors → detectors}/drift/ks.py +4 -3
- dataeval/{_internal/detectors → detectors}/drift/mmd.py +23 -25
- dataeval/{_internal/detectors → detectors}/drift/torch.py +13 -11
- dataeval/{_internal/detectors → detectors}/drift/uncertainty.py +7 -5
- dataeval/detectors/drift/updates.py +61 -0
- dataeval/detectors/linters/__init__.py +3 -3
- dataeval/{_internal/detectors → detectors/linters}/clusterer.py +41 -39
- dataeval/{_internal/detectors → detectors/linters}/duplicates.py +19 -9
- dataeval/{_internal/detectors → detectors/linters}/merged_stats.py +3 -1
- dataeval/{_internal/detectors → detectors/linters}/outliers.py +14 -21
- dataeval/detectors/ood/__init__.py +6 -6
- dataeval/{_internal/detectors → detectors}/ood/ae.py +20 -12
- dataeval/detectors/ood/aegmm.py +66 -0
- dataeval/{_internal/detectors → detectors}/ood/base.py +33 -21
- dataeval/{_internal/detectors → detectors}/ood/llr.py +43 -33
- dataeval/detectors/ood/metadata_ks_compare.py +99 -0
- dataeval/detectors/ood/metadata_least_likely.py +119 -0
- dataeval/detectors/ood/metadata_ood_mi.py +92 -0
- dataeval/{_internal/detectors → detectors}/ood/vae.py +23 -17
- dataeval/detectors/ood/vaegmm.py +75 -0
- dataeval/interop.py +56 -0
- dataeval/metrics/__init__.py +1 -1
- dataeval/metrics/bias/__init__.py +4 -4
- dataeval/{_internal/metrics → metrics/bias}/balance.py +75 -13
- dataeval/{_internal/metrics → metrics/bias}/coverage.py +41 -7
- dataeval/{_internal/metrics → metrics/bias}/diversity.py +75 -18
- dataeval/metrics/bias/metadata.py +358 -0
- dataeval/{_internal/metrics → metrics/bias}/parity.py +54 -44
- dataeval/metrics/estimators/__init__.py +3 -3
- dataeval/{_internal/metrics → metrics/estimators}/ber.py +25 -22
- dataeval/{_internal/metrics → metrics/estimators}/divergence.py +11 -12
- dataeval/{_internal/metrics → metrics/estimators}/uap.py +5 -3
- dataeval/metrics/stats/__init__.py +7 -7
- dataeval/{_internal/metrics → metrics}/stats/base.py +59 -35
- dataeval/{_internal/metrics → metrics}/stats/boxratiostats.py +18 -14
- dataeval/{_internal/metrics → metrics}/stats/datasetstats.py +18 -16
- dataeval/{_internal/metrics → metrics}/stats/dimensionstats.py +9 -7
- dataeval/metrics/stats/hashstats.py +156 -0
- dataeval/{_internal/metrics → metrics}/stats/labelstats.py +5 -3
- dataeval/{_internal/metrics → metrics}/stats/pixelstats.py +9 -8
- dataeval/{_internal/metrics → metrics}/stats/visualstats.py +10 -9
- dataeval/{_internal/output.py → output.py} +26 -6
- dataeval/utils/__init__.py +8 -3
- dataeval/utils/image.py +71 -0
- dataeval/utils/lazy.py +26 -0
- dataeval/utils/metadata.py +258 -0
- dataeval/utils/shared.py +151 -0
- dataeval/{_internal → utils}/split_dataset.py +98 -33
- dataeval/utils/tensorflow/__init__.py +7 -6
- dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/gmm.py +8 -2
- dataeval/{_internal/models/tensorflow/losses.py → utils/tensorflow/_internal/loss.py} +28 -18
- dataeval/{_internal/models/tensorflow/pixelcnn.py → utils/tensorflow/_internal/models.py} +387 -97
- dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/trainer.py +15 -6
- dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/utils.py +84 -85
- dataeval/utils/tensorflow/loss/__init__.py +6 -2
- dataeval/utils/torch/__init__.py +7 -3
- dataeval/{_internal/models/pytorch → utils/torch}/blocks.py +19 -14
- dataeval/{_internal → utils/torch}/datasets.py +48 -42
- dataeval/utils/torch/models.py +138 -0
- dataeval/{_internal/models/pytorch/autoencoder.py → utils/torch/trainer.py} +7 -136
- dataeval/{_internal → utils/torch}/utils.py +3 -1
- dataeval/workflows/__init__.py +1 -1
- dataeval/{_internal/workflows → workflows}/sufficiency.py +39 -34
- {dataeval-0.72.1.dist-info → dataeval-0.73.0.dist-info}/METADATA +4 -3
- dataeval-0.73.0.dist-info/RECORD +73 -0
- dataeval/_internal/detectors/__init__.py +0 -0
- dataeval/_internal/detectors/drift/__init__.py +0 -0
- dataeval/_internal/detectors/ood/__init__.py +0 -0
- dataeval/_internal/detectors/ood/aegmm.py +0 -78
- dataeval/_internal/detectors/ood/vaegmm.py +0 -89
- dataeval/_internal/interop.py +0 -49
- dataeval/_internal/metrics/__init__.py +0 -0
- dataeval/_internal/metrics/stats/hashstats.py +0 -75
- dataeval/_internal/metrics/utils.py +0 -447
- dataeval/_internal/models/__init__.py +0 -0
- dataeval/_internal/models/pytorch/__init__.py +0 -0
- dataeval/_internal/models/pytorch/utils.py +0 -67
- dataeval/_internal/models/tensorflow/__init__.py +0 -0
- dataeval/_internal/models/tensorflow/autoencoder.py +0 -320
- dataeval/_internal/workflows/__init__.py +0 -0
- dataeval/detectors/drift/kernels/__init__.py +0 -10
- dataeval/detectors/drift/updates/__init__.py +0 -8
- dataeval/utils/tensorflow/models/__init__.py +0 -9
- dataeval/utils/tensorflow/recon/__init__.py +0 -3
- dataeval/utils/torch/datasets/__init__.py +0 -12
- dataeval/utils/torch/models/__init__.py +0 -11
- dataeval/utils/torch/trainer/__init__.py +0 -7
- dataeval-0.72.1.dist-info/RECORD +0 -81
- {dataeval-0.72.1.dist-info → dataeval-0.73.0.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.72.1.dist-info → dataeval-0.73.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,156 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__all__ = ["HashStatsOutput", "hashstats"]
|
4
|
+
|
5
|
+
from dataclasses import dataclass
|
6
|
+
from typing import Callable, Iterable
|
7
|
+
|
8
|
+
import numpy as np
|
9
|
+
import xxhash as xxh
|
10
|
+
from numpy.typing import ArrayLike
|
11
|
+
from PIL import Image
|
12
|
+
from scipy.fftpack import dct
|
13
|
+
|
14
|
+
from dataeval.interop import as_numpy
|
15
|
+
from dataeval.metrics.stats.base import BaseStatsOutput, StatsProcessor, run_stats
|
16
|
+
from dataeval.output import set_metadata
|
17
|
+
from dataeval.utils.image import normalize_image_shape, rescale
|
18
|
+
|
19
|
+
HASH_SIZE = 8
|
20
|
+
MAX_FACTOR = 4
|
21
|
+
|
22
|
+
|
23
|
+
@dataclass(frozen=True)
|
24
|
+
class HashStatsOutput(BaseStatsOutput):
|
25
|
+
"""
|
26
|
+
Output class for :func:`hashstats` stats metric
|
27
|
+
|
28
|
+
Attributes
|
29
|
+
----------
|
30
|
+
xxhash : List[str]
|
31
|
+
xxHash hash of the images as a hex string
|
32
|
+
pchash : List[str]
|
33
|
+
:term:`Perception-based Hash` of the images as a hex string
|
34
|
+
"""
|
35
|
+
|
36
|
+
xxhash: list[str]
|
37
|
+
pchash: list[str]
|
38
|
+
|
39
|
+
|
40
|
+
def pchash(image: ArrayLike) -> str:
|
41
|
+
"""
|
42
|
+
Performs a perceptual hash on an image by resizing to a square NxN image
|
43
|
+
using the Lanczos algorithm where N is 32x32 or the largest multiple of
|
44
|
+
8 that is smaller than the input image dimensions. The resampled image
|
45
|
+
is compressed using a discrete cosine transform and the lowest frequency
|
46
|
+
component is encoded as a bit array of greater or less than median value
|
47
|
+
and returned as a hex string.
|
48
|
+
|
49
|
+
Parameters
|
50
|
+
----------
|
51
|
+
image : ArrayLike
|
52
|
+
An image as a numpy array in CxHxW format
|
53
|
+
|
54
|
+
Returns
|
55
|
+
-------
|
56
|
+
str
|
57
|
+
The hex string hash of the image using perceptual hashing
|
58
|
+
"""
|
59
|
+
# Verify that the image is at least larger than an 8x8 image
|
60
|
+
arr = as_numpy(image)
|
61
|
+
min_dim = min(arr.shape[-2:])
|
62
|
+
if min_dim < HASH_SIZE + 1:
|
63
|
+
raise ValueError(f"Image must be larger than {HASH_SIZE}x{HASH_SIZE} for fuzzy hashing.")
|
64
|
+
|
65
|
+
# Calculates the dimensions of the resized square image
|
66
|
+
resize_dim = HASH_SIZE * min((min_dim - 1) // HASH_SIZE, MAX_FACTOR)
|
67
|
+
|
68
|
+
# Normalizes the image to CxHxW and takes the mean over all the channels
|
69
|
+
normalized = np.mean(normalize_image_shape(arr), axis=0).squeeze()
|
70
|
+
|
71
|
+
# Rescales the pixel values to an 8-bit 0-255 image
|
72
|
+
rescaled = rescale(normalized, 8).astype(np.uint8)
|
73
|
+
|
74
|
+
# Resizes the image using the Lanczos algorithm to a square image
|
75
|
+
im = np.array(Image.fromarray(rescaled).resize((resize_dim, resize_dim), Image.Resampling.LANCZOS))
|
76
|
+
|
77
|
+
# Performs discrete cosine transforms to compress the image information and takes the lowest frequency component
|
78
|
+
transform = dct(dct(im.T).T)[:HASH_SIZE, :HASH_SIZE]
|
79
|
+
|
80
|
+
# Encodes the transform as a bit array over the median value
|
81
|
+
diff = transform > np.median(transform)
|
82
|
+
|
83
|
+
# Pads the front of the bit array to a multiple of 8 with False
|
84
|
+
padded = np.full(int(np.ceil(diff.size / 8) * 8), False)
|
85
|
+
padded[-diff.size :] = diff.ravel()
|
86
|
+
|
87
|
+
# Converts the bit array to a hex string and strips leading 0s
|
88
|
+
hash_hex = np.packbits(padded).tobytes().hex().lstrip("0")
|
89
|
+
return hash_hex if hash_hex else "0"
|
90
|
+
|
91
|
+
|
92
|
+
def xxhash(image: ArrayLike) -> str:
|
93
|
+
"""
|
94
|
+
Performs a fast non-cryptographic hash using the xxhash algorithm
|
95
|
+
(xxhash.com) against the image as a flattened bytearray. The hash
|
96
|
+
is returned as a hex string.
|
97
|
+
|
98
|
+
Parameters
|
99
|
+
----------
|
100
|
+
image : ArrayLike
|
101
|
+
An image as a numpy array
|
102
|
+
|
103
|
+
Returns
|
104
|
+
-------
|
105
|
+
str
|
106
|
+
The hex string hash of the image using the xxHash algorithm
|
107
|
+
"""
|
108
|
+
return xxh.xxh3_64_hexdigest(as_numpy(image).ravel().tobytes())
|
109
|
+
|
110
|
+
|
111
|
+
class HashStatsProcessor(StatsProcessor[HashStatsOutput]):
|
112
|
+
output_class: type = HashStatsOutput
|
113
|
+
image_function_map: dict[str, Callable[[StatsProcessor[HashStatsOutput]], str]] = {
|
114
|
+
"xxhash": lambda x: xxhash(x.image),
|
115
|
+
"pchash": lambda x: pchash(x.image),
|
116
|
+
}
|
117
|
+
|
118
|
+
|
119
|
+
@set_metadata()
|
120
|
+
def hashstats(
|
121
|
+
images: Iterable[ArrayLike],
|
122
|
+
bboxes: Iterable[ArrayLike] | None = None,
|
123
|
+
) -> HashStatsOutput:
|
124
|
+
"""
|
125
|
+
Calculates hashes for each image
|
126
|
+
|
127
|
+
This function computes hashes from the images including exact hashes and perception-based
|
128
|
+
hashes. These hash values can be used to determine if images are exact or near matches.
|
129
|
+
|
130
|
+
Parameters
|
131
|
+
----------
|
132
|
+
images : ArrayLike
|
133
|
+
Images to hashing
|
134
|
+
bboxes : Iterable[ArrayLike] or None
|
135
|
+
Bounding boxes in `xyxy` format for each image
|
136
|
+
|
137
|
+
Returns
|
138
|
+
-------
|
139
|
+
HashStatsOutput
|
140
|
+
A dictionary-like object containing the computed hashes for each image.
|
141
|
+
|
142
|
+
See Also
|
143
|
+
--------
|
144
|
+
Duplicates
|
145
|
+
|
146
|
+
Examples
|
147
|
+
--------
|
148
|
+
Calculating the statistics on the images, whose shape is (C, H, W)
|
149
|
+
|
150
|
+
>>> results = hashstats(images)
|
151
|
+
>>> print(results.xxhash)
|
152
|
+
['a72434443d6e7336', 'efc12c2f14581d79', '4a1e03483a27d674', '3a3ecedbcf814226']
|
153
|
+
>>> print(results.pchash)
|
154
|
+
['8f25506af46a7c6a', '8000808000008080', '8e71f18e0ef18e0e', 'a956d6a956d6a928']
|
155
|
+
"""
|
156
|
+
return run_stats(images, bboxes, False, [HashStatsProcessor])[0]
|
@@ -1,13 +1,15 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
__all__ = ["LabelStatsOutput", "labelstats"]
|
4
|
+
|
3
5
|
from collections import Counter, defaultdict
|
4
6
|
from dataclasses import dataclass
|
5
7
|
from typing import Any, Iterable, Mapping, TypeVar
|
6
8
|
|
7
9
|
from numpy.typing import ArrayLike
|
8
10
|
|
9
|
-
from dataeval.
|
10
|
-
from dataeval.
|
11
|
+
from dataeval.interop import to_numpy
|
12
|
+
from dataeval.output import OutputMetadata, set_metadata
|
11
13
|
|
12
14
|
|
13
15
|
@dataclass(frozen=True)
|
@@ -55,7 +57,7 @@ def sort(d: Mapping[TKey, Any]) -> dict[TKey, Any]:
|
|
55
57
|
return dict(sorted(d.items(), key=lambda x: x[0]))
|
56
58
|
|
57
59
|
|
58
|
-
@set_metadata(
|
60
|
+
@set_metadata()
|
59
61
|
def labelstats(
|
60
62
|
labels: Iterable[ArrayLike],
|
61
63
|
) -> LabelStatsOutput:
|
@@ -1,14 +1,16 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
__all__ = ["PixelStatsOutput", "pixelstats"]
|
4
|
+
|
3
5
|
from dataclasses import dataclass
|
4
|
-
from typing import Iterable
|
6
|
+
from typing import Any, Callable, Iterable
|
5
7
|
|
6
8
|
import numpy as np
|
7
9
|
from numpy.typing import ArrayLike, NDArray
|
8
10
|
from scipy.stats import entropy, kurtosis, skew
|
9
11
|
|
10
|
-
from dataeval.
|
11
|
-
from dataeval.
|
12
|
+
from dataeval.metrics.stats.base import BaseStatsOutput, StatsProcessor, run_stats
|
13
|
+
from dataeval.output import set_metadata
|
12
14
|
|
13
15
|
|
14
16
|
@dataclass(frozen=True)
|
@@ -44,9 +46,8 @@ class PixelStatsOutput(BaseStatsOutput):
|
|
44
46
|
|
45
47
|
|
46
48
|
class PixelStatsProcessor(StatsProcessor[PixelStatsOutput]):
|
47
|
-
output_class = PixelStatsOutput
|
48
|
-
|
49
|
-
image_function_map = {
|
49
|
+
output_class: type = PixelStatsOutput
|
50
|
+
image_function_map: dict[str, Callable[[StatsProcessor[PixelStatsOutput]], Any]] = {
|
50
51
|
"mean": lambda self: np.mean(self.scaled),
|
51
52
|
"std": lambda x: np.std(x.scaled),
|
52
53
|
"var": lambda x: np.var(x.scaled),
|
@@ -55,7 +56,7 @@ class PixelStatsProcessor(StatsProcessor[PixelStatsOutput]):
|
|
55
56
|
"histogram": lambda x: np.histogram(x.scaled, 256, (0, 1))[0],
|
56
57
|
"entropy": lambda x: entropy(x.get("histogram")),
|
57
58
|
}
|
58
|
-
channel_function_map = {
|
59
|
+
channel_function_map: dict[str, Callable[[StatsProcessor[PixelStatsOutput]], Any]] = {
|
59
60
|
"mean": lambda x: np.mean(x.scaled, axis=1),
|
60
61
|
"std": lambda x: np.std(x.scaled, axis=1),
|
61
62
|
"var": lambda x: np.var(x.scaled, axis=1),
|
@@ -66,7 +67,7 @@ class PixelStatsProcessor(StatsProcessor[PixelStatsOutput]):
|
|
66
67
|
}
|
67
68
|
|
68
69
|
|
69
|
-
@set_metadata(
|
70
|
+
@set_metadata()
|
70
71
|
def pixelstats(
|
71
72
|
images: Iterable[ArrayLike],
|
72
73
|
bboxes: Iterable[ArrayLike] | None = None,
|
@@ -1,14 +1,16 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
__all__ = ["VisualStatsOutput", "visualstats"]
|
4
|
+
|
3
5
|
from dataclasses import dataclass
|
4
|
-
from typing import Iterable
|
6
|
+
from typing import Any, Callable, Iterable
|
5
7
|
|
6
8
|
import numpy as np
|
7
9
|
from numpy.typing import ArrayLike, NDArray
|
8
10
|
|
9
|
-
from dataeval.
|
10
|
-
from dataeval.
|
11
|
-
from dataeval.
|
11
|
+
from dataeval.metrics.stats.base import BaseStatsOutput, StatsProcessor, run_stats
|
12
|
+
from dataeval.output import set_metadata
|
13
|
+
from dataeval.utils.image import edge_filter
|
12
14
|
|
13
15
|
QUARTILES = (0, 25, 50, 75, 100)
|
14
16
|
|
@@ -46,9 +48,8 @@ class VisualStatsOutput(BaseStatsOutput):
|
|
46
48
|
|
47
49
|
|
48
50
|
class VisualStatsProcessor(StatsProcessor[VisualStatsOutput]):
|
49
|
-
output_class = VisualStatsOutput
|
50
|
-
|
51
|
-
image_function_map = {
|
51
|
+
output_class: type = VisualStatsOutput
|
52
|
+
image_function_map: dict[str, Callable[[StatsProcessor[VisualStatsOutput]], Any]] = {
|
52
53
|
"brightness": lambda x: x.get("percentiles")[1],
|
53
54
|
"contrast": lambda x: np.nan_to_num(
|
54
55
|
(np.max(x.get("percentiles")) - np.min(x.get("percentiles"))) / np.mean(x.get("percentiles"))
|
@@ -59,7 +60,7 @@ class VisualStatsProcessor(StatsProcessor[VisualStatsOutput]):
|
|
59
60
|
"zeros": lambda x: np.count_nonzero(np.sum(x.image, axis=0) == 0) / np.prod(x.shape[-2:]),
|
60
61
|
"percentiles": lambda x: np.nanpercentile(x.scaled, q=QUARTILES),
|
61
62
|
}
|
62
|
-
channel_function_map = {
|
63
|
+
channel_function_map: dict[str, Callable[[StatsProcessor[VisualStatsOutput]], Any]] = {
|
63
64
|
"brightness": lambda x: x.get("percentiles")[:, 1],
|
64
65
|
"contrast": lambda x: np.nan_to_num(
|
65
66
|
(np.max(x.get("percentiles"), axis=1) - np.min(x.get("percentiles"), axis=1))
|
@@ -73,7 +74,7 @@ class VisualStatsProcessor(StatsProcessor[VisualStatsOutput]):
|
|
73
74
|
}
|
74
75
|
|
75
76
|
|
76
|
-
@set_metadata(
|
77
|
+
@set_metadata()
|
77
78
|
def visualstats(
|
78
79
|
images: Iterable[ArrayLike],
|
79
80
|
bboxes: Iterable[ArrayLike] | None = None,
|
@@ -1,12 +1,20 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
__all__ = []
|
4
|
+
|
3
5
|
import inspect
|
6
|
+
import sys
|
4
7
|
from datetime import datetime, timezone
|
5
8
|
from functools import wraps
|
6
|
-
from typing import Any
|
9
|
+
from typing import Any, Callable, Iterable, TypeVar
|
7
10
|
|
8
11
|
import numpy as np
|
9
12
|
|
13
|
+
if sys.version_info >= (3, 10):
|
14
|
+
from typing import ParamSpec
|
15
|
+
else:
|
16
|
+
from typing_extensions import ParamSpec
|
17
|
+
|
10
18
|
from dataeval import __version__
|
11
19
|
|
12
20
|
|
@@ -25,10 +33,18 @@ class OutputMetadata:
|
|
25
33
|
return {k.removeprefix("_"): v for k, v in self.__dict__.items() if k.startswith("_")}
|
26
34
|
|
27
35
|
|
28
|
-
|
29
|
-
|
36
|
+
P = ParamSpec("P")
|
37
|
+
R = TypeVar("R", bound=OutputMetadata)
|
38
|
+
|
39
|
+
|
40
|
+
def set_metadata(
|
41
|
+
state_attr: Iterable[str] | None = None,
|
42
|
+
) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
43
|
+
"""Decorator to stamp OutputMetadata classes with runtime metadata"""
|
44
|
+
|
45
|
+
def decorator(fn: Callable[P, R]) -> Callable[P, R]:
|
30
46
|
@wraps(fn)
|
31
|
-
def wrapper(*args, **kwargs):
|
47
|
+
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
32
48
|
def fmt(v):
|
33
49
|
if np.isscalar(v):
|
34
50
|
return v
|
@@ -52,9 +68,13 @@ def set_metadata(module_name: str = "", state_attr: list[str] | None = None):
|
|
52
68
|
if "self" in arguments and state_attr
|
53
69
|
else {}
|
54
70
|
)
|
55
|
-
name =
|
71
|
+
name = (
|
72
|
+
f"{args[0].__class__.__module__}.{args[0].__class__.__name__}.{fn.__name__}"
|
73
|
+
if "self" in arguments
|
74
|
+
else f"{fn.__module__}.{fn.__qualname__}"
|
75
|
+
)
|
56
76
|
metadata = {
|
57
|
-
"_name":
|
77
|
+
"_name": name,
|
58
78
|
"_execution_time": time,
|
59
79
|
"_execution_duration": duration,
|
60
80
|
"_arguments": {k: v for k, v in arguments.items() if k != "self"},
|
dataeval/utils/__init__.py
CHANGED
@@ -5,15 +5,20 @@ metrics. Currently DataEval supports both :term:`TensorFlow` and PyTorch backend
|
|
5
5
|
"""
|
6
6
|
|
7
7
|
from dataeval import _IS_TENSORFLOW_AVAILABLE, _IS_TORCH_AVAILABLE
|
8
|
+
from dataeval.utils.metadata import merge_metadata
|
9
|
+
from dataeval.utils.split_dataset import split_dataset
|
8
10
|
|
9
|
-
__all__ = []
|
11
|
+
__all__ = ["split_dataset", "merge_metadata"]
|
10
12
|
|
11
13
|
if _IS_TORCH_AVAILABLE: # pragma: no cover
|
12
|
-
from . import torch
|
14
|
+
from dataeval.utils import torch
|
13
15
|
|
14
16
|
__all__ += ["torch"]
|
15
17
|
|
16
18
|
if _IS_TENSORFLOW_AVAILABLE: # pragma: no cover
|
17
|
-
from . import tensorflow
|
19
|
+
from dataeval.utils import tensorflow
|
18
20
|
|
19
21
|
__all__ += ["tensorflow"]
|
22
|
+
|
23
|
+
del _IS_TENSORFLOW_AVAILABLE
|
24
|
+
del _IS_TORCH_AVAILABLE
|
dataeval/utils/image.py
ADDED
@@ -0,0 +1,71 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__all__ = []
|
4
|
+
|
5
|
+
from typing import Any, NamedTuple
|
6
|
+
|
7
|
+
import numpy as np
|
8
|
+
from numpy.typing import ArrayLike, NDArray
|
9
|
+
from scipy.signal import convolve2d
|
10
|
+
|
11
|
+
EDGE_KERNEL = np.array([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]], dtype=np.int8)
|
12
|
+
BIT_DEPTH = (1, 8, 12, 16, 32)
|
13
|
+
|
14
|
+
|
15
|
+
class BitDepth(NamedTuple):
|
16
|
+
depth: int
|
17
|
+
pmin: float | int
|
18
|
+
pmax: float | int
|
19
|
+
|
20
|
+
|
21
|
+
def get_bitdepth(image: NDArray[Any]) -> BitDepth:
|
22
|
+
"""
|
23
|
+
Approximates the bit depth of the image using the
|
24
|
+
min and max pixel values.
|
25
|
+
"""
|
26
|
+
pmin, pmax = np.min(image), np.max(image)
|
27
|
+
if pmin < 0:
|
28
|
+
return BitDepth(0, pmin, pmax)
|
29
|
+
else:
|
30
|
+
depth = ([x for x in BIT_DEPTH if 2**x > pmax] or [max(BIT_DEPTH)])[0]
|
31
|
+
return BitDepth(depth, 0, 2**depth - 1)
|
32
|
+
|
33
|
+
|
34
|
+
def rescale(image: NDArray[Any], depth: int = 1) -> NDArray[Any]:
|
35
|
+
"""
|
36
|
+
Rescales the image using the bit depth provided.
|
37
|
+
"""
|
38
|
+
bitdepth = get_bitdepth(image)
|
39
|
+
if bitdepth.depth == depth:
|
40
|
+
return image
|
41
|
+
else:
|
42
|
+
normalized = (image + bitdepth.pmin) / (bitdepth.pmax - bitdepth.pmin)
|
43
|
+
return normalized * (2**depth - 1)
|
44
|
+
|
45
|
+
|
46
|
+
def normalize_image_shape(image: NDArray[Any]) -> NDArray[Any]:
|
47
|
+
"""
|
48
|
+
Normalizes the image shape into (C,H,W).
|
49
|
+
"""
|
50
|
+
ndim = image.ndim
|
51
|
+
if ndim == 2:
|
52
|
+
return np.expand_dims(image, axis=0)
|
53
|
+
elif ndim == 3:
|
54
|
+
return image
|
55
|
+
elif ndim > 3:
|
56
|
+
# Slice all but the last 3 dimensions
|
57
|
+
return image[(0,) * (ndim - 3)]
|
58
|
+
else:
|
59
|
+
raise ValueError("Images must have 2 or more dimensions.")
|
60
|
+
|
61
|
+
|
62
|
+
def edge_filter(image: ArrayLike, offset: float = 0.5) -> NDArray[np.uint8]:
|
63
|
+
"""
|
64
|
+
Returns the image filtered using a 3x3 edge detection kernel:
|
65
|
+
[[ -1, -1, -1 ],
|
66
|
+
[ -1, 8, -1 ],
|
67
|
+
[ -1, -1, -1 ]]
|
68
|
+
"""
|
69
|
+
edges = convolve2d(image, EDGE_KERNEL, mode="same", boundary="symm") + offset
|
70
|
+
np.clip(edges, 0, 255, edges)
|
71
|
+
return edges
|
dataeval/utils/lazy.py
ADDED
@@ -0,0 +1,26 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from functools import cached_property
|
4
|
+
from importlib import import_module
|
5
|
+
from typing import Any
|
6
|
+
|
7
|
+
|
8
|
+
class LazyModule:
|
9
|
+
def __init__(self, name: str) -> None:
|
10
|
+
self._name = name
|
11
|
+
|
12
|
+
def __getattr__(self, key: str) -> Any:
|
13
|
+
return getattr(self._module, key)
|
14
|
+
|
15
|
+
@cached_property
|
16
|
+
def _module(self):
|
17
|
+
return import_module(self._name)
|
18
|
+
|
19
|
+
|
20
|
+
LAZY_MODULES: dict[str, LazyModule] = {}
|
21
|
+
|
22
|
+
|
23
|
+
def lazyload(name: str) -> LazyModule:
|
24
|
+
if name not in LAZY_MODULES:
|
25
|
+
LAZY_MODULES[name] = LazyModule(name)
|
26
|
+
return LAZY_MODULES[name]
|