dataeval 0.74.0__py3-none-any.whl → 0.74.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. dataeval/__init__.py +23 -10
  2. dataeval/detectors/__init__.py +2 -10
  3. dataeval/detectors/drift/base.py +3 -3
  4. dataeval/detectors/drift/mmd.py +1 -1
  5. dataeval/detectors/linters/clusterer.py +3 -3
  6. dataeval/detectors/linters/duplicates.py +4 -4
  7. dataeval/detectors/linters/outliers.py +4 -4
  8. dataeval/detectors/ood/__init__.py +5 -12
  9. dataeval/detectors/ood/base.py +5 -5
  10. dataeval/detectors/ood/metadata_ks_compare.py +12 -13
  11. dataeval/interop.py +15 -3
  12. dataeval/logging.py +16 -0
  13. dataeval/metrics/bias/balance.py +3 -3
  14. dataeval/metrics/bias/coverage.py +3 -3
  15. dataeval/metrics/bias/diversity.py +3 -3
  16. dataeval/metrics/bias/metadata_preprocessing.py +3 -3
  17. dataeval/metrics/bias/parity.py +4 -4
  18. dataeval/metrics/estimators/ber.py +3 -3
  19. dataeval/metrics/estimators/divergence.py +3 -3
  20. dataeval/metrics/estimators/uap.py +3 -3
  21. dataeval/metrics/stats/base.py +2 -2
  22. dataeval/metrics/stats/boxratiostats.py +1 -1
  23. dataeval/metrics/stats/datasetstats.py +6 -6
  24. dataeval/metrics/stats/dimensionstats.py +1 -1
  25. dataeval/metrics/stats/hashstats.py +1 -1
  26. dataeval/metrics/stats/labelstats.py +3 -3
  27. dataeval/metrics/stats/pixelstats.py +1 -1
  28. dataeval/metrics/stats/visualstats.py +1 -1
  29. dataeval/output.py +81 -57
  30. dataeval/utils/__init__.py +1 -7
  31. dataeval/utils/split_dataset.py +306 -279
  32. dataeval/workflows/sufficiency.py +4 -4
  33. {dataeval-0.74.0.dist-info → dataeval-0.74.2.dist-info}/METADATA +3 -8
  34. dataeval-0.74.2.dist-info/RECORD +66 -0
  35. dataeval/detectors/ood/ae.py +0 -76
  36. dataeval/detectors/ood/aegmm.py +0 -67
  37. dataeval/detectors/ood/base_tf.py +0 -109
  38. dataeval/detectors/ood/llr.py +0 -302
  39. dataeval/detectors/ood/vae.py +0 -98
  40. dataeval/detectors/ood/vaegmm.py +0 -76
  41. dataeval/utils/lazy.py +0 -26
  42. dataeval/utils/tensorflow/__init__.py +0 -19
  43. dataeval/utils/tensorflow/_internal/gmm.py +0 -103
  44. dataeval/utils/tensorflow/_internal/loss.py +0 -121
  45. dataeval/utils/tensorflow/_internal/models.py +0 -1394
  46. dataeval/utils/tensorflow/_internal/trainer.py +0 -114
  47. dataeval/utils/tensorflow/_internal/utils.py +0 -256
  48. dataeval/utils/tensorflow/loss/__init__.py +0 -11
  49. dataeval-0.74.0.dist-info/RECORD +0 -79
  50. {dataeval-0.74.0.dist-info → dataeval-0.74.2.dist-info}/LICENSE.txt +0 -0
  51. {dataeval-0.74.0.dist-info → dataeval-0.74.2.dist-info}/WHEEL +0 -0
@@ -1,9 +1,13 @@
1
1
  from __future__ import annotations
2
2
 
3
- __all__ = ["split_dataset"]
3
+ from dataclasses import dataclass
4
+
5
+ from dataeval.output import Output, set_metadata
6
+
7
+ __all__ = ["split_dataset", "SplitDatasetOutput"]
4
8
 
5
9
  import warnings
6
- from typing import Any
10
+ from typing import Any, Iterator, NamedTuple, Protocol
7
11
 
8
12
  import numpy as np
9
13
  from numpy.typing import NDArray
@@ -13,128 +17,156 @@ from sklearn.model_selection import GroupKFold, KFold, StratifiedGroupKFold, Str
13
17
  from sklearn.utils.multiclass import type_of_target
14
18
 
15
19
 
16
- def validate_test_val(num_folds: int, test_frac: float | None, val_frac: float | None) -> tuple[float, float]:
17
- """Check input fractions to ensure unambiguous splitting arguments are passed return calculated
18
- test and validation fractions.
20
+ class TrainValSplit(NamedTuple):
21
+ """Tuple containing train and validation indices"""
22
+
23
+ train: NDArray[np.int_]
24
+ val: NDArray[np.int_]
25
+
26
+
27
+ @dataclass(frozen=True)
28
+ class SplitDatasetOutput(Output):
29
+ """Output class containing test indices and a list of TrainValSplits"""
30
+
31
+ test: NDArray[np.int_]
32
+ folds: list[TrainValSplit]
33
+
19
34
 
35
+ class KFoldSplitter(Protocol):
36
+ """Protocol covering sklearn KFold variant splitters"""
37
+
38
+ def __init__(self, n_splits: int): ...
39
+ def split(self, X: Any, y: Any, groups: Any) -> Iterator[tuple[NDArray[Any], NDArray[Any]]]: ...
40
+
41
+
42
+ KFOLD_GROUP_STRATIFIED_MAP: dict[tuple[bool, bool], type[KFoldSplitter]] = {
43
+ (False, False): KFold,
44
+ (False, True): StratifiedKFold,
45
+ (True, False): GroupKFold,
46
+ (True, True): StratifiedGroupKFold,
47
+ }
48
+
49
+
50
+ def calculate_validation_fraction(num_folds: int, test_frac: float, val_frac: float) -> float:
51
+ """
52
+ Calculate possible validation fraction based on the number of folds and test fraction.
20
53
 
21
54
  Parameters
22
55
  ----------
23
56
  num_folds : int
24
- number of [train, val] cross-validation folds to generate
25
- test_frac : float, optional
26
- If specified, also generate a test set containing (test_frac*100)% of the data
27
- val_frac : float, optional
28
- Only specify if requesting a single [train, val] split. The validation split will
29
- contain (val_frac*100)% of any data not already allocated to the test set
57
+ number of train and validation cross-validation folds to generate
58
+ test_frac : float
59
+ The fraction of the data to extract for testing before folds are created
60
+ val_frac : float
61
+ The validation split will contain (val_frac * 100)% of any data not already allocated to the test set.
62
+ Only required if requesting a single [train, val] split.
30
63
 
31
64
  Raises
32
65
  ------
33
- UnboundLocalError
34
- Raised if more than one fold AND the fraction of data to be used for validation are
35
- both requested. In this case, val_frac is ambiguous, since the validation fraction must be
36
- by definition 1/num_folds
37
66
  ValueError
38
- Raised if num_folds is 1 (or left blank) AND val_frac is unspecified. When only 1 fold is
39
- requested, we need to know how much of the data should be allocated for validation.
67
+ When number of folds requested is less than 1
40
68
  ValueError
41
- Raised if the total fraction of data used for evaluation (val + test) meets or exceeds 1.0
69
+ When the test fraction is not within 0.0 and 1.0 inclusively
70
+ ValueError
71
+ When more than one fold and the validation fraction are both requested
72
+ ValueError
73
+ When number of folds equals one but the validation fraction is 0.0
74
+ ValueError
75
+ When the validation fraction is not within 0.0 and 1.0 inclusively
42
76
 
43
77
  Returns
44
78
  -------
45
- tuple[float, float]
46
- Tuple of the validated and calculated values as appropriate for test and validation fractions
79
+ float
80
+ The updated validation fraction of the remaining data after the testing fraction is removed
81
+ """
82
+ if num_folds < 1:
83
+ raise ValueError(f"Number of folds must be greater than or equal to 1, got {num_folds}")
84
+ if test_frac < 0.0 or test_frac > 1.0:
85
+ raise ValueError(f"test_frac out of bounds. Must be between 0.0 and 1.0, got {test_frac}")
86
+
87
+ # val base is a variable placeholder so val_frac can be ignored if num_folds != 1
88
+ val_base: float = 1.0
89
+ if num_folds == 1:
90
+ if val_frac == 0.0:
91
+ raise ValueError("If num_folds equals 1, must assign a value to val_frac")
92
+ if val_frac < 0.0 or val_frac > 1.0:
93
+ raise ValueError(f"val_frac out of bounds. Must be between 0.0 and 1.0, got {val_frac}")
94
+ val_base = val_frac
95
+ # num folds must be >1 in this case
96
+ elif val_frac != 0.0:
97
+ raise ValueError("Can only specify val_frac when num_folds equals 1")
98
+
99
+ # This value is mathematically bound between 0-1 inclusive
100
+ return val_base * (1.0 / num_folds) * (1.0 - test_frac)
101
+
102
+
103
+ def _validate_labels(labels: NDArray[np.int_], total_partitions: int) -> None:
47
104
  """
48
- if (num_folds > 1) and (val_frac is not None):
49
- raise ValueError("If specifying val_frac, num_folds must be None or 1")
50
- if (num_folds == 1) and (val_frac is None):
51
- raise ValueError("If num_folds is None or 1, must assign a value to val_frac")
52
- t_frac = 0.0 if test_frac is None else test_frac
53
- v_frac = 1.0 / num_folds * (1.0 - t_frac) if val_frac is None else val_frac * (1.0 - t_frac)
54
- if (t_frac + v_frac) >= 1.0:
55
- raise ValueError(f"val_frac + test_frac must be less that 1.0, currently {v_frac+t_frac}")
56
- return t_frac, v_frac
57
-
58
-
59
- def check_labels(
60
- labels: list[int] | NDArray[np.int_], total_partitions: int
61
- ) -> tuple[NDArray[np.int_], NDArray[np.int_]]:
62
- """Check to make sure there are more input data than the total number of partitions requested
63
- Also converts labels to a numpy array, if it isn't already
105
+ Check to make sure there is more input data than the total number of partitions requested
64
106
 
65
107
  Parameters
66
108
  ----------
67
- labels : list or np.ndarray
68
- all class labels from the input dataset
109
+ labels : np.ndarray of ints
110
+ All class labels from the input dataset
69
111
  total_partitions : int
70
- number of train-val splits requested (+1 if a test holdout is specified)
112
+ Number of [train, val, test] splits requested
71
113
 
72
114
  Raises
73
115
  ------
74
- IndexError
75
- Raised if more partitions are requested than number of labels. This is exceedingly rare and
76
- usually means you've specified some argument incorrectly.
77
116
  ValueError
78
- Raised if the labels are considered continuous by Scikit-Learn. This does not necessarily
79
- mean that floats are not accepted as a label format. Rather, this exception implies that
80
- there are too many unique values in the set relative to it's cardinality.
81
-
82
- Returns
83
- -------
84
- index : np.ndarray
85
- Integer index generated based on the total number of labels
86
- labels : np.ndarray
87
- labels, converted to an ndarray if passed as a list.
117
+ When more partitions are requested than number of labels.
118
+ ValueError
119
+ When the labels are considered continuous by Scikit-Learn. This does not necessarily
120
+ mean that floats are not accepted as a label format. Rather, this implies that
121
+ there are too many unique values in the set relative to its cardinality.
88
122
  """
123
+
89
124
  if len(labels) <= total_partitions:
90
- raise IndexError(f"""
91
- Total number of labels must greater than the number of total partitions.
92
- Got {len(labels)} labels and {total_partitions} total train/val/test partitions.""")
93
- if isinstance(labels, list):
94
- labels = np.array(labels)
125
+ raise ValueError(
126
+ "Total number of labels must be greater than the total number of partitions. "
127
+ f"Got {len(labels)} labels and {total_partitions} total [train, val, test] partitions."
128
+ )
129
+
95
130
  if type_of_target(labels) == "continuous":
96
- raise ValueError("Detected continuous labels, labels must be discrete for proper stratification")
97
- index = np.arange(len(labels))
98
- return index, labels
131
+ raise ValueError("Detected continuous labels. Labels must be discrete for proper stratification")
99
132
 
100
133
 
101
- def check_stratifiable(labels: NDArray[np.int_], total_partitions: int) -> bool:
134
+ def is_stratifiable(labels: NDArray[np.int_], num_partitions: int) -> bool:
102
135
  """
103
- Very basic check to see if dataset can be stratified by class label. This is not a
104
- comprehensive test, as factors such as grouping also affect the ability to stratify by label
136
+ Check if the dataset can be stratified by class label over the given number of partitions
105
137
 
106
138
  Parameters
107
139
  ----------
108
- labels : list or np.ndarray
109
- all class labels from the input dataset
110
- total_partitions : int
111
- number of train-val splits requested (+1 if a test holdout is specified)
140
+ labels : NDArray of ints
141
+ All class labels of the input dataset
142
+ num_partitions : int
143
+ Total number of [train, val, test] splits requested
144
+
145
+ Returns
146
+ -------
147
+ bool
148
+ True if dataset can be stratified else False
112
149
 
113
150
  Warns
114
151
  -----
115
152
  UserWarning
116
- Warns user if the dataset cannot be stratified due to the number of total (train, val, test)
153
+ Warns user if the dataset cannot be stratified due to the total number of [train, val, test]
117
154
  partitions exceeding the number of instances of the rarest class label.
118
-
119
- Returns
120
- -------
121
- stratifiable : bool
122
- True if dataset can be stratified according to the criteria above.
123
155
  """
124
156
 
125
- stratifiable = True
126
- _, label_counts = np.unique(labels, return_counts=True)
127
- rarest_label_count = label_counts.min()
128
- if rarest_label_count < total_partitions:
129
- warnings.warn(f"""
130
- Unable to stratify due to label frequency. The rarest label occurs {rarest_label_count},
131
- which is fewer than the total number of partitions requested. Setting stratify flag to
132
- false.""")
133
- stratifiable = False
134
- return stratifiable
157
+ # Get the minimum count of all labels
158
+ lowest_label_count = np.unique(labels, return_counts=True)[1].min()
159
+ if lowest_label_count < num_partitions:
160
+ warnings.warn(
161
+ f"Unable to stratify due to label frequency. The lowest label count ({lowest_label_count}) is fewer "
162
+ f"than the total number of partitions ({num_partitions}) requested.",
163
+ UserWarning,
164
+ )
165
+ return False
166
+ return True
135
167
 
136
168
 
137
- def check_groups(group_ids: NDArray[np.int_], num_partitions: int) -> bool:
169
+ def is_groupable(group_ids: NDArray[np.int_], num_partitions: int) -> bool:
138
170
  """
139
171
  Warns user if the number of unique group_ids is incompatible with a grouped partition containing
140
172
  num_folds folds. If this is the case, returns groups=None, which tells the partitioner not to
@@ -142,34 +174,35 @@ def check_groups(group_ids: NDArray[np.int_], num_partitions: int) -> bool:
142
174
 
143
175
  Parameters
144
176
  ----------
145
- group_ids : np.ndarray
146
- Identifies the group to which a sample at the same index belongs.
177
+ group_ids : NDArray of ints
178
+ The id of the group each sample at the corresponding index belongs to
147
179
  num_partitions : int
148
- How many total (train, val) folds will be generated (+1 if also specifying a test fold).
180
+ Total number of train, val, and test splits requested
181
+
182
+ Returns
183
+ -------
184
+ bool
185
+ True if the dataset can be grouped by the given group ids else False
149
186
 
150
187
  Warns
151
188
  -----
152
189
  UserWarning
153
- Warns if there are fewer groups than the minimum required to successfully partition the data
154
- into num_partitions. The minimum is defined as the number of partitions requested plus one.
155
-
156
- Returns
157
- -------
158
- groupable : bool
159
- True if dataset can be grouped by the given group ids, given the criteria above.
190
+ Warns if there are fewer groups than the requested number of partitions plus one
160
191
  """
161
192
 
162
- groupable = True
163
193
  num_unique_groups = len(np.unique(group_ids))
164
- min_unique_groups = num_partitions + 1
165
- if num_unique_groups < min_unique_groups:
166
- warnings.warn(f"""
167
- {min_unique_groups} unique groups required for {num_partitions} partitions.
168
- Found {num_unique_groups} instead. Reverting to ungrouped partitioning""")
169
- groupable = False
170
- else:
171
- groupable = True
172
- return groupable
194
+ # Cannot separate if only one group exists
195
+ if num_unique_groups == 1:
196
+ return False
197
+
198
+ if num_unique_groups < num_partitions:
199
+ warnings.warn(
200
+ f"Groups must be greater than num partitions. Got {num_unique_groups} and {num_partitions}. "
201
+ "Reverting to ungrouped partitioning",
202
+ UserWarning,
203
+ )
204
+ return False
205
+ return True
173
206
 
174
207
 
175
208
  def bin_kmeans(array: NDArray[Any]) -> NDArray[np.int_]:
@@ -179,14 +212,15 @@ def bin_kmeans(array: NDArray[Any]) -> NDArray[np.int_]:
179
212
 
180
213
  Parameters
181
214
  ----------
182
- array : np.ndarray
215
+ array : NDArray
183
216
  continuous data to bin
184
217
 
185
218
  Returns
186
219
  -------
187
- np.ndarray[int]: bin numbers assigned by the kmeans best clusterer.
220
+ NDArray[int]:
221
+ bin numbers assigned by the kmeans best clusterer.
188
222
  """
189
- array = np.array(array)
223
+
190
224
  if array.ndim == 1:
191
225
  array = array.reshape([-1, 1])
192
226
  best_score = 0.60
@@ -203,28 +237,9 @@ def bin_kmeans(array: NDArray[Any]) -> NDArray[np.int_]:
203
237
  return bin_index
204
238
 
205
239
 
206
- def angle2xy(angles: NDArray[Any]) -> NDArray[Any]:
207
- """
208
- Converts angle measurements to xy coordinates on the unit circle. Needed for binning angle data.
209
-
210
- Parameters
211
- ----------
212
- angles : np.ndarray
213
- angle data in either radians or degrees
214
-
215
- Returns
216
- -------
217
- xy : np.ndarray
218
- Nx2 array of xy coordinates for each angle (can be radians or degrees)
219
- """
220
- is_radians = ((angles >= -np.pi) & (angles <= 2 * np.pi)).all()
221
- radians = angles if is_radians else np.pi / 180 * angles
222
- xy = np.stack([np.cos(radians), np.sin(radians)], axis=1)
223
- return xy
224
-
225
-
226
240
  def get_group_ids(metadata: dict[str, Any], group_names: list[str], num_samples: int) -> NDArray[np.int_]:
227
- """Returns individual group numbers based on a subset of metadata defined by groupnames
241
+ """
242
+ Returns individual group numbers based on a subset of metadata defined by groupnames
228
243
 
229
244
  Parameters
230
245
  ----------
@@ -242,7 +257,7 @@ def get_group_ids(metadata: dict[str, Any], group_names: list[str], num_samples:
242
257
 
243
258
  Returns
244
259
  -------
245
- group_ids : np.ndarray
260
+ np.ndarray
246
261
  group identifiers from metadata
247
262
  """
248
263
  features2group = {k: np.array(v) for k, v in metadata.items() if k in group_names}
@@ -250,11 +265,12 @@ def get_group_ids(metadata: dict[str, Any], group_names: list[str], num_samples:
250
265
  return np.zeros(num_samples, dtype=np.int_)
251
266
  for name, feature in features2group.items():
252
267
  if len(feature) != num_samples:
253
- raise IndexError(f"""Feature length does not match number of labels.
254
- Got {len(feature)} features and {num_samples} samples""")
268
+ raise ValueError(
269
+ f"Feature length does not match number of labels. "
270
+ f"Got {len(feature)} features and {num_samples} samples"
271
+ )
272
+
255
273
  if type_of_target(feature) == "continuous":
256
- if ("ANGLE" in name.upper()) or ("AZIMUTH" in name.upper()):
257
- feature = angle2xy(feature)
258
274
  features2group[name] = bin_kmeans(feature)
259
275
  binned_features = np.stack(list(features2group.values()), axis=1)
260
276
  _, group_ids = np.unique(binned_features, axis=0, return_inverse=True)
@@ -265,228 +281,239 @@ def make_splits(
265
281
  index: NDArray[np.int_],
266
282
  labels: NDArray[np.int_],
267
283
  n_folds: int,
268
- groups: NDArray[np.int_] | None = None,
269
- stratified: bool = False,
270
- ) -> list[dict[str, NDArray[np.int_]]]:
271
- """Split data into n_folds partitions of training and validation data.
284
+ groups: NDArray[np.int_] | None,
285
+ stratified: bool,
286
+ ) -> list[TrainValSplit]:
287
+ """
288
+ Split data into n_folds partitions of training and validation data.
272
289
 
273
290
  Parameters
274
291
  ----------
275
- index : np.ndarray
276
- index corresponding to each label (see below)
277
- labels : np.ndarray
292
+ index : NDArray of ints
293
+ index corresponding to each label
294
+ labels : NDArray of ints
278
295
  classification labels
279
296
  n_folds : int
280
- number or train/val folds
281
- groups : np.ndarray, Optional
297
+ number of [train, val] folds
298
+ groups : NDArray of ints or None
282
299
  group index for grouped partitions. Grouped partitions are split such that no group id is
283
300
  present in both a training and validation split.
284
- stratified : bool, default=False
285
- If True, maintain dataset class balance within each train/val split
301
+ stratified : bool
302
+ If True, maintain dataset class balance within each [train, val] split
286
303
 
287
304
  Returns
288
305
  -------
289
- split_defs : list[dict]
290
- list of dictionaries, which specifying train index, validation index, and the ratio of
306
+ split_defs : list[TrainValSplit]
307
+ List of TrainValSplits, which specify train index, validation index, and the ratio of
291
308
  validation to all data.
292
309
  """
293
- split_defs = []
294
- index = index.reshape([-1, 1])
295
- if groups is not None:
296
- splitter = StratifiedGroupKFold(n_folds) if stratified else GroupKFold(n_folds)
310
+ split_defs: list[TrainValSplit] = []
311
+ n_labels = len(np.unique(labels))
312
+ splitter = KFOLD_GROUP_STRATIFIED_MAP[(groups is not None, stratified)](n_folds)
313
+ good = False
314
+ attempts = 0
315
+ while not good and attempts < 3:
316
+ attempts += 1
297
317
  splits = splitter.split(index, labels, groups)
298
- else:
299
- splitter = StratifiedKFold(n_folds) if stratified else KFold(n_folds)
300
- splits = splitter.split(index, labels)
301
- for train_idx, eval_idx in splits:
302
- test_ratio = len(eval_idx) / index.shape[0]
303
- split_defs.append(
304
- {
305
- "train": train_idx.astype(np.int_),
306
- "eval": eval_idx.astype(np.int_),
307
- "eval_frac": test_ratio,
308
- }
309
- )
318
+ split_defs.clear()
319
+ for train_idx, eval_idx in splits:
320
+ # test_ratio = len(eval_idx) / len(index)
321
+ t = np.atleast_1d(train_idx).astype(np.int_)
322
+ v = np.atleast_1d(eval_idx).astype(np.int_)
323
+ good = good or (len(np.unique(labels[t])) == n_labels and len(np.unique(labels[v])) == n_labels)
324
+ split_defs.append(TrainValSplit(t, v))
325
+ if not good and attempts == 3:
326
+ warnings.warn("Unable to create a good split definition, not all classes are represented in each split.")
310
327
  return split_defs
311
328
 
312
329
 
313
330
  def find_best_split(
314
- labels: NDArray[np.int_], split_defs: list[dict[str, NDArray[np.int_]]], stratified: bool, eval_frac: float
315
- ) -> tuple[NDArray[np.int_], NDArray[np.int_]]:
316
- """Finds the split that most closely satisfies a criterion determined by the arguments passed.
331
+ labels: NDArray[np.int_], split_defs: list[TrainValSplit], stratified: bool, split_frac: float
332
+ ) -> TrainValSplit:
333
+ """
334
+ Finds the split that most closely satisfies a criterion determined by the arguments passed.
317
335
  If stratified is True, returns the split whose class balance most closely resembles the overall
318
- class balance. If false, returns the split with the size closest to the desired eval_frac
336
+ class balance. If false, returns the split with the size closest to the desired split_frac
319
337
 
320
338
  Parameters
321
339
  ----------
322
340
  labels : np.ndarray
323
341
  Labels upon which splits are (optionally) stratified
324
- split_defs : list[dict]
325
- List of dictionaries, which specifying train index, validation index, and the ratio of
326
- validation to all data.
342
+ split_defs : list of TrainValSplits
343
+ Specifies the train index, validation index
327
344
  stratified : bool
328
- If True, maintain dataset class balance within each train/val split
329
- eval_frac : float
345
+ If True, maintain dataset class balance within each [train, val] split
346
+ split_frac : float
330
347
  Desired fraction of the dataset sequestered for evaluation
331
348
 
332
349
  Returns
333
350
  -------
334
- train_index : np.ndarray
335
- indices of data partitioned for training
336
- eval_index : np.ndarray
337
- indices of data partitioned for evaluation
351
+ TrainValSplit
352
+ Indices of data partitioned for training and evaluation
338
353
  """
339
354
 
340
- def class_freq_diff(split):
341
- train_labels = labels[split["train"]]
342
- _, train_counts = np.unique(train_labels, return_counts=True)
343
- train_freq = train_counts / train_counts.sum()
344
- return np.square(train_freq - class_freq).sum()
355
+ # Minimization functions and helpers
356
+ def freq(arr: NDArray[Any], minlength: int = 0) -> NDArray[np.floating[Any]]:
357
+ counts = np.bincount(arr, minlength=minlength)
358
+ return counts / np.sum(counts)
359
+
360
+ def weight(arr: NDArray, class_freq: NDArray) -> np.float64:
361
+ return np.sum(np.abs(freq(arr, len(class_freq)) - class_freq))
362
+
363
+ def class_freq_diff(split: TrainValSplit) -> np.float64:
364
+ class_freq = freq(labels)
365
+ return weight(labels[split.train], class_freq) + weight(labels[split.val], class_freq)
366
+
367
+ def split_ratio(split: TrainValSplit) -> np.float64:
368
+ return np.float64(len(split.val) / (len(split.val) + len(split.train)))
345
369
 
370
+ def split_diff(split: TrainValSplit) -> np.float64:
371
+ return abs(split_frac - split_ratio(split))
372
+
373
+ def split_inv_diff(split: TrainValSplit) -> np.float64:
374
+ return abs(1 - split_frac - split_ratio(split))
375
+
376
+ # Selects minimization function based on inputs
346
377
  if stratified:
347
- _, class_counts = np.unique(labels, return_counts=True)
348
- class_freq = class_counts / class_counts.sum()
349
- best_split = min(split_defs, key=class_freq_diff)
350
- return best_split["train"], best_split["eval"]
351
- elif eval_frac <= 2 / 3:
352
- best_split = min(split_defs, key=lambda x: abs(eval_frac - x["eval_frac"])) # type: ignore
353
- return best_split["train"], best_split["eval"]
378
+ key_func = class_freq_diff
379
+ elif split_frac <= 2 / 3:
380
+ key_func = split_diff
354
381
  else:
355
- best_split = min(split_defs, key=lambda x: abs(eval_frac - (1 - x["eval_frac"]))) # type: ignore
356
- return best_split["eval"], best_split["train"]
382
+ key_func = split_inv_diff
383
+
384
+ return min(split_defs, key=key_func)
357
385
 
358
386
 
359
387
  def single_split(
360
388
  index: NDArray[np.int_],
361
389
  labels: NDArray[np.int_],
362
- eval_frac: float,
390
+ split_frac: float,
363
391
  groups: NDArray[np.int_] | None = None,
364
392
  stratified: bool = False,
365
- ) -> tuple[NDArray[np.int_], NDArray[np.int_]]:
366
- """Handles the special case where only 1 partition of the data is desired (such as when
393
+ ) -> TrainValSplit:
394
+ """
395
+ Handles the special case where only 1 partition of the data is desired (such as when
367
396
  generating the test holdout split). In this case, the desired fraction of the data to be
368
- partitioned into the test data must be specified, and a single [train, eval] pair are returned.
397
+ partitioned into the test data must be specified, and a single [train, val] pair is returned.
369
398
 
370
399
  Parameters
371
400
  ----------
372
- index : np.ndarray
401
+ index : NDArray of ints
373
402
  Input Dataset index corresponding to each label
374
- labels : np.ndarray
403
+ labels : NDArray of ints
375
404
  Labels upon which splits are (optionally) stratified
376
- eval_frac : float
405
+ split_frac : float
377
406
  Fraction of incoming data to be set aside for evaluation
378
- groups : np.ndarray, Optional
407
+ groups : NDArray of ints, Optional
379
408
  Group_ids (same shape as labels) for optional group partitioning
380
- stratified : bool, default=False
409
+ stratified : bool, default False
381
410
  Generates stratified splits if true (recommended)
382
411
 
383
412
  Returns
384
413
  -------
385
- train_index : np.ndarray
386
- indices of data partitioned for training
387
- eval_index : np.ndarray
388
- indices of data partitioned for evaluation
414
+ TrainValSplit
415
+ Indices of data partitioned for training and evaluation
389
416
  """
390
- if groups is not None:
391
- n_unique_groups = np.unique(groups).shape[0]
392
- _, label_counts = np.unique(labels, return_counts=True)
393
- n_folds = min(n_unique_groups, label_counts.min())
394
- elif eval_frac <= 2 / 3:
395
- n_folds = max(2, int(round(1 / (eval_frac + 1e-6))))
396
- else:
397
- n_folds = max(2, int(round(1 / (1 - eval_frac - 1e-6))))
417
+
418
+ _, label_counts = np.unique(labels, return_counts=True)
419
+ max_folds = label_counts.min()
420
+ min_folds = np.unique(groups).shape[0] if groups is not None else 2
421
+ divisor = split_frac + 1e-06 if split_frac <= 2 / 3 else 1 - split_frac - 1e-06
422
+ n_folds = round(min(max(1 / divisor, min_folds), max_folds)) # Clips value between min_folds and max_folds
423
+
398
424
  split_candidates = make_splits(index, labels, n_folds, groups, stratified)
399
- best_train, best_eval = find_best_split(labels, split_candidates, stratified, eval_frac)
400
- return best_train, best_eval
425
+ return find_best_split(labels, split_candidates, stratified, split_frac)
401
426
 
402
427
 
428
+ @set_metadata
403
429
  def split_dataset(
404
430
  labels: list[int] | NDArray[np.int_],
405
431
  num_folds: int = 1,
406
432
  stratify: bool = False,
407
433
  split_on: list[str] | None = None,
408
434
  metadata: dict[str, Any] | None = None,
409
- test_frac: float | None = None,
410
- val_frac: float | None = None,
411
- ) -> dict[str, dict[str, NDArray[np.int_]] | NDArray[np.int_]]:
412
- """Top level splitting function. Returns a dict with each key-value pair containing
413
- train and validation indices. Indices for a test holdout may also be optionally included
435
+ test_frac: float = 0.0,
436
+ val_frac: float = 0.0,
437
+ ) -> SplitDatasetOutput:
438
+ """
439
+ Top level splitting function. Returns a dataclass containing a list of train and validation indices.
440
+ Indices for a test holdout may also be optionally included
414
441
 
415
442
  Parameters
416
443
  ----------
417
- labels : Union[list, np.ndarray]
444
+ labels : list or NDArray of ints
418
445
  Classification Labels used to generate splits. Determines the size of the dataset
419
- num_folds : int, optional
420
- Number of train/val folds. If None, returns a single train/val split, and val_frac must be
421
- specified. Defaults to None.
422
- stratify : bool, default=False
446
+ num_folds : int, default 1
447
+ Number of [train, val] folds. If equal to 1, val_frac must be greater than 0.0
448
+ stratify : bool, default False
423
449
  If true, dataset is split such that the class distribution of the entire dataset is
424
- preserved within each train/val partition, which is generally recommended.
425
- split_on : list, optional
426
- Keys of the metadata dictionary which map to columns upon which to group the dataset.
450
+ preserved within each [train, val] partition, which is generally recommended.
451
+ split_on : list or None, default None
452
+ Keys of the metadata dictionary upon which to group the dataset.
427
453
  A grouped partition is divided such that no group is present within both the training and
428
- validation set. Split_on groups should be selected to mitigate validation bias. Defaults to
429
- None, in which groups will not be considered when partitioning the data.
430
- metadata : dict, optional
431
- Dict containing data for potential dataset grouping. See split_on above. Defaults to None.
432
- test_frac : float, optional
433
- Fraction of data to be optionally held out for test set. Defaults to None, in which no test
434
- set is created.
435
- val_frac : float, optional
454
+ validation set. Split_on groups should be selected to mitigate validation bias
455
+ metadata : dict or None, default None
456
+ Dict containing data for potential dataset grouping. See split_on above
457
+ test_frac : float, default 0.0
458
+ Fraction of data to be optionally held out for test set
459
+ val_frac : float, default 0.0
436
460
  Fraction of training data to be set aside for validation in the case where a single
437
- train/val split is desired. Defaults to None.
461
+ [train, val] split is desired
462
+
463
+ Returns
464
+ -------
465
+ split_defs : SplitDatasetOutput
466
+ Output class containing a list of indices of training
467
+ and validation data for each fold and optional test indices
438
468
 
439
469
  Raises
440
470
  ------
441
- UnboundLocalError
442
- Raised if split_on is passed, but metadata is left as None. This is because split_on
443
- defines the keys in which metadata dict must be indexed to determine the group index of the
444
- data
471
+ TypeError
472
+ Raised if split_on is passed, but metadata is None or empty
445
473
 
446
- Returns
447
- -------
448
- split_defs : dict
449
- dictionary of folds, each containing indices of training and validation data.
450
- ex.
451
- {
452
- "Fold_00": {
453
- "train": [1,2,3,5,6,7,9,10,11],
454
- "val": [0, 4, 8, 12]
455
- },
456
- "test": [13, 14, 15, 16]
457
- }
474
+ Note
475
+ ----
476
+ When specifying groups and/or stratification, ratios for test and validation splits can vary
477
+ as the stratification and grouping take higher priority than the percentages
458
478
  """
459
479
 
460
- test_frac, val_frac = validate_test_val(num_folds, test_frac, val_frac)
480
+ val_frac = calculate_validation_fraction(num_folds, test_frac, val_frac)
461
481
  total_partitions = num_folds + 1 if test_frac else num_folds
462
- index, labels = check_labels(labels, total_partitions)
463
- stratify &= check_stratifiable(labels, total_partitions)
482
+
483
+ if isinstance(labels, list):
484
+ labels = np.array(labels, dtype=np.int_)
485
+
486
+ label_length: int = len(labels)
487
+
488
+ _validate_labels(labels, total_partitions)
489
+ stratify &= is_stratifiable(labels, total_partitions)
490
+ groups = None
464
491
  if split_on:
465
- if metadata is None:
466
- raise UnboundLocalError("If split_on is specified, metadata must also be provided")
467
- groups = get_group_ids(metadata, split_on, len(labels))
468
- groupable = check_groups(groups, total_partitions)
469
- if not groupable:
470
- groups = None
471
- else:
472
- groups = None
473
- split_defs: dict[str, dict[str, NDArray[np.int_]] | NDArray[np.int_]] = {}
474
- if test_frac:
475
- tv_idx, test_idx = single_split(index, labels, test_frac, groups, stratify)
476
- tv_labels = labels[tv_idx]
477
- tv_groups = groups[tv_idx] if groups is not None else None
478
- split_defs["test"] = test_idx
479
- else:
480
- tv_idx = np.arange(len(labels)).reshape((-1, 1))
481
- tv_labels = labels
482
- tv_groups = groups
492
+ if metadata is None or metadata == {}:
493
+ raise TypeError("If split_on is specified, metadata must also be provided, got None")
494
+ possible_groups = get_group_ids(metadata, split_on, label_length)
495
+ # Accounts for a test set that is 100 % of the data
496
+ group_partitions = total_partitions + 1 if val_frac else total_partitions
497
+ if is_groupable(possible_groups, group_partitions):
498
+ groups = possible_groups
499
+
500
+ test_indices: NDArray[np.int_]
501
+ index = np.arange(label_length)
502
+
503
+ tv_indices, test_indices = (
504
+ single_split(index=index, labels=labels, split_frac=test_frac, groups=groups, stratified=stratify)
505
+ if test_frac
506
+ else (index, np.array([], dtype=np.int_))
507
+ )
508
+
509
+ tv_labels = labels[tv_indices]
510
+ tv_groups = groups[tv_indices] if groups is not None else None
511
+
483
512
  if num_folds == 1:
484
- train_idx, val_idx = single_split(tv_idx, tv_labels, val_frac, tv_groups, stratify)
485
- split_defs["fold_0"] = {"train": tv_idx[train_idx].squeeze(), "val": tv_idx[val_idx].squeeze()}
513
+ tv_splits = [single_split(tv_indices, tv_labels, val_frac, tv_groups, stratify)]
486
514
  else:
487
- tv_splits = make_splits(tv_idx, tv_labels, num_folds, tv_groups, stratify)
488
- for i, split in enumerate(tv_splits):
489
- train_split = tv_idx[split["train"]]
490
- val_split = tv_idx[split["eval"]]
491
- split_defs[f"fold_{i}"] = {"train": train_split.squeeze(), "val": val_split.squeeze()}
492
- return split_defs
515
+ tv_splits = make_splits(tv_indices, tv_labels, num_folds, tv_groups, stratify)
516
+
517
+ folds: list[TrainValSplit] = [TrainValSplit(tv_indices[split.train], tv_indices[split.val]) for split in tv_splits]
518
+
519
+ return SplitDatasetOutput(test_indices, folds)