dataeval 0.84.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. dataeval/__init__.py +1 -1
  2. dataeval/data/__init__.py +19 -0
  3. dataeval/data/_embeddings.py +345 -0
  4. dataeval/{utils/data → data}/_images.py +2 -2
  5. dataeval/{utils/data → data}/_metadata.py +8 -7
  6. dataeval/{utils/data → data}/_selection.py +22 -9
  7. dataeval/{utils/data → data}/_split.py +1 -1
  8. dataeval/data/selections/__init__.py +19 -0
  9. dataeval/data/selections/_classbalance.py +37 -0
  10. dataeval/data/selections/_classfilter.py +109 -0
  11. dataeval/{utils/data → data}/selections/_indices.py +1 -1
  12. dataeval/{utils/data → data}/selections/_limit.py +1 -1
  13. dataeval/{utils/data → data}/selections/_prioritize.py +3 -3
  14. dataeval/{utils/data → data}/selections/_reverse.py +1 -1
  15. dataeval/{utils/data → data}/selections/_shuffle.py +3 -3
  16. dataeval/detectors/drift/__init__.py +2 -2
  17. dataeval/detectors/drift/_base.py +55 -203
  18. dataeval/detectors/drift/_cvm.py +19 -30
  19. dataeval/detectors/drift/_ks.py +18 -30
  20. dataeval/detectors/drift/_mmd.py +189 -53
  21. dataeval/detectors/drift/_uncertainty.py +52 -56
  22. dataeval/detectors/drift/updates.py +13 -12
  23. dataeval/detectors/linters/duplicates.py +6 -4
  24. dataeval/detectors/linters/outliers.py +3 -3
  25. dataeval/detectors/ood/ae.py +1 -1
  26. dataeval/metadata/_distance.py +1 -1
  27. dataeval/metadata/_ood.py +4 -4
  28. dataeval/metrics/bias/_balance.py +1 -1
  29. dataeval/metrics/bias/_diversity.py +1 -1
  30. dataeval/metrics/bias/_parity.py +1 -1
  31. dataeval/metrics/stats/_base.py +7 -7
  32. dataeval/metrics/stats/_dimensionstats.py +2 -2
  33. dataeval/metrics/stats/_hashstats.py +2 -2
  34. dataeval/metrics/stats/_imagestats.py +4 -4
  35. dataeval/metrics/stats/_labelstats.py +2 -2
  36. dataeval/metrics/stats/_pixelstats.py +2 -2
  37. dataeval/metrics/stats/_visualstats.py +2 -2
  38. dataeval/outputs/_bias.py +1 -1
  39. dataeval/typing.py +53 -19
  40. dataeval/utils/__init__.py +2 -2
  41. dataeval/utils/_array.py +18 -7
  42. dataeval/utils/data/__init__.py +5 -20
  43. dataeval/utils/data/_dataset.py +6 -4
  44. dataeval/utils/data/collate.py +2 -0
  45. dataeval/utils/datasets/__init__.py +17 -0
  46. dataeval/utils/{data/datasets → datasets}/_base.py +10 -7
  47. dataeval/utils/{data/datasets → datasets}/_cifar10.py +11 -11
  48. dataeval/utils/{data/datasets → datasets}/_milco.py +44 -16
  49. dataeval/utils/{data/datasets → datasets}/_mnist.py +11 -7
  50. dataeval/utils/{data/datasets → datasets}/_ships.py +10 -6
  51. dataeval/utils/{data/datasets → datasets}/_voc.py +43 -22
  52. dataeval/utils/torch/_internal.py +12 -35
  53. {dataeval-0.84.0.dist-info → dataeval-1.0.0.dist-info}/METADATA +2 -3
  54. dataeval-1.0.0.dist-info/RECORD +107 -0
  55. dataeval/detectors/drift/_torch.py +0 -222
  56. dataeval/utils/data/_embeddings.py +0 -186
  57. dataeval/utils/data/datasets/__init__.py +0 -17
  58. dataeval/utils/data/selections/__init__.py +0 -17
  59. dataeval/utils/data/selections/_classfilter.py +0 -59
  60. dataeval-0.84.0.dist-info/RECORD +0 -106
  61. /dataeval/{utils/data → data}/_targets.py +0 -0
  62. /dataeval/utils/{metadata.py → data/metadata.py} +0 -0
  63. /dataeval/utils/{data/datasets → datasets}/_fileio.py +0 -0
  64. /dataeval/utils/{data/datasets → datasets}/_mixin.py +0 -0
  65. /dataeval/utils/{data/datasets → datasets}/_types.py +0 -0
  66. {dataeval-0.84.0.dist-info → dataeval-1.0.0.dist-info}/LICENSE.txt +0 -0
  67. {dataeval-0.84.0.dist-info → dataeval-1.0.0.dist-info}/WHEEL +0 -0
@@ -8,11 +8,12 @@ from __future__ import annotations
8
8
  __all__ = ["LastSeenUpdate", "ReservoirSamplingUpdate"]
9
9
 
10
10
  from abc import ABC, abstractmethod
11
- from typing import Any
12
11
 
13
12
  import numpy as np
14
13
  from numpy.typing import NDArray
15
14
 
15
+ from dataeval.utils._array import flatten
16
+
16
17
 
17
18
  class BaseUpdateStrategy(ABC):
18
19
  """
@@ -28,8 +29,7 @@ class BaseUpdateStrategy(ABC):
28
29
  self.n = n
29
30
 
30
31
  @abstractmethod
31
- def __call__(self, x_ref: NDArray[Any], x: NDArray[Any], count: int) -> NDArray[Any]:
32
- """Abstract implementation of update strategy"""
32
+ def __call__(self, x_ref: NDArray[np.float32], x_new: NDArray[np.float32], count: int) -> NDArray[np.float32]: ...
33
33
 
34
34
 
35
35
  class LastSeenUpdate(BaseUpdateStrategy):
@@ -42,9 +42,8 @@ class LastSeenUpdate(BaseUpdateStrategy):
42
42
  Update with last n instances seen by the detector.
43
43
  """
44
44
 
45
- def __call__(self, x_ref: NDArray[Any], x: NDArray[Any], count: int) -> NDArray[Any]:
46
- x_updated = np.concatenate([x_ref, x], axis=0)
47
- return x_updated[-self.n :]
45
+ def __call__(self, x_ref: NDArray[np.float32], x_new: NDArray[np.float32], count: int) -> NDArray[np.float32]:
46
+ return np.concatenate([x_ref, flatten(x_new)], axis=0)[-self.n :]
48
47
 
49
48
 
50
49
  class ReservoirSamplingUpdate(BaseUpdateStrategy):
@@ -57,16 +56,18 @@ class ReservoirSamplingUpdate(BaseUpdateStrategy):
57
56
  Update with last n instances seen by the detector.
58
57
  """
59
58
 
60
- def __call__(self, x_ref: NDArray[Any], x: NDArray[Any], count: int) -> NDArray[Any]:
61
- if x.shape[0] + count <= self.n:
62
- return np.concatenate([x_ref, x], axis=0)
59
+ def __call__(self, x_ref: NDArray[np.float32], x_new: NDArray[np.float32], count: int) -> NDArray[np.float32]:
60
+ if x_new.shape[0] + count <= self.n:
61
+ return np.concatenate([x_ref, flatten(x_new)], axis=0)
63
62
 
64
63
  n_ref = x_ref.shape[0]
65
- output_size = min(self.n, n_ref + x.shape[0])
66
- shape = (output_size,) + x.shape[1:]
64
+ output_size = min(self.n, n_ref + x_new.shape[0])
65
+ shape = (output_size,) + x_new.shape[1:]
66
+
67
67
  x_reservoir = np.zeros(shape, dtype=x_ref.dtype)
68
68
  x_reservoir[:n_ref] = x_ref
69
- for item in x:
69
+
70
+ for item in x_new:
70
71
  count += 1
71
72
  if n_ref < self.n:
72
73
  x_reservoir[n_ref, :] = item
@@ -4,13 +4,13 @@ __all__ = []
4
4
 
5
5
  from typing import Any, Sequence, overload
6
6
 
7
+ from dataeval.data._images import Images
7
8
  from dataeval.metrics.stats import hashstats
8
9
  from dataeval.metrics.stats._base import combine_stats, get_dataset_step_from_idx
9
10
  from dataeval.outputs import DuplicatesOutput, HashStatsOutput
10
11
  from dataeval.outputs._base import set_metadata
11
12
  from dataeval.outputs._linters import DatasetDuplicateGroupMap, DuplicateGroup
12
- from dataeval.typing import Array, Dataset
13
- from dataeval.utils.data._images import Images
13
+ from dataeval.typing import ArrayLike, Dataset
14
14
 
15
15
 
16
16
  class Duplicates:
@@ -110,13 +110,15 @@ class Duplicates:
110
110
  return DuplicatesOutput(**duplicates)
111
111
 
112
112
  @set_metadata(state=["only_exact"])
113
- def evaluate(self, data: Dataset[Array] | Dataset[tuple[Array, Any, Any]]) -> DuplicatesOutput[DuplicateGroup]:
113
+ def evaluate(
114
+ self, data: Dataset[ArrayLike] | Dataset[tuple[ArrayLike, Any, Any]]
115
+ ) -> DuplicatesOutput[DuplicateGroup]:
114
116
  """
115
117
  Returns duplicate image indices for both exact matches and near matches
116
118
 
117
119
  Parameters
118
120
  ----------
119
- data : Iterable[Array], shape - (N, C, H, W) | Dataset[tuple[Array, Any, Any]]
121
+ data : Iterable[ArrayLike], shape - (N, C, H, W) | Dataset[tuple[ArrayLike, Any, Any]]
120
122
  A dataset of images in an Array format or the output(s) from a hashstats analysis
121
123
 
122
124
  Returns
@@ -7,14 +7,14 @@ from typing import Any, Literal, Sequence, overload
7
7
  import numpy as np
8
8
  from numpy.typing import NDArray
9
9
 
10
+ from dataeval.data._images import Images
10
11
  from dataeval.metrics.stats._base import combine_stats, get_dataset_step_from_idx
11
12
  from dataeval.metrics.stats._imagestats import imagestats
12
13
  from dataeval.outputs import DimensionStatsOutput, ImageStatsOutput, OutliersOutput, PixelStatsOutput, VisualStatsOutput
13
14
  from dataeval.outputs._base import set_metadata
14
15
  from dataeval.outputs._linters import IndexIssueMap, OutlierStatsOutput
15
16
  from dataeval.outputs._stats import BOX_COUNT, SOURCE_INDEX
16
- from dataeval.typing import Array, Dataset
17
- from dataeval.utils.data._images import Images
17
+ from dataeval.typing import ArrayLike, Dataset
18
18
 
19
19
 
20
20
  def _get_outlier_mask(
@@ -197,7 +197,7 @@ class Outliers:
197
197
  return OutliersOutput(output_list)
198
198
 
199
199
  @set_metadata(state=["use_dimension", "use_pixel", "use_visual", "outlier_method", "outlier_threshold"])
200
- def evaluate(self, data: Dataset[Array] | Dataset[tuple[Array, Any, Any]]) -> OutliersOutput[IndexIssueMap]:
200
+ def evaluate(self, data: Dataset[ArrayLike] | Dataset[tuple[ArrayLike, Any, Any]]) -> OutliersOutput[IndexIssueMap]:
201
201
  """
202
202
  Returns indices of Outliers with the issues identified for each
203
203
 
@@ -81,7 +81,7 @@ class OOD_AE(OODBase):
81
81
 
82
82
  def _score(self, X: NDArray[np.float32], batch_size: int = int(1e10)) -> OODScoreOutput:
83
83
  # reconstruct instances
84
- X_recon = predict_batch(X, self.model, batch_size=batch_size)
84
+ X_recon = predict_batch(X, self.model, batch_size=batch_size).detach().cpu().numpy()
85
85
 
86
86
  # compute feature and instance level scores
87
87
  fscore = np.power(X - X_recon, 2)
@@ -9,11 +9,11 @@ import numpy as np
9
9
  from scipy.stats import iqr, ks_2samp
10
10
  from scipy.stats import wasserstein_distance as emd
11
11
 
12
+ from dataeval.data import Metadata
12
13
  from dataeval.metadata._utils import _compare_keys, _validate_factors_and_data
13
14
  from dataeval.outputs import MetadataDistanceOutput, MetadataDistanceValues
14
15
  from dataeval.outputs._base import set_metadata
15
16
  from dataeval.typing import ArrayLike
16
- from dataeval.utils.data import Metadata
17
17
 
18
18
 
19
19
  class KSType(NamedTuple):
dataeval/metadata/_ood.py CHANGED
@@ -9,10 +9,10 @@ from numpy.typing import NDArray
9
9
  from sklearn.feature_selection import mutual_info_classif
10
10
 
11
11
  from dataeval.config import get_seed
12
+ from dataeval.data import Metadata
12
13
  from dataeval.metadata._utils import _compare_keys, _validate_factors_and_data
13
14
  from dataeval.outputs import MostDeviatedFactorsOutput, OODOutput, OODPredictorOutput
14
15
  from dataeval.outputs._base import set_metadata
15
- from dataeval.utils.data import Metadata
16
16
 
17
17
 
18
18
  def _combine_discrete_continuous(metadata: Metadata) -> tuple[list[str], NDArray[np.float64]]:
@@ -201,7 +201,7 @@ def find_most_deviated_factors(
201
201
  MostDeviatedFactorsOutput([])
202
202
  """
203
203
 
204
- ood_mask: NDArray[np.bool] = ood.is_ood
204
+ ood_mask: NDArray[np.bool_] = ood.is_ood
205
205
 
206
206
  # No metadata correlated with out of distribution data
207
207
  if not any(ood_mask):
@@ -303,7 +303,7 @@ def find_ood_predictors(
303
303
  OODPredictorOutput({})
304
304
  """
305
305
 
306
- ood_mask: NDArray[np.bool] = ood.is_ood
306
+ ood_mask: NDArray[np.bool_] = ood.is_ood
307
307
 
308
308
  discrete_features_count = len(metadata.discrete_factor_names)
309
309
  factors, data = _combine_discrete_continuous(metadata) # (F, ), (S, F) => F = Fd + Fc
@@ -320,7 +320,7 @@ def find_ood_predictors(
320
320
  # Calculate mean, std of each factor over all samples
321
321
  scaled_data = (data - np.mean(data, axis=0)) / np.std(data, axis=0, ddof=1) # (S, F)
322
322
 
323
- discrete_features = np.zeros_like(factors, dtype=np.bool)
323
+ discrete_features = np.zeros_like(factors, dtype=np.bool_)
324
324
  discrete_features[:discrete_features_count] = True
325
325
 
326
326
  mutual_info_values = (
@@ -9,10 +9,10 @@ import scipy as sp
9
9
  from sklearn.feature_selection import mutual_info_classif, mutual_info_regression
10
10
 
11
11
  from dataeval.config import EPSILON, get_seed
12
+ from dataeval.data import Metadata
12
13
  from dataeval.outputs import BalanceOutput
13
14
  from dataeval.outputs._base import set_metadata
14
15
  from dataeval.utils._bin import get_counts
15
- from dataeval.utils.data import Metadata
16
16
 
17
17
 
18
18
  def _validate_num_neighbors(num_neighbors: int) -> int:
@@ -8,11 +8,11 @@ import numpy as np
8
8
  import scipy as sp
9
9
  from numpy.typing import NDArray
10
10
 
11
+ from dataeval.data import Metadata
11
12
  from dataeval.outputs import DiversityOutput
12
13
  from dataeval.outputs._base import set_metadata
13
14
  from dataeval.utils._bin import get_counts
14
15
  from dataeval.utils._method import get_method
15
- from dataeval.utils.data import Metadata
16
16
 
17
17
 
18
18
  def diversity_shannon(
@@ -10,11 +10,11 @@ from numpy.typing import NDArray
10
10
  from scipy.stats import chisquare
11
11
  from scipy.stats.contingency import chi2_contingency, crosstab
12
12
 
13
+ from dataeval.data import Metadata
13
14
  from dataeval.outputs import LabelParityOutput, ParityOutput
14
15
  from dataeval.outputs._base import set_metadata
15
16
  from dataeval.typing import ArrayLike
16
17
  from dataeval.utils._array import as_numpy
17
- from dataeval.utils.data import Metadata
18
18
 
19
19
 
20
20
  def normalize_expected_dist(expected_dist: NDArray[Any], observed_dist: NDArray[Any]) -> NDArray[Any]:
@@ -10,7 +10,7 @@ from copy import deepcopy
10
10
  from dataclasses import dataclass
11
11
  from functools import partial
12
12
  from multiprocessing import Pool
13
- from typing import Any, Callable, Generic, Iterable, Sequence, TypeVar, cast
13
+ from typing import Any, Callable, Generic, Iterable, Sequence, TypeVar
14
14
 
15
15
  import numpy as np
16
16
  import tqdm
@@ -19,7 +19,7 @@ from numpy.typing import NDArray
19
19
  from dataeval.config import get_max_processes
20
20
  from dataeval.outputs._stats import BaseStatsOutput, SourceIndex
21
21
  from dataeval.typing import Array, ArrayLike, Dataset, ObjectDetectionTarget
22
- from dataeval.utils._array import to_numpy
22
+ from dataeval.utils._array import as_numpy, to_numpy
23
23
  from dataeval.utils._image import normalize_image_shape, rescale
24
24
 
25
25
  DTYPE_REGEX = re.compile(r"NDArray\[np\.(.*?)\]")
@@ -138,19 +138,19 @@ def process_stats(
138
138
 
139
139
 
140
140
  def process_stats_unpack(
141
- args: tuple[int, Array, list[BoundingBox] | None],
141
+ args: tuple[int, ArrayLike, list[BoundingBox] | None],
142
142
  per_channel: bool,
143
143
  stats_processor_cls: Iterable[type[StatsProcessor[TStatsOutput]]],
144
144
  ) -> StatsProcessorOutput:
145
145
  return process_stats(*args, per_channel=per_channel, stats_processor_cls=stats_processor_cls)
146
146
 
147
147
 
148
- def _enumerate(dataset: Dataset[Array] | Dataset[tuple[Array, Any, Any]], per_box: bool):
148
+ def _enumerate(dataset: Dataset[ArrayLike] | Dataset[tuple[ArrayLike, Any, Any]], per_box: bool):
149
149
  for i in range(len(dataset)):
150
150
  d = dataset[i]
151
151
  image = d[0] if isinstance(d, tuple) else d
152
152
  if per_box and isinstance(d, tuple) and isinstance(d[1], ObjectDetectionTarget):
153
- boxes = cast(Array, d[1].boxes)
153
+ boxes = d[1].boxes if isinstance(d[1].boxes, Array) else as_numpy(d[1].boxes)
154
154
  target = [BoundingBox(float(box[i]) for i in range(4)) for box in boxes]
155
155
  else:
156
156
  target = None
@@ -159,7 +159,7 @@ def _enumerate(dataset: Dataset[Array] | Dataset[tuple[Array, Any, Any]], per_bo
159
159
 
160
160
 
161
161
  def run_stats(
162
- dataset: Dataset[Array] | Dataset[tuple[Array, Any, Any]],
162
+ dataset: Dataset[ArrayLike] | Dataset[tuple[ArrayLike, Any, Any]],
163
163
  per_box: bool,
164
164
  per_channel: bool,
165
165
  stats_processor_cls: Iterable[type[StatsProcessor[TStatsOutput]]],
@@ -173,7 +173,7 @@ def run_stats(
173
173
 
174
174
  Parameters
175
175
  ----------
176
- data : Dataset[Array] | Dataset[tuple[Array, Any, Any]]
176
+ data : Dataset[ArrayLike] | Dataset[tuple[ArrayLike, Any, Any]]
177
177
  A dataset of images and targets to compute statistics on.
178
178
  per_box : bool
179
179
  A flag which determines if the statistics should be evaluated on a per-box basis or not.
@@ -9,7 +9,7 @@ import numpy as np
9
9
  from dataeval.metrics.stats._base import StatsProcessor, run_stats
10
10
  from dataeval.outputs import DimensionStatsOutput
11
11
  from dataeval.outputs._base import set_metadata
12
- from dataeval.typing import Array, Dataset
12
+ from dataeval.typing import ArrayLike, Dataset
13
13
  from dataeval.utils._image import get_bitdepth
14
14
 
15
15
 
@@ -34,7 +34,7 @@ class DimensionStatsProcessor(StatsProcessor[DimensionStatsOutput]):
34
34
 
35
35
  @set_metadata
36
36
  def dimensionstats(
37
- dataset: Dataset[Array] | Dataset[tuple[Array, Any, Any]],
37
+ dataset: Dataset[ArrayLike] | Dataset[tuple[ArrayLike, Any, Any]],
38
38
  *,
39
39
  per_box: bool = False,
40
40
  ) -> DimensionStatsOutput:
@@ -14,7 +14,7 @@ from scipy.fftpack import dct
14
14
  from dataeval.metrics.stats._base import StatsProcessor, run_stats
15
15
  from dataeval.outputs import HashStatsOutput
16
16
  from dataeval.outputs._base import set_metadata
17
- from dataeval.typing import Array, ArrayLike, Dataset
17
+ from dataeval.typing import ArrayLike, Dataset
18
18
  from dataeval.utils._array import as_numpy
19
19
  from dataeval.utils._image import normalize_image_shape, rescale
20
20
 
@@ -105,7 +105,7 @@ class HashStatsProcessor(StatsProcessor[HashStatsOutput]):
105
105
 
106
106
  @set_metadata
107
107
  def hashstats(
108
- dataset: Dataset[Array] | Dataset[tuple[Array, Any, Any]],
108
+ dataset: Dataset[ArrayLike] | Dataset[tuple[ArrayLike, Any, Any]],
109
109
  *,
110
110
  per_box: bool = False,
111
111
  ) -> HashStatsOutput:
@@ -10,12 +10,12 @@ from dataeval.metrics.stats._pixelstats import PixelStatsProcessor
10
10
  from dataeval.metrics.stats._visualstats import VisualStatsProcessor
11
11
  from dataeval.outputs import ChannelStatsOutput, ImageStatsOutput
12
12
  from dataeval.outputs._base import set_metadata
13
- from dataeval.typing import Array, Dataset
13
+ from dataeval.typing import ArrayLike, Dataset
14
14
 
15
15
 
16
16
  @overload
17
17
  def imagestats(
18
- dataset: Dataset[Array] | Dataset[tuple[Array, Any, Any]],
18
+ dataset: Dataset[ArrayLike] | Dataset[tuple[ArrayLike, Any, Any]],
19
19
  *,
20
20
  per_box: bool = False,
21
21
  per_channel: Literal[True],
@@ -24,7 +24,7 @@ def imagestats(
24
24
 
25
25
  @overload
26
26
  def imagestats(
27
- dataset: Dataset[Array] | Dataset[tuple[Array, Any, Any]],
27
+ dataset: Dataset[ArrayLike] | Dataset[tuple[ArrayLike, Any, Any]],
28
28
  *,
29
29
  per_box: bool = False,
30
30
  per_channel: Literal[False] = False,
@@ -33,7 +33,7 @@ def imagestats(
33
33
 
34
34
  @set_metadata
35
35
  def imagestats(
36
- dataset: Dataset[Array] | Dataset[tuple[Array, Any, Any]],
36
+ dataset: Dataset[ArrayLike] | Dataset[tuple[ArrayLike, Any, Any]],
37
37
  *,
38
38
  per_box: bool = False,
39
39
  per_channel: bool = False,
@@ -5,10 +5,10 @@ __all__ = []
5
5
  from collections import Counter, defaultdict
6
6
  from typing import Any, Mapping, TypeVar
7
7
 
8
+ from dataeval.data._metadata import Metadata
8
9
  from dataeval.outputs import LabelStatsOutput
9
10
  from dataeval.outputs._base import set_metadata
10
11
  from dataeval.typing import AnnotatedDataset
11
- from dataeval.utils.data._metadata import Metadata
12
12
 
13
13
  TValue = TypeVar("TValue")
14
14
 
@@ -38,7 +38,7 @@ def labelstats(dataset: Metadata | AnnotatedDataset[Any]) -> LabelStatsOutput:
38
38
  --------
39
39
  Calculate basic :term:`statistics<Statistics>` on labels for a dataset.
40
40
 
41
- >>> from dataeval.utils.data import Metadata
41
+ >>> from dataeval.data import Metadata
42
42
  >>> stats = labelstats(Metadata(dataset))
43
43
  >>> print(stats.to_table())
44
44
  Class Count: 5
@@ -10,7 +10,7 @@ from scipy.stats import entropy, kurtosis, skew
10
10
  from dataeval.metrics.stats._base import StatsProcessor, run_stats
11
11
  from dataeval.outputs import PixelStatsOutput
12
12
  from dataeval.outputs._base import set_metadata
13
- from dataeval.typing import Array, Dataset
13
+ from dataeval.typing import ArrayLike, Dataset
14
14
 
15
15
 
16
16
  class PixelStatsProcessor(StatsProcessor[PixelStatsOutput]):
@@ -37,7 +37,7 @@ class PixelStatsProcessor(StatsProcessor[PixelStatsOutput]):
37
37
 
38
38
  @set_metadata
39
39
  def pixelstats(
40
- dataset: Dataset[Array] | Dataset[tuple[Array, Any, Any]],
40
+ dataset: Dataset[ArrayLike] | Dataset[tuple[ArrayLike, Any, Any]],
41
41
  *,
42
42
  per_box: bool = False,
43
43
  per_channel: bool = False,
@@ -9,7 +9,7 @@ import numpy as np
9
9
  from dataeval.metrics.stats._base import StatsProcessor, run_stats
10
10
  from dataeval.outputs import VisualStatsOutput
11
11
  from dataeval.outputs._base import set_metadata
12
- from dataeval.typing import Array, Dataset
12
+ from dataeval.typing import ArrayLike, Dataset
13
13
  from dataeval.utils._image import edge_filter
14
14
 
15
15
  QUARTILES = (0, 25, 50, 75, 100)
@@ -44,7 +44,7 @@ class VisualStatsProcessor(StatsProcessor[VisualStatsOutput]):
44
44
 
45
45
  @set_metadata
46
46
  def visualstats(
47
- dataset: Dataset[Array] | Dataset[tuple[Array, Any, Any]],
47
+ dataset: Dataset[ArrayLike] | Dataset[tuple[ArrayLike, Any, Any]],
48
48
  *,
49
49
  per_box: bool = False,
50
50
  per_channel: bool = False,
dataeval/outputs/_bias.py CHANGED
@@ -13,11 +13,11 @@ with contextlib.suppress(ImportError):
13
13
  import pandas as pd
14
14
  from matplotlib.figure import Figure
15
15
 
16
+ from dataeval.data._images import Images
16
17
  from dataeval.outputs._base import Output
17
18
  from dataeval.typing import ArrayLike, Dataset
18
19
  from dataeval.utils._array import as_numpy, channels_first_to_last
19
20
  from dataeval.utils._plot import heatmap
20
- from dataeval.utils.data._images import Images
21
21
 
22
22
  TData = TypeVar("TData", np.float64, NDArray[np.float64])
23
23
 
dataeval/typing.py CHANGED
@@ -21,8 +21,9 @@ __all__ = [
21
21
 
22
22
 
23
23
  import sys
24
- from typing import Any, Generic, Iterator, Protocol, Sequence, TypedDict, TypeVar, Union, runtime_checkable
24
+ from typing import Any, Generic, Iterator, Protocol, TypedDict, TypeVar, runtime_checkable
25
25
 
26
+ import numpy.typing
26
27
  from typing_extensions import NotRequired, ReadOnly, Required
27
28
 
28
29
  if sys.version_info >= (3, 10):
@@ -31,6 +32,16 @@ else:
31
32
  from typing_extensions import TypeAlias
32
33
 
33
34
 
35
+ ArrayLike: TypeAlias = numpy.typing.ArrayLike
36
+ """
37
+ Type alias for a `Union` representing objects that can be coerced into an array.
38
+
39
+ See Also
40
+ --------
41
+ `NumPy ArrayLike <https://numpy.org/doc/stable/reference/typing.html#numpy.typing.ArrayLike>`_
42
+ """
43
+
44
+
34
45
  @runtime_checkable
35
46
  class Array(Protocol):
36
47
  """
@@ -67,16 +78,8 @@ class Array(Protocol):
67
78
  def __len__(self) -> int: ...
68
79
 
69
80
 
70
- T = TypeVar("T")
81
+ _T = TypeVar("_T")
71
82
  _T_co = TypeVar("_T_co", covariant=True)
72
- _ScalarType = Union[int, float, bool, str]
73
- ArrayLike: TypeAlias = Union[Sequence[_ScalarType], Sequence[Sequence[_ScalarType]], Sequence[Array], Array]
74
- """
75
- Type alias for array-like objects used for interoperability with DataEval.
76
-
77
- This includes native Python sequences, as well as objects that conform to
78
- the :class:`Array` protocol.
79
- """
80
83
 
81
84
 
82
85
  class DatasetMetadata(TypedDict, total=False):
@@ -95,6 +98,22 @@ class DatasetMetadata(TypedDict, total=False):
95
98
  index2label: NotRequired[ReadOnly[dict[int, str]]]
96
99
 
97
100
 
101
+ class ModelMetadata(TypedDict, total=False):
102
+ """
103
+ Model metadata required for all `AnnotatedModel` classes.
104
+
105
+ Attributes
106
+ ----------
107
+ id : Required[str]
108
+ A unique identifier for the model
109
+ index2label : NotRequired[dict[int, str]]
110
+ A lookup table converting label value to class name
111
+ """
112
+
113
+ id: Required[ReadOnly[str]]
114
+ index2label: NotRequired[ReadOnly[dict[int, str]]]
115
+
116
+
98
117
  @runtime_checkable
99
118
  class Dataset(Generic[_T_co], Protocol):
100
119
  """
@@ -140,12 +159,12 @@ class AnnotatedDataset(Dataset[_T_co], Generic[_T_co], Protocol):
140
159
  # ========== IMAGE CLASSIFICATION DATASETS ==========
141
160
 
142
161
 
143
- ImageClassificationDatum: TypeAlias = tuple[Array, Array, dict[str, Any]]
162
+ ImageClassificationDatum: TypeAlias = tuple[ArrayLike, ArrayLike, dict[str, Any]]
144
163
  """
145
164
  Type alias for an image classification datum tuple.
146
165
 
147
- - :class:`Array` of shape (C, H, W) - Image data in channel, height, width format.
148
- - :class:`Array` of shape (N,) - Class label as one-hot encoded ground-truth or prediction confidences.
166
+ - :class:`ArrayLike` of shape (C, H, W) - Image data in channel, height, width format.
167
+ - :class:`ArrayLike` of shape (N,) - Class label as one-hot encoded ground-truth or prediction confidences.
149
168
  - dict[str, Any] - Datum level metadata.
150
169
  """
151
170
 
@@ -180,11 +199,11 @@ class ObjectDetectionTarget(Protocol):
180
199
  def scores(self) -> ArrayLike: ...
181
200
 
182
201
 
183
- ObjectDetectionDatum: TypeAlias = tuple[Array, ObjectDetectionTarget, dict[str, Any]]
202
+ ObjectDetectionDatum: TypeAlias = tuple[ArrayLike, ObjectDetectionTarget, dict[str, Any]]
184
203
  """
185
204
  Type alias for an object detection datum tuple.
186
205
 
187
- - :class:`Array` of shape (C, H, W) - Image data in channel, height, width format.
206
+ - :class:`ArrayLike` of shape (C, H, W) - Image data in channel, height, width format.
188
207
  - :class:`ObjectDetectionTarget` - Object detection target information for the image.
189
208
  - dict[str, Any] - Datum level metadata.
190
209
  """
@@ -221,11 +240,11 @@ class SegmentationTarget(Protocol):
221
240
  def scores(self) -> ArrayLike: ...
222
241
 
223
242
 
224
- SegmentationDatum: TypeAlias = tuple[Array, SegmentationTarget, dict[str, Any]]
243
+ SegmentationDatum: TypeAlias = tuple[ArrayLike, SegmentationTarget, dict[str, Any]]
225
244
  """
226
245
  Type alias for an image classification datum tuple.
227
246
 
228
- - :class:`Array` of shape (C, H, W) - Image data in channel, height, width format.
247
+ - :class:`ArrayLike` of shape (C, H, W) - Image data in channel, height, width format.
229
248
  - :class:`SegmentationTarget` - Segmentation target information for the image.
230
249
  - dict[str, Any] - Datum level metadata.
231
250
  """
@@ -235,9 +254,24 @@ SegmentationDataset: TypeAlias = AnnotatedDataset[SegmentationDatum]
235
254
  Type alias for an :class:`AnnotatedDataset` of :class:`SegmentationDatum` elements.
236
255
  """
237
256
 
257
+ # ========== MODEL ==========
258
+
259
+
260
+ @runtime_checkable
261
+ class AnnotatedModel(Protocol):
262
+ """
263
+ Protocol for an annotated model.
264
+ """
265
+
266
+ @property
267
+ def metadata(self) -> ModelMetadata: ...
268
+
269
+
270
+ # ========== TRANSFORM ==========
271
+
238
272
 
239
273
  @runtime_checkable
240
- class Transform(Generic[T], Protocol):
274
+ class Transform(Generic[_T], Protocol):
241
275
  """
242
276
  Protocol defining a transform function.
243
277
 
@@ -262,4 +296,4 @@ class Transform(Generic[T], Protocol):
262
296
  array([0.004, 0.008, 0.012])
263
297
  """
264
298
 
265
- def __call__(self, data: T, /) -> T: ...
299
+ def __call__(self, data: _T, /) -> _T: ...
@@ -4,6 +4,6 @@ in setting up data and architectures that are guaranteed to work with applicable
4
4
  DataEval metrics.
5
5
  """
6
6
 
7
- __all__ = ["data", "metadata", "torch"]
7
+ __all__ = ["data", "datasets", "torch"]
8
8
 
9
- from . import data, metadata, torch
9
+ from . import data, datasets, torch
dataeval/utils/_array.py CHANGED
@@ -92,7 +92,7 @@ def ensure_embeddings(
92
92
  @overload
93
93
  def ensure_embeddings(
94
94
  embeddings: T,
95
- dtype: None,
95
+ dtype: None = None,
96
96
  unit_interval: Literal[True, False, "force"] = False,
97
97
  ) -> T: ...
98
98
 
@@ -152,21 +152,32 @@ def ensure_embeddings(
152
152
  return arr
153
153
 
154
154
 
155
- def flatten(array: ArrayLike) -> NDArray[Any]:
155
+ @overload
156
+ def flatten(array: torch.Tensor) -> torch.Tensor: ...
157
+ @overload
158
+ def flatten(array: ArrayLike) -> NDArray[Any]: ...
159
+
160
+
161
+ def flatten(array: ArrayLike) -> NDArray[Any] | torch.Tensor:
156
162
  """
157
163
  Flattens input array from (N, ... ) to (N, -1) where all samples N have all data in their last dimension
158
164
 
159
165
  Parameters
160
166
  ----------
161
- X : NDArray, shape - (N, ... )
167
+ array : ArrayLike
162
168
  Input array
163
169
 
164
170
  Returns
165
171
  -------
166
- NDArray, shape - (N, -1)
172
+ np.ndarray or torch.Tensor, shape: (N, -1)
167
173
  """
168
- nparr = as_numpy(array)
169
- return nparr.reshape((nparr.shape[0], -1))
174
+ if isinstance(array, np.ndarray):
175
+ nparr = as_numpy(array)
176
+ return nparr.reshape((nparr.shape[0], -1))
177
+ elif isinstance(array, torch.Tensor):
178
+ return torch.flatten(array, start_dim=1)
179
+ else:
180
+ raise TypeError(f"Unsupported array type {type(array)}.")
170
181
 
171
182
 
172
183
  _TArray = TypeVar("_TArray", bound=Array)
@@ -191,4 +202,4 @@ def channels_first_to_last(array: _TArray) -> _TArray:
191
202
  elif isinstance(array, torch.Tensor):
192
203
  return torch.permute(array, (1, 2, 0))
193
204
  else:
194
- raise TypeError(f"Unsupported array type {type(array)} for conversion.")
205
+ raise TypeError(f"Unsupported array type {type(array)}.")
@@ -1,26 +1,11 @@
1
- """Provides utility functions for interacting with Computer Vision datasets."""
1
+ """Provides access to common Computer Vision datasets."""
2
+
3
+ from dataeval.utils.data import collate, metadata
4
+ from dataeval.utils.data._dataset import to_image_classification_dataset, to_object_detection_dataset
2
5
 
3
6
  __all__ = [
4
7
  "collate",
5
- "datasets",
6
- "Embeddings",
7
- "Images",
8
- "Metadata",
9
- "Select",
10
- "SplitDatasetOutput",
11
- "Targets",
12
- "split_dataset",
8
+ "metadata",
13
9
  "to_image_classification_dataset",
14
10
  "to_object_detection_dataset",
15
11
  ]
16
-
17
- from dataeval.outputs._utils import SplitDatasetOutput
18
- from dataeval.utils.data._dataset import to_image_classification_dataset, to_object_detection_dataset
19
- from dataeval.utils.data._embeddings import Embeddings
20
- from dataeval.utils.data._images import Images
21
- from dataeval.utils.data._metadata import Metadata
22
- from dataeval.utils.data._selection import Select
23
- from dataeval.utils.data._split import split_dataset
24
- from dataeval.utils.data._targets import Targets
25
-
26
- from . import collate, datasets