dataeval 0.76.0__py3-none-any.whl → 0.81.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. dataeval/__init__.py +3 -3
  2. dataeval/{output.py → _output.py} +14 -0
  3. dataeval/config.py +77 -0
  4. dataeval/detectors/__init__.py +1 -1
  5. dataeval/detectors/drift/__init__.py +6 -6
  6. dataeval/detectors/drift/{base.py → _base.py} +41 -30
  7. dataeval/detectors/drift/{cvm.py → _cvm.py} +21 -28
  8. dataeval/detectors/drift/{ks.py → _ks.py} +20 -26
  9. dataeval/detectors/drift/{mmd.py → _mmd.py} +33 -19
  10. dataeval/detectors/drift/{torch.py → _torch.py} +2 -1
  11. dataeval/detectors/drift/{uncertainty.py → _uncertainty.py} +23 -7
  12. dataeval/detectors/drift/updates.py +1 -1
  13. dataeval/detectors/linters/__init__.py +0 -3
  14. dataeval/detectors/linters/duplicates.py +17 -8
  15. dataeval/detectors/linters/outliers.py +52 -43
  16. dataeval/detectors/ood/ae.py +29 -8
  17. dataeval/detectors/ood/base.py +5 -4
  18. dataeval/detectors/ood/metadata_ks_compare.py +1 -1
  19. dataeval/detectors/ood/mixin.py +20 -5
  20. dataeval/detectors/ood/output.py +1 -1
  21. dataeval/detectors/ood/vae.py +73 -0
  22. dataeval/metadata/__init__.py +5 -0
  23. dataeval/metadata/_ood.py +238 -0
  24. dataeval/metrics/__init__.py +1 -1
  25. dataeval/metrics/bias/__init__.py +5 -4
  26. dataeval/metrics/bias/{balance.py → _balance.py} +67 -17
  27. dataeval/metrics/bias/{coverage.py → _coverage.py} +41 -35
  28. dataeval/metrics/bias/{diversity.py → _diversity.py} +17 -12
  29. dataeval/metrics/bias/{parity.py → _parity.py} +89 -63
  30. dataeval/metrics/estimators/__init__.py +14 -4
  31. dataeval/metrics/estimators/{ber.py → _ber.py} +42 -11
  32. dataeval/metrics/estimators/_clusterer.py +104 -0
  33. dataeval/metrics/estimators/{divergence.py → _divergence.py} +18 -13
  34. dataeval/metrics/estimators/{uap.py → _uap.py} +4 -4
  35. dataeval/metrics/stats/__init__.py +7 -7
  36. dataeval/metrics/stats/{base.py → _base.py} +52 -16
  37. dataeval/metrics/stats/{boxratiostats.py → _boxratiostats.py} +6 -9
  38. dataeval/metrics/stats/{datasetstats.py → _datasetstats.py} +10 -14
  39. dataeval/metrics/stats/{dimensionstats.py → _dimensionstats.py} +6 -5
  40. dataeval/metrics/stats/{hashstats.py → _hashstats.py} +6 -6
  41. dataeval/metrics/stats/{labelstats.py → _labelstats.py} +25 -25
  42. dataeval/metrics/stats/{pixelstats.py → _pixelstats.py} +5 -4
  43. dataeval/metrics/stats/{visualstats.py → _visualstats.py} +9 -8
  44. dataeval/typing.py +54 -0
  45. dataeval/utils/__init__.py +2 -2
  46. dataeval/utils/_array.py +169 -0
  47. dataeval/utils/_bin.py +199 -0
  48. dataeval/utils/_clusterer.py +144 -0
  49. dataeval/utils/_fast_mst.py +189 -0
  50. dataeval/utils/{image.py → _image.py} +6 -4
  51. dataeval/utils/_method.py +18 -0
  52. dataeval/utils/{shared.py → _mst.py} +3 -65
  53. dataeval/utils/{plot.py → _plot.py} +4 -4
  54. dataeval/utils/data/__init__.py +22 -0
  55. dataeval/utils/data/_embeddings.py +105 -0
  56. dataeval/utils/data/_images.py +65 -0
  57. dataeval/utils/data/_metadata.py +352 -0
  58. dataeval/utils/data/_selection.py +119 -0
  59. dataeval/utils/{dataset/split.py → data/_split.py} +13 -14
  60. dataeval/utils/data/_targets.py +73 -0
  61. dataeval/utils/data/_types.py +58 -0
  62. dataeval/utils/data/collate.py +103 -0
  63. dataeval/utils/data/datasets/__init__.py +17 -0
  64. dataeval/utils/data/datasets/_base.py +254 -0
  65. dataeval/utils/data/datasets/_cifar10.py +134 -0
  66. dataeval/utils/data/datasets/_fileio.py +168 -0
  67. dataeval/utils/data/datasets/_milco.py +153 -0
  68. dataeval/utils/data/datasets/_mixin.py +56 -0
  69. dataeval/utils/data/datasets/_mnist.py +183 -0
  70. dataeval/utils/data/datasets/_ships.py +123 -0
  71. dataeval/utils/data/datasets/_voc.py +352 -0
  72. dataeval/utils/data/selections/__init__.py +15 -0
  73. dataeval/utils/data/selections/_classfilter.py +60 -0
  74. dataeval/utils/data/selections/_indices.py +26 -0
  75. dataeval/utils/data/selections/_limit.py +26 -0
  76. dataeval/utils/data/selections/_reverse.py +18 -0
  77. dataeval/utils/data/selections/_shuffle.py +29 -0
  78. dataeval/utils/metadata.py +198 -376
  79. dataeval/utils/torch/{gmm.py → _gmm.py} +4 -2
  80. dataeval/utils/torch/{internal.py → _internal.py} +21 -51
  81. dataeval/utils/torch/models.py +43 -2
  82. dataeval/workflows/sufficiency.py +10 -9
  83. {dataeval-0.76.0.dist-info → dataeval-0.81.0.dist-info}/METADATA +44 -15
  84. dataeval-0.81.0.dist-info/RECORD +94 -0
  85. dataeval/detectors/linters/clusterer.py +0 -512
  86. dataeval/detectors/linters/merged_stats.py +0 -49
  87. dataeval/detectors/ood/metadata_least_likely.py +0 -119
  88. dataeval/interop.py +0 -69
  89. dataeval/utils/dataset/__init__.py +0 -7
  90. dataeval/utils/dataset/datasets.py +0 -412
  91. dataeval/utils/dataset/read.py +0 -63
  92. dataeval-0.76.0.dist-info/RECORD +0 -67
  93. /dataeval/{log.py → _log.py} +0 -0
  94. /dataeval/utils/torch/{blocks.py → _blocks.py} +0 -0
  95. {dataeval-0.76.0.dist-info → dataeval-0.81.0.dist-info}/LICENSE.txt +0 -0
  96. {dataeval-0.76.0.dist-info → dataeval-0.81.0.dist-info}/WHEEL +0 -0
@@ -1,32 +1,31 @@
1
1
  from __future__ import annotations
2
2
 
3
- from dataeval.utils.plot import histogram_plot
4
-
5
3
  __all__ = []
6
4
 
7
5
  import re
8
6
  import warnings
7
+ from copy import deepcopy
9
8
  from dataclasses import dataclass
10
9
  from functools import partial
11
10
  from itertools import repeat
12
11
  from multiprocessing import Pool
13
- from typing import Any, Callable, Generic, Iterable, NamedTuple, Optional, TypeVar, Union
12
+ from typing import Any, Callable, Generic, Iterable, Optional, Sequence, Sized, TypeVar, Union
14
13
 
15
14
  import numpy as np
16
15
  import tqdm
17
- from numpy.typing import ArrayLike, NDArray
16
+ from numpy.typing import NDArray
18
17
 
19
- from dataeval.interop import to_numpy_iter
20
- from dataeval.output import Output
21
- from dataeval.utils.image import normalize_image_shape, rescale
18
+ from dataeval._output import Output
19
+ from dataeval.config import get_max_processes
20
+ from dataeval.typing import ArrayLike
21
+ from dataeval.utils._array import to_numpy_iter
22
+ from dataeval.utils._image import normalize_image_shape, rescale
23
+ from dataeval.utils._plot import histogram_plot
22
24
 
23
25
  DTYPE_REGEX = re.compile(r"NDArray\[np\.(.*?)\]")
24
26
  SOURCE_INDEX = "source_index"
25
27
  BOX_COUNT = "box_count"
26
28
 
27
- # TODO: Replace with global config
28
- DEFAULT_PROCESSES: int | None = None
29
-
30
29
  OptionalRange = Optional[Union[int, Iterable[int]]]
31
30
 
32
31
 
@@ -49,7 +48,8 @@ def normalize_box_shape(bounding_box: NDArray[Any]) -> NDArray[Any]:
49
48
  return bounding_box
50
49
 
51
50
 
52
- class SourceIndex(NamedTuple):
51
+ @dataclass
52
+ class SourceIndex:
53
53
  """
54
54
  Attributes
55
55
  ----------
@@ -205,7 +205,8 @@ class StatsProcessor(Generic[TStatsOutput]):
205
205
  return cls.output_class(**output, source_index=source_index, box_count=np.asarray(box_count, dtype=np.uint16))
206
206
 
207
207
 
208
- class StatsProcessorOutput(NamedTuple):
208
+ @dataclass
209
+ class StatsProcessorOutput:
209
210
  results: list[dict[str, Any]]
210
211
  source_indices: list[SourceIndex]
211
212
  box_counts: list[int]
@@ -272,8 +273,6 @@ def run_stats(
272
273
  A flag which determines if the states should be evaluated on a per-channel basis or not.
273
274
  stats_processor_cls : Iterable[type[StatsProcessor]]
274
275
  An iterable of stats processor classes that calculate stats and return output classes.
275
- processes : int | None, default None
276
- Number of processes to use, defaults to None which uses all available CPU cores.
277
276
 
278
277
  Returns
279
278
  -------
@@ -297,11 +296,11 @@ def run_stats(
297
296
  bbox_iter = repeat(None) if bboxes is None else to_numpy_iter(bboxes)
298
297
 
299
298
  warning_list = []
300
- total_for_status = getattr(images, "__len__")() if hasattr(images, "__len__") else None
299
+ total_for_status = len(images) if isinstance(images, Sized) else None
301
300
  stats_processor_cls = stats_processor_cls if isinstance(stats_processor_cls, Iterable) else [stats_processor_cls]
302
301
 
303
302
  # TODO: Introduce global controls for CPU job parallelism and GPU configurations
304
- with Pool(processes=DEFAULT_PROCESSES) as p:
303
+ with Pool(processes=get_max_processes()) as p:
305
304
  for r in tqdm.tqdm(
306
305
  p.imap(
307
306
  partial(process_stats_unpack, per_channel=per_channel, stats_processor_cls=stats_processor_cls),
@@ -330,3 +329,40 @@ def run_stats(
330
329
 
331
330
  outputs = [s.convert_output(output, source_index, box_count) for s in stats_processor_cls]
332
331
  return outputs
332
+
333
+
334
+ def add_stats(a: TStatsOutput, b: TStatsOutput) -> TStatsOutput:
335
+ if type(a) is not type(b):
336
+ raise TypeError(f"Types {type(a)} and {type(b)} cannot be added.")
337
+
338
+ sum_dict = deepcopy(a.dict())
339
+
340
+ for k in sum_dict:
341
+ if isinstance(sum_dict[k], list):
342
+ sum_dict[k].extend(b.dict()[k])
343
+ else:
344
+ sum_dict[k] = np.concatenate((sum_dict[k], b.dict()[k]))
345
+
346
+ return type(a)(**sum_dict)
347
+
348
+
349
+ def combine_stats(stats: Sequence[TStatsOutput]) -> tuple[TStatsOutput, list[int]]:
350
+ output = None
351
+ dataset_steps = []
352
+ cur_len = 0
353
+ for s in stats:
354
+ output = s if output is None else add_stats(output, s)
355
+ cur_len += len(s)
356
+ dataset_steps.append(cur_len)
357
+ if output is None:
358
+ raise TypeError("Cannot combine empty sequence of stats.")
359
+ return output, dataset_steps
360
+
361
+
362
+ def get_dataset_step_from_idx(idx: int, dataset_steps: list[int]) -> tuple[int, int]:
363
+ last_step = 0
364
+ for i, step in enumerate(dataset_steps):
365
+ if idx < step:
366
+ return i, idx - last_step
367
+ last_step = step
368
+ return -1, idx
@@ -8,9 +8,9 @@ from typing import Any, Callable, Generic, TypeVar, cast
8
8
  import numpy as np
9
9
  from numpy.typing import NDArray
10
10
 
11
- from dataeval.metrics.stats.base import BOX_COUNT, SOURCE_INDEX, BaseStatsOutput
12
- from dataeval.metrics.stats.dimensionstats import DimensionStatsOutput
13
- from dataeval.output import set_metadata
11
+ from dataeval._output import set_metadata
12
+ from dataeval.metrics.stats._base import BOX_COUNT, SOURCE_INDEX, BaseStatsOutput
13
+ from dataeval.metrics.stats._dimensionstats import DimensionStatsOutput
14
14
 
15
15
  TStatOutput = TypeVar("TStatOutput", bound=BaseStatsOutput, contravariant=True)
16
16
  ArraySlice = tuple[int, int]
@@ -50,7 +50,7 @@ RATIOSTATS_OVERRIDE_MAP: dict[type, dict[str, Callable[..., NDArray[Any]]]] = {
50
50
  "depth": lambda x: x.box["depth"],
51
51
  "distance": lambda x: x.box["distance"],
52
52
  }
53
- )
53
+ ),
54
54
  }
55
55
 
56
56
 
@@ -87,11 +87,8 @@ def calculate_ratios(key: str, box_stats: BaseStatsOutput, img_stats: BaseStatsO
87
87
  stats = BoxImageStatsOutputSlice(box_stats, (box_i, box_j), img_stats, (img_i, img_j))
88
88
  out_type = type(box_stats)
89
89
  use_override = out_type in RATIOSTATS_OVERRIDE_MAP and key in RATIOSTATS_OVERRIDE_MAP[out_type]
90
- ratio = (
91
- RATIOSTATS_OVERRIDE_MAP[out_type][key](stats)
92
- if use_override
93
- else np.nan_to_num(stats.box[key] / stats.img[key])
94
- )
90
+ with np.errstate(divide="ignore", invalid="ignore"):
91
+ ratio = RATIOSTATS_OVERRIDE_MAP[out_type][key](stats) if use_override else stats.box[key] / stats.img[key]
95
92
  out_stats[box_i:box_j] = ratio.reshape(-1, *out_stats[box_i].shape)
96
93
  return out_stats
97
94
 
@@ -5,24 +5,20 @@ __all__ = []
5
5
  from dataclasses import dataclass
6
6
  from typing import Any, Iterable
7
7
 
8
- from numpy.typing import ArrayLike
9
-
10
- from dataeval.metrics.stats.base import BaseStatsOutput, HistogramPlotMixin, _is_plottable, run_stats
11
- from dataeval.metrics.stats.dimensionstats import (
12
- DimensionStatsOutput,
13
- DimensionStatsProcessor,
14
- )
15
- from dataeval.metrics.stats.labelstats import LabelStatsOutput, labelstats
16
- from dataeval.metrics.stats.pixelstats import PixelStatsOutput, PixelStatsProcessor
17
- from dataeval.metrics.stats.visualstats import VisualStatsOutput, VisualStatsProcessor
18
- from dataeval.output import Output, set_metadata
19
- from dataeval.utils.plot import channel_histogram_plot
8
+ from dataeval._output import Output, set_metadata
9
+ from dataeval.metrics.stats._base import BaseStatsOutput, HistogramPlotMixin, _is_plottable, run_stats
10
+ from dataeval.metrics.stats._dimensionstats import DimensionStatsOutput, DimensionStatsProcessor
11
+ from dataeval.metrics.stats._labelstats import LabelStatsOutput, labelstats
12
+ from dataeval.metrics.stats._pixelstats import PixelStatsOutput, PixelStatsProcessor
13
+ from dataeval.metrics.stats._visualstats import VisualStatsOutput, VisualStatsProcessor
14
+ from dataeval.typing import ArrayLike
15
+ from dataeval.utils._plot import channel_histogram_plot
20
16
 
21
17
 
22
18
  @dataclass(frozen=True)
23
19
  class DatasetStatsOutput(Output, HistogramPlotMixin):
24
20
  """
25
- Output class for :func:`datasetstats` stats metric.
21
+ Output class for :func:`.datasetstats` stats metric.
26
22
 
27
23
  This class represents the outputs of various stats functions against a single
28
24
  dataset, such that each index across all stat outputs are representative of
@@ -82,7 +78,7 @@ def _get_channels(cls, channel_limit: int | None = None, channel_index: int | It
82
78
  @dataclass(frozen=True)
83
79
  class ChannelStatsOutput(Output):
84
80
  """
85
- Output class for :func:`channelstats` stats metric.
81
+ Output class for :func:`.channelstats` stats metric.
86
82
 
87
83
  This class represents the outputs of various per-channel stats functions against
88
84
  a single dataset, such that each index across all stat outputs are representative
@@ -6,17 +6,18 @@ from dataclasses import dataclass
6
6
  from typing import Any, Callable, Iterable
7
7
 
8
8
  import numpy as np
9
- from numpy.typing import ArrayLike, NDArray
9
+ from numpy.typing import NDArray
10
10
 
11
- from dataeval.metrics.stats.base import BaseStatsOutput, HistogramPlotMixin, StatsProcessor, run_stats
12
- from dataeval.output import set_metadata
13
- from dataeval.utils.image import get_bitdepth
11
+ from dataeval._output import set_metadata
12
+ from dataeval.metrics.stats._base import BaseStatsOutput, HistogramPlotMixin, StatsProcessor, run_stats
13
+ from dataeval.typing import ArrayLike
14
+ from dataeval.utils._image import get_bitdepth
14
15
 
15
16
 
16
17
  @dataclass(frozen=True)
17
18
  class DimensionStatsOutput(BaseStatsOutput, HistogramPlotMixin):
18
19
  """
19
- Output class for :func:`dimensionstats` stats metric.
20
+ Output class for :func:`.dimensionstats` stats metric.
20
21
 
21
22
  Attributes
22
23
  ----------
@@ -9,14 +9,14 @@ from typing import Callable, Iterable
9
9
 
10
10
  import numpy as np
11
11
  import xxhash as xxh
12
- from numpy.typing import ArrayLike
13
12
  from PIL import Image
14
13
  from scipy.fftpack import dct
15
14
 
16
- from dataeval.interop import as_numpy
17
- from dataeval.metrics.stats.base import BaseStatsOutput, StatsProcessor, run_stats
18
- from dataeval.output import set_metadata
19
- from dataeval.utils.image import normalize_image_shape, rescale
15
+ from dataeval._output import set_metadata
16
+ from dataeval.metrics.stats._base import BaseStatsOutput, StatsProcessor, run_stats
17
+ from dataeval.typing import ArrayLike
18
+ from dataeval.utils._array import as_numpy
19
+ from dataeval.utils._image import normalize_image_shape, rescale
20
20
 
21
21
  HASH_SIZE = 8
22
22
  MAX_FACTOR = 4
@@ -25,7 +25,7 @@ MAX_FACTOR = 4
25
25
  @dataclass(frozen=True)
26
26
  class HashStatsOutput(BaseStatsOutput):
27
27
  """
28
- Output class for :func:`hashstats` stats metric.
28
+ Output class for :func:`.hashstats` stats metric.
29
29
 
30
30
  Attributes
31
31
  ----------
@@ -2,25 +2,25 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
- # import contextlib
5
+ import contextlib
6
6
  from collections import Counter, defaultdict
7
7
  from dataclasses import dataclass
8
8
  from typing import Any, Iterable, Mapping, TypeVar
9
9
 
10
10
  import numpy as np
11
- from numpy.typing import ArrayLike
12
11
 
13
- from dataeval.interop import as_numpy
14
- from dataeval.output import Output, set_metadata
12
+ from dataeval._output import Output, set_metadata
13
+ from dataeval.typing import ArrayLike
14
+ from dataeval.utils._array import as_numpy
15
15
 
16
- # with contextlib.suppress(ImportError):
17
- # import pandas as pd
16
+ with contextlib.suppress(ImportError):
17
+ import pandas as pd
18
18
 
19
19
 
20
20
  @dataclass(frozen=True)
21
21
  class LabelStatsOutput(Output):
22
22
  """
23
- Output class for :func:`labelstats` stats metric.
23
+ Output class for :func:`.labelstats` stats metric.
24
24
 
25
25
  Attributes
26
26
  ----------
@@ -73,24 +73,24 @@ class LabelStatsOutput(Output):
73
73
 
74
74
  return table_str
75
75
 
76
- # def to_dataframe(self) -> pd.DataFrame:
77
- # import pandas as pd
78
-
79
- # class_list = []
80
- # total_count = []
81
- # image_count = []
82
- # for cls in self.label_counts_per_class:
83
- # class_list.append(cls)
84
- # total_count.append(self.label_counts_per_class[cls])
85
- # image_count.append(self.image_counts_per_label[cls])
86
-
87
- # return pd.DataFrame(
88
- # {
89
- # "Label": class_list,
90
- # "Total Count": total_count,
91
- # "Image Count": image_count,
92
- # }
93
- # )
76
+ def to_dataframe(self) -> pd.DataFrame:
77
+ import pandas as pd
78
+
79
+ class_list = []
80
+ total_count = []
81
+ image_count = []
82
+ for cls in self.label_counts_per_class:
83
+ class_list.append(cls)
84
+ total_count.append(self.label_counts_per_class[cls])
85
+ image_count.append(self.image_counts_per_label[cls])
86
+
87
+ return pd.DataFrame(
88
+ {
89
+ "Label": class_list,
90
+ "Total Count": total_count,
91
+ "Image Count": image_count,
92
+ }
93
+ )
94
94
 
95
95
 
96
96
  TKey = TypeVar("TKey", int, str)
@@ -6,17 +6,18 @@ from dataclasses import dataclass
6
6
  from typing import Any, Callable, Iterable
7
7
 
8
8
  import numpy as np
9
- from numpy.typing import ArrayLike, NDArray
9
+ from numpy.typing import NDArray
10
10
  from scipy.stats import entropy, kurtosis, skew
11
11
 
12
- from dataeval.metrics.stats.base import BaseStatsOutput, HistogramPlotMixin, StatsProcessor, run_stats
13
- from dataeval.output import set_metadata
12
+ from dataeval._output import set_metadata
13
+ from dataeval.metrics.stats._base import BaseStatsOutput, HistogramPlotMixin, StatsProcessor, run_stats
14
+ from dataeval.typing import ArrayLike
14
15
 
15
16
 
16
17
  @dataclass(frozen=True)
17
18
  class PixelStatsOutput(BaseStatsOutput, HistogramPlotMixin):
18
19
  """
19
- Output class for :func:`pixelstats` stats metric.
20
+ Output class for :func:`.pixelstats` stats metric.
20
21
 
21
22
  Attributes
22
23
  ----------
@@ -6,11 +6,12 @@ from dataclasses import dataclass
6
6
  from typing import Any, Callable, Iterable
7
7
 
8
8
  import numpy as np
9
- from numpy.typing import ArrayLike, NDArray
9
+ from numpy.typing import NDArray
10
10
 
11
- from dataeval.metrics.stats.base import BaseStatsOutput, HistogramPlotMixin, StatsProcessor, run_stats
12
- from dataeval.output import set_metadata
13
- from dataeval.utils.image import edge_filter
11
+ from dataeval._output import set_metadata
12
+ from dataeval.metrics.stats._base import BaseStatsOutput, HistogramPlotMixin, StatsProcessor, run_stats
13
+ from dataeval.typing import ArrayLike
14
+ from dataeval.utils._image import edge_filter
14
15
 
15
16
  QUARTILES = (0, 25, 50, 75, 100)
16
17
 
@@ -18,7 +19,7 @@ QUARTILES = (0, 25, 50, 75, 100)
18
19
  @dataclass(frozen=True)
19
20
  class VisualStatsOutput(BaseStatsOutput, HistogramPlotMixin):
20
21
  """
21
- Output class for :func:`visualstats` stats metric.
22
+ Output class for :func:`.visualstats` stats metric.
22
23
 
23
24
  Attributes
24
25
  ----------
@@ -53,9 +54,9 @@ class VisualStatsProcessor(StatsProcessor[VisualStatsOutput]):
53
54
  output_class: type = VisualStatsOutput
54
55
  image_function_map: dict[str, Callable[[StatsProcessor[VisualStatsOutput]], Any]] = {
55
56
  "brightness": lambda x: x.get("percentiles")[1],
56
- "contrast": lambda x: np.nan_to_num(
57
- (np.max(x.get("percentiles")) - np.min(x.get("percentiles"))) / np.mean(x.get("percentiles"))
58
- ),
57
+ "contrast": lambda x: 0
58
+ if np.mean(x.get("percentiles")) == 0
59
+ else (np.max(x.get("percentiles")) - np.min(x.get("percentiles"))) / np.mean(x.get("percentiles")),
59
60
  "darkness": lambda x: x.get("percentiles")[-2],
60
61
  "missing": lambda x: np.count_nonzero(np.isnan(np.sum(x.image, axis=0))) / np.prod(x.shape[-2:]),
61
62
  "sharpness": lambda x: np.std(edge_filter(np.mean(x.image, axis=0))),
dataeval/typing.py ADDED
@@ -0,0 +1,54 @@
1
+ """
2
+ Common type hints used for interoperability with DataEval.
3
+ """
4
+
5
+ __all__ = ["Array", "ArrayLike"]
6
+
7
+ from typing import Any, Iterator, Protocol, Sequence, TypeVar, Union, runtime_checkable
8
+
9
+
10
+ @runtime_checkable
11
+ class Array(Protocol):
12
+ """
13
+ Protocol for array objects providing interoperability with DataEval.
14
+
15
+ Supports common array representations with popular libraries like
16
+ PyTorch, Tensorflow and JAX, as well as NumPy arrays.
17
+
18
+ Example
19
+ -------
20
+ >>> import numpy as np
21
+ >>> import torch
22
+ >>> from dataeval.typing import Array
23
+
24
+ Create array objects
25
+
26
+ >>> ndarray = np.random.random((10, 10))
27
+ >>> tensor = torch.tensor([1, 2, 3])
28
+
29
+ Check type at runtime
30
+
31
+ >>> isinstance(ndarray, Array)
32
+ True
33
+
34
+ >>> isinstance(tensor, Array)
35
+ True
36
+ """
37
+
38
+ @property
39
+ def shape(self) -> tuple[int, ...]: ...
40
+ def __array__(self) -> Any: ...
41
+ def __getitem__(self, key: Any, /) -> Any: ...
42
+ def __iter__(self) -> Iterator[Any]: ...
43
+ def __len__(self) -> int: ...
44
+
45
+
46
+ TArray = TypeVar("TArray", bound=Array)
47
+
48
+ ArrayLike = Union[Sequence[Any], Array]
49
+ """
50
+ Type alias for array-like objects used for interoperability with DataEval.
51
+
52
+ This includes native Python sequences, as well as objects that conform to
53
+ the `Array` protocol.
54
+ """
@@ -4,6 +4,6 @@ in setting up data and architectures that are guaranteed to work with applicable
4
4
  DataEval metrics.
5
5
  """
6
6
 
7
- __all__ = ["dataset", "metadata", "torch"]
7
+ __all__ = ["data", "metadata", "torch"]
8
8
 
9
- from dataeval.utils import dataset, metadata, torch
9
+ from . import data, metadata, torch
@@ -0,0 +1,169 @@
1
+ from __future__ import annotations
2
+
3
+ __all__ = []
4
+
5
+ import logging
6
+ import warnings
7
+ from importlib import import_module
8
+ from types import ModuleType
9
+ from typing import Any, Iterable, Iterator, Literal, TypeVar, overload
10
+
11
+ import numpy as np
12
+ import torch
13
+ from numpy.typing import NDArray
14
+
15
+ from dataeval._log import LogMessage
16
+ from dataeval.typing import ArrayLike
17
+
18
+ _logger = logging.getLogger(__name__)
19
+
20
+ _MODULE_CACHE = {}
21
+
22
+ T = TypeVar("T", ArrayLike, np.ndarray, torch.Tensor)
23
+ _np_dtype = TypeVar("_np_dtype", bound=np.generic)
24
+
25
+
26
+ def _try_import(module_name) -> ModuleType | None:
27
+ if module_name in _MODULE_CACHE:
28
+ return _MODULE_CACHE[module_name]
29
+
30
+ try:
31
+ module = import_module(module_name)
32
+ except ImportError: # pragma: no cover
33
+ _logger.log(logging.INFO, f"Unable to import {module_name}.")
34
+ module = None
35
+
36
+ _MODULE_CACHE[module_name] = module
37
+ return module
38
+
39
+
40
+ def as_numpy(array: ArrayLike | None) -> NDArray[Any]:
41
+ """Converts an ArrayLike to Numpy array without copying (if possible)"""
42
+ return to_numpy(array, copy=False)
43
+
44
+
45
+ def to_numpy(array: ArrayLike | None, copy: bool = True) -> NDArray[Any]:
46
+ """Converts an ArrayLike to new Numpy array"""
47
+ if array is None:
48
+ return np.ndarray([])
49
+
50
+ if isinstance(array, np.ndarray):
51
+ return array.copy() if copy else array
52
+
53
+ if array.__class__.__module__.startswith("tensorflow"): # pragma: no cover - removed tf from deps
54
+ tf = _try_import("tensorflow")
55
+ if tf and tf.is_tensor(array):
56
+ _logger.log(logging.INFO, "Converting Tensorflow array to NumPy array.")
57
+ return array.numpy().copy() if copy else array.numpy() # type: ignore
58
+
59
+ if array.__class__.__module__.startswith("torch"):
60
+ torch = _try_import("torch")
61
+ if torch and isinstance(array, torch.Tensor):
62
+ _logger.log(logging.INFO, "Converting PyTorch array to NumPy array.")
63
+ numpy = array.detach().cpu().numpy().copy() if copy else array.detach().cpu().numpy() # type: ignore
64
+ _logger.log(logging.DEBUG, LogMessage(lambda: f"{str(array)} -> {str(numpy)}"))
65
+ return numpy
66
+
67
+ return np.array(array) if copy else np.asarray(array)
68
+
69
+
70
+ def to_numpy_iter(iterable: Iterable[ArrayLike]) -> Iterator[NDArray[Any]]:
71
+ """Yields an iterator of numpy arrays from an ArrayLike"""
72
+ for array in iterable:
73
+ yield to_numpy(array)
74
+
75
+
76
+ @overload
77
+ def ensure_embeddings(
78
+ embeddings: T,
79
+ dtype: torch.dtype,
80
+ unit_interval: Literal[True, False, "force"] = False,
81
+ ) -> torch.Tensor: ...
82
+
83
+
84
+ @overload
85
+ def ensure_embeddings(
86
+ embeddings: T,
87
+ dtype: type[_np_dtype],
88
+ unit_interval: Literal[True, False, "force"] = False,
89
+ ) -> NDArray[_np_dtype]: ...
90
+
91
+
92
+ @overload
93
+ def ensure_embeddings(
94
+ embeddings: T,
95
+ dtype: None,
96
+ unit_interval: Literal[True, False, "force"] = False,
97
+ ) -> T: ...
98
+
99
+
100
+ def ensure_embeddings(
101
+ embeddings: T,
102
+ dtype: type[_np_dtype] | torch.dtype | None = None,
103
+ unit_interval: Literal[True, False, "force"] = False,
104
+ ) -> torch.Tensor | NDArray[_np_dtype] | T:
105
+ """
106
+ Validates the embeddings array and converts it to the specified type
107
+
108
+ Parameters
109
+ ----------
110
+ embeddings : ArrayLike
111
+ Embeddings array
112
+ dtype : numpy dtype or torch dtype or None, default None
113
+ The desired dtype of the output array, None to skip conversion
114
+ unit_interval : bool or "force", default False
115
+ Whether to validate or force the embeddings to unit interval
116
+
117
+ Returns
118
+ -------
119
+ Converted embeddings array
120
+
121
+ Raises
122
+ ------
123
+ ValueError
124
+ If the embeddings array is not 2D
125
+ ValueError
126
+ If the embeddings array is not unit interval [0, 1]
127
+ """
128
+ if isinstance(dtype, torch.dtype):
129
+ arr = torch.as_tensor(embeddings, dtype=dtype)
130
+ else:
131
+ arr = (
132
+ embeddings.detach().cpu().numpy().astype(dtype)
133
+ if isinstance(embeddings, torch.Tensor)
134
+ else np.asarray(embeddings, dtype=dtype)
135
+ )
136
+
137
+ if arr.ndim != 2:
138
+ raise ValueError(f"Expected a 2D array, but got a {arr.ndim}D array.")
139
+
140
+ if unit_interval:
141
+ arr_min, arr_max = arr.min(), arr.max()
142
+ if arr_min < 0 or arr_max > 1:
143
+ if unit_interval == "force":
144
+ warnings.warn("Embeddings are not unit interval [0, 1]. Forcing to unit interval.")
145
+ arr = (arr - arr_min) / (arr_max - arr_min)
146
+ else:
147
+ raise ValueError("Embeddings must be unit interval [0, 1].")
148
+
149
+ if dtype is None:
150
+ return embeddings
151
+ else:
152
+ return arr
153
+
154
+
155
+ def flatten(array: ArrayLike) -> NDArray[Any]:
156
+ """
157
+ Flattens input array from (N, ... ) to (N, -1) where all samples N have all data in their last dimension
158
+
159
+ Parameters
160
+ ----------
161
+ X : NDArray, shape - (N, ... )
162
+ Input array
163
+
164
+ Returns
165
+ -------
166
+ NDArray, shape - (N, -1)
167
+ """
168
+ nparr = as_numpy(array)
169
+ return nparr.reshape((nparr.shape[0], -1))