dataeval 0.76.0__py3-none-any.whl → 0.81.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +3 -3
- dataeval/{output.py → _output.py} +14 -0
- dataeval/config.py +77 -0
- dataeval/detectors/__init__.py +1 -1
- dataeval/detectors/drift/__init__.py +6 -6
- dataeval/detectors/drift/{base.py → _base.py} +41 -30
- dataeval/detectors/drift/{cvm.py → _cvm.py} +21 -28
- dataeval/detectors/drift/{ks.py → _ks.py} +20 -26
- dataeval/detectors/drift/{mmd.py → _mmd.py} +33 -19
- dataeval/detectors/drift/{torch.py → _torch.py} +2 -1
- dataeval/detectors/drift/{uncertainty.py → _uncertainty.py} +23 -7
- dataeval/detectors/drift/updates.py +1 -1
- dataeval/detectors/linters/__init__.py +0 -3
- dataeval/detectors/linters/duplicates.py +17 -8
- dataeval/detectors/linters/outliers.py +52 -43
- dataeval/detectors/ood/ae.py +29 -8
- dataeval/detectors/ood/base.py +5 -4
- dataeval/detectors/ood/metadata_ks_compare.py +1 -1
- dataeval/detectors/ood/mixin.py +20 -5
- dataeval/detectors/ood/output.py +1 -1
- dataeval/detectors/ood/vae.py +73 -0
- dataeval/metadata/__init__.py +5 -0
- dataeval/metadata/_ood.py +238 -0
- dataeval/metrics/__init__.py +1 -1
- dataeval/metrics/bias/__init__.py +5 -4
- dataeval/metrics/bias/{balance.py → _balance.py} +67 -17
- dataeval/metrics/bias/{coverage.py → _coverage.py} +41 -35
- dataeval/metrics/bias/{diversity.py → _diversity.py} +17 -12
- dataeval/metrics/bias/{parity.py → _parity.py} +89 -63
- dataeval/metrics/estimators/__init__.py +14 -4
- dataeval/metrics/estimators/{ber.py → _ber.py} +42 -11
- dataeval/metrics/estimators/_clusterer.py +104 -0
- dataeval/metrics/estimators/{divergence.py → _divergence.py} +18 -13
- dataeval/metrics/estimators/{uap.py → _uap.py} +4 -4
- dataeval/metrics/stats/__init__.py +7 -7
- dataeval/metrics/stats/{base.py → _base.py} +52 -16
- dataeval/metrics/stats/{boxratiostats.py → _boxratiostats.py} +6 -9
- dataeval/metrics/stats/{datasetstats.py → _datasetstats.py} +10 -14
- dataeval/metrics/stats/{dimensionstats.py → _dimensionstats.py} +6 -5
- dataeval/metrics/stats/{hashstats.py → _hashstats.py} +6 -6
- dataeval/metrics/stats/{labelstats.py → _labelstats.py} +25 -25
- dataeval/metrics/stats/{pixelstats.py → _pixelstats.py} +5 -4
- dataeval/metrics/stats/{visualstats.py → _visualstats.py} +9 -8
- dataeval/typing.py +54 -0
- dataeval/utils/__init__.py +2 -2
- dataeval/utils/_array.py +169 -0
- dataeval/utils/_bin.py +199 -0
- dataeval/utils/_clusterer.py +144 -0
- dataeval/utils/_fast_mst.py +189 -0
- dataeval/utils/{image.py → _image.py} +6 -4
- dataeval/utils/_method.py +18 -0
- dataeval/utils/{shared.py → _mst.py} +3 -65
- dataeval/utils/{plot.py → _plot.py} +4 -4
- dataeval/utils/data/__init__.py +22 -0
- dataeval/utils/data/_embeddings.py +105 -0
- dataeval/utils/data/_images.py +65 -0
- dataeval/utils/data/_metadata.py +352 -0
- dataeval/utils/data/_selection.py +119 -0
- dataeval/utils/{dataset/split.py → data/_split.py} +13 -14
- dataeval/utils/data/_targets.py +73 -0
- dataeval/utils/data/_types.py +58 -0
- dataeval/utils/data/collate.py +103 -0
- dataeval/utils/data/datasets/__init__.py +17 -0
- dataeval/utils/data/datasets/_base.py +254 -0
- dataeval/utils/data/datasets/_cifar10.py +134 -0
- dataeval/utils/data/datasets/_fileio.py +168 -0
- dataeval/utils/data/datasets/_milco.py +153 -0
- dataeval/utils/data/datasets/_mixin.py +56 -0
- dataeval/utils/data/datasets/_mnist.py +183 -0
- dataeval/utils/data/datasets/_ships.py +123 -0
- dataeval/utils/data/datasets/_voc.py +352 -0
- dataeval/utils/data/selections/__init__.py +15 -0
- dataeval/utils/data/selections/_classfilter.py +60 -0
- dataeval/utils/data/selections/_indices.py +26 -0
- dataeval/utils/data/selections/_limit.py +26 -0
- dataeval/utils/data/selections/_reverse.py +18 -0
- dataeval/utils/data/selections/_shuffle.py +29 -0
- dataeval/utils/metadata.py +198 -376
- dataeval/utils/torch/{gmm.py → _gmm.py} +4 -2
- dataeval/utils/torch/{internal.py → _internal.py} +21 -51
- dataeval/utils/torch/models.py +43 -2
- dataeval/workflows/sufficiency.py +10 -9
- {dataeval-0.76.0.dist-info → dataeval-0.81.0.dist-info}/METADATA +44 -15
- dataeval-0.81.0.dist-info/RECORD +94 -0
- dataeval/detectors/linters/clusterer.py +0 -512
- dataeval/detectors/linters/merged_stats.py +0 -49
- dataeval/detectors/ood/metadata_least_likely.py +0 -119
- dataeval/interop.py +0 -69
- dataeval/utils/dataset/__init__.py +0 -7
- dataeval/utils/dataset/datasets.py +0 -412
- dataeval/utils/dataset/read.py +0 -63
- dataeval-0.76.0.dist-info/RECORD +0 -67
- /dataeval/{log.py → _log.py} +0 -0
- /dataeval/utils/torch/{blocks.py → _blocks.py} +0 -0
- {dataeval-0.76.0.dist-info → dataeval-0.81.0.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.76.0.dist-info → dataeval-0.81.0.dist-info}/WHEEL +0 -0
@@ -1,32 +1,31 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
from dataeval.utils.plot import histogram_plot
|
4
|
-
|
5
3
|
__all__ = []
|
6
4
|
|
7
5
|
import re
|
8
6
|
import warnings
|
7
|
+
from copy import deepcopy
|
9
8
|
from dataclasses import dataclass
|
10
9
|
from functools import partial
|
11
10
|
from itertools import repeat
|
12
11
|
from multiprocessing import Pool
|
13
|
-
from typing import Any, Callable, Generic, Iterable,
|
12
|
+
from typing import Any, Callable, Generic, Iterable, Optional, Sequence, Sized, TypeVar, Union
|
14
13
|
|
15
14
|
import numpy as np
|
16
15
|
import tqdm
|
17
|
-
from numpy.typing import
|
16
|
+
from numpy.typing import NDArray
|
18
17
|
|
19
|
-
from dataeval.
|
20
|
-
from dataeval.
|
21
|
-
from dataeval.
|
18
|
+
from dataeval._output import Output
|
19
|
+
from dataeval.config import get_max_processes
|
20
|
+
from dataeval.typing import ArrayLike
|
21
|
+
from dataeval.utils._array import to_numpy_iter
|
22
|
+
from dataeval.utils._image import normalize_image_shape, rescale
|
23
|
+
from dataeval.utils._plot import histogram_plot
|
22
24
|
|
23
25
|
DTYPE_REGEX = re.compile(r"NDArray\[np\.(.*?)\]")
|
24
26
|
SOURCE_INDEX = "source_index"
|
25
27
|
BOX_COUNT = "box_count"
|
26
28
|
|
27
|
-
# TODO: Replace with global config
|
28
|
-
DEFAULT_PROCESSES: int | None = None
|
29
|
-
|
30
29
|
OptionalRange = Optional[Union[int, Iterable[int]]]
|
31
30
|
|
32
31
|
|
@@ -49,7 +48,8 @@ def normalize_box_shape(bounding_box: NDArray[Any]) -> NDArray[Any]:
|
|
49
48
|
return bounding_box
|
50
49
|
|
51
50
|
|
52
|
-
|
51
|
+
@dataclass
|
52
|
+
class SourceIndex:
|
53
53
|
"""
|
54
54
|
Attributes
|
55
55
|
----------
|
@@ -205,7 +205,8 @@ class StatsProcessor(Generic[TStatsOutput]):
|
|
205
205
|
return cls.output_class(**output, source_index=source_index, box_count=np.asarray(box_count, dtype=np.uint16))
|
206
206
|
|
207
207
|
|
208
|
-
|
208
|
+
@dataclass
|
209
|
+
class StatsProcessorOutput:
|
209
210
|
results: list[dict[str, Any]]
|
210
211
|
source_indices: list[SourceIndex]
|
211
212
|
box_counts: list[int]
|
@@ -272,8 +273,6 @@ def run_stats(
|
|
272
273
|
A flag which determines if the states should be evaluated on a per-channel basis or not.
|
273
274
|
stats_processor_cls : Iterable[type[StatsProcessor]]
|
274
275
|
An iterable of stats processor classes that calculate stats and return output classes.
|
275
|
-
processes : int | None, default None
|
276
|
-
Number of processes to use, defaults to None which uses all available CPU cores.
|
277
276
|
|
278
277
|
Returns
|
279
278
|
-------
|
@@ -297,11 +296,11 @@ def run_stats(
|
|
297
296
|
bbox_iter = repeat(None) if bboxes is None else to_numpy_iter(bboxes)
|
298
297
|
|
299
298
|
warning_list = []
|
300
|
-
total_for_status =
|
299
|
+
total_for_status = len(images) if isinstance(images, Sized) else None
|
301
300
|
stats_processor_cls = stats_processor_cls if isinstance(stats_processor_cls, Iterable) else [stats_processor_cls]
|
302
301
|
|
303
302
|
# TODO: Introduce global controls for CPU job parallelism and GPU configurations
|
304
|
-
with Pool(processes=
|
303
|
+
with Pool(processes=get_max_processes()) as p:
|
305
304
|
for r in tqdm.tqdm(
|
306
305
|
p.imap(
|
307
306
|
partial(process_stats_unpack, per_channel=per_channel, stats_processor_cls=stats_processor_cls),
|
@@ -330,3 +329,40 @@ def run_stats(
|
|
330
329
|
|
331
330
|
outputs = [s.convert_output(output, source_index, box_count) for s in stats_processor_cls]
|
332
331
|
return outputs
|
332
|
+
|
333
|
+
|
334
|
+
def add_stats(a: TStatsOutput, b: TStatsOutput) -> TStatsOutput:
|
335
|
+
if type(a) is not type(b):
|
336
|
+
raise TypeError(f"Types {type(a)} and {type(b)} cannot be added.")
|
337
|
+
|
338
|
+
sum_dict = deepcopy(a.dict())
|
339
|
+
|
340
|
+
for k in sum_dict:
|
341
|
+
if isinstance(sum_dict[k], list):
|
342
|
+
sum_dict[k].extend(b.dict()[k])
|
343
|
+
else:
|
344
|
+
sum_dict[k] = np.concatenate((sum_dict[k], b.dict()[k]))
|
345
|
+
|
346
|
+
return type(a)(**sum_dict)
|
347
|
+
|
348
|
+
|
349
|
+
def combine_stats(stats: Sequence[TStatsOutput]) -> tuple[TStatsOutput, list[int]]:
|
350
|
+
output = None
|
351
|
+
dataset_steps = []
|
352
|
+
cur_len = 0
|
353
|
+
for s in stats:
|
354
|
+
output = s if output is None else add_stats(output, s)
|
355
|
+
cur_len += len(s)
|
356
|
+
dataset_steps.append(cur_len)
|
357
|
+
if output is None:
|
358
|
+
raise TypeError("Cannot combine empty sequence of stats.")
|
359
|
+
return output, dataset_steps
|
360
|
+
|
361
|
+
|
362
|
+
def get_dataset_step_from_idx(idx: int, dataset_steps: list[int]) -> tuple[int, int]:
|
363
|
+
last_step = 0
|
364
|
+
for i, step in enumerate(dataset_steps):
|
365
|
+
if idx < step:
|
366
|
+
return i, idx - last_step
|
367
|
+
last_step = step
|
368
|
+
return -1, idx
|
@@ -8,9 +8,9 @@ from typing import Any, Callable, Generic, TypeVar, cast
|
|
8
8
|
import numpy as np
|
9
9
|
from numpy.typing import NDArray
|
10
10
|
|
11
|
-
from dataeval.
|
12
|
-
from dataeval.metrics.stats.
|
13
|
-
from dataeval.
|
11
|
+
from dataeval._output import set_metadata
|
12
|
+
from dataeval.metrics.stats._base import BOX_COUNT, SOURCE_INDEX, BaseStatsOutput
|
13
|
+
from dataeval.metrics.stats._dimensionstats import DimensionStatsOutput
|
14
14
|
|
15
15
|
TStatOutput = TypeVar("TStatOutput", bound=BaseStatsOutput, contravariant=True)
|
16
16
|
ArraySlice = tuple[int, int]
|
@@ -50,7 +50,7 @@ RATIOSTATS_OVERRIDE_MAP: dict[type, dict[str, Callable[..., NDArray[Any]]]] = {
|
|
50
50
|
"depth": lambda x: x.box["depth"],
|
51
51
|
"distance": lambda x: x.box["distance"],
|
52
52
|
}
|
53
|
-
)
|
53
|
+
),
|
54
54
|
}
|
55
55
|
|
56
56
|
|
@@ -87,11 +87,8 @@ def calculate_ratios(key: str, box_stats: BaseStatsOutput, img_stats: BaseStatsO
|
|
87
87
|
stats = BoxImageStatsOutputSlice(box_stats, (box_i, box_j), img_stats, (img_i, img_j))
|
88
88
|
out_type = type(box_stats)
|
89
89
|
use_override = out_type in RATIOSTATS_OVERRIDE_MAP and key in RATIOSTATS_OVERRIDE_MAP[out_type]
|
90
|
-
|
91
|
-
RATIOSTATS_OVERRIDE_MAP[out_type][key](stats)
|
92
|
-
if use_override
|
93
|
-
else np.nan_to_num(stats.box[key] / stats.img[key])
|
94
|
-
)
|
90
|
+
with np.errstate(divide="ignore", invalid="ignore"):
|
91
|
+
ratio = RATIOSTATS_OVERRIDE_MAP[out_type][key](stats) if use_override else stats.box[key] / stats.img[key]
|
95
92
|
out_stats[box_i:box_j] = ratio.reshape(-1, *out_stats[box_i].shape)
|
96
93
|
return out_stats
|
97
94
|
|
@@ -5,24 +5,20 @@ __all__ = []
|
|
5
5
|
from dataclasses import dataclass
|
6
6
|
from typing import Any, Iterable
|
7
7
|
|
8
|
-
from
|
9
|
-
|
10
|
-
from dataeval.metrics.stats.
|
11
|
-
from dataeval.metrics.stats.
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
from dataeval.
|
16
|
-
from dataeval.metrics.stats.pixelstats import PixelStatsOutput, PixelStatsProcessor
|
17
|
-
from dataeval.metrics.stats.visualstats import VisualStatsOutput, VisualStatsProcessor
|
18
|
-
from dataeval.output import Output, set_metadata
|
19
|
-
from dataeval.utils.plot import channel_histogram_plot
|
8
|
+
from dataeval._output import Output, set_metadata
|
9
|
+
from dataeval.metrics.stats._base import BaseStatsOutput, HistogramPlotMixin, _is_plottable, run_stats
|
10
|
+
from dataeval.metrics.stats._dimensionstats import DimensionStatsOutput, DimensionStatsProcessor
|
11
|
+
from dataeval.metrics.stats._labelstats import LabelStatsOutput, labelstats
|
12
|
+
from dataeval.metrics.stats._pixelstats import PixelStatsOutput, PixelStatsProcessor
|
13
|
+
from dataeval.metrics.stats._visualstats import VisualStatsOutput, VisualStatsProcessor
|
14
|
+
from dataeval.typing import ArrayLike
|
15
|
+
from dataeval.utils._plot import channel_histogram_plot
|
20
16
|
|
21
17
|
|
22
18
|
@dataclass(frozen=True)
|
23
19
|
class DatasetStatsOutput(Output, HistogramPlotMixin):
|
24
20
|
"""
|
25
|
-
Output class for :func
|
21
|
+
Output class for :func:`.datasetstats` stats metric.
|
26
22
|
|
27
23
|
This class represents the outputs of various stats functions against a single
|
28
24
|
dataset, such that each index across all stat outputs are representative of
|
@@ -82,7 +78,7 @@ def _get_channels(cls, channel_limit: int | None = None, channel_index: int | It
|
|
82
78
|
@dataclass(frozen=True)
|
83
79
|
class ChannelStatsOutput(Output):
|
84
80
|
"""
|
85
|
-
Output class for :func
|
81
|
+
Output class for :func:`.channelstats` stats metric.
|
86
82
|
|
87
83
|
This class represents the outputs of various per-channel stats functions against
|
88
84
|
a single dataset, such that each index across all stat outputs are representative
|
@@ -6,17 +6,18 @@ from dataclasses import dataclass
|
|
6
6
|
from typing import Any, Callable, Iterable
|
7
7
|
|
8
8
|
import numpy as np
|
9
|
-
from numpy.typing import
|
9
|
+
from numpy.typing import NDArray
|
10
10
|
|
11
|
-
from dataeval.
|
12
|
-
from dataeval.
|
13
|
-
from dataeval.
|
11
|
+
from dataeval._output import set_metadata
|
12
|
+
from dataeval.metrics.stats._base import BaseStatsOutput, HistogramPlotMixin, StatsProcessor, run_stats
|
13
|
+
from dataeval.typing import ArrayLike
|
14
|
+
from dataeval.utils._image import get_bitdepth
|
14
15
|
|
15
16
|
|
16
17
|
@dataclass(frozen=True)
|
17
18
|
class DimensionStatsOutput(BaseStatsOutput, HistogramPlotMixin):
|
18
19
|
"""
|
19
|
-
Output class for :func
|
20
|
+
Output class for :func:`.dimensionstats` stats metric.
|
20
21
|
|
21
22
|
Attributes
|
22
23
|
----------
|
@@ -9,14 +9,14 @@ from typing import Callable, Iterable
|
|
9
9
|
|
10
10
|
import numpy as np
|
11
11
|
import xxhash as xxh
|
12
|
-
from numpy.typing import ArrayLike
|
13
12
|
from PIL import Image
|
14
13
|
from scipy.fftpack import dct
|
15
14
|
|
16
|
-
from dataeval.
|
17
|
-
from dataeval.metrics.stats.
|
18
|
-
from dataeval.
|
19
|
-
from dataeval.utils.
|
15
|
+
from dataeval._output import set_metadata
|
16
|
+
from dataeval.metrics.stats._base import BaseStatsOutput, StatsProcessor, run_stats
|
17
|
+
from dataeval.typing import ArrayLike
|
18
|
+
from dataeval.utils._array import as_numpy
|
19
|
+
from dataeval.utils._image import normalize_image_shape, rescale
|
20
20
|
|
21
21
|
HASH_SIZE = 8
|
22
22
|
MAX_FACTOR = 4
|
@@ -25,7 +25,7 @@ MAX_FACTOR = 4
|
|
25
25
|
@dataclass(frozen=True)
|
26
26
|
class HashStatsOutput(BaseStatsOutput):
|
27
27
|
"""
|
28
|
-
Output class for :func
|
28
|
+
Output class for :func:`.hashstats` stats metric.
|
29
29
|
|
30
30
|
Attributes
|
31
31
|
----------
|
@@ -2,25 +2,25 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
|
-
|
5
|
+
import contextlib
|
6
6
|
from collections import Counter, defaultdict
|
7
7
|
from dataclasses import dataclass
|
8
8
|
from typing import Any, Iterable, Mapping, TypeVar
|
9
9
|
|
10
10
|
import numpy as np
|
11
|
-
from numpy.typing import ArrayLike
|
12
11
|
|
13
|
-
from dataeval.
|
14
|
-
from dataeval.
|
12
|
+
from dataeval._output import Output, set_metadata
|
13
|
+
from dataeval.typing import ArrayLike
|
14
|
+
from dataeval.utils._array import as_numpy
|
15
15
|
|
16
|
-
|
17
|
-
|
16
|
+
with contextlib.suppress(ImportError):
|
17
|
+
import pandas as pd
|
18
18
|
|
19
19
|
|
20
20
|
@dataclass(frozen=True)
|
21
21
|
class LabelStatsOutput(Output):
|
22
22
|
"""
|
23
|
-
Output class for :func
|
23
|
+
Output class for :func:`.labelstats` stats metric.
|
24
24
|
|
25
25
|
Attributes
|
26
26
|
----------
|
@@ -73,24 +73,24 @@ class LabelStatsOutput(Output):
|
|
73
73
|
|
74
74
|
return table_str
|
75
75
|
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
76
|
+
def to_dataframe(self) -> pd.DataFrame:
|
77
|
+
import pandas as pd
|
78
|
+
|
79
|
+
class_list = []
|
80
|
+
total_count = []
|
81
|
+
image_count = []
|
82
|
+
for cls in self.label_counts_per_class:
|
83
|
+
class_list.append(cls)
|
84
|
+
total_count.append(self.label_counts_per_class[cls])
|
85
|
+
image_count.append(self.image_counts_per_label[cls])
|
86
|
+
|
87
|
+
return pd.DataFrame(
|
88
|
+
{
|
89
|
+
"Label": class_list,
|
90
|
+
"Total Count": total_count,
|
91
|
+
"Image Count": image_count,
|
92
|
+
}
|
93
|
+
)
|
94
94
|
|
95
95
|
|
96
96
|
TKey = TypeVar("TKey", int, str)
|
@@ -6,17 +6,18 @@ from dataclasses import dataclass
|
|
6
6
|
from typing import Any, Callable, Iterable
|
7
7
|
|
8
8
|
import numpy as np
|
9
|
-
from numpy.typing import
|
9
|
+
from numpy.typing import NDArray
|
10
10
|
from scipy.stats import entropy, kurtosis, skew
|
11
11
|
|
12
|
-
from dataeval.
|
13
|
-
from dataeval.
|
12
|
+
from dataeval._output import set_metadata
|
13
|
+
from dataeval.metrics.stats._base import BaseStatsOutput, HistogramPlotMixin, StatsProcessor, run_stats
|
14
|
+
from dataeval.typing import ArrayLike
|
14
15
|
|
15
16
|
|
16
17
|
@dataclass(frozen=True)
|
17
18
|
class PixelStatsOutput(BaseStatsOutput, HistogramPlotMixin):
|
18
19
|
"""
|
19
|
-
Output class for :func
|
20
|
+
Output class for :func:`.pixelstats` stats metric.
|
20
21
|
|
21
22
|
Attributes
|
22
23
|
----------
|
@@ -6,11 +6,12 @@ from dataclasses import dataclass
|
|
6
6
|
from typing import Any, Callable, Iterable
|
7
7
|
|
8
8
|
import numpy as np
|
9
|
-
from numpy.typing import
|
9
|
+
from numpy.typing import NDArray
|
10
10
|
|
11
|
-
from dataeval.
|
12
|
-
from dataeval.
|
13
|
-
from dataeval.
|
11
|
+
from dataeval._output import set_metadata
|
12
|
+
from dataeval.metrics.stats._base import BaseStatsOutput, HistogramPlotMixin, StatsProcessor, run_stats
|
13
|
+
from dataeval.typing import ArrayLike
|
14
|
+
from dataeval.utils._image import edge_filter
|
14
15
|
|
15
16
|
QUARTILES = (0, 25, 50, 75, 100)
|
16
17
|
|
@@ -18,7 +19,7 @@ QUARTILES = (0, 25, 50, 75, 100)
|
|
18
19
|
@dataclass(frozen=True)
|
19
20
|
class VisualStatsOutput(BaseStatsOutput, HistogramPlotMixin):
|
20
21
|
"""
|
21
|
-
Output class for :func
|
22
|
+
Output class for :func:`.visualstats` stats metric.
|
22
23
|
|
23
24
|
Attributes
|
24
25
|
----------
|
@@ -53,9 +54,9 @@ class VisualStatsProcessor(StatsProcessor[VisualStatsOutput]):
|
|
53
54
|
output_class: type = VisualStatsOutput
|
54
55
|
image_function_map: dict[str, Callable[[StatsProcessor[VisualStatsOutput]], Any]] = {
|
55
56
|
"brightness": lambda x: x.get("percentiles")[1],
|
56
|
-
"contrast": lambda x:
|
57
|
-
|
58
|
-
),
|
57
|
+
"contrast": lambda x: 0
|
58
|
+
if np.mean(x.get("percentiles")) == 0
|
59
|
+
else (np.max(x.get("percentiles")) - np.min(x.get("percentiles"))) / np.mean(x.get("percentiles")),
|
59
60
|
"darkness": lambda x: x.get("percentiles")[-2],
|
60
61
|
"missing": lambda x: np.count_nonzero(np.isnan(np.sum(x.image, axis=0))) / np.prod(x.shape[-2:]),
|
61
62
|
"sharpness": lambda x: np.std(edge_filter(np.mean(x.image, axis=0))),
|
dataeval/typing.py
ADDED
@@ -0,0 +1,54 @@
|
|
1
|
+
"""
|
2
|
+
Common type hints used for interoperability with DataEval.
|
3
|
+
"""
|
4
|
+
|
5
|
+
__all__ = ["Array", "ArrayLike"]
|
6
|
+
|
7
|
+
from typing import Any, Iterator, Protocol, Sequence, TypeVar, Union, runtime_checkable
|
8
|
+
|
9
|
+
|
10
|
+
@runtime_checkable
|
11
|
+
class Array(Protocol):
|
12
|
+
"""
|
13
|
+
Protocol for array objects providing interoperability with DataEval.
|
14
|
+
|
15
|
+
Supports common array representations with popular libraries like
|
16
|
+
PyTorch, Tensorflow and JAX, as well as NumPy arrays.
|
17
|
+
|
18
|
+
Example
|
19
|
+
-------
|
20
|
+
>>> import numpy as np
|
21
|
+
>>> import torch
|
22
|
+
>>> from dataeval.typing import Array
|
23
|
+
|
24
|
+
Create array objects
|
25
|
+
|
26
|
+
>>> ndarray = np.random.random((10, 10))
|
27
|
+
>>> tensor = torch.tensor([1, 2, 3])
|
28
|
+
|
29
|
+
Check type at runtime
|
30
|
+
|
31
|
+
>>> isinstance(ndarray, Array)
|
32
|
+
True
|
33
|
+
|
34
|
+
>>> isinstance(tensor, Array)
|
35
|
+
True
|
36
|
+
"""
|
37
|
+
|
38
|
+
@property
|
39
|
+
def shape(self) -> tuple[int, ...]: ...
|
40
|
+
def __array__(self) -> Any: ...
|
41
|
+
def __getitem__(self, key: Any, /) -> Any: ...
|
42
|
+
def __iter__(self) -> Iterator[Any]: ...
|
43
|
+
def __len__(self) -> int: ...
|
44
|
+
|
45
|
+
|
46
|
+
TArray = TypeVar("TArray", bound=Array)
|
47
|
+
|
48
|
+
ArrayLike = Union[Sequence[Any], Array]
|
49
|
+
"""
|
50
|
+
Type alias for array-like objects used for interoperability with DataEval.
|
51
|
+
|
52
|
+
This includes native Python sequences, as well as objects that conform to
|
53
|
+
the `Array` protocol.
|
54
|
+
"""
|
dataeval/utils/__init__.py
CHANGED
@@ -4,6 +4,6 @@ in setting up data and architectures that are guaranteed to work with applicable
|
|
4
4
|
DataEval metrics.
|
5
5
|
"""
|
6
6
|
|
7
|
-
__all__ = ["
|
7
|
+
__all__ = ["data", "metadata", "torch"]
|
8
8
|
|
9
|
-
from
|
9
|
+
from . import data, metadata, torch
|
dataeval/utils/_array.py
ADDED
@@ -0,0 +1,169 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__all__ = []
|
4
|
+
|
5
|
+
import logging
|
6
|
+
import warnings
|
7
|
+
from importlib import import_module
|
8
|
+
from types import ModuleType
|
9
|
+
from typing import Any, Iterable, Iterator, Literal, TypeVar, overload
|
10
|
+
|
11
|
+
import numpy as np
|
12
|
+
import torch
|
13
|
+
from numpy.typing import NDArray
|
14
|
+
|
15
|
+
from dataeval._log import LogMessage
|
16
|
+
from dataeval.typing import ArrayLike
|
17
|
+
|
18
|
+
_logger = logging.getLogger(__name__)
|
19
|
+
|
20
|
+
_MODULE_CACHE = {}
|
21
|
+
|
22
|
+
T = TypeVar("T", ArrayLike, np.ndarray, torch.Tensor)
|
23
|
+
_np_dtype = TypeVar("_np_dtype", bound=np.generic)
|
24
|
+
|
25
|
+
|
26
|
+
def _try_import(module_name) -> ModuleType | None:
|
27
|
+
if module_name in _MODULE_CACHE:
|
28
|
+
return _MODULE_CACHE[module_name]
|
29
|
+
|
30
|
+
try:
|
31
|
+
module = import_module(module_name)
|
32
|
+
except ImportError: # pragma: no cover
|
33
|
+
_logger.log(logging.INFO, f"Unable to import {module_name}.")
|
34
|
+
module = None
|
35
|
+
|
36
|
+
_MODULE_CACHE[module_name] = module
|
37
|
+
return module
|
38
|
+
|
39
|
+
|
40
|
+
def as_numpy(array: ArrayLike | None) -> NDArray[Any]:
|
41
|
+
"""Converts an ArrayLike to Numpy array without copying (if possible)"""
|
42
|
+
return to_numpy(array, copy=False)
|
43
|
+
|
44
|
+
|
45
|
+
def to_numpy(array: ArrayLike | None, copy: bool = True) -> NDArray[Any]:
|
46
|
+
"""Converts an ArrayLike to new Numpy array"""
|
47
|
+
if array is None:
|
48
|
+
return np.ndarray([])
|
49
|
+
|
50
|
+
if isinstance(array, np.ndarray):
|
51
|
+
return array.copy() if copy else array
|
52
|
+
|
53
|
+
if array.__class__.__module__.startswith("tensorflow"): # pragma: no cover - removed tf from deps
|
54
|
+
tf = _try_import("tensorflow")
|
55
|
+
if tf and tf.is_tensor(array):
|
56
|
+
_logger.log(logging.INFO, "Converting Tensorflow array to NumPy array.")
|
57
|
+
return array.numpy().copy() if copy else array.numpy() # type: ignore
|
58
|
+
|
59
|
+
if array.__class__.__module__.startswith("torch"):
|
60
|
+
torch = _try_import("torch")
|
61
|
+
if torch and isinstance(array, torch.Tensor):
|
62
|
+
_logger.log(logging.INFO, "Converting PyTorch array to NumPy array.")
|
63
|
+
numpy = array.detach().cpu().numpy().copy() if copy else array.detach().cpu().numpy() # type: ignore
|
64
|
+
_logger.log(logging.DEBUG, LogMessage(lambda: f"{str(array)} -> {str(numpy)}"))
|
65
|
+
return numpy
|
66
|
+
|
67
|
+
return np.array(array) if copy else np.asarray(array)
|
68
|
+
|
69
|
+
|
70
|
+
def to_numpy_iter(iterable: Iterable[ArrayLike]) -> Iterator[NDArray[Any]]:
|
71
|
+
"""Yields an iterator of numpy arrays from an ArrayLike"""
|
72
|
+
for array in iterable:
|
73
|
+
yield to_numpy(array)
|
74
|
+
|
75
|
+
|
76
|
+
@overload
|
77
|
+
def ensure_embeddings(
|
78
|
+
embeddings: T,
|
79
|
+
dtype: torch.dtype,
|
80
|
+
unit_interval: Literal[True, False, "force"] = False,
|
81
|
+
) -> torch.Tensor: ...
|
82
|
+
|
83
|
+
|
84
|
+
@overload
|
85
|
+
def ensure_embeddings(
|
86
|
+
embeddings: T,
|
87
|
+
dtype: type[_np_dtype],
|
88
|
+
unit_interval: Literal[True, False, "force"] = False,
|
89
|
+
) -> NDArray[_np_dtype]: ...
|
90
|
+
|
91
|
+
|
92
|
+
@overload
|
93
|
+
def ensure_embeddings(
|
94
|
+
embeddings: T,
|
95
|
+
dtype: None,
|
96
|
+
unit_interval: Literal[True, False, "force"] = False,
|
97
|
+
) -> T: ...
|
98
|
+
|
99
|
+
|
100
|
+
def ensure_embeddings(
|
101
|
+
embeddings: T,
|
102
|
+
dtype: type[_np_dtype] | torch.dtype | None = None,
|
103
|
+
unit_interval: Literal[True, False, "force"] = False,
|
104
|
+
) -> torch.Tensor | NDArray[_np_dtype] | T:
|
105
|
+
"""
|
106
|
+
Validates the embeddings array and converts it to the specified type
|
107
|
+
|
108
|
+
Parameters
|
109
|
+
----------
|
110
|
+
embeddings : ArrayLike
|
111
|
+
Embeddings array
|
112
|
+
dtype : numpy dtype or torch dtype or None, default None
|
113
|
+
The desired dtype of the output array, None to skip conversion
|
114
|
+
unit_interval : bool or "force", default False
|
115
|
+
Whether to validate or force the embeddings to unit interval
|
116
|
+
|
117
|
+
Returns
|
118
|
+
-------
|
119
|
+
Converted embeddings array
|
120
|
+
|
121
|
+
Raises
|
122
|
+
------
|
123
|
+
ValueError
|
124
|
+
If the embeddings array is not 2D
|
125
|
+
ValueError
|
126
|
+
If the embeddings array is not unit interval [0, 1]
|
127
|
+
"""
|
128
|
+
if isinstance(dtype, torch.dtype):
|
129
|
+
arr = torch.as_tensor(embeddings, dtype=dtype)
|
130
|
+
else:
|
131
|
+
arr = (
|
132
|
+
embeddings.detach().cpu().numpy().astype(dtype)
|
133
|
+
if isinstance(embeddings, torch.Tensor)
|
134
|
+
else np.asarray(embeddings, dtype=dtype)
|
135
|
+
)
|
136
|
+
|
137
|
+
if arr.ndim != 2:
|
138
|
+
raise ValueError(f"Expected a 2D array, but got a {arr.ndim}D array.")
|
139
|
+
|
140
|
+
if unit_interval:
|
141
|
+
arr_min, arr_max = arr.min(), arr.max()
|
142
|
+
if arr_min < 0 or arr_max > 1:
|
143
|
+
if unit_interval == "force":
|
144
|
+
warnings.warn("Embeddings are not unit interval [0, 1]. Forcing to unit interval.")
|
145
|
+
arr = (arr - arr_min) / (arr_max - arr_min)
|
146
|
+
else:
|
147
|
+
raise ValueError("Embeddings must be unit interval [0, 1].")
|
148
|
+
|
149
|
+
if dtype is None:
|
150
|
+
return embeddings
|
151
|
+
else:
|
152
|
+
return arr
|
153
|
+
|
154
|
+
|
155
|
+
def flatten(array: ArrayLike) -> NDArray[Any]:
|
156
|
+
"""
|
157
|
+
Flattens input array from (N, ... ) to (N, -1) where all samples N have all data in their last dimension
|
158
|
+
|
159
|
+
Parameters
|
160
|
+
----------
|
161
|
+
X : NDArray, shape - (N, ... )
|
162
|
+
Input array
|
163
|
+
|
164
|
+
Returns
|
165
|
+
-------
|
166
|
+
NDArray, shape - (N, -1)
|
167
|
+
"""
|
168
|
+
nparr = as_numpy(array)
|
169
|
+
return nparr.reshape((nparr.shape[0], -1))
|