dataeval 0.86.9__py3-none-any.whl → 0.88.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. dataeval/__init__.py +1 -1
  2. dataeval/_log.py +1 -1
  3. dataeval/_version.py +2 -2
  4. dataeval/config.py +4 -19
  5. dataeval/data/_embeddings.py +78 -35
  6. dataeval/data/_images.py +41 -8
  7. dataeval/data/_metadata.py +348 -66
  8. dataeval/data/_selection.py +22 -7
  9. dataeval/data/_split.py +3 -2
  10. dataeval/data/selections/_classbalance.py +4 -3
  11. dataeval/data/selections/_classfilter.py +9 -8
  12. dataeval/data/selections/_indices.py +4 -3
  13. dataeval/data/selections/_prioritize.py +249 -29
  14. dataeval/data/selections/_reverse.py +1 -1
  15. dataeval/data/selections/_shuffle.py +5 -4
  16. dataeval/detectors/drift/_base.py +2 -1
  17. dataeval/detectors/drift/_mmd.py +2 -1
  18. dataeval/detectors/drift/_nml/_base.py +1 -1
  19. dataeval/detectors/drift/_nml/_chunk.py +2 -1
  20. dataeval/detectors/drift/_nml/_result.py +3 -2
  21. dataeval/detectors/drift/_nml/_thresholds.py +6 -5
  22. dataeval/detectors/drift/_uncertainty.py +2 -1
  23. dataeval/detectors/linters/duplicates.py +2 -1
  24. dataeval/detectors/linters/outliers.py +4 -3
  25. dataeval/detectors/ood/__init__.py +2 -1
  26. dataeval/detectors/ood/ae.py +1 -1
  27. dataeval/detectors/ood/base.py +39 -1
  28. dataeval/detectors/ood/knn.py +95 -0
  29. dataeval/detectors/ood/mixin.py +2 -1
  30. dataeval/metadata/_utils.py +1 -1
  31. dataeval/metrics/bias/_balance.py +29 -22
  32. dataeval/metrics/bias/_diversity.py +4 -4
  33. dataeval/metrics/bias/_parity.py +2 -2
  34. dataeval/metrics/stats/_base.py +3 -29
  35. dataeval/metrics/stats/_boxratiostats.py +2 -1
  36. dataeval/metrics/stats/_dimensionstats.py +2 -1
  37. dataeval/metrics/stats/_hashstats.py +21 -3
  38. dataeval/metrics/stats/_pixelstats.py +2 -1
  39. dataeval/metrics/stats/_visualstats.py +2 -1
  40. dataeval/outputs/_base.py +2 -3
  41. dataeval/outputs/_bias.py +2 -1
  42. dataeval/outputs/_estimators.py +1 -1
  43. dataeval/outputs/_linters.py +3 -3
  44. dataeval/outputs/_stats.py +3 -3
  45. dataeval/outputs/_utils.py +1 -1
  46. dataeval/outputs/_workflows.py +49 -31
  47. dataeval/typing.py +23 -9
  48. dataeval/utils/__init__.py +2 -2
  49. dataeval/utils/_array.py +3 -2
  50. dataeval/utils/_bin.py +9 -7
  51. dataeval/utils/_method.py +2 -3
  52. dataeval/utils/_multiprocessing.py +34 -0
  53. dataeval/utils/_plot.py +2 -1
  54. dataeval/utils/data/__init__.py +6 -5
  55. dataeval/utils/data/{metadata.py → _merge.py} +3 -2
  56. dataeval/utils/data/_validate.py +170 -0
  57. dataeval/utils/data/collate.py +2 -1
  58. dataeval/utils/torch/_internal.py +2 -1
  59. dataeval/utils/torch/trainer.py +1 -1
  60. dataeval/workflows/sufficiency.py +13 -9
  61. {dataeval-0.86.9.dist-info → dataeval-0.88.0.dist-info}/METADATA +8 -21
  62. dataeval-0.88.0.dist-info/RECORD +105 -0
  63. dataeval/utils/data/_dataset.py +0 -246
  64. dataeval/utils/datasets/__init__.py +0 -21
  65. dataeval/utils/datasets/_antiuav.py +0 -189
  66. dataeval/utils/datasets/_base.py +0 -266
  67. dataeval/utils/datasets/_cifar10.py +0 -201
  68. dataeval/utils/datasets/_fileio.py +0 -142
  69. dataeval/utils/datasets/_milco.py +0 -197
  70. dataeval/utils/datasets/_mixin.py +0 -54
  71. dataeval/utils/datasets/_mnist.py +0 -202
  72. dataeval/utils/datasets/_seadrone.py +0 -512
  73. dataeval/utils/datasets/_ships.py +0 -144
  74. dataeval/utils/datasets/_types.py +0 -48
  75. dataeval/utils/datasets/_voc.py +0 -583
  76. dataeval-0.86.9.dist-info/RECORD +0 -115
  77. {dataeval-0.86.9.dist-info → dataeval-0.88.0.dist-info}/WHEEL +0 -0
  78. /dataeval-0.86.9.dist-info/licenses/LICENSE.txt → /dataeval-0.88.0.dist-info/licenses/LICENSE +0 -0
@@ -2,11 +2,11 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
+ from collections.abc import Mapping, Sequence
5
6
  from dataclasses import dataclass
6
- from typing import Generic, Mapping, Sequence, TypeVar, Union
7
+ from typing import Generic, TypeAlias, TypeVar
7
8
 
8
9
  import pandas as pd
9
- from typing_extensions import TypeAlias
10
10
 
11
11
  from dataeval.outputs._base import Output
12
12
  from dataeval.outputs._stats import DimensionStatsOutput, LabelStatsOutput, PixelStatsOutput, VisualStatsOutput
@@ -16,7 +16,7 @@ DatasetDuplicateGroupMap: TypeAlias = Mapping[int, DuplicateGroup]
16
16
  TIndexCollection = TypeVar("TIndexCollection", DuplicateGroup, DatasetDuplicateGroupMap)
17
17
 
18
18
  IndexIssueMap: TypeAlias = Mapping[int, Mapping[str, float]]
19
- OutlierStatsOutput: TypeAlias = Union[DimensionStatsOutput, PixelStatsOutput, VisualStatsOutput]
19
+ OutlierStatsOutput: TypeAlias = DimensionStatsOutput | PixelStatsOutput | VisualStatsOutput
20
20
  TIndexIssueMap = TypeVar("TIndexIssueMap", IndexIssueMap, Sequence[IndexIssueMap])
21
21
 
22
22
 
@@ -2,13 +2,13 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
+ from collections.abc import Iterable, Mapping, Sequence
5
6
  from dataclasses import dataclass
6
- from typing import TYPE_CHECKING, Any, Iterable, Mapping, NamedTuple, Optional, Sequence, Union
7
+ from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias
7
8
 
8
9
  import numpy as np
9
10
  import polars as pl
10
11
  from numpy.typing import NDArray
11
- from typing_extensions import TypeAlias
12
12
 
13
13
  from dataeval.outputs._base import Output
14
14
  from dataeval.utils._plot import channel_histogram_plot, histogram_plot
@@ -16,7 +16,7 @@ from dataeval.utils._plot import channel_histogram_plot, histogram_plot
16
16
  if TYPE_CHECKING:
17
17
  from matplotlib.figure import Figure
18
18
 
19
- OptionalRange: TypeAlias = Optional[Union[int, Iterable[int]]]
19
+ OptionalRange: TypeAlias = int | Iterable[int] | None
20
20
 
21
21
  SOURCE_INDEX = "source_index"
22
22
  OBJECT_COUNT = "object_count"
@@ -2,8 +2,8 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
+ from collections.abc import Sequence
5
6
  from dataclasses import dataclass
6
- from typing import Sequence
7
7
 
8
8
  import numpy as np
9
9
  from numpy.typing import NDArray
@@ -4,8 +4,9 @@ __all__ = []
4
4
 
5
5
  import contextlib
6
6
  import warnings
7
- from dataclasses import dataclass
8
- from typing import Any, Iterable, Mapping, Sequence, cast
7
+ from collections.abc import Iterable, Mapping, MutableMapping, Sequence
8
+ from dataclasses import dataclass, field
9
+ from typing import Any, cast
9
10
 
10
11
  import numpy as np
11
12
  from numpy.typing import NDArray
@@ -76,10 +77,8 @@ def plot_measure(
76
77
  ax.set_title(f"{name} Sufficiency")
77
78
  ax.set_ylabel(f"{name}")
78
79
  ax.set_xlabel("Steps")
79
-
80
80
  # Plot measure over each step
81
81
  ax.scatter(steps, measure, label=f"Model Results ({name})", s=15, c="black")
82
-
83
82
  # Plot extrapolation
84
83
  ax.plot(
85
84
  projection,
@@ -92,7 +91,7 @@ def plot_measure(
92
91
  return fig
93
92
 
94
93
 
95
- def f_inv_out(y_i: NDArray[Any], x: NDArray[Any]) -> NDArray[np.uint64]:
94
+ def f_inv_out(y_i: NDArray[Any], x: NDArray[Any]) -> NDArray[np.int64]:
96
95
  """
97
96
  Inverse function for f_out()
98
97
 
@@ -106,13 +105,27 @@ def f_inv_out(y_i: NDArray[Any], x: NDArray[Any]) -> NDArray[np.uint64]:
106
105
  Returns
107
106
  -------
108
107
  NDArray
109
- Array of sample sizes
108
+ Sample size or -1 if unachievable for each data point
110
109
  """
111
- n_i = ((y_i - x[2]) / x[0]) ** (-1 / x[1])
112
- return np.asarray(n_i, dtype=np.uint64)
110
+ with np.errstate(invalid="ignore"):
111
+ n_i = ((y_i - x[2]) / x[0]) ** (-1 / x[1])
112
+ unachievable_targets = np.isnan(n_i) | np.any(n_i > np.iinfo(np.int64).max)
113
+ if any(unachievable_targets):
114
+ with np.printoptions(suppress=True):
115
+ warnings.warn(
116
+ "Number of samples could not be determined for target(s): "
117
+ f"""{
118
+ np.array2string(
119
+ 1 - y_i[unachievable_targets], separator=", ", formatter={"float": lambda x: f"{x}"}
120
+ )
121
+ }""",
122
+ UserWarning,
123
+ )
124
+ n_i[unachievable_targets] = -1
125
+ return np.asarray(n_i, dtype=np.int64)
113
126
 
114
127
 
115
- def inv_project_steps(params: NDArray[Any], targets: NDArray[Any]) -> NDArray[np.uint64]:
128
+ def inv_project_steps(params: NDArray[Any], targets: NDArray[Any]) -> NDArray[np.int64]:
116
129
  """Inverse function for project_steps()
117
130
 
118
131
  Parameters
@@ -125,14 +138,13 @@ def inv_project_steps(params: NDArray[Any], targets: NDArray[Any]) -> NDArray[np
125
138
  Returns
126
139
  -------
127
140
  NDArray
128
- Array of sample sizes, or 0 if overflow
141
+ Samples required or -1 if unachievable for each target value
129
142
  """
130
143
  steps = f_inv_out(1 - np.array(targets), params)
131
- steps[np.isnan(steps)] = 0
132
144
  return np.ceil(steps)
133
145
 
134
146
 
135
- def calc_params(p_i: NDArray[Any], n_i: NDArray[Any], niter: int) -> NDArray[Any]:
147
+ def calc_params(p_i: NDArray[Any], n_i: NDArray[Any], niter: int) -> NDArray[np.float64]:
136
148
  """
137
149
  Retrieves the inverse power curve coefficients for the line of best fit.
138
150
  Global minimization is done via basin hopping. More info on this algorithm
@@ -178,11 +190,11 @@ def calc_params(p_i: NDArray[Any], n_i: NDArray[Any], niter: int) -> NDArray[Any
178
190
 
179
191
 
180
192
  def get_curve_params(
181
- measures: Mapping[str, NDArray[Any]], ranges: NDArray[Any], niter: int
182
- ) -> Mapping[str, NDArray[Any]]:
193
+ averaged_measures: MutableMapping[str, NDArray[Any]], ranges: NDArray[Any], niter: int
194
+ ) -> Mapping[str, NDArray[np.float64]]:
183
195
  """Calculates and aggregates parameters for both single and multi-class metrics"""
184
196
  output = {}
185
- for name, measure in measures.items():
197
+ for name, measure in averaged_measures.items():
186
198
  measure = cast(np.ndarray, measure)
187
199
  if measure.ndim > 1:
188
200
  result = []
@@ -203,19 +215,25 @@ class SufficiencyOutput(Output):
203
215
  ----------
204
216
  steps : NDArray
205
217
  Array of sample sizes
206
- measures : Dict[str, NDArray]
207
- Average of values observed for each sample size step for each measure
218
+ measures : dict[str, NDArray]
219
+ 3D array [runs, substep, classes] of values for all runs observed for each sample size step for each measure
220
+ averaged_measures : dict[str, NDArray]
221
+ Average of values for all runs observed for each sample size step for each measure
208
222
  n_iter : int, default 1000
209
223
  Number of iterations to perform in the basin-hopping curve-fit process
210
224
  """
211
225
 
212
226
  steps: NDArray[np.uint32]
213
- measures: Mapping[str, NDArray[np.float64]]
227
+ measures: Mapping[str, NDArray[Any]]
228
+ averaged_measures: MutableMapping[str, NDArray[Any]] = field(default_factory=lambda: {})
214
229
  n_iter: int = 1000
215
230
 
216
231
  def __post_init__(self) -> None:
232
+ if len(self.averaged_measures) == 0:
233
+ for metric, values in self.measures.items():
234
+ self.averaged_measures[metric] = np.asarray(np.mean(values, axis=0)).T
217
235
  c = len(self.steps)
218
- for m, v in self.measures.items():
236
+ for m, v in self.averaged_measures.items():
219
237
  c_v = v.shape[1] if v.ndim > 1 else len(v)
220
238
  if c != c_v:
221
239
  raise ValueError(f"{m} does not contain the expected number ({c}) of data points.")
@@ -226,7 +244,7 @@ class SufficiencyOutput(Output):
226
244
  if self._params is None:
227
245
  self._params = {}
228
246
  if self.n_iter not in self._params:
229
- self._params[self.n_iter] = get_curve_params(self.measures, self.steps, self.n_iter)
247
+ self._params[self.n_iter] = get_curve_params(self.averaged_measures, self.steps, self.n_iter)
230
248
  return self._params[self.n_iter]
231
249
 
232
250
  @set_metadata
@@ -259,16 +277,16 @@ class SufficiencyOutput(Output):
259
277
  raise ValueError("'projection' must consist of numerical values")
260
278
 
261
279
  output = {}
262
- for name, measures in self.measures.items():
263
- if measures.ndim > 1:
280
+ for name, averaged_measures in self.averaged_measures.items():
281
+ if averaged_measures.ndim > 1:
264
282
  result = []
265
- for i in range(len(measures)):
283
+ for i in range(len(averaged_measures)):
266
284
  projected = project_steps(self.params[name][i], projection)
267
285
  result.append(projected)
268
286
  output[name] = np.array(result)
269
287
  else:
270
288
  output[name] = project_steps(self.params[name], projection)
271
- proj = SufficiencyOutput(projection, output, self.n_iter)
289
+ proj = SufficiencyOutput(projection, measures=self.measures, averaged_measures=output, n_iter=self.n_iter)
272
290
  proj._params = self._params
273
291
  return proj
274
292
 
@@ -304,11 +322,11 @@ class SufficiencyOutput(Output):
304
322
  plots = []
305
323
 
306
324
  # Create a plot for each measure on one figure
307
- for name, measures in self.measures.items():
308
- if measures.ndim > 1:
309
- if class_names is not None and len(measures) != len(class_names):
325
+ for name, averaged_measures in self.averaged_measures.items():
326
+ if averaged_measures.ndim > 1:
327
+ if class_names is not None and len(averaged_measures) != len(class_names):
310
328
  raise IndexError("Class name count does not align with measures")
311
- for i, measure in enumerate(measures):
329
+ for i, measure in enumerate(averaged_measures):
312
330
  class_name = str(i) if class_names is None else class_names[i]
313
331
  fig = plot_measure(
314
332
  f"{name}_{class_name}",
@@ -320,7 +338,7 @@ class SufficiencyOutput(Output):
320
338
  plots.append(fig)
321
339
 
322
340
  else:
323
- fig = plot_measure(name, self.steps, measures, self.params[name], extrapolated)
341
+ fig = plot_measure(name, self.steps, averaged_measures, self.params[name], extrapolated)
324
342
  plots.append(fig)
325
343
 
326
344
  return plots
@@ -350,10 +368,10 @@ class SufficiencyOutput(Output):
350
368
 
351
369
  for name, target in targets.items():
352
370
  tarray = as_numpy(target)
353
- if name not in self.measures:
371
+ if name not in self.averaged_measures:
354
372
  continue
355
373
 
356
- measure = self.measures[name]
374
+ measure = self.averaged_measures[name]
357
375
  if measure.ndim > 1:
358
376
  projection[name] = np.zeros((len(measure), len(tarray)))
359
377
  for i in range(len(measure)):
dataeval/typing.py CHANGED
@@ -3,11 +3,12 @@ Common type protocols used for interoperability with DataEval.
3
3
  """
4
4
 
5
5
  __all__ = [
6
+ "AnnotatedDataset",
6
7
  "Array",
7
8
  "ArrayLike",
8
9
  "Dataset",
9
- "AnnotatedDataset",
10
10
  "DatasetMetadata",
11
+ "DeviceLike",
11
12
  "ImageClassificationDatum",
12
13
  "ImageClassificationDataset",
13
14
  "ObjectDetectionTarget",
@@ -20,18 +21,21 @@ __all__ = [
20
21
  ]
21
22
 
22
23
 
23
- import sys
24
- from typing import Any, Generic, Iterator, Mapping, Protocol, TypedDict, TypeVar, runtime_checkable
24
+ from collections.abc import Iterator, Mapping
25
+ from typing import (
26
+ Any,
27
+ Generic,
28
+ Protocol,
29
+ TypeAlias,
30
+ TypedDict,
31
+ TypeVar,
32
+ runtime_checkable,
33
+ )
25
34
 
26
35
  import numpy.typing
36
+ import torch
27
37
  from typing_extensions import NotRequired, ReadOnly, Required
28
38
 
29
- if sys.version_info >= (3, 10):
30
- from typing import TypeAlias
31
- else:
32
- from typing_extensions import TypeAlias
33
-
34
-
35
39
  ArrayLike: TypeAlias = numpy.typing.ArrayLike
36
40
  """
37
41
  Type alias for a `Union` representing objects that can be coerced into an array.
@@ -42,6 +46,16 @@ See Also
42
46
  """
43
47
 
44
48
 
49
+ DeviceLike: TypeAlias = int | str | tuple[str, int] | torch.device
50
+ """
51
+ Type alias for a `Union` representing types that specify a torch.device.
52
+
53
+ See Also
54
+ --------
55
+ `torch.device <https://pytorch.org/docs/stable/tensor_attributes.html#torch.device>`_
56
+ """
57
+
58
+
45
59
  @runtime_checkable
46
60
  class Array(Protocol):
47
61
  """
@@ -4,6 +4,6 @@ in setting up data and architectures that are guaranteed to work with applicable
4
4
  DataEval metrics.
5
5
  """
6
6
 
7
- __all__ = ["data", "datasets", "torch"]
7
+ __all__ = ["data", "torch"]
8
8
 
9
- from . import data, datasets, torch
9
+ from . import data, torch
dataeval/utils/_array.py CHANGED
@@ -4,9 +4,10 @@ __all__ = []
4
4
 
5
5
  import logging
6
6
  import warnings
7
+ from collections.abc import Iterable, Iterator
7
8
  from importlib import import_module
8
9
  from types import ModuleType
9
- from typing import Any, Iterable, Iterator, Literal, TypeVar, overload
10
+ from typing import Any, Literal, TypeVar, overload
10
11
 
11
12
  import numpy as np
12
13
  import torch
@@ -79,7 +80,7 @@ def rescale_array(array: NDArray[_np_dtype]) -> NDArray[_np_dtype]: ...
79
80
  def rescale_array(array: torch.Tensor) -> torch.Tensor: ...
80
81
  def rescale_array(array: Array | NDArray[_np_dtype] | torch.Tensor) -> Array | NDArray[_np_dtype] | torch.Tensor:
81
82
  """Rescale an array to the range [0, 1]"""
82
- if isinstance(array, (np.ndarray, torch.Tensor)):
83
+ if isinstance(array, np.ndarray | torch.Tensor):
83
84
  arr_min = array.min()
84
85
  arr_max = array.max()
85
86
  return (array - arr_min) / (arr_max - arr_min)
dataeval/utils/_bin.py CHANGED
@@ -3,7 +3,8 @@ from __future__ import annotations
3
3
  __all__ = []
4
4
 
5
5
  import warnings
6
- from typing import Any, Iterable
6
+ from collections.abc import Iterable
7
+ from typing import Any
7
8
 
8
9
  import numpy as np
9
10
  from numpy.typing import NDArray
@@ -94,7 +95,7 @@ def bin_data(data: NDArray[Any], bin_method: str) -> NDArray[np.int_]:
94
95
  return np.digitize(data, bin_edges)
95
96
 
96
97
 
97
- def is_continuous(data: NDArray[np.number[Any]], image_indices: NDArray[np.number[Any]]) -> bool:
98
+ def is_continuous(data: NDArray[np.number[Any]], image_indices: NDArray[np.number[Any]] | None = None) -> bool:
98
99
  """
99
100
  Determines whether the data is continuous or discrete using the Wasserstein distance.
100
101
 
@@ -113,11 +114,12 @@ def is_continuous(data: NDArray[np.number[Any]], image_indices: NDArray[np.numbe
113
114
  measured from a uniform distribution is greater or less than 0.054, respectively.
114
115
  """
115
116
  # Check if the metadata is image specific
116
- _, data_indices_unsorted = np.unique(data, return_index=True)
117
- if data_indices_unsorted.size == image_indices.size:
118
- data_indices = np.sort(data_indices_unsorted)
119
- if (data_indices == image_indices).all():
120
- data = data[data_indices]
117
+ if image_indices is not None:
118
+ _, data_indices_unsorted = np.unique(data, return_index=True)
119
+ if data_indices_unsorted.size == image_indices.size:
120
+ data_indices = np.sort(data_indices_unsorted)
121
+ if (data_indices == image_indices).all():
122
+ data = data[data_indices]
121
123
 
122
124
  n_examples = len(data)
123
125
 
dataeval/utils/_method.py CHANGED
@@ -1,8 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import Callable, TypeVar
4
-
5
- from typing_extensions import ParamSpec
3
+ from collections.abc import Callable
4
+ from typing import ParamSpec, TypeVar
6
5
 
7
6
  P = ParamSpec("P")
8
7
  R = TypeVar("R")
@@ -0,0 +1,34 @@
1
+ from __future__ import annotations
2
+
3
+ __all__ = []
4
+
5
+ from collections.abc import Callable, Iterable, Iterator
6
+ from multiprocessing import Pool
7
+ from typing import Any, TypeVar
8
+
9
+ _S = TypeVar("_S")
10
+ _T = TypeVar("_T")
11
+
12
+
13
+ class PoolWrapper:
14
+ """
15
+ Wraps `multiprocessing.Pool` to allow for easy switching between
16
+ multiprocessing and single-threaded execution.
17
+
18
+ This helps with debugging and profiling, as well as usage with Jupyter notebooks
19
+ in VS Code, which does not support subprocess debugging.
20
+ """
21
+
22
+ def __init__(self, processes: int | None) -> None:
23
+ self.pool = Pool(processes) if processes is None or processes > 1 else None
24
+
25
+ def imap(self, func: Callable[[_S], _T], iterable: Iterable[_S]) -> Iterator[_T]:
26
+ return map(func, iterable) if self.pool is None else self.pool.imap(func, iterable)
27
+
28
+ def __enter__(self, *args: Any, **kwargs: Any) -> PoolWrapper:
29
+ return self
30
+
31
+ def __exit__(self, *args: Any) -> None:
32
+ if self.pool is not None:
33
+ self.pool.close()
34
+ self.pool.join()
dataeval/utils/_plot.py CHANGED
@@ -4,7 +4,8 @@ __all__ = []
4
4
 
5
5
  import contextlib
6
6
  import math
7
- from typing import Any, Mapping, Sequence
7
+ from collections.abc import Mapping, Sequence
8
+ from typing import Any
8
9
 
9
10
  import numpy as np
10
11
 
@@ -1,11 +1,12 @@
1
1
  """Provides access to common Computer Vision datasets."""
2
2
 
3
- from dataeval.utils.data import collate, metadata
4
- from dataeval.utils.data._dataset import to_image_classification_dataset, to_object_detection_dataset
3
+ from dataeval.utils.data import collate
4
+ from dataeval.utils.data._merge import flatten, merge
5
+ from dataeval.utils.data._validate import validate_dataset
5
6
 
6
7
  __all__ = [
7
8
  "collate",
8
- "metadata",
9
- "to_image_classification_dataset",
10
- "to_object_detection_dataset",
9
+ "flatten",
10
+ "merge",
11
+ "validate_dataset",
11
12
  ]
@@ -7,8 +7,9 @@ from __future__ import annotations
7
7
  __all__ = ["merge", "flatten"]
8
8
 
9
9
  import warnings
10
+ from collections.abc import Iterable, Mapping, Sequence
10
11
  from enum import Enum
11
- from typing import Any, Iterable, Literal, Mapping, Sequence, overload
12
+ from typing import Any, Literal, overload
12
13
 
13
14
  import numpy as np
14
15
  from numpy.typing import NDArray
@@ -132,7 +133,7 @@ def _flatten_dict_inner(
132
133
  if isinstance(v, dict):
133
134
  fd, size = _flatten_dict_inner(v, dropped, new_keys, size=size, nested=nested)
134
135
  items.update(fd)
135
- elif isinstance(v, (list, tuple)):
136
+ elif isinstance(v, list | tuple):
136
137
  if nested:
137
138
  dropped.setdefault(parent_keys + (k,), set()).add(DropReason.NESTED_LIST)
138
139
  elif size is not None and size != len(v):
@@ -0,0 +1,170 @@
1
+ from __future__ import annotations
2
+
3
+ __all__ = []
4
+
5
+ from collections.abc import Sequence, Sized
6
+ from typing import Any, Literal
7
+
8
+ from dataeval.config import EPSILON
9
+ from dataeval.typing import Array, ObjectDetectionTarget
10
+ from dataeval.utils._array import as_numpy
11
+
12
+
13
+ class ValidationMessages:
14
+ DATASET_SIZED = "Dataset must be sized."
15
+ DATASET_INDEXABLE = "Dataset must be indexable."
16
+ DATASET_NONEMPTY = "Dataset must be non-empty."
17
+ DATASET_METADATA = "Dataset must have a 'metadata' attribute."
18
+ DATASET_METADATA_TYPE = "Dataset metadata must be a dictionary."
19
+ DATASET_METADATA_FORMAT = "Dataset metadata must contain an 'id' key."
20
+ DATUM_TYPE = "Dataset datum must be a tuple."
21
+ DATUM_FORMAT = "Dataset datum must contain 3 elements: image, target, metadata."
22
+ DATUM_IMAGE_TYPE = "Images must be 3-dimensional arrays."
23
+ DATUM_IMAGE_FORMAT = "Images must be in CHW format."
24
+ DATUM_TARGET_IC_TYPE = "ImageClassificationDataset targets must be one-dimensional arrays."
25
+ DATUM_TARGET_IC_FORMAT = "ImageClassificationDataset targets must be one-hot encoded or pseudo-probabilities."
26
+ DATUM_TARGET_OD_TYPE = "ObjectDetectionDataset targets must be have 'boxes', 'labels' and 'scores'."
27
+ DATUM_TARGET_OD_LABELS_TYPE = "ObjectDetectionTarget labels must be one-dimensional (N,) arrays."
28
+ DATUM_TARGET_OD_BOXES_TYPE = "ObjectDetectionTarget boxes must be two-dimensional (N, 4) arrays in xxyy format."
29
+ DATUM_TARGET_OD_SCORES_TYPE = "ObjectDetectionTarget scores must be one (N,) or two-dimensional (N, M) arrays."
30
+ DATUM_TARGET_TYPE = "Target is not a valid ImageClassification or ObjectDetection target type."
31
+ DATUM_METADATA_TYPE = "Datum metadata must be a dictionary."
32
+ DATUM_METADATA_FORMAT = "Datum metadata must contain an 'id' key."
33
+
34
+
35
+ def _validate_dataset_type(dataset: Any) -> list[str]:
36
+ issues = []
37
+ is_sized = isinstance(dataset, Sized)
38
+ is_indexable = hasattr(dataset, "__getitem__")
39
+ if not is_sized:
40
+ issues.append(ValidationMessages.DATASET_SIZED)
41
+ if not is_indexable:
42
+ issues.append(ValidationMessages.DATASET_INDEXABLE)
43
+ if is_sized and len(dataset) == 0:
44
+ issues.append(ValidationMessages.DATASET_NONEMPTY)
45
+ return issues
46
+
47
+
48
+ def _validate_dataset_metadata(dataset: Any) -> list[str]:
49
+ issues = []
50
+ if not hasattr(dataset, "metadata"):
51
+ issues.append(ValidationMessages.DATASET_METADATA)
52
+ metadata = getattr(dataset, "metadata", None)
53
+ if not isinstance(metadata, dict):
54
+ issues.append(ValidationMessages.DATASET_METADATA_TYPE)
55
+ if not isinstance(metadata, dict) or "id" not in metadata:
56
+ issues.append(ValidationMessages.DATASET_METADATA_FORMAT)
57
+ return issues
58
+
59
+
60
+ def _validate_datum_type(datum: Any) -> list[str]:
61
+ issues = []
62
+ if not isinstance(datum, tuple):
63
+ issues.append(ValidationMessages.DATUM_TYPE)
64
+ if datum is None or isinstance(datum, Sized) and len(datum) != 3:
65
+ issues.append(ValidationMessages.DATUM_FORMAT)
66
+ return issues
67
+
68
+
69
+ def _validate_datum_image(image: Any) -> list[str]:
70
+ issues = []
71
+ if not isinstance(image, Array) or len(image.shape) != 3:
72
+ issues.append(ValidationMessages.DATUM_IMAGE_TYPE)
73
+ if (
74
+ not isinstance(image, Array)
75
+ or len(image.shape) == 3
76
+ and (image.shape[0] > image.shape[1] or image.shape[0] > image.shape[2])
77
+ ):
78
+ issues.append(ValidationMessages.DATUM_IMAGE_FORMAT)
79
+ return issues
80
+
81
+
82
+ def _validate_datum_target_ic(target: Any) -> list[str]:
83
+ issues = []
84
+ if not isinstance(target, Array) or len(target.shape) != 1:
85
+ issues.append(ValidationMessages.DATUM_TARGET_IC_TYPE)
86
+ if target is None or sum(target) > 1 + EPSILON or sum(target) < 1 - EPSILON:
87
+ issues.append(ValidationMessages.DATUM_TARGET_IC_FORMAT)
88
+ return issues
89
+
90
+
91
+ def _validate_datum_target_od(target: Any) -> list[str]:
92
+ issues = []
93
+ if not isinstance(target, ObjectDetectionTarget):
94
+ issues.append(ValidationMessages.DATUM_TARGET_OD_TYPE)
95
+ od_target: ObjectDetectionTarget | None = target if isinstance(target, ObjectDetectionTarget) else None
96
+ if od_target is None or len(as_numpy(od_target.labels).shape) != 1:
97
+ issues.append(ValidationMessages.DATUM_TARGET_OD_LABELS_TYPE)
98
+ if (
99
+ od_target is None
100
+ or len(as_numpy(od_target.boxes).shape) != 2
101
+ or (len(as_numpy(od_target.boxes).shape) == 2 and as_numpy(od_target.boxes).shape[1] != 4)
102
+ ):
103
+ issues.append(ValidationMessages.DATUM_TARGET_OD_BOXES_TYPE)
104
+ if od_target is None or len(as_numpy(od_target.scores).shape) not in (1, 2):
105
+ issues.append(ValidationMessages.DATUM_TARGET_OD_SCORES_TYPE)
106
+ return issues
107
+
108
+
109
+ def _detect_target_type(target: Any) -> Literal["ic", "od", "auto"]:
110
+ if isinstance(target, Array):
111
+ return "ic"
112
+ if isinstance(target, ObjectDetectionTarget):
113
+ return "od"
114
+ return "auto"
115
+
116
+
117
+ def _validate_datum_target(target: Any, target_type: Literal["ic", "od", "auto"]) -> list[str]:
118
+ issues = []
119
+ target_type = _detect_target_type(target) if target_type == "auto" else target_type
120
+ if target_type == "ic":
121
+ issues.extend(_validate_datum_target_ic(target))
122
+ elif target_type == "od":
123
+ issues.extend(_validate_datum_target_od(target))
124
+ else:
125
+ issues.append(ValidationMessages.DATUM_TARGET_TYPE)
126
+ return issues
127
+
128
+
129
+ def _validate_datum_metadata(metadata: Any) -> list[str]:
130
+ issues = []
131
+ if metadata is None or not isinstance(metadata, dict):
132
+ issues.append(ValidationMessages.DATUM_METADATA_TYPE)
133
+ if metadata is None or isinstance(metadata, dict) and "id" not in metadata:
134
+ issues.append(ValidationMessages.DATUM_METADATA_FORMAT)
135
+ return issues
136
+
137
+
138
+ def validate_dataset(dataset: Any, dataset_type: Literal["ic", "od", "auto"] = "auto") -> None:
139
+ """
140
+ Validate a dataset for compliance with MAITE protocol.
141
+
142
+ Parameters
143
+ ----------
144
+ dataset: Any
145
+ Dataset to validate.
146
+ dataset_type: "ic", "od", or "auto", default "auto"
147
+ Dataset type, if known.
148
+
149
+ Raises
150
+ ------
151
+ ValueError
152
+ Raises exception if dataset is invalid with a list of validation issues.
153
+ """
154
+ issues = []
155
+ issues.extend(_validate_dataset_type(dataset))
156
+ datum = None if issues else dataset[0] # type: ignore
157
+ issues.extend(_validate_dataset_metadata(dataset))
158
+ issues.extend(_validate_datum_type(datum))
159
+
160
+ is_seq = isinstance(datum, Sequence)
161
+ datum_len = len(datum) if is_seq else 0
162
+ image = datum[0] if is_seq and datum_len > 0 else None
163
+ target = datum[1] if is_seq and datum_len > 1 else None
164
+ metadata = datum[2] if is_seq and datum_len > 2 else None
165
+ issues.extend(_validate_datum_image(image))
166
+ issues.extend(_validate_datum_target(target, dataset_type))
167
+ issues.extend(_validate_datum_metadata(metadata))
168
+
169
+ if issues:
170
+ raise ValueError("Dataset validation issues found:\n - " + "\n - ".join(issues))
@@ -6,7 +6,8 @@ from __future__ import annotations
6
6
 
7
7
  __all__ = ["list_collate_fn", "numpy_collate_fn", "torch_collate_fn"]
8
8
 
9
- from typing import Any, Iterable, Sequence, TypeVar
9
+ from collections.abc import Iterable, Sequence
10
+ from typing import Any, TypeVar
10
11
 
11
12
  import numpy as np
12
13
  import torch
@@ -2,7 +2,8 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
- from typing import Any, Callable
5
+ from collections.abc import Callable
6
+ from typing import Any
6
7
 
7
8
  import numpy as np
8
9
  import torch