dataeval 0.82.1__py3-none-any.whl → 0.83.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. dataeval/__init__.py +7 -2
  2. dataeval/config.py +10 -0
  3. dataeval/metadata/__init__.py +2 -2
  4. dataeval/metadata/_ood.py +144 -27
  5. dataeval/metrics/bias/_balance.py +3 -3
  6. dataeval/metrics/estimators/_ber.py +2 -1
  7. dataeval/metrics/stats/_base.py +17 -18
  8. dataeval/metrics/stats/_dimensionstats.py +2 -2
  9. dataeval/metrics/stats/_hashstats.py +2 -2
  10. dataeval/metrics/stats/_imagestats.py +4 -4
  11. dataeval/metrics/stats/_pixelstats.py +2 -2
  12. dataeval/metrics/stats/_visualstats.py +2 -2
  13. dataeval/outputs/__init__.py +2 -1
  14. dataeval/outputs/_metadata.py +7 -0
  15. dataeval/typing.py +40 -9
  16. dataeval/utils/_mst.py +1 -2
  17. dataeval/utils/data/_embeddings.py +15 -10
  18. dataeval/utils/data/_selection.py +22 -11
  19. dataeval/utils/data/datasets/_base.py +4 -2
  20. dataeval/utils/data/datasets/_cifar10.py +17 -9
  21. dataeval/utils/data/datasets/_milco.py +18 -12
  22. dataeval/utils/data/datasets/_mnist.py +24 -8
  23. dataeval/utils/data/datasets/_ships.py +18 -8
  24. dataeval/utils/data/datasets/_types.py +1 -5
  25. dataeval/utils/data/datasets/_voc.py +47 -24
  26. dataeval/utils/data/selections/__init__.py +2 -0
  27. dataeval/utils/data/selections/_classfilter.py +1 -1
  28. dataeval/utils/data/selections/_prioritize.py +296 -0
  29. dataeval/utils/data/selections/_shuffle.py +13 -4
  30. dataeval/utils/torch/_gmm.py +3 -2
  31. {dataeval-0.82.1.dist-info → dataeval-0.83.0.dist-info}/METADATA +4 -4
  32. {dataeval-0.82.1.dist-info → dataeval-0.83.0.dist-info}/RECORD +34 -34
  33. dataeval/detectors/ood/metadata_ood_mi.py +0 -91
  34. {dataeval-0.82.1.dist-info → dataeval-0.83.0.dist-info}/LICENSE.txt +0 -0
  35. {dataeval-0.82.1.dist-info → dataeval-0.83.0.dist-info}/WHEEL +0 -0
dataeval/typing.py CHANGED
@@ -1,5 +1,5 @@
1
1
  """
2
- Common type hints used for interoperability with DataEval.
2
+ Common type protocols used for interoperability with DataEval.
3
3
  """
4
4
 
5
5
  __all__ = [
@@ -16,6 +16,7 @@ __all__ = [
16
16
  "SegmentationTarget",
17
17
  "SegmentationDatum",
18
18
  "SegmentationDataset",
19
+ "Transform",
19
20
  ]
20
21
 
21
22
 
@@ -66,6 +67,7 @@ class Array(Protocol):
66
67
  def __len__(self) -> int: ...
67
68
 
68
69
 
70
+ T = TypeVar("T")
69
71
  _T_co = TypeVar("_T_co", covariant=True)
70
72
  _ScalarType = Union[int, float, bool, str]
71
73
  ArrayLike: TypeAlias = Union[Sequence[_ScalarType], Sequence[Sequence[_ScalarType]], Sequence[Array], Array]
@@ -140,7 +142,7 @@ class AnnotatedDataset(Dataset[_T_co], Generic[_T_co], Protocol):
140
142
 
141
143
  ImageClassificationDatum: TypeAlias = tuple[Array, Array, dict[str, Any]]
142
144
  """
143
- A type definition for an image classification datum tuple.
145
+ Type alias for an image classification datum tuple.
144
146
 
145
147
  - :class:`Array` of shape (C, H, W) - Image data in channel, height, width format.
146
148
  - :class:`Array` of shape (N,) - Class label as one-hot encoded ground-truth or prediction confidences.
@@ -150,7 +152,7 @@ A type definition for an image classification datum tuple.
150
152
 
151
153
  ImageClassificationDataset: TypeAlias = AnnotatedDataset[ImageClassificationDatum]
152
154
  """
153
- A type definition for an :class:`AnnotatedDataset` of :class:`ImageClassificationDatum` elements.
155
+ Type alias for an :class:`AnnotatedDataset` of :class:`ImageClassificationDatum` elements.
154
156
  """
155
157
 
156
158
  # ========== OBJECT DETECTION DATASETS ==========
@@ -159,7 +161,7 @@ A type definition for an :class:`AnnotatedDataset` of :class:`ImageClassificatio
159
161
  @runtime_checkable
160
162
  class ObjectDetectionTarget(Protocol):
161
163
  """
162
- A protocol for targets in an Object Detection dataset.
164
+ Protocol for targets in an Object Detection dataset.
163
165
 
164
166
  Attributes
165
167
  ----------
@@ -180,7 +182,7 @@ class ObjectDetectionTarget(Protocol):
180
182
 
181
183
  ObjectDetectionDatum: TypeAlias = tuple[Array, ObjectDetectionTarget, dict[str, Any]]
182
184
  """
183
- A type definition for an object detection datum tuple.
185
+ Type alias for an object detection datum tuple.
184
186
 
185
187
  - :class:`Array` of shape (C, H, W) - Image data in channel, height, width format.
186
188
  - :class:`ObjectDetectionTarget` - Object detection target information for the image.
@@ -190,7 +192,7 @@ A type definition for an object detection datum tuple.
190
192
 
191
193
  ObjectDetectionDataset: TypeAlias = AnnotatedDataset[ObjectDetectionDatum]
192
194
  """
193
- A type definition for an :class:`AnnotatedDataset` of :class:`ObjectDetectionDatum` elements.
195
+ Type alias for an :class:`AnnotatedDataset` of :class:`ObjectDetectionDatum` elements.
194
196
  """
195
197
 
196
198
 
@@ -200,7 +202,7 @@ A type definition for an :class:`AnnotatedDataset` of :class:`ObjectDetectionDat
200
202
  @runtime_checkable
201
203
  class SegmentationTarget(Protocol):
202
204
  """
203
- A protocol for targets in a Segmentation dataset.
205
+ Protocol for targets in a Segmentation dataset.
204
206
 
205
207
  Attributes
206
208
  ----------
@@ -221,7 +223,7 @@ class SegmentationTarget(Protocol):
221
223
 
222
224
  SegmentationDatum: TypeAlias = tuple[Array, SegmentationTarget, dict[str, Any]]
223
225
  """
224
- A type definition for an image classification datum tuple.
226
+ Type alias for an image classification datum tuple.
225
227
 
226
228
  - :class:`Array` of shape (C, H, W) - Image data in channel, height, width format.
227
229
  - :class:`SegmentationTarget` - Segmentation target information for the image.
@@ -230,5 +232,34 @@ A type definition for an image classification datum tuple.
230
232
 
231
233
  SegmentationDataset: TypeAlias = AnnotatedDataset[SegmentationDatum]
232
234
  """
233
- A type definition for an :class:`AnnotatedDataset` of :class:`SegmentationDatum` elements.
235
+ Type alias for an :class:`AnnotatedDataset` of :class:`SegmentationDatum` elements.
234
236
  """
237
+
238
+
239
+ @runtime_checkable
240
+ class Transform(Generic[T], Protocol):
241
+ """
242
+ Protocol defining a transform function.
243
+
244
+ Requires a `__call__` method that returns transformed data.
245
+
246
+ Example
247
+ -------
248
+ >>> from typing import Any
249
+ >>> from numpy.typing import NDArray
250
+
251
+ >>> class MyTransform:
252
+ ... def __init__(self, divisor: float) -> None:
253
+ ... self.divisor = divisor
254
+ ...
255
+ ... def __call__(self, data: NDArray[Any], /) -> NDArray[Any]:
256
+ ... return data / self.divisor
257
+
258
+ >>> my_transform = MyTransform(divisor=255.0)
259
+ >>> isinstance(my_transform, Transform)
260
+ True
261
+ >>> my_transform(np.array([1, 2, 3]))
262
+ array([0.004, 0.008, 0.012])
263
+ """
264
+
265
+ def __call__(self, data: T, /) -> T: ...
dataeval/utils/_mst.py CHANGED
@@ -10,10 +10,9 @@ from scipy.sparse.csgraph import minimum_spanning_tree as mst
10
10
  from scipy.spatial.distance import pdist, squareform
11
11
  from sklearn.neighbors import NearestNeighbors
12
12
 
13
+ from dataeval.config import EPSILON
13
14
  from dataeval.utils._array import flatten
14
15
 
15
- EPSILON = 1e-5
16
-
17
16
 
18
17
  def minimum_spanning_tree(X: NDArray[Any]) -> Any:
19
18
  """
@@ -57,20 +57,27 @@ class Embeddings:
57
57
  self._encoder = model.encode if isinstance(model, SupportsEncode) else model
58
58
  self._collate_fn = lambda datum: [torch.as_tensor(i) for i, _, _ in datum]
59
59
 
60
- def to_tensor(self) -> torch.Tensor:
60
+ def to_tensor(self, indices: Sequence[int] | None = None) -> torch.Tensor:
61
61
  """
62
- Converts entire dataset to embeddings.
62
+ Converts dataset to embeddings.
63
63
 
64
- Warning
65
- -------
66
- Will process the entire dataset in batches and return
67
- embeddings as a single Tensor in memory.
64
+ Parameters
65
+ ----------
66
+ indices : Sequence[int] or None, default None
67
+ The indices to convert to embeddings
68
68
 
69
69
  Returns
70
70
  -------
71
71
  torch.Tensor
72
+
73
+ Warning
74
+ -------
75
+ Processing large quantities of data can be resource intensive.
72
76
  """
73
- return self[:]
77
+ if indices is not None:
78
+ return torch.vstack(list(self._batch(indices))).to(self.device)
79
+ else:
80
+ return self[:]
74
81
 
75
82
  # Reduce overhead cost by not tracking tensor gradients
76
83
  @torch.no_grad
@@ -85,9 +92,7 @@ class Embeddings:
85
92
  embeddings = self._encoder(torch.stack(images).to(self.device))
86
93
  yield embeddings
87
94
 
88
- def __getitem__(self, key: int | slice | list[int], /) -> torch.Tensor:
89
- if isinstance(key, list):
90
- return torch.vstack(list(self._batch(key))).to(self.device)
95
+ def __getitem__(self, key: int | slice, /) -> torch.Tensor:
91
96
  if isinstance(key, slice):
92
97
  return torch.vstack(list(self._batch(range(len(self._dataset))[key]))).to(self.device)
93
98
  elif isinstance(key, int):
@@ -5,9 +5,9 @@ __all__ = []
5
5
  from enum import IntEnum
6
6
  from typing import Generic, Iterator, Sequence, TypeVar
7
7
 
8
- from dataeval.typing import AnnotatedDataset, DatasetMetadata
8
+ from dataeval.typing import AnnotatedDataset, DatasetMetadata, Transform
9
9
 
10
- _TDatum = TypeVar("_TDatum", covariant=True)
10
+ _TDatum = TypeVar("_TDatum")
11
11
 
12
12
 
13
13
  class SelectionStage(IntEnum):
@@ -35,6 +35,8 @@ class Select(AnnotatedDataset[_TDatum]):
35
35
  The dataset to wrap.
36
36
  selections : Selection or list[Selection], optional
37
37
  The selection criteria to apply to the dataset.
38
+ transforms : Transform or list[Transform], optional
39
+ The transforms to apply to the dataset.
38
40
 
39
41
  Examples
40
42
  --------
@@ -67,13 +69,17 @@ class Select(AnnotatedDataset[_TDatum]):
67
69
  def __init__(
68
70
  self,
69
71
  dataset: AnnotatedDataset[_TDatum],
70
- selections: Selection[_TDatum] | list[Selection[_TDatum]] | None = None,
72
+ selections: Selection[_TDatum] | Sequence[Selection[_TDatum]] | None = None,
73
+ transforms: Transform[_TDatum] | Sequence[Transform[_TDatum]] | None = None,
71
74
  ) -> None:
72
75
  self.__dict__.update(dataset.__dict__)
73
76
  self._dataset = dataset
74
77
  self._size_limit = len(dataset)
75
78
  self._selection = list(range(self._size_limit))
76
- self._selections = self._sort_selections(selections)
79
+ self._selections = self._sort(selections)
80
+ self._transforms = (
81
+ [] if transforms is None else [transforms] if isinstance(transforms, Transform) else transforms
82
+ )
77
83
 
78
84
  # Ensure metadata is populated correctly as DatasetMetadata TypedDict
79
85
  _metadata = getattr(dataset, "metadata", {})
@@ -81,8 +87,7 @@ class Select(AnnotatedDataset[_TDatum]):
81
87
  _metadata["id"] = dataset.__class__.__name__
82
88
  self._metadata = DatasetMetadata(**_metadata)
83
89
 
84
- if self._selections:
85
- self._apply_selections()
90
+ self._select()
86
91
 
87
92
  @property
88
93
  def metadata(self) -> DatasetMetadata:
@@ -92,10 +97,11 @@ class Select(AnnotatedDataset[_TDatum]):
92
97
  nt = "\n "
93
98
  title = f"{self.__class__.__name__} Dataset"
94
99
  sep = "-" * len(title)
95
- selections = f"Selections: [{', '.join([str(s) for s in self._sort_selections(self._selections)])}]"
96
- return f"{title}\n{sep}{nt}{selections}{nt}Selected Size: {len(self)}\n\n{self._dataset}"
100
+ selections = f"Selections: [{', '.join([str(s) for s in self._selections])}]"
101
+ transforms = f"Transforms: [{', '.join([str(t) for t in self._transforms])}]"
102
+ return f"{title}\n{sep}{nt}{selections}{nt}{transforms}{nt}Selected Size: {len(self)}\n\n{self._dataset}"
97
103
 
98
- def _sort_selections(self, selections: Selection[_TDatum] | Sequence[Selection[_TDatum]] | None) -> list[Selection]:
104
+ def _sort(self, selections: Selection[_TDatum] | Sequence[Selection[_TDatum]] | None) -> list[Selection]:
99
105
  if not selections:
100
106
  return []
101
107
 
@@ -106,13 +112,18 @@ class Select(AnnotatedDataset[_TDatum]):
106
112
  selection_list = [selection for category in sorted(grouped) for selection in grouped[category]]
107
113
  return selection_list
108
114
 
109
- def _apply_selections(self) -> None:
115
+ def _select(self) -> None:
110
116
  for selection in self._selections:
111
117
  selection(self)
112
118
  self._selection = self._selection[: self._size_limit]
113
119
 
120
+ def _transform(self, datum: _TDatum) -> _TDatum:
121
+ for t in self._transforms:
122
+ datum = t(datum)
123
+ return datum
124
+
114
125
  def __getitem__(self, index: int) -> _TDatum:
115
- return self._dataset[self._selection[index]]
126
+ return self._transform(self._dataset[self._selection[index]])
116
127
 
117
128
  def __iter__(self) -> Iterator[_TDatum]:
118
129
  for i in range(len(self)):
@@ -4,7 +4,7 @@ __all__ = []
4
4
 
5
5
  from abc import abstractmethod
6
6
  from pathlib import Path
7
- from typing import Any, Generic, Iterator, Literal, NamedTuple, Sequence, TypeVar
7
+ from typing import TYPE_CHECKING, Any, Generic, Iterator, Literal, NamedTuple, Sequence, TypeVar
8
8
 
9
9
  from dataeval.utils.data.datasets._fileio import _ensure_exists
10
10
  from dataeval.utils.data.datasets._mixin import BaseDatasetMixin
@@ -16,9 +16,11 @@ from dataeval.utils.data.datasets._types import (
16
16
  ObjectDetectionTarget,
17
17
  SegmentationDataset,
18
18
  SegmentationTarget,
19
- Transform,
20
19
  )
21
20
 
21
+ if TYPE_CHECKING:
22
+ from dataeval.typing import Transform
23
+
22
24
  _TArray = TypeVar("_TArray")
23
25
  _TTarget = TypeVar("_TTarget")
24
26
  _TRawTarget = TypeVar("_TRawTarget", list[int], list[str])
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
  __all__ = []
4
4
 
5
5
  from pathlib import Path
6
- from typing import Any, Literal, Sequence, TypeVar
6
+ from typing import TYPE_CHECKING, Any, Literal, Sequence, TypeVar
7
7
 
8
8
  import numpy as np
9
9
  from numpy.typing import NDArray
@@ -11,7 +11,9 @@ from PIL import Image
11
11
 
12
12
  from dataeval.utils.data.datasets._base import BaseICDataset, DataLocation
13
13
  from dataeval.utils.data.datasets._mixin import BaseDatasetNumpyMixin
14
- from dataeval.utils.data.datasets._types import Transform
14
+
15
+ if TYPE_CHECKING:
16
+ from dataeval.typing import Transform
15
17
 
16
18
  CIFARClassStringMap = Literal["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]
17
19
  TCIFARClassMap = TypeVar("TCIFARClassMap", CIFARClassStringMap, int, list[CIFARClassStringMap], list[int])
@@ -30,21 +32,27 @@ class CIFAR10(BaseICDataset[NDArray[Any]], BaseDatasetNumpyMixin):
30
32
  Class checks to see if data is already downloaded to ensure it does not create a duplicate download.
31
33
  image_set : "train", "test" or "base", default "train"
32
34
  If "base", returns all of the data to allow the user to create their own splits.
33
- transforms : Transform | Sequence[Transform] | None, default None
35
+ transforms : Transform, Sequence[Transform] or None, default None
34
36
  Transform(s) to apply to the data.
35
37
  verbose : bool, default False
36
38
  If True, outputs print statements.
37
39
 
38
40
  Attributes
39
41
  ----------
40
- index2label : dict
42
+ path : pathlib.Path
43
+ Location of the folder containing the data.
44
+ image_set : "train", "test" or "base"
45
+ The selected image set from the dataset.
46
+ index2label : dict[int, str]
41
47
  Dictionary which translates from class integers to the associated class strings.
42
- label2index : dict
48
+ label2index : dict[str, int]
43
49
  Dictionary which translates from class strings to the associated class integers.
44
- path : Path
45
- Location of the folder containing the data.
46
- metadata : dict
47
- Dictionary containing Dataset metadata, such as `id` which returns the dataset class name.
50
+ metadata : DatasetMetadata
51
+ Typed dictionary containing dataset metadata, such as `id` which returns the dataset class name.
52
+ transforms : Sequence[Transform]
53
+ The transforms to be applied to the data.
54
+ size : int
55
+ The size of the dataset.
48
56
  """
49
57
 
50
58
  _resources = [
@@ -1,23 +1,23 @@
1
1
  from __future__ import annotations
2
2
 
3
- from dataeval.utils.data.datasets._mixin import BaseDatasetNumpyMixin
4
-
5
3
  __all__ = []
6
4
 
7
5
  from pathlib import Path
8
- from typing import Any, Sequence
6
+ from typing import TYPE_CHECKING, Any, Sequence
9
7
 
10
8
  from numpy.typing import NDArray
11
9
 
12
10
  from dataeval.utils.data.datasets._base import BaseODDataset, DataLocation
13
- from dataeval.utils.data.datasets._types import Transform
11
+ from dataeval.utils.data.datasets._mixin import BaseDatasetNumpyMixin
12
+
13
+ if TYPE_CHECKING:
14
+ from dataeval.typing import Transform
14
15
 
15
16
 
16
17
  class MILCO(BaseODDataset[NDArray[Any]], BaseDatasetNumpyMixin):
17
18
  """
18
19
  A side-scan sonar dataset focused on mine (object) detection.
19
20
 
20
-
21
21
  The dataset comes from the paper
22
22
  `Side-scan sonar imaging data of underwater vehicles for mine detection <https://doi.org/10.1016/j.dib.2024.110132>`_
23
23
  by N.P. Santos et. al. (2024).
@@ -43,21 +43,27 @@ class MILCO(BaseODDataset[NDArray[Any]], BaseDatasetNumpyMixin):
43
43
  download : bool, default False
44
44
  If True, downloads the dataset from the internet and puts it in root directory.
45
45
  Class checks to see if data is already downloaded to ensure it does not create a duplicate download.
46
- transforms : Transform | Sequence[Transform] | None, default None
46
+ transforms : Transform, Sequence[Transform] or None, default None
47
47
  Transform(s) to apply to the data.
48
48
  verbose : bool, default False
49
49
  If True, outputs print statements.
50
50
 
51
51
  Attributes
52
52
  ----------
53
- index2label : dict
53
+ path : pathlib.Path
54
+ Location of the folder containing the data.
55
+ image_set : "base"
56
+ The base image set is the only available image set for the MILCO dataset.
57
+ index2label : dict[int, str]
54
58
  Dictionary which translates from class integers to the associated class strings.
55
- label2index : dict
59
+ label2index : dict[str, int]
56
60
  Dictionary which translates from class strings to the associated class integers.
57
- path : Path
58
- Location of the folder containing the data.
59
- metadata : dict
60
- Dictionary containing Dataset metadata, such as `id` which returns the dataset class name.
61
+ metadata : DatasetMetadata
62
+ Typed dictionary containing dataset metadata, such as `id` which returns the dataset class name.
63
+ transforms : Sequence[Transform]
64
+ The transforms to be applied to the data.
65
+ size : int
66
+ The size of the dataset.
61
67
  """
62
68
 
63
69
  _resources = [
@@ -3,14 +3,16 @@ from __future__ import annotations
3
3
  __all__ = []
4
4
 
5
5
  from pathlib import Path
6
- from typing import Any, Literal, Sequence, TypeVar
6
+ from typing import TYPE_CHECKING, Any, Literal, Sequence, TypeVar
7
7
 
8
8
  import numpy as np
9
9
  from numpy.typing import NDArray
10
10
 
11
11
  from dataeval.utils.data.datasets._base import BaseICDataset, DataLocation
12
12
  from dataeval.utils.data.datasets._mixin import BaseDatasetNumpyMixin
13
- from dataeval.utils.data.datasets._types import Transform
13
+
14
+ if TYPE_CHECKING:
15
+ from dataeval.typing import Transform
14
16
 
15
17
  MNISTClassStringMap = Literal["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
16
18
  TMNISTClassMap = TypeVar("TMNISTClassMap", MNISTClassStringMap, int, list[MNISTClassStringMap], list[int])
@@ -52,19 +54,33 @@ class MNIST(BaseICDataset[NDArray[Any]], BaseDatasetNumpyMixin):
52
54
  Class checks to see if data is already downloaded to ensure it does not create a duplicate download.
53
55
  image_set : "train", "test" or "base", default "train"
54
56
  If "base", returns all of the data to allow the user to create their own splits.
57
+ corruption : "identity", "shot_noise", "impulse_noise", "glass_blur", "motion_blur", \
58
+ "shear", "scale", "rotate", "brightness", "translate", "stripe", "fog", "spatter", \
59
+ "dotted_line", "zigzag", "canny_edges" or None, default None
60
+ Corruption to apply to the data.
61
+ transforms : Transform, Sequence[Transform] or None, default None
62
+ Transform(s) to apply to the data.
55
63
  verbose : bool, default False
56
64
  If True, outputs print statements.
57
65
 
58
66
  Attributes
59
67
  ----------
60
- index2label : dict
68
+ path : pathlib.Path
69
+ Location of the folder containing the data.
70
+ image_set : "train", "test" or "base"
71
+ The selected image set from the dataset.
72
+ index2label : dict[int, str]
61
73
  Dictionary which translates from class integers to the associated class strings.
62
- label2index : dict
74
+ label2index : dict[str, int]
63
75
  Dictionary which translates from class strings to the associated class integers.
64
- path : Path
65
- Location of the folder containing the data.
66
- metadata : dict
67
- Dictionary containing Dataset metadata, such as `id` which returns the dataset class name.
76
+ metadata : DatasetMetadata
77
+ Typed dictionary containing dataset metadata, such as `id` which returns the dataset class name.
78
+ corruption : str or None
79
+ Corruption applied to the data.
80
+ transforms : Sequence[Transform]
81
+ The transforms to be applied to the data.
82
+ size : int
83
+ The size of the dataset.
68
84
  """
69
85
 
70
86
  _resources = [
@@ -3,14 +3,16 @@ from __future__ import annotations
3
3
  __all__ = []
4
4
 
5
5
  from pathlib import Path
6
- from typing import Any, Sequence
6
+ from typing import TYPE_CHECKING, Any, Sequence
7
7
 
8
8
  import numpy as np
9
9
  from numpy.typing import NDArray
10
10
 
11
11
  from dataeval.utils.data.datasets._base import BaseICDataset, DataLocation
12
12
  from dataeval.utils.data.datasets._mixin import BaseDatasetNumpyMixin
13
- from dataeval.utils.data.datasets._types import Transform
13
+
14
+ if TYPE_CHECKING:
15
+ from dataeval.typing import Transform
14
16
 
15
17
 
16
18
  class Ships(BaseICDataset[NDArray[Any]], BaseDatasetNumpyMixin):
@@ -32,19 +34,27 @@ class Ships(BaseICDataset[NDArray[Any]], BaseDatasetNumpyMixin):
32
34
  download : bool, default False
33
35
  If True, downloads the dataset from the internet and puts it in root directory.
34
36
  Class checks to see if data is already downloaded to ensure it does not create a duplicate download.
37
+ transforms : Transform, Sequence[Transform] or None, default None
38
+ Transform(s) to apply to the data.
35
39
  verbose : bool, default False
36
40
  If True, outputs print statements.
37
41
 
38
42
  Attributes
39
43
  ----------
40
- index2label : dict
44
+ path : pathlib.Path
45
+ Location of the folder containing the data.
46
+ image_set : "base"
47
+ The base image set is the only available image set for the Ships dataset.
48
+ index2label : dict[int, str]
41
49
  Dictionary which translates from class integers to the associated class strings.
42
- label2index : dict
50
+ label2index : dict[str, int]
43
51
  Dictionary which translates from class strings to the associated class integers.
44
- path : Path
45
- Location of the folder containing the data.
46
- metadata : dict
47
- Dictionary containing Dataset metadata, such as `id` which returns the dataset class name.
52
+ metadata : DatasetMetadata
53
+ Typed dictionary containing dataset metadata, such as `id` which returns the dataset class name.
54
+ transforms : Sequence[Transform]
55
+ The transforms to be applied to the data.
56
+ size : int
57
+ The size of the dataset.
48
58
  """
49
59
 
50
60
  _resources = [
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
  __all__ = []
4
4
 
5
5
  from dataclasses import dataclass
6
- from typing import Any, Generic, Protocol, TypedDict, TypeVar
6
+ from typing import Any, Generic, TypedDict, TypeVar
7
7
 
8
8
  from torch.utils.data import Dataset
9
9
  from typing_extensions import NotRequired, Required
@@ -46,7 +46,3 @@ class SegmentationTarget(Generic[_TArray]):
46
46
 
47
47
 
48
48
  class SegmentationDataset(AnnotatedDataset[tuple[_TArray, SegmentationTarget[_TArray], dict[str, Any]]]): ...
49
-
50
-
51
- class Transform(Generic[_TArray], Protocol):
52
- def __call__(self, data: _TArray, /) -> _TArray: ...