dataeval 0.81.0__py3-none-any.whl → 0.82.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. dataeval/__init__.py +1 -1
  2. dataeval/config.py +68 -11
  3. dataeval/detectors/drift/__init__.py +2 -2
  4. dataeval/detectors/drift/_base.py +8 -64
  5. dataeval/detectors/drift/_mmd.py +12 -38
  6. dataeval/detectors/drift/_torch.py +7 -7
  7. dataeval/detectors/drift/_uncertainty.py +6 -5
  8. dataeval/detectors/drift/updates.py +20 -3
  9. dataeval/detectors/linters/__init__.py +3 -2
  10. dataeval/detectors/linters/duplicates.py +14 -46
  11. dataeval/detectors/linters/outliers.py +25 -159
  12. dataeval/detectors/ood/__init__.py +1 -1
  13. dataeval/detectors/ood/ae.py +6 -5
  14. dataeval/detectors/ood/base.py +2 -2
  15. dataeval/detectors/ood/metadata_ood_mi.py +4 -6
  16. dataeval/detectors/ood/mixin.py +3 -4
  17. dataeval/detectors/ood/vae.py +3 -2
  18. dataeval/metadata/__init__.py +2 -1
  19. dataeval/metadata/_distance.py +134 -0
  20. dataeval/metadata/_ood.py +30 -49
  21. dataeval/metadata/_utils.py +44 -0
  22. dataeval/metrics/bias/__init__.py +5 -4
  23. dataeval/metrics/bias/_balance.py +17 -149
  24. dataeval/metrics/bias/_coverage.py +4 -106
  25. dataeval/metrics/bias/_diversity.py +12 -107
  26. dataeval/metrics/bias/_parity.py +7 -71
  27. dataeval/metrics/estimators/__init__.py +5 -4
  28. dataeval/metrics/estimators/_ber.py +2 -20
  29. dataeval/metrics/estimators/_clusterer.py +1 -61
  30. dataeval/metrics/estimators/_divergence.py +2 -19
  31. dataeval/metrics/estimators/_uap.py +2 -16
  32. dataeval/metrics/stats/__init__.py +15 -12
  33. dataeval/metrics/stats/_base.py +41 -128
  34. dataeval/metrics/stats/_boxratiostats.py +13 -13
  35. dataeval/metrics/stats/_dimensionstats.py +17 -58
  36. dataeval/metrics/stats/_hashstats.py +19 -35
  37. dataeval/metrics/stats/_imagestats.py +94 -0
  38. dataeval/metrics/stats/_labelstats.py +42 -121
  39. dataeval/metrics/stats/_pixelstats.py +19 -51
  40. dataeval/metrics/stats/_visualstats.py +19 -51
  41. dataeval/outputs/__init__.py +57 -0
  42. dataeval/outputs/_base.py +182 -0
  43. dataeval/outputs/_bias.py +381 -0
  44. dataeval/outputs/_drift.py +83 -0
  45. dataeval/outputs/_estimators.py +114 -0
  46. dataeval/outputs/_linters.py +186 -0
  47. dataeval/outputs/_metadata.py +54 -0
  48. dataeval/{detectors/ood/output.py → outputs/_ood.py} +22 -22
  49. dataeval/outputs/_stats.py +393 -0
  50. dataeval/outputs/_utils.py +44 -0
  51. dataeval/outputs/_workflows.py +364 -0
  52. dataeval/typing.py +187 -7
  53. dataeval/utils/_method.py +1 -5
  54. dataeval/utils/_plot.py +2 -2
  55. dataeval/utils/data/__init__.py +5 -1
  56. dataeval/utils/data/_dataset.py +217 -0
  57. dataeval/utils/data/_embeddings.py +12 -14
  58. dataeval/utils/data/_images.py +30 -27
  59. dataeval/utils/data/_metadata.py +28 -11
  60. dataeval/utils/data/_selection.py +25 -22
  61. dataeval/utils/data/_split.py +5 -29
  62. dataeval/utils/data/_targets.py +14 -2
  63. dataeval/utils/data/datasets/_base.py +5 -5
  64. dataeval/utils/data/datasets/_cifar10.py +1 -1
  65. dataeval/utils/data/datasets/_milco.py +1 -1
  66. dataeval/utils/data/datasets/_mnist.py +1 -1
  67. dataeval/utils/data/datasets/_ships.py +1 -1
  68. dataeval/utils/data/{_types.py → datasets/_types.py} +10 -16
  69. dataeval/utils/data/datasets/_voc.py +1 -1
  70. dataeval/utils/data/selections/_classfilter.py +4 -5
  71. dataeval/utils/data/selections/_indices.py +2 -2
  72. dataeval/utils/data/selections/_limit.py +2 -2
  73. dataeval/utils/data/selections/_reverse.py +2 -2
  74. dataeval/utils/data/selections/_shuffle.py +2 -2
  75. dataeval/utils/torch/_internal.py +5 -5
  76. dataeval/utils/torch/trainer.py +8 -8
  77. dataeval/workflows/__init__.py +2 -1
  78. dataeval/workflows/sufficiency.py +6 -342
  79. {dataeval-0.81.0.dist-info → dataeval-0.82.1.dist-info}/METADATA +2 -2
  80. dataeval-0.82.1.dist-info/RECORD +105 -0
  81. dataeval/_output.py +0 -137
  82. dataeval/detectors/ood/metadata_ks_compare.py +0 -129
  83. dataeval/metrics/stats/_datasetstats.py +0 -198
  84. dataeval-0.81.0.dist-info/RECORD +0 -94
  85. {dataeval-0.81.0.dist-info → dataeval-0.82.1.dist-info}/LICENSE.txt +0 -0
  86. {dataeval-0.81.0.dist-info → dataeval-0.82.1.dist-info}/WHEEL +0 -0
@@ -0,0 +1,44 @@
1
+ from __future__ import annotations
2
+
3
+ __all__ = []
4
+
5
+ from dataclasses import dataclass
6
+
7
+ import numpy as np
8
+ from numpy.typing import NDArray
9
+
10
+ from dataeval.outputs._base import Output
11
+
12
+
13
+ @dataclass(frozen=True)
14
+ class TrainValSplit:
15
+ """
16
+ Dataclass containing train and validation indices.
17
+
18
+ Attributes
19
+ ----------
20
+ train: NDArray[np.intp]
21
+ Indices for the training set
22
+ val: NDArray[np.intp]
23
+ Indices for the validation set
24
+ """
25
+
26
+ train: NDArray[np.intp]
27
+ val: NDArray[np.intp]
28
+
29
+
30
+ @dataclass(frozen=True)
31
+ class SplitDatasetOutput(Output):
32
+ """
33
+ Output class containing test indices and a list of TrainValSplits.
34
+
35
+ Attributes
36
+ ----------
37
+ test: NDArray[np.intp]
38
+ Indices for the test set
39
+ folds: list[TrainValSplit]
40
+ List of train and validation split indices
41
+ """
42
+
43
+ test: NDArray[np.intp]
44
+ folds: list[TrainValSplit]
@@ -0,0 +1,364 @@
1
+ from __future__ import annotations
2
+
3
+ __all__ = []
4
+
5
+ import contextlib
6
+ import warnings
7
+ from dataclasses import dataclass
8
+ from typing import Any, Iterable, Mapping, Sequence, cast
9
+
10
+ import numpy as np
11
+ from numpy.typing import NDArray
12
+
13
+ with contextlib.suppress(ImportError):
14
+ from matplotlib.figure import Figure
15
+
16
+ from scipy.optimize import basinhopping
17
+
18
+ from dataeval.outputs._base import Output, set_metadata
19
+ from dataeval.typing import ArrayLike
20
+ from dataeval.utils._array import as_numpy
21
+
22
+
23
+ def f_out(n_i: NDArray[Any], x: NDArray[Any]) -> NDArray[Any]:
24
+ """
25
+ Calculates the line of best fit based on its free parameters
26
+
27
+ Parameters
28
+ ----------
29
+ n_i : NDArray
30
+ Array of sample sizes
31
+ x : NDArray
32
+ Array of inverse power curve coefficients
33
+
34
+ Returns
35
+ -------
36
+ NDArray
37
+ Data points for the line of best fit
38
+ """
39
+ return x[0] * n_i ** (-x[1]) + x[2]
40
+
41
+
42
+ def project_steps(params: NDArray[Any], projection: NDArray[Any]) -> NDArray[Any]:
43
+ """Projects the measures for each value of X
44
+
45
+ Parameters
46
+ ----------
47
+ params : NDArray
48
+ Inverse power curve coefficients used to calculate projection
49
+ projection : NDArray
50
+ Steps to extrapolate
51
+
52
+ Returns
53
+ -------
54
+ NDArray
55
+ Extrapolated measure values at each projection step
56
+
57
+ """
58
+ return 1 - f_out(projection, params)
59
+
60
+
61
+ def plot_measure(
62
+ name: str,
63
+ steps: NDArray[Any],
64
+ measure: NDArray[Any],
65
+ params: NDArray[Any],
66
+ projection: NDArray[Any],
67
+ ) -> Figure:
68
+ import matplotlib.pyplot
69
+
70
+ fig = matplotlib.pyplot.figure()
71
+ fig = cast(Figure, fig)
72
+ fig.tight_layout()
73
+
74
+ ax = fig.add_subplot(111)
75
+
76
+ ax.set_title(f"{name} Sufficiency")
77
+ ax.set_ylabel(f"{name}")
78
+ ax.set_xlabel("Steps")
79
+
80
+ # Plot measure over each step
81
+ ax.scatter(steps, measure, label=f"Model Results ({name})", s=15, c="black")
82
+
83
+ # Plot extrapolation
84
+ ax.plot(
85
+ projection,
86
+ project_steps(params, projection),
87
+ linestyle="dashed",
88
+ label=f"Potential Model Results ({name})",
89
+ )
90
+
91
+ ax.legend()
92
+ return fig
93
+
94
+
95
+ def f_inv_out(y_i: NDArray[Any], x: NDArray[Any]) -> NDArray[np.uint64]:
96
+ """
97
+ Inverse function for f_out()
98
+
99
+ Parameters
100
+ ----------
101
+ y_i : NDArray
102
+ Data points for the line of best fit
103
+ x : NDArray
104
+ Array of inverse power curve coefficients
105
+
106
+ Returns
107
+ -------
108
+ NDArray
109
+ Array of sample sizes
110
+ """
111
+ n_i = ((y_i - x[2]) / x[0]) ** (-1 / x[1])
112
+ return np.asarray(n_i, dtype=np.uint64)
113
+
114
+
115
+ def inv_project_steps(params: NDArray[Any], targets: NDArray[Any]) -> NDArray[np.uint64]:
116
+ """Inverse function for project_steps()
117
+
118
+ Parameters
119
+ ----------
120
+ params : NDArray
121
+ Inverse power curve coefficients used to calculate projection
122
+ targets : NDArray
123
+ Desired measure values
124
+
125
+ Returns
126
+ -------
127
+ NDArray
128
+ Array of sample sizes, or 0 if overflow
129
+ """
130
+ steps = f_inv_out(1 - np.array(targets), params)
131
+ steps[np.isnan(steps)] = 0
132
+ return np.ceil(steps)
133
+
134
+
135
+ def calc_params(p_i: NDArray[Any], n_i: NDArray[Any], niter: int) -> NDArray[Any]:
136
+ """
137
+ Retrieves the inverse power curve coefficients for the line of best fit.
138
+ Global minimization is done via basin hopping. More info on this algorithm
139
+ can be found here: https://arxiv.org/abs/cond-mat/9803344 .
140
+
141
+ Parameters
142
+ ----------
143
+ p_i : NDArray
144
+ Array of corresponding losses
145
+ n_i : NDArray
146
+ Array of sample sizes
147
+ niter : int
148
+ Number of iterations to perform in the basin-hopping
149
+ numerical process to curve-fit p_i
150
+
151
+ Returns
152
+ -------
153
+ NDArray
154
+ Array of parameters to recreate line of best fit
155
+ """
156
+
157
+ def is_valid(f_new, x_new, f_old, x_old):
158
+ return f_new != np.nan
159
+
160
+ def f(x):
161
+ try:
162
+ return np.sum(np.square(p_i - f_out(n_i, x)))
163
+ except RuntimeWarning:
164
+ return np.nan
165
+
166
+ with warnings.catch_warnings():
167
+ warnings.filterwarnings("error", category=RuntimeWarning)
168
+ res = basinhopping(
169
+ f,
170
+ np.array([0.5, 0.5, 0.1]),
171
+ niter=niter,
172
+ stepsize=1.0,
173
+ minimizer_kwargs={"method": "Powell"},
174
+ accept_test=is_valid,
175
+ niter_success=200,
176
+ )
177
+ return res.x
178
+
179
+
180
+ def get_curve_params(measures: dict[str, NDArray[Any]], ranges: NDArray[Any], niter: int) -> dict[str, NDArray[Any]]:
181
+ """Calculates and aggregates parameters for both single and multi-class metrics"""
182
+ output = {}
183
+ for name, measure in measures.items():
184
+ measure = cast(np.ndarray, measure)
185
+ if measure.ndim > 1:
186
+ result = []
187
+ for value in measure:
188
+ result.append(calc_params(1 - value, ranges, niter))
189
+ output[name] = np.array(result)
190
+ else:
191
+ output[name] = calc_params(1 - measure, ranges, niter)
192
+ return output
193
+
194
+
195
+ @dataclass
196
+ class SufficiencyOutput(Output):
197
+ """
198
+ Output class for :class:`.Sufficiency` workflow.
199
+
200
+ Attributes
201
+ ----------
202
+ steps : NDArray
203
+ Array of sample sizes
204
+ measures : Dict[str, NDArray]
205
+ Average of values observed for each sample size step for each measure
206
+ n_iter : int, default 1000
207
+ Number of iterations to perform in the basin-hopping curve-fit process
208
+ """
209
+
210
+ steps: NDArray[np.uint32]
211
+ measures: dict[str, NDArray[np.float64]]
212
+ n_iter: int = 1000
213
+
214
+ def __post_init__(self) -> None:
215
+ c = len(self.steps)
216
+ for m, v in self.measures.items():
217
+ c_v = v.shape[1] if v.ndim > 1 else len(v)
218
+ if c != c_v:
219
+ raise ValueError(f"{m} does not contain the expected number ({c}) of data points.")
220
+ self._params = None
221
+
222
+ @property
223
+ def params(self) -> dict[str, NDArray[Any]]:
224
+ if self._params is None:
225
+ self._params = {}
226
+ if self.n_iter not in self._params:
227
+ self._params[self.n_iter] = get_curve_params(self.measures, self.steps, self.n_iter)
228
+ return self._params[self.n_iter]
229
+
230
+ @set_metadata
231
+ def project(
232
+ self,
233
+ projection: int | Iterable[int],
234
+ ) -> SufficiencyOutput:
235
+ """
236
+ Projects the measures for each step.
237
+
238
+ Parameters
239
+ ----------
240
+ projection : int | Iterable[int]
241
+ Step or steps to project
242
+
243
+ Returns
244
+ -------
245
+ SufficiencyOutput
246
+ Dataclass containing the projected measures per projection
247
+
248
+ Raises
249
+ ------
250
+ ValueError
251
+ If the length of data points in the measures do not match
252
+ If `projection` is not numerical
253
+ """
254
+ projection = np.asarray(list(projection) if isinstance(projection, Iterable) else [projection])
255
+
256
+ if not np.issubdtype(projection.dtype, np.number):
257
+ raise ValueError("'projection' must consist of numerical values")
258
+
259
+ output = {}
260
+ for name, measures in self.measures.items():
261
+ if measures.ndim > 1:
262
+ result = []
263
+ for i in range(len(measures)):
264
+ projected = project_steps(self.params[name][i], projection)
265
+ result.append(projected)
266
+ output[name] = np.array(result)
267
+ else:
268
+ output[name] = project_steps(self.params[name], projection)
269
+ proj = SufficiencyOutput(projection, output, self.n_iter)
270
+ proj._params = self._params
271
+ return proj
272
+
273
+ def plot(self, class_names: Sequence[str] | None = None) -> list[Figure]:
274
+ """
275
+ Plotting function for data :term:`sufficience<Sufficiency>` tasks.
276
+
277
+ Parameters
278
+ ----------
279
+ class_names : Sequence[str] | None, default None
280
+ List of class names
281
+
282
+ Returns
283
+ -------
284
+ list[Figure]
285
+ List of Figures for each measure
286
+
287
+ Raises
288
+ ------
289
+ ValueError
290
+ If the length of data points in the measures do not match
291
+
292
+ Notes
293
+ -----
294
+ This method requires `matplotlib <https://matplotlib.org/>`_ to be installed.
295
+ """
296
+ # Extrapolation parameters
297
+ last_X = self.steps[-1]
298
+ geomshape = (0.01 * last_X, last_X * 4, len(self.steps))
299
+ extrapolated = np.geomspace(*geomshape).astype(np.int64)
300
+
301
+ # Stores all plots
302
+ plots = []
303
+
304
+ # Create a plot for each measure on one figure
305
+ for name, measures in self.measures.items():
306
+ if measures.ndim > 1:
307
+ if class_names is not None and len(measures) != len(class_names):
308
+ raise IndexError("Class name count does not align with measures")
309
+ for i, measure in enumerate(measures):
310
+ class_name = str(i) if class_names is None else class_names[i]
311
+ fig = plot_measure(
312
+ f"{name}_{class_name}",
313
+ self.steps,
314
+ measure,
315
+ self.params[name][i],
316
+ extrapolated,
317
+ )
318
+ plots.append(fig)
319
+
320
+ else:
321
+ fig = plot_measure(name, self.steps, measures, self.params[name], extrapolated)
322
+ plots.append(fig)
323
+
324
+ return plots
325
+
326
+ def inv_project(
327
+ self, targets: Mapping[str, ArrayLike], n_iter: int | None = None
328
+ ) -> dict[str, NDArray[np.float64]]:
329
+ """
330
+ Calculate training samples needed to achieve target model metric values.
331
+
332
+ Parameters
333
+ ----------
334
+ targets : Mapping[str, ArrayLike]
335
+ Mapping of target metric scores (from 0.0 to 1.0) that we want
336
+ to achieve, where the key is the name of the metric.
337
+ n_iter : int or None, default None
338
+ Iteration to use when calculating the inverse power curve, if None defaults to 1000
339
+
340
+ Returns
341
+ -------
342
+ dict[str, NDArray]
343
+ List of the number of training samples needed to achieve each
344
+ corresponding entry in targets
345
+ """
346
+
347
+ projection = {}
348
+
349
+ for name, target in targets.items():
350
+ tarray = as_numpy(target)
351
+ if name not in self.measures:
352
+ continue
353
+
354
+ measure = self.measures[name]
355
+ if measure.ndim > 1:
356
+ projection[name] = np.zeros((len(measure), len(tarray)))
357
+ for i in range(len(measure)):
358
+ projection[name][i] = inv_project_steps(
359
+ self.params[name][i], tarray[i] if tarray.ndim == measure.ndim else tarray
360
+ )
361
+ else:
362
+ projection[name] = inv_project_steps(self.params[name], tarray)
363
+
364
+ return projection
dataeval/typing.py CHANGED
@@ -2,9 +2,32 @@
2
2
  Common type hints used for interoperability with DataEval.
3
3
  """
4
4
 
5
- __all__ = ["Array", "ArrayLike"]
6
-
7
- from typing import Any, Iterator, Protocol, Sequence, TypeVar, Union, runtime_checkable
5
+ __all__ = [
6
+ "Array",
7
+ "ArrayLike",
8
+ "Dataset",
9
+ "AnnotatedDataset",
10
+ "DatasetMetadata",
11
+ "ImageClassificationDatum",
12
+ "ImageClassificationDataset",
13
+ "ObjectDetectionTarget",
14
+ "ObjectDetectionDatum",
15
+ "ObjectDetectionDataset",
16
+ "SegmentationTarget",
17
+ "SegmentationDatum",
18
+ "SegmentationDataset",
19
+ ]
20
+
21
+
22
+ import sys
23
+ from typing import Any, Generic, Iterator, Protocol, Sequence, TypedDict, TypeVar, Union, runtime_checkable
24
+
25
+ from typing_extensions import NotRequired, Required
26
+
27
+ if sys.version_info >= (3, 10):
28
+ from typing import TypeAlias
29
+ else:
30
+ from typing_extensions import TypeAlias
8
31
 
9
32
 
10
33
  @runtime_checkable
@@ -43,12 +66,169 @@ class Array(Protocol):
43
66
  def __len__(self) -> int: ...
44
67
 
45
68
 
46
- TArray = TypeVar("TArray", bound=Array)
47
-
48
- ArrayLike = Union[Sequence[Any], Array]
69
+ _T_co = TypeVar("_T_co", covariant=True)
70
+ _ScalarType = Union[int, float, bool, str]
71
+ ArrayLike: TypeAlias = Union[Sequence[_ScalarType], Sequence[Sequence[_ScalarType]], Sequence[Array], Array]
49
72
  """
50
73
  Type alias for array-like objects used for interoperability with DataEval.
51
74
 
52
75
  This includes native Python sequences, as well as objects that conform to
53
- the `Array` protocol.
76
+ the :class:`Array` protocol.
77
+ """
78
+
79
+
80
+ class DatasetMetadata(TypedDict, total=False):
81
+ """
82
+ Dataset level metadata required for all `AnnotatedDataset` classes.
83
+
84
+ Attributes
85
+ ----------
86
+ id : Required[str]
87
+ A unique identifier for the dataset
88
+ index2label : NotRequired[dict[int, str]]
89
+ A lookup table converting label value to class name
90
+ """
91
+
92
+ id: Required[str]
93
+ index2label: NotRequired[dict[int, str]]
94
+
95
+
96
+ @runtime_checkable
97
+ class Dataset(Generic[_T_co], Protocol):
98
+ """
99
+ Protocol for a generic `Dataset`.
100
+
101
+ Methods
102
+ -------
103
+ __getitem__(index: int)
104
+ Returns datum at specified index.
105
+ __len__()
106
+ Returns dataset length.
107
+ """
108
+
109
+ def __getitem__(self, index: int, /) -> _T_co: ...
110
+ def __len__(self) -> int: ...
111
+
112
+
113
+ @runtime_checkable
114
+ class AnnotatedDataset(Dataset[_T_co], Generic[_T_co], Protocol):
115
+ """
116
+ Protocol for a generic `AnnotatedDataset`.
117
+
118
+ Attributes
119
+ ----------
120
+ metadata : :class:`.DatasetMetadata` or derivatives.
121
+
122
+ Methods
123
+ -------
124
+ __getitem__(index: int)
125
+ Returns datum at specified index.
126
+ __len__()
127
+ Returns dataset length.
128
+
129
+ Notes
130
+ -----
131
+ Inherits from :class:`.Dataset`.
132
+ """
133
+
134
+ @property
135
+ def metadata(self) -> DatasetMetadata: ...
136
+
137
+
138
+ # ========== IMAGE CLASSIFICATION DATASETS ==========
139
+
140
+
141
+ ImageClassificationDatum: TypeAlias = tuple[Array, Array, dict[str, Any]]
142
+ """
143
+ A type definition for an image classification datum tuple.
144
+
145
+ - :class:`Array` of shape (C, H, W) - Image data in channel, height, width format.
146
+ - :class:`Array` of shape (N,) - Class label as one-hot encoded ground-truth or prediction confidences.
147
+ - dict[str, Any] - Datum level metadata.
148
+ """
149
+
150
+
151
+ ImageClassificationDataset: TypeAlias = AnnotatedDataset[ImageClassificationDatum]
152
+ """
153
+ A type definition for an :class:`AnnotatedDataset` of :class:`ImageClassificationDatum` elements.
154
+ """
155
+
156
+ # ========== OBJECT DETECTION DATASETS ==========
157
+
158
+
159
+ @runtime_checkable
160
+ class ObjectDetectionTarget(Protocol):
161
+ """
162
+ A protocol for targets in an Object Detection dataset.
163
+
164
+ Attributes
165
+ ----------
166
+ boxes : :class:`ArrayLike` of shape (N, 4)
167
+ labels : :class:`ArrayLike` of shape (N,)
168
+ scores : :class:`ArrayLike` of shape (N, M)
169
+ """
170
+
171
+ @property
172
+ def boxes(self) -> ArrayLike: ...
173
+
174
+ @property
175
+ def labels(self) -> ArrayLike: ...
176
+
177
+ @property
178
+ def scores(self) -> ArrayLike: ...
179
+
180
+
181
+ ObjectDetectionDatum: TypeAlias = tuple[Array, ObjectDetectionTarget, dict[str, Any]]
182
+ """
183
+ A type definition for an object detection datum tuple.
184
+
185
+ - :class:`Array` of shape (C, H, W) - Image data in channel, height, width format.
186
+ - :class:`ObjectDetectionTarget` - Object detection target information for the image.
187
+ - dict[str, Any] - Datum level metadata.
188
+ """
189
+
190
+
191
+ ObjectDetectionDataset: TypeAlias = AnnotatedDataset[ObjectDetectionDatum]
192
+ """
193
+ A type definition for an :class:`AnnotatedDataset` of :class:`ObjectDetectionDatum` elements.
194
+ """
195
+
196
+
197
+ # ========== SEGMENTATION DATASETS ==========
198
+
199
+
200
+ @runtime_checkable
201
+ class SegmentationTarget(Protocol):
202
+ """
203
+ A protocol for targets in a Segmentation dataset.
204
+
205
+ Attributes
206
+ ----------
207
+ mask : :class:`ArrayLike`
208
+ labels : :class:`ArrayLike`
209
+ scores : :class:`ArrayLike`
210
+ """
211
+
212
+ @property
213
+ def mask(self) -> ArrayLike: ...
214
+
215
+ @property
216
+ def labels(self) -> ArrayLike: ...
217
+
218
+ @property
219
+ def scores(self) -> ArrayLike: ...
220
+
221
+
222
+ SegmentationDatum: TypeAlias = tuple[Array, SegmentationTarget, dict[str, Any]]
223
+ """
224
+ A type definition for an image classification datum tuple.
225
+
226
+ - :class:`Array` of shape (C, H, W) - Image data in channel, height, width format.
227
+ - :class:`SegmentationTarget` - Segmentation target information for the image.
228
+ - dict[str, Any] - Datum level metadata.
229
+ """
230
+
231
+ SegmentationDataset: TypeAlias = AnnotatedDataset[SegmentationDatum]
232
+ """
233
+ A type definition for an :class:`AnnotatedDataset` of :class:`SegmentationDatum` elements.
54
234
  """
dataeval/utils/_method.py CHANGED
@@ -1,12 +1,8 @@
1
1
  from __future__ import annotations
2
2
 
3
- import sys
4
3
  from typing import Callable, TypeVar
5
4
 
6
- if sys.version_info >= (3, 10):
7
- from typing import ParamSpec
8
- else:
9
- from typing_extensions import ParamSpec
5
+ from typing_extensions import ParamSpec
10
6
 
11
7
  P = ParamSpec("P")
12
8
  R = TypeVar("R")
dataeval/utils/_plot.py CHANGED
@@ -49,8 +49,8 @@ def heatmap(
49
49
  from matplotlib.ticker import FuncFormatter
50
50
 
51
51
  np_data = to_numpy(data)
52
- rows = row_labels if isinstance(row_labels, list) else to_numpy(row_labels)
53
- cols = col_labels if isinstance(col_labels, list) else to_numpy(col_labels)
52
+ rows: list[str] = [str(n) for n in to_numpy(row_labels)]
53
+ cols: list[str] = [str(n) for n in to_numpy(col_labels)]
54
54
 
55
55
  fig, ax = plt.subplots(figsize=(10, 10))
56
56
 
@@ -10,13 +10,17 @@ __all__ = [
10
10
  "SplitDatasetOutput",
11
11
  "Targets",
12
12
  "split_dataset",
13
+ "to_image_classification_dataset",
14
+ "to_object_detection_dataset",
13
15
  ]
14
16
 
17
+ from dataeval.outputs._utils import SplitDatasetOutput
18
+ from dataeval.utils.data._dataset import to_image_classification_dataset, to_object_detection_dataset
15
19
  from dataeval.utils.data._embeddings import Embeddings
16
20
  from dataeval.utils.data._images import Images
17
21
  from dataeval.utils.data._metadata import Metadata
18
22
  from dataeval.utils.data._selection import Select
19
- from dataeval.utils.data._split import SplitDatasetOutput, split_dataset
23
+ from dataeval.utils.data._split import split_dataset
20
24
  from dataeval.utils.data._targets import Targets
21
25
 
22
26
  from . import collate, datasets