dataeval 0.76.1__py3-none-any.whl → 0.81.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. dataeval/__init__.py +3 -3
  2. dataeval/{output.py → _output.py} +14 -0
  3. dataeval/config.py +77 -0
  4. dataeval/detectors/__init__.py +1 -1
  5. dataeval/detectors/drift/__init__.py +6 -6
  6. dataeval/detectors/drift/{base.py → _base.py} +41 -30
  7. dataeval/detectors/drift/{cvm.py → _cvm.py} +21 -28
  8. dataeval/detectors/drift/{ks.py → _ks.py} +20 -26
  9. dataeval/detectors/drift/{mmd.py → _mmd.py} +33 -19
  10. dataeval/detectors/drift/{torch.py → _torch.py} +2 -1
  11. dataeval/detectors/drift/{uncertainty.py → _uncertainty.py} +23 -7
  12. dataeval/detectors/drift/updates.py +1 -1
  13. dataeval/detectors/linters/__init__.py +0 -3
  14. dataeval/detectors/linters/duplicates.py +17 -8
  15. dataeval/detectors/linters/outliers.py +23 -14
  16. dataeval/detectors/ood/ae.py +29 -8
  17. dataeval/detectors/ood/base.py +5 -4
  18. dataeval/detectors/ood/metadata_ks_compare.py +1 -1
  19. dataeval/detectors/ood/mixin.py +20 -5
  20. dataeval/detectors/ood/output.py +1 -1
  21. dataeval/detectors/ood/vae.py +73 -0
  22. dataeval/metadata/__init__.py +5 -0
  23. dataeval/metadata/_ood.py +238 -0
  24. dataeval/metrics/__init__.py +1 -1
  25. dataeval/metrics/bias/__init__.py +5 -4
  26. dataeval/metrics/bias/{balance.py → _balance.py} +67 -17
  27. dataeval/metrics/bias/{coverage.py → _coverage.py} +41 -35
  28. dataeval/metrics/bias/{diversity.py → _diversity.py} +17 -12
  29. dataeval/metrics/bias/{parity.py → _parity.py} +89 -61
  30. dataeval/metrics/estimators/__init__.py +14 -4
  31. dataeval/metrics/estimators/{ber.py → _ber.py} +42 -11
  32. dataeval/metrics/estimators/_clusterer.py +104 -0
  33. dataeval/metrics/estimators/{divergence.py → _divergence.py} +18 -13
  34. dataeval/metrics/estimators/{uap.py → _uap.py} +4 -4
  35. dataeval/metrics/stats/__init__.py +7 -7
  36. dataeval/metrics/stats/{base.py → _base.py} +52 -16
  37. dataeval/metrics/stats/{boxratiostats.py → _boxratiostats.py} +6 -9
  38. dataeval/metrics/stats/{datasetstats.py → _datasetstats.py} +10 -14
  39. dataeval/metrics/stats/{dimensionstats.py → _dimensionstats.py} +6 -5
  40. dataeval/metrics/stats/{hashstats.py → _hashstats.py} +6 -6
  41. dataeval/metrics/stats/{labelstats.py → _labelstats.py} +4 -4
  42. dataeval/metrics/stats/{pixelstats.py → _pixelstats.py} +5 -4
  43. dataeval/metrics/stats/{visualstats.py → _visualstats.py} +9 -8
  44. dataeval/typing.py +54 -0
  45. dataeval/utils/__init__.py +2 -2
  46. dataeval/utils/_array.py +169 -0
  47. dataeval/utils/_bin.py +199 -0
  48. dataeval/utils/_clusterer.py +144 -0
  49. dataeval/utils/_fast_mst.py +189 -0
  50. dataeval/utils/{image.py → _image.py} +6 -4
  51. dataeval/utils/_method.py +18 -0
  52. dataeval/utils/{shared.py → _mst.py} +3 -65
  53. dataeval/utils/{plot.py → _plot.py} +4 -4
  54. dataeval/utils/data/__init__.py +22 -0
  55. dataeval/utils/data/_embeddings.py +105 -0
  56. dataeval/utils/data/_images.py +65 -0
  57. dataeval/utils/data/_metadata.py +352 -0
  58. dataeval/utils/data/_selection.py +119 -0
  59. dataeval/utils/{dataset/split.py → data/_split.py} +13 -14
  60. dataeval/utils/data/_targets.py +73 -0
  61. dataeval/utils/data/_types.py +58 -0
  62. dataeval/utils/data/collate.py +103 -0
  63. dataeval/utils/data/datasets/__init__.py +17 -0
  64. dataeval/utils/data/datasets/_base.py +254 -0
  65. dataeval/utils/data/datasets/_cifar10.py +134 -0
  66. dataeval/utils/data/datasets/_fileio.py +168 -0
  67. dataeval/utils/data/datasets/_milco.py +153 -0
  68. dataeval/utils/data/datasets/_mixin.py +56 -0
  69. dataeval/utils/data/datasets/_mnist.py +183 -0
  70. dataeval/utils/data/datasets/_ships.py +123 -0
  71. dataeval/utils/data/datasets/_voc.py +352 -0
  72. dataeval/utils/data/selections/__init__.py +15 -0
  73. dataeval/utils/data/selections/_classfilter.py +60 -0
  74. dataeval/utils/data/selections/_indices.py +26 -0
  75. dataeval/utils/data/selections/_limit.py +26 -0
  76. dataeval/utils/data/selections/_reverse.py +18 -0
  77. dataeval/utils/data/selections/_shuffle.py +29 -0
  78. dataeval/utils/metadata.py +51 -376
  79. dataeval/utils/torch/{gmm.py → _gmm.py} +4 -2
  80. dataeval/utils/torch/{internal.py → _internal.py} +21 -51
  81. dataeval/utils/torch/models.py +43 -2
  82. dataeval/workflows/sufficiency.py +10 -9
  83. {dataeval-0.76.1.dist-info → dataeval-0.81.0.dist-info}/METADATA +4 -1
  84. dataeval-0.81.0.dist-info/RECORD +94 -0
  85. dataeval/detectors/linters/clusterer.py +0 -512
  86. dataeval/detectors/linters/merged_stats.py +0 -49
  87. dataeval/detectors/ood/metadata_least_likely.py +0 -119
  88. dataeval/interop.py +0 -69
  89. dataeval/utils/dataset/__init__.py +0 -7
  90. dataeval/utils/dataset/datasets.py +0 -412
  91. dataeval/utils/dataset/read.py +0 -63
  92. dataeval-0.76.1.dist-info/RECORD +0 -67
  93. /dataeval/{log.py → _log.py} +0 -0
  94. /dataeval/utils/torch/{blocks.py → _blocks.py} +0 -0
  95. {dataeval-0.76.1.dist-info → dataeval-0.81.0.dist-info}/LICENSE.txt +0 -0
  96. {dataeval-0.76.1.dist-info → dataeval-0.81.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,189 @@
1
+ # Adapted from fast_hdbscan python module
2
+ # Original Authors: Leland McInnes <https://github.com/TutteInstitute/fast_hdbscan>
3
+ # Adapted for DataEval by Ryan Wood
4
+ # License: BSD 2-Clause
5
+
6
+ __all__ = []
7
+
8
+ import warnings
9
+
10
+ import numba
11
+ import numpy as np
12
+ from sklearn.neighbors import NearestNeighbors
13
+
14
+ with warnings.catch_warnings():
15
+ warnings.simplefilter("ignore", category=FutureWarning)
16
+ from fast_hdbscan.disjoint_set import ds_find, ds_rank_create
17
+
18
+
19
+ @numba.njit()
20
+ def _ds_union_by_rank(disjoint_set, point, nbr):
21
+ y = ds_find(disjoint_set, point)
22
+ x = ds_find(disjoint_set, nbr)
23
+
24
+ if x == y:
25
+ return 0
26
+
27
+ if disjoint_set.rank[x] < disjoint_set.rank[y]:
28
+ x, y = y, x
29
+
30
+ disjoint_set.parent[y] = x
31
+ if disjoint_set.rank[x] == disjoint_set.rank[y]:
32
+ disjoint_set.rank[x] += 1
33
+ return 1
34
+
35
+
36
+ @numba.njit(locals={"i": numba.types.uint32, "nbr": numba.types.uint32, "dist": numba.types.float32})
37
+ def _init_tree(n_neighbors, n_distance):
38
+ # Initial graph to hold tree connections
39
+ tree = np.zeros((n_neighbors.size - 1, 3), dtype=np.float32)
40
+ disjoint_set = ds_rank_create(n_neighbors.size)
41
+ cluster_points = np.empty(n_neighbors.size, dtype=np.uint32)
42
+
43
+ int_tree = 0
44
+ for i in range(n_neighbors.size):
45
+ nbr = n_neighbors[i]
46
+ connect = _ds_union_by_rank(disjoint_set, i, nbr)
47
+ if connect == 1:
48
+ dist = n_distance[i]
49
+ tree[int_tree] = (np.float32(i), np.float32(nbr), dist)
50
+ int_tree += 1
51
+
52
+ for i in range(cluster_points.size):
53
+ cluster_points[i] = ds_find(disjoint_set, i)
54
+
55
+ return tree, int_tree, disjoint_set, cluster_points
56
+
57
+
58
+ @numba.njit(locals={"i": numba.types.uint32, "nbr": numba.types.uint32})
59
+ def _update_tree_by_distance(tree, int_tree, disjoint_set, n_neighbors, n_distance):
60
+ cluster_points = np.empty(n_neighbors.size, dtype=np.uint32)
61
+ sort_dist = np.argsort(n_distance)
62
+ dist_sorted = n_distance[sort_dist]
63
+ nbrs_sorted = n_neighbors[sort_dist]
64
+ points = np.arange(n_neighbors.size)
65
+ point_sorted = points[sort_dist]
66
+
67
+ for i in range(n_neighbors.size):
68
+ point = point_sorted[i]
69
+ nbr = nbrs_sorted[i]
70
+ connect = _ds_union_by_rank(disjoint_set, point, nbr)
71
+ if connect == 1:
72
+ dist = dist_sorted[i]
73
+ tree[int_tree] = (np.float32(point), np.float32(nbr), dist)
74
+ int_tree += 1
75
+
76
+ for i in range(cluster_points.size):
77
+ cluster_points[i] = ds_find(disjoint_set, i)
78
+
79
+ return tree, int_tree, disjoint_set, cluster_points
80
+
81
+
82
+ @numba.njit(locals={"i": numba.types.uint32})
83
+ def _cluster_edges(tracker, last_idx, cluster_distances):
84
+ cluster_ids = np.unique(tracker)
85
+ edge_points = []
86
+ for idx in range(cluster_ids.size):
87
+ cluster_points = np.nonzero(tracker == cluster_ids[idx])[0]
88
+ cluster_size = cluster_points.size
89
+ cluster_mean = cluster_distances[: last_idx + 1, cluster_points].mean()
90
+ cluster_std = cluster_distances[: last_idx + 1, cluster_points].std()
91
+ threshold = cluster_mean + cluster_std
92
+ points_mean = np.empty_like(cluster_points, dtype=np.float32)
93
+ for i in range(cluster_size):
94
+ points_mean[i] = cluster_distances[: last_idx + 1, cluster_points[i]].mean()
95
+ pts_to_add = cluster_points[np.nonzero(points_mean > threshold)[0]]
96
+ threshold = int(cluster_size * 0.01) if np.floor(np.log10(cluster_size)) > 2 else int(cluster_size * 0.1)
97
+ threshold = max(10, threshold)
98
+ if pts_to_add.size > threshold:
99
+ edge_points.append(pts_to_add)
100
+ else:
101
+ edge_points.append(cluster_points)
102
+ return edge_points
103
+
104
+
105
+ def _compute_nn(dataA, dataB, k):
106
+ distances, neighbors = NearestNeighbors(n_neighbors=k + 1, algorithm="brute").fit(dataA).kneighbors(dataB)
107
+ neighbors = np.array(neighbors[:, 1 : k + 1], dtype=np.int32)
108
+ distances = np.array(distances[:, 1 : k + 1], dtype=np.float32)
109
+ return neighbors, distances
110
+
111
+
112
+ def _calculate_cluster_neighbors(data, groups, point_array):
113
+ """Rerun nearest neighbor based on clusters"""
114
+ cluster_neighbors = np.zeros(point_array.size, dtype=np.uint32)
115
+ cluster_nbr_distances = np.full(point_array.size, np.inf, dtype=np.float32)
116
+
117
+ for i in range(len(groups)):
118
+ selectionA = groups[i]
119
+ groupA = data[selectionA]
120
+ selectionB = np.concatenate([arr for j, arr in enumerate(groups) if j != i])
121
+ groupB = data[selectionB]
122
+ new_neighbors, new_distances = _compute_nn(groupB, groupA, 2)
123
+ cluster_neighbors[selectionA] = selectionB[new_neighbors[:, 1]]
124
+ cluster_nbr_distances[selectionA] = new_distances[:, 1]
125
+
126
+ return cluster_neighbors, cluster_nbr_distances
127
+
128
+
129
+ def minimum_spanning_tree(data, neighbors, distances):
130
+ # Transpose arrays to get number of samples along a row
131
+ k_neighbors = neighbors.T.astype(np.uint32).copy()
132
+ k_distances = distances.T.astype(np.float32).copy()
133
+
134
+ # Create cluster merging tracker
135
+ merge_tracker = np.full((k_neighbors.shape[0] + 1, k_neighbors.shape[1]), -1, dtype=np.int32)
136
+
137
+ # Initialize tree
138
+ tree, int_tree, tree_disjoint_set, merge_tracker[0] = _init_tree(k_neighbors[0], k_distances[0])
139
+
140
+ # Loop through all of the neighbors, updating the tree
141
+ last_idx = 0
142
+ for i in range(1, k_neighbors.shape[0]):
143
+ tree, int_tree, tree_disjoint_set, merge_tracker[i] = _update_tree_by_distance(
144
+ tree, int_tree, tree_disjoint_set, k_neighbors[i], k_distances[i]
145
+ )
146
+ last_idx = i
147
+ if (merge_tracker[i] == merge_tracker[i - 1]).all():
148
+ last_idx -= 1
149
+ break
150
+
151
+ # Identify final clusters
152
+ cluster_ids = np.unique(merge_tracker[last_idx])
153
+ if cluster_ids.size > 1:
154
+ # Determining the edge points
155
+ edge_points = _cluster_edges(merge_tracker[last_idx], last_idx, k_distances)
156
+
157
+ # Run nearest neighbor again between clusters to reach single cluster
158
+ additional_neighbors, additional_distances = _calculate_cluster_neighbors(
159
+ data, edge_points, merge_tracker[last_idx]
160
+ )
161
+
162
+ # Update clusters
163
+ last_idx += 1
164
+ tree, int_tree, tree_disjoint_set, merge_tracker[last_idx] = _update_tree_by_distance(
165
+ tree, int_tree, tree_disjoint_set, additional_neighbors, additional_distances
166
+ )
167
+
168
+ return tree
169
+
170
+
171
+ def calculate_neighbor_distances(data: np.ndarray, k: int = 10):
172
+ # Have the potential to add in other distance calculations - supported calculations:
173
+ # https://github.com/lmcinnes/pynndescent/blob/master/pynndescent/pynndescent_.py#L524
174
+ try:
175
+ from pynndescent import NNDescent
176
+
177
+ max_descent = 30 if k <= 20 else k + 16
178
+ index = NNDescent(
179
+ data,
180
+ metric="euclidean",
181
+ n_neighbors=max_descent,
182
+ )
183
+ neighbors, distances = index.neighbor_graph
184
+ except ImportError:
185
+ distances, neighbors = NearestNeighbors(n_neighbors=k + 1, algorithm="brute").fit(data).kneighbors(data)
186
+
187
+ neighbors = np.array(neighbors[:, 1 : k + 1], dtype=np.int32)
188
+ distances = np.array(distances[:, 1 : k + 1], dtype=np.float32)
189
+ return neighbors, distances
@@ -2,17 +2,19 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
- from typing import Any, NamedTuple
5
+ from dataclasses import dataclass
6
+ from typing import Any
6
7
 
7
8
  import numpy as np
8
- from numpy.typing import ArrayLike, NDArray
9
+ from numpy.typing import NDArray
9
10
  from scipy.signal import convolve2d
10
11
 
11
12
  EDGE_KERNEL = np.array([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]], dtype=np.int8)
12
13
  BIT_DEPTH = (1, 8, 12, 16, 32)
13
14
 
14
15
 
15
- class BitDepth(NamedTuple):
16
+ @dataclass
17
+ class BitDepth:
16
18
  depth: int
17
19
  pmin: float | int
18
20
  pmax: float | int
@@ -59,7 +61,7 @@ def normalize_image_shape(image: NDArray[Any]) -> NDArray[Any]:
59
61
  raise ValueError("Images must have 2 or more dimensions.")
60
62
 
61
63
 
62
- def edge_filter(image: ArrayLike, offset: float = 0.5) -> NDArray[np.uint8]:
64
+ def edge_filter(image: NDArray[Any], offset: float = 0.5) -> NDArray[np.uint8]:
63
65
  """
64
66
  Returns the image filtered using a 3x3 edge detection kernel:
65
67
  [[ -1, -1, -1 ],
@@ -0,0 +1,18 @@
1
+ from __future__ import annotations
2
+
3
+ import sys
4
+ from typing import Callable, TypeVar
5
+
6
+ if sys.version_info >= (3, 10):
7
+ from typing import ParamSpec
8
+ else:
9
+ from typing_extensions import ParamSpec
10
+
11
+ P = ParamSpec("P")
12
+ R = TypeVar("R")
13
+
14
+
15
+ def get_method(method_map: dict[str, Callable[P, R]], method: str) -> Callable[P, R]:
16
+ if method not in method_map:
17
+ raise ValueError(f"Specified method {method} is not a valid method: {method_map}.")
18
+ return method_map[method]
@@ -2,53 +2,17 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
- import sys
6
- from typing import Any, Callable, Literal, TypeVar
5
+ from typing import Any, Literal
7
6
 
8
- import numpy as np
9
- from numpy.typing import ArrayLike, NDArray
7
+ from numpy.typing import NDArray
10
8
  from scipy.sparse import csr_matrix
11
9
  from scipy.sparse.csgraph import minimum_spanning_tree as mst
12
10
  from scipy.spatial.distance import pdist, squareform
13
11
  from sklearn.neighbors import NearestNeighbors
14
12
 
15
- if sys.version_info >= (3, 10):
16
- from typing import ParamSpec
17
- else:
18
- from typing_extensions import ParamSpec
19
-
20
- from dataeval.interop import as_numpy
13
+ from dataeval.utils._array import flatten
21
14
 
22
15
  EPSILON = 1e-5
23
- HASH_SIZE = 8
24
- MAX_FACTOR = 4
25
-
26
-
27
- P = ParamSpec("P")
28
- R = TypeVar("R")
29
-
30
-
31
- def get_method(method_map: dict[str, Callable[P, R]], method: str) -> Callable[P, R]:
32
- if method not in method_map:
33
- raise ValueError(f"Specified method {method} is not a valid method: {method_map}.")
34
- return method_map[method]
35
-
36
-
37
- def flatten(array: ArrayLike) -> NDArray[Any]:
38
- """
39
- Flattens input array from (N, ... ) to (N, -1) where all samples N have all data in their last dimension
40
-
41
- Parameters
42
- ----------
43
- X : NDArray, shape - (N, ... )
44
- Input array
45
-
46
- Returns
47
- -------
48
- NDArray, shape - (N, -1)
49
- """
50
- nparr = as_numpy(array)
51
- return nparr.reshape((nparr.shape[0], -1))
52
16
 
53
17
 
54
18
  def minimum_spanning_tree(X: NDArray[Any]) -> Any:
@@ -73,32 +37,6 @@ def minimum_spanning_tree(X: NDArray[Any]) -> Any:
73
37
  return mst(eudist_csr)
74
38
 
75
39
 
76
- def get_classes_counts(labels: NDArray[np.int_]) -> tuple[int, int]:
77
- """
78
- Returns the classes and counts of from an array of labels
79
-
80
- Parameters
81
- ----------
82
- label : NDArray
83
- Numpy labels array
84
-
85
- Returns
86
- -------
87
- Classes and counts
88
-
89
- Raises
90
- ------
91
- ValueError
92
- If the number of unique classes is less than 2
93
- """
94
- classes, counts = np.unique(labels, return_counts=True)
95
- M = len(classes)
96
- if M < 2:
97
- raise ValueError("Label vector contains less than 2 classes!")
98
- N = int(np.sum(counts))
99
- return M, N
100
-
101
-
102
40
  def compute_neighbors(
103
41
  A: NDArray[Any],
104
42
  B: NDArray[Any],
@@ -6,9 +6,9 @@ import contextlib
6
6
  from typing import Any
7
7
 
8
8
  import numpy as np
9
- from numpy.typing import ArrayLike
10
9
 
11
- from dataeval.interop import to_numpy
10
+ from dataeval.typing import ArrayLike
11
+ from dataeval.utils._array import to_numpy
12
12
 
13
13
  with contextlib.suppress(ImportError):
14
14
  from matplotlib.figure import Figure
@@ -171,7 +171,7 @@ def histogram_plot(
171
171
  data_dict,
172
172
  ):
173
173
  # Plot the histogram for the chosen metric
174
- ax.hist(data_dict[metric], bins=20, log=log)
174
+ ax.hist(data_dict[metric].astype(np.float64), bins=20, log=log)
175
175
 
176
176
  # Add labels to the histogram
177
177
  ax.set_title(metric)
@@ -229,7 +229,7 @@ def channel_histogram_plot(
229
229
  # Plot the histogram for the chosen metric
230
230
  data = data_dict[metric][ch_mask].reshape(-1, max_channels)
231
231
  ax.hist(
232
- data,
232
+ data.astype(np.float64),
233
233
  bins=20,
234
234
  density=True,
235
235
  log=log,
@@ -0,0 +1,22 @@
1
+ """Provides utility functions for interacting with Computer Vision datasets."""
2
+
3
+ __all__ = [
4
+ "collate",
5
+ "datasets",
6
+ "Embeddings",
7
+ "Images",
8
+ "Metadata",
9
+ "Select",
10
+ "SplitDatasetOutput",
11
+ "Targets",
12
+ "split_dataset",
13
+ ]
14
+
15
+ from dataeval.utils.data._embeddings import Embeddings
16
+ from dataeval.utils.data._images import Images
17
+ from dataeval.utils.data._metadata import Metadata
18
+ from dataeval.utils.data._selection import Select
19
+ from dataeval.utils.data._split import SplitDatasetOutput, split_dataset
20
+ from dataeval.utils.data._targets import Targets
21
+
22
+ from . import collate, datasets
@@ -0,0 +1,105 @@
1
+ from __future__ import annotations
2
+
3
+ __all__ = []
4
+
5
+ import math
6
+ from typing import Any, Iterator, Sequence
7
+
8
+ import torch
9
+ from torch.utils.data import DataLoader, Subset
10
+ from tqdm import tqdm
11
+
12
+ from dataeval.config import get_device
13
+ from dataeval.typing import TArray
14
+ from dataeval.utils.data._types import Dataset
15
+ from dataeval.utils.torch.models import SupportsEncode
16
+
17
+
18
+ class Embeddings:
19
+ """
20
+ Collection of image embeddings from a dataset.
21
+
22
+ Embeddings are accessed by index or slice and are only loaded on-demand.
23
+
24
+ Parameters
25
+ ----------
26
+ dataset : ImageClassificationDataset or ObjectDetectionDataset
27
+ Dataset to access original images from.
28
+ batch_size : int, optional
29
+ Batch size to use when encoding images.
30
+ model : torch.nn.Module, optional
31
+ Model to use for encoding images.
32
+ device : torch.device, optional
33
+ Device to use for encoding images.
34
+ verbose : bool, optional
35
+ Whether to print progress bar when encoding images.
36
+ """
37
+
38
+ device: torch.device
39
+ batch_size: int
40
+ verbose: bool
41
+
42
+ def __init__(
43
+ self,
44
+ dataset: Dataset[TArray, Any],
45
+ batch_size: int,
46
+ indices: Sequence[int] | None = None,
47
+ model: torch.nn.Module | None = None,
48
+ device: torch.device | str | None = None,
49
+ verbose: bool = False,
50
+ ) -> None:
51
+ self.device = get_device(device)
52
+ self.batch_size = batch_size
53
+ self.verbose = verbose
54
+
55
+ self._dataset = dataset
56
+ self._indices = indices if indices is not None else range(len(dataset))
57
+ model = torch.nn.Flatten() if model is None else model
58
+ self._model = model.to(self.device).eval()
59
+ self._encoder = model.encode if isinstance(model, SupportsEncode) else model
60
+ self._collate_fn = lambda datum: [torch.as_tensor(i) for i, _, _ in datum]
61
+
62
+ def to_tensor(self) -> torch.Tensor:
63
+ """
64
+ Converts entire dataset to embeddings.
65
+
66
+ Warning
67
+ -------
68
+ Will process the entire dataset in batches and return
69
+ embeddings as a single Tensor in memory.
70
+
71
+ Returns
72
+ -------
73
+ torch.Tensor
74
+ """
75
+ return self[:]
76
+
77
+ # Reduce overhead cost by not tracking tensor gradients
78
+ @torch.no_grad
79
+ def _batch(self, indices: Sequence[int]) -> Iterator[torch.Tensor]:
80
+ # manual batching
81
+ dataloader = DataLoader(Subset(self._dataset, indices), batch_size=self.batch_size, collate_fn=self._collate_fn)
82
+ for i, images in (
83
+ tqdm(enumerate(dataloader), total=math.ceil(len(indices) / self.batch_size), desc="Batch processing")
84
+ if self.verbose
85
+ else enumerate(dataloader)
86
+ ):
87
+ embeddings = self._encoder(torch.stack(images).to(self.device))
88
+ yield embeddings
89
+
90
+ def __getitem__(self, key: int | slice | list[int]) -> torch.Tensor:
91
+ if isinstance(key, list):
92
+ return torch.vstack(list(self._batch(key))).to(self.device)
93
+ if isinstance(key, slice):
94
+ return torch.vstack(list(self._batch(range(len(self._dataset))[key]))).to(self.device)
95
+ elif isinstance(key, int):
96
+ return self._encoder(torch.as_tensor(self._dataset[key][0]).to(self.device))
97
+ raise TypeError("Invalid argument type.")
98
+
99
+ def __iter__(self) -> Iterator[torch.Tensor]:
100
+ # process in batches while yielding individual embeddings
101
+ for batch in self._batch(range(len(self._dataset))):
102
+ yield from batch
103
+
104
+ def __len__(self) -> int:
105
+ return len(self._dataset)
@@ -0,0 +1,65 @@
1
+ from __future__ import annotations
2
+
3
+ __all__ = []
4
+
5
+ from typing import Any, Generic, Iterator, Sequence, overload
6
+
7
+ from dataeval.typing import TArray
8
+ from dataeval.utils.data._types import Dataset
9
+
10
+
11
+ class Images(Generic[TArray]):
12
+ """
13
+ Collection of image data from a dataset.
14
+
15
+ Images are accessed by index or slice and are only loaded on-demand.
16
+
17
+ Parameters
18
+ ----------
19
+ dataset : ImageClassificationDataset or ObjectDetectionDataset
20
+ Dataset to access images from.
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ dataset: Dataset[TArray, Any],
26
+ ) -> None:
27
+ self._dataset = dataset
28
+
29
+ def to_list(self) -> Sequence[TArray]:
30
+ """
31
+ Converts entire dataset to a sequence of images.
32
+
33
+ Warning
34
+ -------
35
+ Will load the entire dataset and return the images as a
36
+ single sequence of images in memory.
37
+
38
+ Returns
39
+ -------
40
+ list[TArray]
41
+ """
42
+ return self[:]
43
+
44
+ @overload
45
+ def __getitem__(self, key: slice | list[int]) -> Sequence[TArray]: ...
46
+
47
+ @overload
48
+ def __getitem__(self, key: int) -> TArray: ...
49
+
50
+ def __getitem__(self, key: int | slice | list[int]) -> Sequence[TArray] | TArray:
51
+ if isinstance(key, list):
52
+ return [self._dataset[i][0] for i in key]
53
+ if isinstance(key, slice):
54
+ indices = list(range(len(self._dataset))[key])
55
+ return [self._dataset[i][0] for i in indices]
56
+ elif isinstance(key, int):
57
+ return self._dataset[key][0]
58
+ raise TypeError("Invalid argument type.")
59
+
60
+ def __iter__(self) -> Iterator[TArray]:
61
+ for i in range(len(self._dataset)):
62
+ yield self._dataset[i][0]
63
+
64
+ def __len__(self) -> int:
65
+ return len(self._dataset)