dataeval 0.76.1__py3-none-any.whl → 0.82.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (113) hide show
  1. dataeval/__init__.py +3 -3
  2. dataeval/config.py +77 -0
  3. dataeval/detectors/__init__.py +1 -1
  4. dataeval/detectors/drift/__init__.py +6 -6
  5. dataeval/detectors/drift/{base.py → _base.py} +40 -85
  6. dataeval/detectors/drift/{cvm.py → _cvm.py} +21 -28
  7. dataeval/detectors/drift/{ks.py → _ks.py} +20 -26
  8. dataeval/detectors/drift/{mmd.py → _mmd.py} +31 -43
  9. dataeval/detectors/drift/{torch.py → _torch.py} +2 -1
  10. dataeval/detectors/drift/{uncertainty.py → _uncertainty.py} +24 -7
  11. dataeval/detectors/drift/updates.py +20 -3
  12. dataeval/detectors/linters/__init__.py +3 -5
  13. dataeval/detectors/linters/duplicates.py +13 -36
  14. dataeval/detectors/linters/outliers.py +23 -148
  15. dataeval/detectors/ood/__init__.py +1 -1
  16. dataeval/detectors/ood/ae.py +30 -9
  17. dataeval/detectors/ood/base.py +5 -4
  18. dataeval/detectors/ood/mixin.py +21 -7
  19. dataeval/detectors/ood/vae.py +73 -0
  20. dataeval/metadata/__init__.py +6 -0
  21. dataeval/metadata/_distance.py +167 -0
  22. dataeval/metadata/_ood.py +217 -0
  23. dataeval/metadata/_utils.py +44 -0
  24. dataeval/metrics/__init__.py +1 -1
  25. dataeval/metrics/bias/__init__.py +6 -4
  26. dataeval/metrics/bias/{balance.py → _balance.py} +15 -101
  27. dataeval/metrics/bias/_coverage.py +98 -0
  28. dataeval/metrics/bias/{diversity.py → _diversity.py} +18 -111
  29. dataeval/metrics/bias/{parity.py → _parity.py} +39 -77
  30. dataeval/metrics/estimators/__init__.py +15 -4
  31. dataeval/metrics/estimators/{ber.py → _ber.py} +42 -29
  32. dataeval/metrics/estimators/_clusterer.py +44 -0
  33. dataeval/metrics/estimators/{divergence.py → _divergence.py} +18 -30
  34. dataeval/metrics/estimators/{uap.py → _uap.py} +4 -18
  35. dataeval/metrics/stats/__init__.py +16 -13
  36. dataeval/metrics/stats/{base.py → _base.py} +82 -133
  37. dataeval/metrics/stats/{boxratiostats.py → _boxratiostats.py} +15 -18
  38. dataeval/metrics/stats/_dimensionstats.py +75 -0
  39. dataeval/metrics/stats/{hashstats.py → _hashstats.py} +21 -37
  40. dataeval/metrics/stats/_imagestats.py +94 -0
  41. dataeval/metrics/stats/_labelstats.py +131 -0
  42. dataeval/metrics/stats/{pixelstats.py → _pixelstats.py} +19 -50
  43. dataeval/metrics/stats/{visualstats.py → _visualstats.py} +23 -54
  44. dataeval/outputs/__init__.py +53 -0
  45. dataeval/{output.py → outputs/_base.py} +55 -25
  46. dataeval/outputs/_bias.py +381 -0
  47. dataeval/outputs/_drift.py +83 -0
  48. dataeval/outputs/_estimators.py +114 -0
  49. dataeval/outputs/_linters.py +184 -0
  50. dataeval/{detectors/ood/output.py → outputs/_ood.py} +22 -22
  51. dataeval/outputs/_stats.py +387 -0
  52. dataeval/outputs/_utils.py +44 -0
  53. dataeval/outputs/_workflows.py +364 -0
  54. dataeval/typing.py +234 -0
  55. dataeval/utils/__init__.py +2 -2
  56. dataeval/utils/_array.py +169 -0
  57. dataeval/utils/_bin.py +199 -0
  58. dataeval/utils/_clusterer.py +144 -0
  59. dataeval/utils/_fast_mst.py +189 -0
  60. dataeval/utils/{image.py → _image.py} +6 -4
  61. dataeval/utils/_method.py +14 -0
  62. dataeval/utils/{shared.py → _mst.py} +3 -65
  63. dataeval/utils/{plot.py → _plot.py} +6 -6
  64. dataeval/utils/data/__init__.py +26 -0
  65. dataeval/utils/data/_dataset.py +217 -0
  66. dataeval/utils/data/_embeddings.py +104 -0
  67. dataeval/utils/data/_images.py +68 -0
  68. dataeval/utils/data/_metadata.py +360 -0
  69. dataeval/utils/data/_selection.py +126 -0
  70. dataeval/utils/{dataset/split.py → data/_split.py} +12 -38
  71. dataeval/utils/data/_targets.py +85 -0
  72. dataeval/utils/data/collate.py +103 -0
  73. dataeval/utils/data/datasets/__init__.py +17 -0
  74. dataeval/utils/data/datasets/_base.py +254 -0
  75. dataeval/utils/data/datasets/_cifar10.py +134 -0
  76. dataeval/utils/data/datasets/_fileio.py +168 -0
  77. dataeval/utils/data/datasets/_milco.py +153 -0
  78. dataeval/utils/data/datasets/_mixin.py +56 -0
  79. dataeval/utils/data/datasets/_mnist.py +183 -0
  80. dataeval/utils/data/datasets/_ships.py +123 -0
  81. dataeval/utils/data/datasets/_types.py +52 -0
  82. dataeval/utils/data/datasets/_voc.py +352 -0
  83. dataeval/utils/data/selections/__init__.py +15 -0
  84. dataeval/utils/data/selections/_classfilter.py +57 -0
  85. dataeval/utils/data/selections/_indices.py +26 -0
  86. dataeval/utils/data/selections/_limit.py +26 -0
  87. dataeval/utils/data/selections/_reverse.py +18 -0
  88. dataeval/utils/data/selections/_shuffle.py +29 -0
  89. dataeval/utils/metadata.py +51 -376
  90. dataeval/utils/torch/{gmm.py → _gmm.py} +4 -2
  91. dataeval/utils/torch/{internal.py → _internal.py} +21 -51
  92. dataeval/utils/torch/models.py +43 -2
  93. dataeval/workflows/__init__.py +2 -1
  94. dataeval/workflows/sufficiency.py +11 -346
  95. {dataeval-0.76.1.dist-info → dataeval-0.82.0.dist-info}/METADATA +5 -2
  96. dataeval-0.82.0.dist-info/RECORD +104 -0
  97. dataeval/detectors/linters/clusterer.py +0 -512
  98. dataeval/detectors/linters/merged_stats.py +0 -49
  99. dataeval/detectors/ood/metadata_ks_compare.py +0 -129
  100. dataeval/detectors/ood/metadata_least_likely.py +0 -119
  101. dataeval/interop.py +0 -69
  102. dataeval/metrics/bias/coverage.py +0 -194
  103. dataeval/metrics/stats/datasetstats.py +0 -202
  104. dataeval/metrics/stats/dimensionstats.py +0 -115
  105. dataeval/metrics/stats/labelstats.py +0 -210
  106. dataeval/utils/dataset/__init__.py +0 -7
  107. dataeval/utils/dataset/datasets.py +0 -412
  108. dataeval/utils/dataset/read.py +0 -63
  109. dataeval-0.76.1.dist-info/RECORD +0 -67
  110. /dataeval/{log.py → _log.py} +0 -0
  111. /dataeval/utils/torch/{blocks.py → _blocks.py} +0 -0
  112. {dataeval-0.76.1.dist-info → dataeval-0.82.0.dist-info}/LICENSE.txt +0 -0
  113. {dataeval-0.76.1.dist-info → dataeval-0.82.0.dist-info}/WHEEL +0 -0
@@ -2,17 +2,19 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
- from typing import Any, NamedTuple
5
+ from dataclasses import dataclass
6
+ from typing import Any
6
7
 
7
8
  import numpy as np
8
- from numpy.typing import ArrayLike, NDArray
9
+ from numpy.typing import NDArray
9
10
  from scipy.signal import convolve2d
10
11
 
11
12
  EDGE_KERNEL = np.array([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]], dtype=np.int8)
12
13
  BIT_DEPTH = (1, 8, 12, 16, 32)
13
14
 
14
15
 
15
- class BitDepth(NamedTuple):
16
+ @dataclass
17
+ class BitDepth:
16
18
  depth: int
17
19
  pmin: float | int
18
20
  pmax: float | int
@@ -59,7 +61,7 @@ def normalize_image_shape(image: NDArray[Any]) -> NDArray[Any]:
59
61
  raise ValueError("Images must have 2 or more dimensions.")
60
62
 
61
63
 
62
- def edge_filter(image: ArrayLike, offset: float = 0.5) -> NDArray[np.uint8]:
64
+ def edge_filter(image: NDArray[Any], offset: float = 0.5) -> NDArray[np.uint8]:
63
65
  """
64
66
  Returns the image filtered using a 3x3 edge detection kernel:
65
67
  [[ -1, -1, -1 ],
@@ -0,0 +1,14 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Callable, TypeVar
4
+
5
+ from typing_extensions import ParamSpec
6
+
7
+ P = ParamSpec("P")
8
+ R = TypeVar("R")
9
+
10
+
11
+ def get_method(method_map: dict[str, Callable[P, R]], method: str) -> Callable[P, R]:
12
+ if method not in method_map:
13
+ raise ValueError(f"Specified method {method} is not a valid method: {method_map}.")
14
+ return method_map[method]
@@ -2,53 +2,17 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
- import sys
6
- from typing import Any, Callable, Literal, TypeVar
5
+ from typing import Any, Literal
7
6
 
8
- import numpy as np
9
- from numpy.typing import ArrayLike, NDArray
7
+ from numpy.typing import NDArray
10
8
  from scipy.sparse import csr_matrix
11
9
  from scipy.sparse.csgraph import minimum_spanning_tree as mst
12
10
  from scipy.spatial.distance import pdist, squareform
13
11
  from sklearn.neighbors import NearestNeighbors
14
12
 
15
- if sys.version_info >= (3, 10):
16
- from typing import ParamSpec
17
- else:
18
- from typing_extensions import ParamSpec
19
-
20
- from dataeval.interop import as_numpy
13
+ from dataeval.utils._array import flatten
21
14
 
22
15
  EPSILON = 1e-5
23
- HASH_SIZE = 8
24
- MAX_FACTOR = 4
25
-
26
-
27
- P = ParamSpec("P")
28
- R = TypeVar("R")
29
-
30
-
31
- def get_method(method_map: dict[str, Callable[P, R]], method: str) -> Callable[P, R]:
32
- if method not in method_map:
33
- raise ValueError(f"Specified method {method} is not a valid method: {method_map}.")
34
- return method_map[method]
35
-
36
-
37
- def flatten(array: ArrayLike) -> NDArray[Any]:
38
- """
39
- Flattens input array from (N, ... ) to (N, -1) where all samples N have all data in their last dimension
40
-
41
- Parameters
42
- ----------
43
- X : NDArray, shape - (N, ... )
44
- Input array
45
-
46
- Returns
47
- -------
48
- NDArray, shape - (N, -1)
49
- """
50
- nparr = as_numpy(array)
51
- return nparr.reshape((nparr.shape[0], -1))
52
16
 
53
17
 
54
18
  def minimum_spanning_tree(X: NDArray[Any]) -> Any:
@@ -73,32 +37,6 @@ def minimum_spanning_tree(X: NDArray[Any]) -> Any:
73
37
  return mst(eudist_csr)
74
38
 
75
39
 
76
- def get_classes_counts(labels: NDArray[np.int_]) -> tuple[int, int]:
77
- """
78
- Returns the classes and counts of from an array of labels
79
-
80
- Parameters
81
- ----------
82
- label : NDArray
83
- Numpy labels array
84
-
85
- Returns
86
- -------
87
- Classes and counts
88
-
89
- Raises
90
- ------
91
- ValueError
92
- If the number of unique classes is less than 2
93
- """
94
- classes, counts = np.unique(labels, return_counts=True)
95
- M = len(classes)
96
- if M < 2:
97
- raise ValueError("Label vector contains less than 2 classes!")
98
- N = int(np.sum(counts))
99
- return M, N
100
-
101
-
102
40
  def compute_neighbors(
103
41
  A: NDArray[Any],
104
42
  B: NDArray[Any],
@@ -6,9 +6,9 @@ import contextlib
6
6
  from typing import Any
7
7
 
8
8
  import numpy as np
9
- from numpy.typing import ArrayLike
10
9
 
11
- from dataeval.interop import to_numpy
10
+ from dataeval.typing import ArrayLike
11
+ from dataeval.utils._array import to_numpy
12
12
 
13
13
  with contextlib.suppress(ImportError):
14
14
  from matplotlib.figure import Figure
@@ -49,8 +49,8 @@ def heatmap(
49
49
  from matplotlib.ticker import FuncFormatter
50
50
 
51
51
  np_data = to_numpy(data)
52
- rows = row_labels if isinstance(row_labels, list) else to_numpy(row_labels)
53
- cols = col_labels if isinstance(col_labels, list) else to_numpy(col_labels)
52
+ rows: list[str] = [str(n) for n in to_numpy(row_labels)]
53
+ cols: list[str] = [str(n) for n in to_numpy(col_labels)]
54
54
 
55
55
  fig, ax = plt.subplots(figsize=(10, 10))
56
56
 
@@ -171,7 +171,7 @@ def histogram_plot(
171
171
  data_dict,
172
172
  ):
173
173
  # Plot the histogram for the chosen metric
174
- ax.hist(data_dict[metric], bins=20, log=log)
174
+ ax.hist(data_dict[metric].astype(np.float64), bins=20, log=log)
175
175
 
176
176
  # Add labels to the histogram
177
177
  ax.set_title(metric)
@@ -229,7 +229,7 @@ def channel_histogram_plot(
229
229
  # Plot the histogram for the chosen metric
230
230
  data = data_dict[metric][ch_mask].reshape(-1, max_channels)
231
231
  ax.hist(
232
- data,
232
+ data.astype(np.float64),
233
233
  bins=20,
234
234
  density=True,
235
235
  log=log,
@@ -0,0 +1,26 @@
1
+ """Provides utility functions for interacting with Computer Vision datasets."""
2
+
3
+ __all__ = [
4
+ "collate",
5
+ "datasets",
6
+ "Embeddings",
7
+ "Images",
8
+ "Metadata",
9
+ "Select",
10
+ "SplitDatasetOutput",
11
+ "Targets",
12
+ "split_dataset",
13
+ "to_image_classification_dataset",
14
+ "to_object_detection_dataset",
15
+ ]
16
+
17
+ from dataeval.outputs._utils import SplitDatasetOutput
18
+ from dataeval.utils.data._dataset import to_image_classification_dataset, to_object_detection_dataset
19
+ from dataeval.utils.data._embeddings import Embeddings
20
+ from dataeval.utils.data._images import Images
21
+ from dataeval.utils.data._metadata import Metadata
22
+ from dataeval.utils.data._selection import Select
23
+ from dataeval.utils.data._split import split_dataset
24
+ from dataeval.utils.data._targets import Targets
25
+
26
+ from . import collate, datasets
@@ -0,0 +1,217 @@
1
+ from __future__ import annotations
2
+
3
+ __all__ = []
4
+
5
+ from typing import Any, Generic, Iterable, Literal, Sequence, TypeVar
6
+
7
+ from dataeval.typing import (
8
+ Array,
9
+ ArrayLike,
10
+ DatasetMetadata,
11
+ ImageClassificationDataset,
12
+ ObjectDetectionDataset,
13
+ )
14
+ from dataeval.utils._array import as_numpy
15
+
16
+
17
+ def _validate_data(
18
+ datum_type: Literal["ic", "od"],
19
+ images: Array | Sequence[Array],
20
+ labels: Sequence[int] | Sequence[Sequence[int]],
21
+ bboxes: Sequence[Sequence[Sequence[float]]] | None,
22
+ metadata: Sequence[dict[str, Any]] | None,
23
+ ) -> None:
24
+ # Validate inputs
25
+ dataset_len = len(images)
26
+
27
+ if not isinstance(images, (Sequence, Array)) or len(images[0].shape) != 3:
28
+ raise ValueError("Images must be a sequence or array of 3 dimensional arrays (H, W, C).")
29
+ if len(labels) != dataset_len:
30
+ raise ValueError(f"Number of labels ({len(labels)}) does not match number of images ({dataset_len}).")
31
+ if bboxes is not None and len(bboxes) != dataset_len:
32
+ raise ValueError(f"Number of bboxes ({len(bboxes)}) does not match number of images ({dataset_len}).")
33
+ if metadata is not None and len(metadata) != dataset_len:
34
+ raise ValueError(f"Number of metadata ({len(metadata)}) does not match number of images ({dataset_len}).")
35
+
36
+ if datum_type == "ic":
37
+ if not isinstance(labels, Sequence) or not isinstance(labels[0], int):
38
+ raise TypeError("Labels must be a sequence of integers for image classification.")
39
+ elif datum_type == "od":
40
+ if not isinstance(labels, Sequence) or not isinstance(labels[0], Sequence) or not isinstance(labels[0][0], int):
41
+ raise TypeError("Labels must be a sequence of sequences of integers for object detection.")
42
+ if (
43
+ bboxes is None
44
+ or not isinstance(bboxes, (Sequence, Array))
45
+ or not isinstance(bboxes[0], (Sequence, Array))
46
+ or not isinstance(bboxes[0][0], (Sequence, Array))
47
+ or not len(bboxes[0][0]) == 4
48
+ ):
49
+ raise TypeError("Boxes must be a sequence of sequences of (x0, y0, x1, y1) for object detection.")
50
+
51
+
52
+ def _find_max(arr: ArrayLike) -> Any:
53
+ if isinstance(arr[0], (Iterable, Sequence, Array)):
54
+ return max([_find_max(x) for x in arr]) # type: ignore
55
+ else:
56
+ return max(arr)
57
+
58
+
59
+ _TLabels = TypeVar("_TLabels", Sequence[int], Sequence[Sequence[int]])
60
+
61
+
62
+ class BaseAnnotatedDataset(Generic[_TLabels]):
63
+ def __init__(
64
+ self,
65
+ datum_type: Literal["ic", "od"],
66
+ images: Array | Sequence[Array],
67
+ labels: _TLabels,
68
+ metadata: Sequence[dict[str, Any]] | None,
69
+ classes: Sequence[str] | None,
70
+ name: str | None = None,
71
+ ) -> None:
72
+ self._classes = classes if classes is not None else [str(i) for i in range(_find_max(labels) + 1)]
73
+ self._index2label = dict(enumerate(self._classes))
74
+ self._images = images
75
+ self._labels = labels
76
+ self._metadata = metadata
77
+ self._id = name or f"{len(self._images)}_image_{len(self._index2label)}_class_{datum_type}_dataset"
78
+
79
+ @property
80
+ def metadata(self) -> DatasetMetadata:
81
+ return DatasetMetadata(id=self._id, index2label=self._index2label)
82
+
83
+ def __len__(self) -> int:
84
+ return len(self._images)
85
+
86
+
87
+ class CustomImageClassificationDataset(BaseAnnotatedDataset[Sequence[int]], ImageClassificationDataset):
88
+ def __init__(
89
+ self,
90
+ images: Array | Sequence[Array],
91
+ labels: Sequence[int],
92
+ metadata: Sequence[dict[str, Any]] | None,
93
+ classes: Sequence[str] | None,
94
+ name: str | None = None,
95
+ ) -> None:
96
+ super().__init__("ic", images, labels, metadata, classes)
97
+ if name is not None:
98
+ self.__name__ = name
99
+ self.__class__.__name__ = name
100
+ self.__class__.__qualname__ = name
101
+
102
+ def __getitem__(self, idx: int, /) -> tuple[Array, Array, dict[str, Any]]:
103
+ one_hot = [0.0] * len(self._index2label)
104
+ one_hot[self._labels[idx]] = 1.0
105
+ return (
106
+ self._images[idx],
107
+ as_numpy(one_hot),
108
+ self._metadata[idx] if self._metadata is not None else {},
109
+ )
110
+
111
+
112
+ class CustomObjectDetectionDataset(BaseAnnotatedDataset[Sequence[Sequence[int]]], ObjectDetectionDataset):
113
+ class ObjectDetectionTarget:
114
+ def __init__(self, labels: Sequence[int], bboxes: Sequence[Sequence[float]]) -> None:
115
+ self._labels = labels
116
+ self._bboxes = bboxes
117
+ self._scores = [1.0] * len(labels)
118
+
119
+ @property
120
+ def labels(self) -> Sequence[int]:
121
+ return self._labels
122
+
123
+ @property
124
+ def boxes(self) -> Sequence[Sequence[float]]:
125
+ return self._bboxes
126
+
127
+ @property
128
+ def scores(self) -> Sequence[float]:
129
+ return self._scores
130
+
131
+ def __init__(
132
+ self,
133
+ images: Array | Sequence[Array],
134
+ labels: Sequence[Sequence[int]],
135
+ bboxes: Sequence[Sequence[Sequence[float]]],
136
+ metadata: Sequence[dict[str, Any]] | None,
137
+ classes: Sequence[str] | None,
138
+ name: str | None = None,
139
+ ) -> None:
140
+ super().__init__("od", images, labels, metadata, classes)
141
+ if name is not None:
142
+ self.__name__ = name
143
+ self.__class__.__name__ = name
144
+ self.__class__.__qualname__ = name
145
+ self._bboxes = bboxes
146
+
147
+ @property
148
+ def metadata(self) -> DatasetMetadata:
149
+ return DatasetMetadata(id=self._id, index2label=self._index2label)
150
+
151
+ def __getitem__(self, idx: int, /) -> tuple[Array, ObjectDetectionTarget, dict[str, Any]]:
152
+ return (
153
+ self._images[idx],
154
+ self.ObjectDetectionTarget(self._labels[idx], self._bboxes[idx]),
155
+ self._metadata[idx] if self._metadata is not None else {},
156
+ )
157
+
158
+
159
+ def to_image_classification_dataset(
160
+ images: Array | Sequence[Array],
161
+ labels: Sequence[int],
162
+ metadata: Sequence[dict[str, Any]] | None,
163
+ classes: Sequence[str] | None,
164
+ name: str | None = None,
165
+ ) -> ImageClassificationDataset:
166
+ """
167
+ Helper function to create custom ImageClassificationDataset classes.
168
+
169
+ Parameters
170
+ ----------
171
+ images : Array | Sequence[Array]
172
+ The images to use in the dataset.
173
+ labels : Sequence[int]
174
+ The labels to use in the dataset.
175
+ metadata : Sequence[dict[str, Any]] | None
176
+ The metadata to use in the dataset.
177
+ classes : Sequence[str] | None
178
+ The classes to use in the dataset.
179
+
180
+ Returns
181
+ -------
182
+ ImageClassificationDataset
183
+ """
184
+ _validate_data("ic", images, labels, None, metadata)
185
+ return CustomImageClassificationDataset(images, labels, metadata, classes, name)
186
+
187
+
188
+ def to_object_detection_dataset(
189
+ images: Array | Sequence[Array],
190
+ labels: Sequence[Sequence[int]],
191
+ bboxes: Sequence[Sequence[Sequence[float]]],
192
+ metadata: Sequence[dict[str, Any]] | None,
193
+ classes: Sequence[str] | None,
194
+ name: str | None = None,
195
+ ) -> ObjectDetectionDataset:
196
+ """
197
+ Helper function to create custom ObjectDetectionDataset classes.
198
+
199
+ Parameters
200
+ ----------
201
+ images : Array | Sequence[Array]
202
+ The images to use in the dataset.
203
+ labels : Sequence[Sequence[int]]
204
+ The labels to use in the dataset.
205
+ bboxes : Sequence[Sequence[Sequence[float]]]
206
+ The bounding boxes (x0,y0,x1,y0) to use in the dataset.
207
+ metadata : Sequence[dict[str, Any]] | None
208
+ The metadata to use in the dataset.
209
+ classes : Sequence[str] | None
210
+ The classes to use in the dataset.
211
+
212
+ Returns
213
+ -------
214
+ ObjectDetectionDataset
215
+ """
216
+ _validate_data("od", images, labels, bboxes, metadata)
217
+ return CustomObjectDetectionDataset(images, labels, bboxes, metadata, classes, name)
@@ -0,0 +1,104 @@
1
+ from __future__ import annotations
2
+
3
+ __all__ = []
4
+
5
+ import math
6
+ from typing import Any, Iterator, Sequence
7
+
8
+ import torch
9
+ from torch.utils.data import DataLoader, Subset
10
+ from tqdm import tqdm
11
+
12
+ from dataeval.config import get_device
13
+ from dataeval.typing import Array, Dataset
14
+ from dataeval.utils.torch.models import SupportsEncode
15
+
16
+
17
+ class Embeddings:
18
+ """
19
+ Collection of image embeddings from a dataset.
20
+
21
+ Embeddings are accessed by index or slice and are only loaded on-demand.
22
+
23
+ Parameters
24
+ ----------
25
+ dataset : ImageClassificationDataset or ObjectDetectionDataset
26
+ Dataset to access original images from.
27
+ batch_size : int, optional
28
+ Batch size to use when encoding images.
29
+ model : torch.nn.Module, optional
30
+ Model to use for encoding images.
31
+ device : torch.device, optional
32
+ Device to use for encoding images.
33
+ verbose : bool, optional
34
+ Whether to print progress bar when encoding images.
35
+ """
36
+
37
+ device: torch.device
38
+ batch_size: int
39
+ verbose: bool
40
+
41
+ def __init__(
42
+ self,
43
+ dataset: Dataset[tuple[Array, Any, Any]],
44
+ batch_size: int,
45
+ indices: Sequence[int] | None = None,
46
+ model: torch.nn.Module | None = None,
47
+ device: torch.device | str | None = None,
48
+ verbose: bool = False,
49
+ ) -> None:
50
+ self.device = get_device(device)
51
+ self.batch_size = batch_size
52
+ self.verbose = verbose
53
+
54
+ self._dataset = dataset
55
+ self._indices = indices if indices is not None else range(len(dataset))
56
+ model = torch.nn.Flatten() if model is None else model
57
+ self._model = model.to(self.device).eval()
58
+ self._encoder = model.encode if isinstance(model, SupportsEncode) else model
59
+ self._collate_fn = lambda datum: [torch.as_tensor(i) for i, _, _ in datum]
60
+
61
+ def to_tensor(self) -> torch.Tensor:
62
+ """
63
+ Converts entire dataset to embeddings.
64
+
65
+ Warning
66
+ -------
67
+ Will process the entire dataset in batches and return
68
+ embeddings as a single Tensor in memory.
69
+
70
+ Returns
71
+ -------
72
+ torch.Tensor
73
+ """
74
+ return self[:]
75
+
76
+ # Reduce overhead cost by not tracking tensor gradients
77
+ @torch.no_grad
78
+ def _batch(self, indices: Sequence[int]) -> Iterator[torch.Tensor]:
79
+ # manual batching
80
+ dataloader = DataLoader(Subset(self._dataset, indices), batch_size=self.batch_size, collate_fn=self._collate_fn) # type: ignore
81
+ for i, images in (
82
+ tqdm(enumerate(dataloader), total=math.ceil(len(indices) / self.batch_size), desc="Batch processing")
83
+ if self.verbose
84
+ else enumerate(dataloader)
85
+ ):
86
+ embeddings = self._encoder(torch.stack(images).to(self.device))
87
+ yield embeddings
88
+
89
+ def __getitem__(self, key: int | slice | list[int], /) -> torch.Tensor:
90
+ if isinstance(key, list):
91
+ return torch.vstack(list(self._batch(key))).to(self.device)
92
+ if isinstance(key, slice):
93
+ return torch.vstack(list(self._batch(range(len(self._dataset))[key]))).to(self.device)
94
+ elif isinstance(key, int):
95
+ return self._encoder(torch.as_tensor(self._dataset[key][0]).to(self.device))
96
+ raise TypeError("Invalid argument type.")
97
+
98
+ def __iter__(self) -> Iterator[torch.Tensor]:
99
+ # process in batches while yielding individual embeddings
100
+ for batch in self._batch(range(len(self._dataset))):
101
+ yield from batch
102
+
103
+ def __len__(self) -> int:
104
+ return len(self._dataset)
@@ -0,0 +1,68 @@
1
+ from __future__ import annotations
2
+
3
+ __all__ = []
4
+
5
+ from typing import Any, Generic, Iterator, Sequence, TypeVar, cast, overload
6
+
7
+ from dataeval.typing import Dataset
8
+
9
+ T = TypeVar("T")
10
+
11
+
12
+ class Images(Generic[T]):
13
+ """
14
+ Collection of image data from a dataset.
15
+
16
+ Images are accessed by index or slice and are only loaded on-demand.
17
+
18
+ Parameters
19
+ ----------
20
+ dataset : Dataset[tuple[T, ...]] or Dataset[T]
21
+ Dataset to access images from.
22
+ """
23
+
24
+ def __init__(self, dataset: Dataset[tuple[T, Any, Any] | T]) -> None:
25
+ self._is_tuple_datum = isinstance(dataset[0], tuple)
26
+ self._dataset = dataset
27
+
28
+ def to_list(self) -> Sequence[T]:
29
+ """
30
+ Converts entire dataset to a sequence of images.
31
+
32
+ Warning
33
+ -------
34
+ Will load the entire dataset and return the images as a
35
+ single sequence of images in memory.
36
+
37
+ Returns
38
+ -------
39
+ list[T]
40
+ """
41
+ return self[:]
42
+
43
+ @overload
44
+ def __getitem__(self, key: int, /) -> T: ...
45
+ @overload
46
+ def __getitem__(self, key: slice, /) -> Sequence[T]: ...
47
+
48
+ def __getitem__(self, key: int | slice, /) -> Sequence[T] | T:
49
+ if self._is_tuple_datum:
50
+ dataset = cast(Dataset[tuple[T, Any, Any]], self._dataset)
51
+ if isinstance(key, slice):
52
+ return [dataset[k][0] for k in range(len(self._dataset))[key]]
53
+ elif isinstance(key, int):
54
+ return dataset[key][0]
55
+ else:
56
+ dataset = cast(Dataset[T], self._dataset)
57
+ if isinstance(key, slice):
58
+ return [dataset[k] for k in range(len(self._dataset))[key]]
59
+ elif isinstance(key, int):
60
+ return dataset[key]
61
+ raise TypeError(f"Key must be integers or slices, not {type(key)}")
62
+
63
+ def __iter__(self) -> Iterator[T]:
64
+ for i in range(len(self._dataset)):
65
+ yield self[i]
66
+
67
+ def __len__(self) -> int:
68
+ return len(self._dataset)