dataeval 0.83.0__py3-none-any.whl → 0.84.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,14 +3,14 @@ from __future__ import annotations
3
3
  __all__ = []
4
4
 
5
5
  import math
6
- from typing import Any, Iterator, Sequence
6
+ from typing import Any, Iterator, Sequence, cast
7
7
 
8
8
  import torch
9
9
  from torch.utils.data import DataLoader, Subset
10
10
  from tqdm import tqdm
11
11
 
12
12
  from dataeval.config import DeviceLike, get_device
13
- from dataeval.typing import Array, Dataset
13
+ from dataeval.typing import Array, Dataset, Transform
14
14
  from dataeval.utils.torch.models import SupportsEncode
15
15
 
16
16
 
@@ -26,11 +26,15 @@ class Embeddings:
26
26
  Dataset to access original images from.
27
27
  batch_size : int
28
28
  Batch size to use when encoding images.
29
+ transforms : Transform or Sequence[Transform] or None, default None
30
+ Transforms to apply to images before encoding.
29
31
  model : torch.nn.Module or None, default None
30
32
  Model to use for encoding images.
31
33
  device : DeviceLike or None, default None
32
34
  The hardware device to use if specified, otherwise uses the DataEval
33
35
  default or torch default.
36
+ cache : bool, default False
37
+ Whether to cache the embeddings in memory.
34
38
  verbose : bool, default False
35
39
  Whether to print progress bar when encoding images.
36
40
  """
@@ -43,19 +47,27 @@ class Embeddings:
43
47
  self,
44
48
  dataset: Dataset[tuple[Array, Any, Any]],
45
49
  batch_size: int,
50
+ transforms: Transform[torch.Tensor] | Sequence[Transform[torch.Tensor]] | None = None,
46
51
  model: torch.nn.Module | None = None,
47
52
  device: DeviceLike | None = None,
53
+ cache: bool = False,
48
54
  verbose: bool = False,
49
55
  ) -> None:
50
56
  self.device = get_device(device)
51
- self.batch_size = batch_size
57
+ self.cache = cache
58
+ self.batch_size = batch_size if batch_size > 0 else 1
52
59
  self.verbose = verbose
53
60
 
54
61
  self._dataset = dataset
62
+ self._length = len(dataset)
55
63
  model = torch.nn.Flatten() if model is None else model
64
+ self._transforms = [transforms] if isinstance(transforms, Transform) else transforms
56
65
  self._model = model.to(self.device).eval()
57
66
  self._encoder = model.encode if isinstance(model, SupportsEncode) else model
58
67
  self._collate_fn = lambda datum: [torch.as_tensor(i) for i, _, _ in datum]
68
+ self._cached_idx = set()
69
+ self._embeddings: torch.Tensor = torch.empty(())
70
+ self._shallow: bool = False
59
71
 
60
72
  def to_tensor(self, indices: Sequence[int] | None = None) -> torch.Tensor:
61
73
  """
@@ -79,30 +91,96 @@ class Embeddings:
79
91
  else:
80
92
  return self[:]
81
93
 
82
- # Reduce overhead cost by not tracking tensor gradients
83
- @torch.no_grad
94
+ @classmethod
95
+ def from_array(cls, array: Array, device: DeviceLike | None = None) -> Embeddings:
96
+ """
97
+ Instantiates a shallow Embeddings object using an array.
98
+
99
+ Parameters
100
+ ----------
101
+ array : Array
102
+ The array to convert to embeddings.
103
+ device : DeviceLike or None, default None
104
+ The hardware device to use if specified, otherwise uses the DataEval
105
+ default or torch default.
106
+
107
+ Returns
108
+ -------
109
+ Embeddings
110
+
111
+ Example
112
+ -------
113
+ >>> import numpy as np
114
+ >>> from dataeval.utils.data._embeddings import Embeddings
115
+ >>> array = np.random.randn(100, 3, 224, 224)
116
+ >>> embeddings = Embeddings.from_array(array)
117
+ >>> print(embeddings.to_tensor().shape)
118
+ torch.Size([100, 3, 224, 224])
119
+ """
120
+ embeddings = Embeddings([], 0, None, None, device, True, False)
121
+ embeddings._length = len(array)
122
+ embeddings._cached_idx = set(range(len(array)))
123
+ embeddings._embeddings = torch.as_tensor(array).to(get_device(device))
124
+ embeddings._shallow = True
125
+ return embeddings
126
+
127
+ def _encode(self, images: list[torch.Tensor]) -> torch.Tensor:
128
+ if self._transforms:
129
+ images = [transform(image) for transform in self._transforms for image in images]
130
+ return self._encoder(torch.stack(images).to(self.device))
131
+
132
+ @torch.no_grad() # Reduce overhead cost by not tracking tensor gradients
84
133
  def _batch(self, indices: Sequence[int]) -> Iterator[torch.Tensor]:
85
- # manual batching
86
- dataloader = DataLoader(Subset(self._dataset, indices), batch_size=self.batch_size, collate_fn=self._collate_fn) # type: ignore
87
- for i, images in (
88
- tqdm(enumerate(dataloader), total=math.ceil(len(indices) / self.batch_size), desc="Batch processing")
89
- if self.verbose
90
- else enumerate(dataloader)
91
- ):
92
- embeddings = self._encoder(torch.stack(images).to(self.device))
93
- yield embeddings
134
+ dataset = cast(torch.utils.data.Dataset[tuple[Array, Any, Any]], self._dataset)
135
+ total_batches = math.ceil(len(indices) / self.batch_size)
136
+
137
+ # If not caching, process all indices normally
138
+ if not self.cache:
139
+ for images in tqdm(
140
+ DataLoader(Subset(dataset, indices), self.batch_size, collate_fn=self._collate_fn),
141
+ total=total_batches,
142
+ desc="Batch embedding",
143
+ disable=not self.verbose,
144
+ ):
145
+ yield self._encode(images)
146
+ return
147
+
148
+ # If caching, process each batch of indices at a time, preserving original order
149
+ for i in tqdm(range(0, len(indices), self.batch_size), desc="Batch embedding", disable=not self.verbose):
150
+ batch = indices[i : i + self.batch_size]
151
+ uncached = [idx for idx in batch if idx not in self._cached_idx]
152
+
153
+ if uncached:
154
+ # Process uncached indices as as single batch
155
+ for images in DataLoader(Subset(dataset, uncached), len(uncached), collate_fn=self._collate_fn):
156
+ embeddings = self._encode(images)
157
+
158
+ if not self._embeddings.shape:
159
+ full_shape = (len(self._dataset), *embeddings.shape[1:])
160
+ self._embeddings = torch.empty(full_shape, dtype=embeddings.dtype, device=self.device)
161
+
162
+ self._embeddings[uncached] = embeddings
163
+ self._cached_idx.update(uncached)
164
+
165
+ yield self._embeddings[batch]
94
166
 
95
167
  def __getitem__(self, key: int | slice, /) -> torch.Tensor:
96
- if isinstance(key, slice):
97
- return torch.vstack(list(self._batch(range(len(self._dataset))[key]))).to(self.device)
98
- elif isinstance(key, int):
99
- return self._encoder(torch.as_tensor(self._dataset[key][0]).to(self.device))
100
- raise TypeError("Invalid argument type.")
168
+ if not isinstance(key, slice) and not hasattr(key, "__int__"):
169
+ raise TypeError("Invalid argument type.")
170
+
171
+ if self._shallow:
172
+ if not self._embeddings.shape:
173
+ raise ValueError("Embeddings not initialized.")
174
+ return self._embeddings[key]
175
+
176
+ indices = list(range(len(self._dataset))[key]) if isinstance(key, slice) else [int(key)]
177
+ result = torch.vstack(list(self._batch(indices))).to(self.device)
178
+ return result.squeeze(0) if len(indices) == 1 else result
101
179
 
102
180
  def __iter__(self) -> Iterator[torch.Tensor]:
103
181
  # process in batches while yielding individual embeddings
104
- for batch in self._batch(range(len(self._dataset))):
182
+ for batch in self._batch(range(self._length)):
105
183
  yield from batch
106
184
 
107
185
  def __len__(self) -> int:
108
- return len(self._dataset)
186
+ return self._length
@@ -2,11 +2,15 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
- from typing import Any, Generic, Iterator, Sequence, TypeVar, cast, overload
5
+ from typing import TYPE_CHECKING, Any, Generic, Iterator, Sequence, TypeVar, cast, overload
6
6
 
7
- from dataeval.typing import Dataset
7
+ from dataeval.typing import Array, Dataset
8
+ from dataeval.utils._array import as_numpy, channels_first_to_last
8
9
 
9
- T = TypeVar("T")
10
+ if TYPE_CHECKING:
11
+ from matplotlib.figure import Figure
12
+
13
+ T = TypeVar("T", bound=Array)
10
14
 
11
15
 
12
16
  class Images(Generic[T]):
@@ -21,7 +25,10 @@ class Images(Generic[T]):
21
25
  Dataset to access images from.
22
26
  """
23
27
 
24
- def __init__(self, dataset: Dataset[tuple[T, Any, Any] | T]) -> None:
28
+ def __init__(
29
+ self,
30
+ dataset: Dataset[tuple[T, Any, Any] | T],
31
+ ) -> None:
25
32
  self._is_tuple_datum = isinstance(dataset[0], tuple)
26
33
  self._dataset = dataset
27
34
 
@@ -40,25 +47,41 @@ class Images(Generic[T]):
40
47
  """
41
48
  return self[:]
42
49
 
50
+ def plot(
51
+ self,
52
+ indices: Sequence[int],
53
+ images_per_row: int = 3,
54
+ figsize: tuple[int, int] = (10, 10),
55
+ ) -> Figure:
56
+ import matplotlib.pyplot as plt
57
+
58
+ num_images = len(indices)
59
+ num_rows = (num_images + images_per_row - 1) // images_per_row
60
+ fig, axes = plt.subplots(num_rows, images_per_row, figsize=figsize)
61
+ for i, ax in enumerate(axes.flatten()):
62
+ image = channels_first_to_last(as_numpy(self[i]))
63
+ ax.imshow(image)
64
+ ax.axis("off")
65
+ plt.tight_layout()
66
+ return fig
67
+
43
68
  @overload
44
69
  def __getitem__(self, key: int, /) -> T: ...
45
70
  @overload
46
71
  def __getitem__(self, key: slice, /) -> Sequence[T]: ...
47
72
 
48
73
  def __getitem__(self, key: int | slice, /) -> Sequence[T] | T:
74
+ if isinstance(key, slice):
75
+ return [self._get_image(k) for k in range(len(self._dataset))[key]]
76
+ elif hasattr(key, "__int__"):
77
+ return self._get_image(int(key))
78
+ raise TypeError(f"Key must be integers or slices, not {type(key)}")
79
+
80
+ def _get_image(self, index: int) -> T:
49
81
  if self._is_tuple_datum:
50
- dataset = cast(Dataset[tuple[T, Any, Any]], self._dataset)
51
- if isinstance(key, slice):
52
- return [dataset[k][0] for k in range(len(self._dataset))[key]]
53
- elif isinstance(key, int):
54
- return dataset[key][0]
82
+ return cast(Dataset[tuple[T, Any, Any]], self._dataset)[index][0]
55
83
  else:
56
- dataset = cast(Dataset[T], self._dataset)
57
- if isinstance(key, slice):
58
- return [dataset[k] for k in range(len(self._dataset))[key]]
59
- elif isinstance(key, int):
60
- return dataset[key]
61
- raise TypeError(f"Key must be integers or slices, not {type(key)}")
84
+ return cast(Dataset[T], self._dataset)[index]
62
85
 
63
86
  def __iter__(self) -> Iterator[T]:
64
87
  for i in range(len(self._dataset)):
@@ -5,7 +5,7 @@ __all__ = []
5
5
  from enum import IntEnum
6
6
  from typing import Generic, Iterator, Sequence, TypeVar
7
7
 
8
- from dataeval.typing import AnnotatedDataset, DatasetMetadata, Transform
8
+ from dataeval.typing import AnnotatedDataset, DatasetMetadata
9
9
 
10
10
  _TDatum = TypeVar("_TDatum")
11
11
 
@@ -35,8 +35,6 @@ class Select(AnnotatedDataset[_TDatum]):
35
35
  The dataset to wrap.
36
36
  selections : Selection or list[Selection], optional
37
37
  The selection criteria to apply to the dataset.
38
- transforms : Transform or list[Transform], optional
39
- The transforms to apply to the dataset.
40
38
 
41
39
  Examples
42
40
  --------
@@ -70,16 +68,12 @@ class Select(AnnotatedDataset[_TDatum]):
70
68
  self,
71
69
  dataset: AnnotatedDataset[_TDatum],
72
70
  selections: Selection[_TDatum] | Sequence[Selection[_TDatum]] | None = None,
73
- transforms: Transform[_TDatum] | Sequence[Transform[_TDatum]] | None = None,
74
71
  ) -> None:
75
72
  self.__dict__.update(dataset.__dict__)
76
73
  self._dataset = dataset
77
74
  self._size_limit = len(dataset)
78
75
  self._selection = list(range(self._size_limit))
79
76
  self._selections = self._sort(selections)
80
- self._transforms = (
81
- [] if transforms is None else [transforms] if isinstance(transforms, Transform) else transforms
82
- )
83
77
 
84
78
  # Ensure metadata is populated correctly as DatasetMetadata TypedDict
85
79
  _metadata = getattr(dataset, "metadata", {})
@@ -98,8 +92,7 @@ class Select(AnnotatedDataset[_TDatum]):
98
92
  title = f"{self.__class__.__name__} Dataset"
99
93
  sep = "-" * len(title)
100
94
  selections = f"Selections: [{', '.join([str(s) for s in self._selections])}]"
101
- transforms = f"Transforms: [{', '.join([str(t) for t in self._transforms])}]"
102
- return f"{title}\n{sep}{nt}{selections}{nt}{transforms}{nt}Selected Size: {len(self)}\n\n{self._dataset}"
95
+ return f"{title}\n{sep}{nt}{selections}{nt}Selected Size: {len(self)}\n\n{self._dataset}"
103
96
 
104
97
  def _sort(self, selections: Selection[_TDatum] | Sequence[Selection[_TDatum]] | None) -> list[Selection]:
105
98
  if not selections:
@@ -117,13 +110,8 @@ class Select(AnnotatedDataset[_TDatum]):
117
110
  selection(self)
118
111
  self._selection = self._selection[: self._size_limit]
119
112
 
120
- def _transform(self, datum: _TDatum) -> _TDatum:
121
- for t in self._transforms:
122
- datum = t(datum)
123
- return datum
124
-
125
113
  def __getitem__(self, index: int) -> _TDatum:
126
- return self._transform(self._dataset[self._selection[index]])
114
+ return self._dataset[self._selection[index]]
127
115
 
128
116
  def __iter__(self) -> Iterator[_TDatum]:
129
117
  for i in range(len(self)):
@@ -2,19 +2,22 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
+ import logging
5
6
  import warnings
6
- from typing import Any, Iterator, Protocol
7
+ from typing import Any, Iterator, Protocol, Sequence
7
8
 
8
9
  import numpy as np
9
10
  from numpy.typing import NDArray
10
- from sklearn.cluster import KMeans
11
- from sklearn.metrics import silhouette_score
12
11
  from sklearn.model_selection import GroupKFold, KFold, StratifiedGroupKFold, StratifiedKFold
13
12
  from sklearn.utils.multiclass import type_of_target
14
13
 
15
- from dataeval.config import get_seed
14
+ from dataeval.config import EPSILON
16
15
  from dataeval.outputs._base import set_metadata
17
16
  from dataeval.outputs._utils import SplitDatasetOutput, TrainValSplit
17
+ from dataeval.typing import AnnotatedDataset
18
+ from dataeval.utils.data._metadata import Metadata
19
+
20
+ _logger = logging.getLogger(__name__)
18
21
 
19
22
 
20
23
  class KFoldSplitter(Protocol):
@@ -85,7 +88,7 @@ def calculate_validation_fraction(num_folds: int, test_frac: float, val_frac: fl
85
88
  return val_base * (1.0 / num_folds) * (1.0 - test_frac)
86
89
 
87
90
 
88
- def _validate_labels(labels: NDArray[np.intp], total_partitions: int) -> None:
91
+ def validate_labels(labels: NDArray[np.intp], total_partitions: int) -> None:
89
92
  """
90
93
  Check to make sure there is more input data than the total number of partitions requested
91
94
 
@@ -116,7 +119,7 @@ def _validate_labels(labels: NDArray[np.intp], total_partitions: int) -> None:
116
119
  raise ValueError("Detected continuous labels. Labels must be discrete for proper stratification")
117
120
 
118
121
 
119
- def is_stratifiable(labels: NDArray[np.intp], num_partitions: int) -> bool:
122
+ def validate_stratifiable(labels: NDArray[np.intp], num_partitions: int) -> None:
120
123
  """
121
124
  Check if the dataset can be stratified by class label over the given number of partitions
122
125
 
@@ -132,26 +135,23 @@ def is_stratifiable(labels: NDArray[np.intp], num_partitions: int) -> bool:
132
135
  bool
133
136
  True if dataset can be stratified else False
134
137
 
135
- Warns
136
- -----
137
- UserWarning
138
- Warns user if the dataset cannot be stratified due to the total number of [train, val, test]
138
+ Raises
139
+ ------
140
+ ValueError
141
+ If the dataset cannot be stratified due to the total number of [train, val, test]
139
142
  partitions exceeding the number of instances of the rarest class label.
140
143
  """
141
144
 
142
145
  # Get the minimum count of all labels
143
146
  lowest_label_count = np.unique(labels, return_counts=True)[1].min()
144
147
  if lowest_label_count < num_partitions:
145
- warnings.warn(
148
+ raise ValueError(
146
149
  f"Unable to stratify due to label frequency. The lowest label count ({lowest_label_count}) is fewer "
147
- f"than the total number of partitions ({num_partitions}) requested.",
148
- UserWarning,
150
+ f"than the total number of partitions ({num_partitions}) requested."
149
151
  )
150
- return False
151
- return True
152
152
 
153
153
 
154
- def is_groupable(group_ids: NDArray[np.intp], num_partitions: int) -> bool:
154
+ def validate_groupable(groups: NDArray[np.intp], num_partitions: int) -> None:
155
155
  """
156
156
  Warns user if the number of unique group_ids is incompatible with a grouped partition containing
157
157
  num_folds folds. If this is the case, returns groups=None, which tells the partitioner not to
@@ -159,7 +159,7 @@ def is_groupable(group_ids: NDArray[np.intp], num_partitions: int) -> bool:
159
159
 
160
160
  Parameters
161
161
  ----------
162
- group_ids : NDArray of ints
162
+ groups : NDArray of ints
163
163
  The id of the group each sample at the corresponding index belongs to
164
164
  num_partitions : int
165
165
  Total number of train, val, and test splits requested
@@ -169,60 +169,24 @@ def is_groupable(group_ids: NDArray[np.intp], num_partitions: int) -> bool:
169
169
  bool
170
170
  True if the dataset can be grouped by the given group ids else False
171
171
 
172
- Warns
173
- -----
174
- UserWarning
175
- Warns if there are fewer groups than the requested number of partitions plus one
172
+ Raises
173
+ ------
174
+ ValueError
175
+ If there are is only one unique group.
176
+ ValueError
177
+ If there are fewer groups than the requested number of partitions plus one
176
178
  """
177
179
 
178
- num_unique_groups = len(np.unique(group_ids))
180
+ num_unique_groups = len(np.unique(groups))
179
181
  # Cannot separate if only one group exists
180
182
  if num_unique_groups == 1:
181
- return False
183
+ raise ValueError(f"Unique groups ({num_unique_groups}) must be greater than 1.")
182
184
 
183
185
  if num_unique_groups < num_partitions:
184
- warnings.warn(
185
- f"Groups must be greater than num partitions. Got {num_unique_groups} and {num_partitions}. "
186
- "Reverting to ungrouped partitioning",
187
- UserWarning,
188
- )
189
- return False
190
- return True
191
-
192
-
193
- def bin_kmeans(array: NDArray[Any]) -> NDArray[np.intp]:
194
- """
195
- Find bins of continuous data by iteratively applying k-means clustering, and keeping the
196
- clustering with the highest silhouette score.
197
-
198
- Parameters
199
- ----------
200
- array : NDArray
201
- continuous data to bin
186
+ raise ValueError(f"Unique groups ({num_unique_groups}) must be greater than num partitions ({num_partitions}).")
202
187
 
203
- Returns
204
- -------
205
- NDArray[int]:
206
- bin numbers assigned by the kmeans best clusterer.
207
- """
208
188
 
209
- if array.ndim == 1:
210
- array = array.reshape([-1, 1])
211
- best_score = 0.60
212
- else:
213
- best_score = 0.50
214
- bin_index = np.zeros(len(array), dtype=np.intp)
215
- for k in range(2, 20):
216
- clusterer = KMeans(n_clusters=k, random_state=get_seed())
217
- cluster_labels = clusterer.fit_predict(array)
218
- score = silhouette_score(array, cluster_labels, sample_size=25_000, random_state=get_seed())
219
- if score > best_score:
220
- best_score = score
221
- bin_index = cluster_labels.astype(np.intp)
222
- return bin_index
223
-
224
-
225
- def get_group_ids(metadata: dict[str, Any], group_names: list[str], num_samples: int) -> NDArray[np.intp]:
189
+ def get_groups(metadata: Metadata, split_on: Sequence[str] | None) -> NDArray[np.intp] | None:
226
190
  """
227
191
  Returns individual group numbers based on a subset of metadata defined by groupnames
228
192
 
@@ -232,32 +196,20 @@ def get_group_ids(metadata: dict[str, Any], group_names: list[str], num_samples:
232
196
  dictionary containing all metadata
233
197
  groupnames : list
234
198
  which groups from the metadata dictionary to consider for dataset grouping
235
- num_samples : int
236
- number of labels. Used to ensure agreement between input data/labels and metadata entries.
237
-
238
- Raises
239
- ------
240
- IndexError
241
- raised if an entry in the metadata dictionary doesn't have the same length as num_samples
242
199
 
243
200
  Returns
244
201
  -------
245
202
  np.ndarray
246
203
  group identifiers from metadata
247
204
  """
248
- features2group = {k: np.array(v) for k, v in metadata.items() if k in group_names}
249
- if not features2group:
250
- return np.zeros(num_samples, dtype=np.intp)
251
- for name, feature in features2group.items():
252
- if len(feature) != num_samples:
253
- raise ValueError(
254
- f"Feature length does not match number of labels. Got {len(feature)} features and {num_samples} samples"
255
- )
256
-
257
- if type_of_target(feature) == "continuous":
258
- features2group[name] = bin_kmeans(feature)
259
- binned_features = np.stack(list(features2group.values()), axis=1)
260
- _, group_ids = np.unique(binned_features, axis=0, return_inverse=True)
205
+ # get only the factors that are present in the metadata
206
+ if split_on is None:
207
+ return None
208
+
209
+ split_set = set(split_on)
210
+ indices = [i for i, name in enumerate(metadata.discrete_factor_names) if name in split_set]
211
+ binned_features = metadata.discrete_data[:, indices]
212
+ group_ids = np.unique(binned_features, axis=0, return_inverse=True)[1]
261
213
  return group_ids
262
214
 
263
215
 
@@ -294,10 +246,18 @@ def make_splits(
294
246
  split_defs: list[TrainValSplit] = []
295
247
  n_labels = len(np.unique(labels))
296
248
  splitter = KFOLD_GROUP_STRATIFIED_MAP[(groups is not None, stratified)](n_folds)
249
+ _logger.log(logging.DEBUG, f"splitter={splitter.__class__.__name__}(n_splits={n_folds})")
297
250
  good = False
298
251
  attempts = 0
299
252
  while not good and attempts < 3:
300
253
  attempts += 1
254
+ _logger.log(
255
+ logging.DEBUG,
256
+ f"attempt={attempts}: splitter.split("
257
+ + f"index=arr(len={len(index)}, unique={np.unique(index)}), "
258
+ + f"labels=arr(len={len(index)}, unique={np.unique(index)}), "
259
+ + ("groups=None" if groups is None else f"groups=arr(len={len(groups)}, unique={np.unique(groups)}))"),
260
+ )
301
261
  splits = splitter.split(index, labels, groups)
302
262
  split_defs.clear()
303
263
  for train_idx, eval_idx in splits:
@@ -341,20 +301,20 @@ def find_best_split(
341
301
  counts = np.bincount(arr, minlength=minlength)
342
302
  return counts / np.sum(counts)
343
303
 
344
- def weight(arr: NDArray, class_freq: NDArray) -> np.float64:
345
- return np.sum(np.abs(freq(arr, len(class_freq)) - class_freq))
304
+ def weight(arr: NDArray, class_freq: NDArray) -> float:
305
+ return float(np.sum(np.abs(freq(arr, len(class_freq)) - class_freq)))
346
306
 
347
- def class_freq_diff(split: TrainValSplit) -> np.float64:
307
+ def class_freq_diff(split: TrainValSplit) -> float:
348
308
  class_freq = freq(labels)
349
309
  return weight(labels[split.train], class_freq) + weight(labels[split.val], class_freq)
350
310
 
351
- def split_ratio(split: TrainValSplit) -> np.float64:
352
- return np.float64(len(split.val) / (len(split.val) + len(split.train)))
311
+ def split_ratio(split: TrainValSplit) -> float:
312
+ return len(split.val) / (len(split.val) + len(split.train))
353
313
 
354
- def split_diff(split: TrainValSplit) -> np.float64:
314
+ def split_diff(split: TrainValSplit) -> float:
355
315
  return abs(split_frac - split_ratio(split))
356
316
 
357
- def split_inv_diff(split: TrainValSplit) -> np.float64:
317
+ def split_inv_diff(split: TrainValSplit) -> float:
358
318
  return abs(1 - split_frac - split_ratio(split))
359
319
 
360
320
  # Selects minimization function based on inputs
@@ -399,11 +359,12 @@ def single_split(
399
359
  Indices of data partitioned for training and evaluation
400
360
  """
401
361
 
402
- _, label_counts = np.unique(labels, return_counts=True)
403
- max_folds = label_counts.min()
404
- min_folds = np.unique(groups).shape[0] if groups is not None else 2
405
- divisor = split_frac + 1e-06 if split_frac <= 2 / 3 else 1 - split_frac - 1e-06
406
- n_folds = round(min(max(1 / divisor, min_folds), max_folds)) # Clips value between min_folds and max_folds
362
+ unique_groups = 2 if groups is None else len(np.unique(groups))
363
+ max_folds = min(min(np.unique(labels, return_counts=True)[1]), unique_groups) if stratified else unique_groups
364
+
365
+ divisor = split_frac if split_frac <= 2 / 3 else 1 - split_frac
366
+ n_folds = min(max(round(1 / (divisor + EPSILON)), 2), max_folds) # Clips value between 2 and max_folds
367
+ _logger.log(logging.DEBUG, f"n_folds={n_folds} clipped between[2, {max_folds}]")
407
368
 
408
369
  split_candidates = make_splits(index, labels, n_folds, groups, stratified)
409
370
  return find_best_split(labels, split_candidates, stratified, split_frac)
@@ -411,22 +372,20 @@ def single_split(
411
372
 
412
373
  @set_metadata
413
374
  def split_dataset(
414
- labels: list[int] | NDArray[np.intp],
375
+ dataset: AnnotatedDataset[Any] | Metadata,
415
376
  num_folds: int = 1,
416
377
  stratify: bool = False,
417
- split_on: list[str] | None = None,
418
- metadata: dict[str, Any] | None = None,
378
+ split_on: Sequence[str] | None = None,
419
379
  test_frac: float = 0.0,
420
380
  val_frac: float = 0.0,
421
381
  ) -> SplitDatasetOutput:
422
382
  """
423
- Top level splitting function. Returns a dataclass containing a list of train and validation indices.
424
- Indices for a test holdout may also be optionally included
383
+ Dataset splitting function. Returns a dataclass containing a list of train and validation indices.
425
384
 
426
385
  Parameters
427
386
  ----------
428
- labels : list or NDArray of ints
429
- Classification Labels used to generate splits. Determines the size of the dataset
387
+ dataset : AnnotatedDataset or Metadata
388
+ Dataset to split.
430
389
  num_folds : int, default 1
431
390
  Number of [train, val] folds. If equal to 1, val_frac must be greater than 0.0
432
391
  stratify : bool, default False
@@ -436,8 +395,6 @@ def split_dataset(
436
395
  Keys of the metadata dictionary upon which to group the dataset.
437
396
  A grouped partition is divided such that no group is present within both the training and
438
397
  validation set. Split_on groups should be selected to mitigate validation bias
439
- metadata : dict or None, default None
440
- Dict containing data for potential dataset grouping. See split_on above
441
398
  test_frac : float, default 0.0
442
399
  Fraction of data to be optionally held out for test set
443
400
  val_frac : float, default 0.0
@@ -450,13 +407,8 @@ def split_dataset(
450
407
  Output class containing a list of indices of training
451
408
  and validation data for each fold and optional test indices
452
409
 
453
- Raises
454
- ------
455
- TypeError
456
- Raised if split_on is passed, but metadata is None or empty
457
-
458
- Note
459
- ----
410
+ Notes
411
+ -----
460
412
  When specifying groups and/or stratification, ratios for test and validation splits can vary
461
413
  as the stratification and grouping take higher priority than the percentages
462
414
  """
@@ -464,30 +416,25 @@ def split_dataset(
464
416
  val_frac = calculate_validation_fraction(num_folds, test_frac, val_frac)
465
417
  total_partitions = num_folds + 1 if test_frac else num_folds
466
418
 
467
- if isinstance(labels, list):
468
- labels = np.array(labels, dtype=np.intp)
419
+ metadata = dataset if isinstance(dataset, Metadata) else Metadata(dataset)
420
+ labels = metadata.class_labels
469
421
 
470
- label_length: int = len(labels)
422
+ validate_labels(labels, total_partitions)
423
+ if stratify:
424
+ validate_stratifiable(labels, total_partitions)
471
425
 
472
- _validate_labels(labels, total_partitions)
473
- stratify &= is_stratifiable(labels, total_partitions)
474
- groups = None
475
- if split_on:
476
- if metadata is None or metadata == {}:
477
- raise TypeError("If split_on is specified, metadata must also be provided, got None")
478
- possible_groups = get_group_ids(metadata, split_on, label_length)
426
+ groups = get_groups(metadata, split_on)
427
+ if groups is not None:
479
428
  # Accounts for a test set that is 100 % of the data
480
429
  group_partitions = total_partitions + 1 if val_frac else total_partitions
481
- if is_groupable(possible_groups, group_partitions):
482
- groups = possible_groups
430
+ validate_groupable(groups, group_partitions)
483
431
 
484
- index = np.arange(label_length)
432
+ index = np.arange(len(labels))
485
433
 
486
- tvs = (
487
- single_split(index=index, labels=labels, split_frac=test_frac, groups=groups, stratified=stratify)
488
- if test_frac
489
- else TrainValSplit(index, np.array([], dtype=np.intp))
490
- )
434
+ if test_frac:
435
+ tvs = single_split(index, labels, test_frac, groups, stratify)
436
+ else:
437
+ tvs = TrainValSplit(index, np.array([], dtype=np.intp))
491
438
 
492
439
  tv_labels = labels[tvs.train]
493
440
  tv_groups = groups[tvs.train] if groups is not None else None