dataeval 0.74.1__py3-none-any.whl → 0.75.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. dataeval/__init__.py +33 -10
  2. dataeval/detectors/__init__.py +2 -2
  3. dataeval/detectors/drift/__init__.py +14 -12
  4. dataeval/detectors/drift/base.py +1 -1
  5. dataeval/detectors/drift/cvm.py +1 -1
  6. dataeval/detectors/drift/ks.py +1 -1
  7. dataeval/detectors/drift/mmd.py +6 -5
  8. dataeval/detectors/drift/torch.py +12 -12
  9. dataeval/detectors/drift/uncertainty.py +3 -2
  10. dataeval/detectors/linters/__init__.py +4 -4
  11. dataeval/detectors/linters/clusterer.py +2 -7
  12. dataeval/detectors/linters/duplicates.py +6 -10
  13. dataeval/detectors/linters/outliers.py +4 -2
  14. dataeval/detectors/ood/__init__.py +3 -10
  15. dataeval/detectors/ood/{ae_torch.py → ae.py} +6 -4
  16. dataeval/detectors/ood/base.py +64 -161
  17. dataeval/detectors/ood/metadata_ks_compare.py +34 -42
  18. dataeval/detectors/ood/metadata_least_likely.py +3 -3
  19. dataeval/detectors/ood/metadata_ood_mi.py +6 -5
  20. dataeval/detectors/ood/mixin.py +146 -0
  21. dataeval/detectors/ood/output.py +63 -0
  22. dataeval/interop.py +16 -3
  23. dataeval/log.py +18 -0
  24. dataeval/metrics/__init__.py +2 -2
  25. dataeval/metrics/bias/__init__.py +9 -12
  26. dataeval/metrics/bias/balance.py +10 -8
  27. dataeval/metrics/bias/coverage.py +52 -4
  28. dataeval/metrics/bias/diversity.py +42 -14
  29. dataeval/metrics/bias/parity.py +15 -12
  30. dataeval/metrics/estimators/__init__.py +2 -2
  31. dataeval/metrics/estimators/ber.py +3 -1
  32. dataeval/metrics/estimators/divergence.py +1 -1
  33. dataeval/metrics/estimators/uap.py +1 -1
  34. dataeval/metrics/stats/__init__.py +18 -18
  35. dataeval/metrics/stats/base.py +4 -4
  36. dataeval/metrics/stats/boxratiostats.py +8 -9
  37. dataeval/metrics/stats/datasetstats.py +10 -14
  38. dataeval/metrics/stats/dimensionstats.py +4 -4
  39. dataeval/metrics/stats/hashstats.py +12 -8
  40. dataeval/metrics/stats/labelstats.py +5 -5
  41. dataeval/metrics/stats/pixelstats.py +4 -9
  42. dataeval/metrics/stats/visualstats.py +4 -9
  43. dataeval/output.py +1 -1
  44. dataeval/utils/__init__.py +4 -13
  45. dataeval/utils/dataset/__init__.py +7 -0
  46. dataeval/utils/{torch → dataset}/datasets.py +2 -0
  47. dataeval/utils/dataset/read.py +63 -0
  48. dataeval/utils/dataset/split.py +527 -0
  49. dataeval/utils/image.py +2 -2
  50. dataeval/utils/metadata.py +310 -5
  51. dataeval/{metrics/bias/metadata_utils.py → utils/plot.py} +1 -104
  52. dataeval/utils/torch/__init__.py +2 -17
  53. dataeval/utils/torch/gmm.py +29 -6
  54. dataeval/utils/torch/{utils.py → internal.py} +82 -58
  55. dataeval/utils/torch/models.py +10 -8
  56. dataeval/utils/torch/trainer.py +6 -85
  57. dataeval/workflows/__init__.py +2 -5
  58. dataeval/workflows/sufficiency.py +16 -6
  59. dataeval-0.75.0.dist-info/METADATA +136 -0
  60. dataeval-0.75.0.dist-info/RECORD +67 -0
  61. dataeval/detectors/ood/base_torch.py +0 -109
  62. dataeval/metrics/bias/metadata_preprocessing.py +0 -285
  63. dataeval/utils/gmm.py +0 -26
  64. dataeval/utils/split_dataset.py +0 -492
  65. dataeval-0.74.1.dist-info/METADATA +0 -120
  66. dataeval-0.74.1.dist-info/RECORD +0 -65
  67. {dataeval-0.74.1.dist-info → dataeval-0.75.0.dist-info}/LICENSE.txt +0 -0
  68. {dataeval-0.74.1.dist-info → dataeval-0.75.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,527 @@
1
+ from __future__ import annotations
2
+
3
+ __all__ = []
4
+
5
+ import warnings
6
+ from dataclasses import dataclass
7
+ from typing import Any, Iterator, NamedTuple, Protocol
8
+
9
+ import numpy as np
10
+ from numpy.typing import NDArray
11
+ from sklearn.cluster import KMeans
12
+ from sklearn.metrics import silhouette_score
13
+ from sklearn.model_selection import GroupKFold, KFold, StratifiedGroupKFold, StratifiedKFold
14
+ from sklearn.utils.multiclass import type_of_target
15
+
16
+ from dataeval.output import Output, set_metadata
17
+
18
+
19
+ class TrainValSplit(NamedTuple):
20
+ """Tuple containing train and validation indices"""
21
+
22
+ train: NDArray[np.intp]
23
+ val: NDArray[np.intp]
24
+
25
+
26
+ @dataclass(frozen=True)
27
+ class SplitDatasetOutput(Output):
28
+ """
29
+ Output class containing test indices and a list of TrainValSplits
30
+
31
+ Attributes
32
+ ----------
33
+ test: NDArray[np.intp]
34
+ Indices for the test set
35
+ folds: list[TrainValSplit]
36
+ List where each index contains the indices for the train and validation splits
37
+ """
38
+
39
+ test: NDArray[np.intp]
40
+ folds: list[TrainValSplit]
41
+
42
+
43
+ class KFoldSplitter(Protocol):
44
+ """Protocol covering sklearn KFold variant splitters"""
45
+
46
+ def __init__(self, n_splits: int): ...
47
+ def split(self, X: Any, y: Any, groups: Any) -> Iterator[tuple[NDArray[Any], NDArray[Any]]]: ...
48
+
49
+
50
+ KFOLD_GROUP_STRATIFIED_MAP: dict[tuple[bool, bool], type[KFoldSplitter]] = {
51
+ (False, False): KFold,
52
+ (False, True): StratifiedKFold,
53
+ (True, False): GroupKFold,
54
+ (True, True): StratifiedGroupKFold,
55
+ }
56
+
57
+
58
+ def calculate_validation_fraction(num_folds: int, test_frac: float, val_frac: float) -> float:
59
+ """
60
+ Calculate possible validation fraction based on the number of folds and test fraction.
61
+
62
+ Parameters
63
+ ----------
64
+ num_folds : int
65
+ number of train and validation cross-validation folds to generate
66
+ test_frac : float
67
+ The fraction of the data to extract for testing before folds are created
68
+ val_frac : float
69
+ The validation split will contain (val_frac * 100)% of any data not already allocated to the test set.
70
+ Only required if requesting a single [train, val] split.
71
+
72
+ Raises
73
+ ------
74
+ ValueError
75
+ When number of folds requested is less than 1
76
+ ValueError
77
+ When the test fraction is not within 0.0 and 1.0 inclusively
78
+ ValueError
79
+ When more than one fold and the validation fraction are both requested
80
+ ValueError
81
+ When number of folds equals one but the validation fraction is 0.0
82
+ ValueError
83
+ When the validation fraction is not within 0.0 and 1.0 inclusively
84
+
85
+ Returns
86
+ -------
87
+ float
88
+ The updated validation fraction of the remaining data after the testing fraction is removed
89
+ """
90
+ if num_folds < 1:
91
+ raise ValueError(f"Number of folds must be greater than or equal to 1, got {num_folds}")
92
+ if test_frac < 0.0 or test_frac > 1.0:
93
+ raise ValueError(f"test_frac out of bounds. Must be between 0.0 and 1.0, got {test_frac}")
94
+
95
+ # val base is a variable placeholder so val_frac can be ignored if num_folds != 1
96
+ val_base: float = 1.0
97
+ if num_folds == 1:
98
+ if val_frac == 0.0:
99
+ raise ValueError("If num_folds equals 1, must assign a value to val_frac")
100
+ if val_frac < 0.0 or val_frac > 1.0:
101
+ raise ValueError(f"val_frac out of bounds. Must be between 0.0 and 1.0, got {val_frac}")
102
+ val_base = val_frac
103
+ # num folds must be >1 in this case
104
+ elif val_frac != 0.0:
105
+ raise ValueError("Can only specify val_frac when num_folds equals 1")
106
+
107
+ # This value is mathematically bound between 0-1 inclusive
108
+ return val_base * (1.0 / num_folds) * (1.0 - test_frac)
109
+
110
+
111
+ def _validate_labels(labels: NDArray[np.intp], total_partitions: int) -> None:
112
+ """
113
+ Check to make sure there is more input data than the total number of partitions requested
114
+
115
+ Parameters
116
+ ----------
117
+ labels : np.ndarray of ints
118
+ All class labels from the input dataset
119
+ total_partitions : int
120
+ Number of [train, val, test] splits requested
121
+
122
+ Raises
123
+ ------
124
+ ValueError
125
+ When more partitions are requested than number of labels.
126
+ ValueError
127
+ When the labels are considered continuous by Scikit-Learn. This does not necessarily
128
+ mean that floats are not accepted as a label format. Rather, this implies that
129
+ there are too many unique values in the set relative to its cardinality.
130
+ """
131
+
132
+ if len(labels) <= total_partitions:
133
+ raise ValueError(
134
+ "Total number of labels must be greater than the total number of partitions. "
135
+ f"Got {len(labels)} labels and {total_partitions} total [train, val, test] partitions."
136
+ )
137
+
138
+ if type_of_target(labels) == "continuous":
139
+ raise ValueError("Detected continuous labels. Labels must be discrete for proper stratification")
140
+
141
+
142
+ def is_stratifiable(labels: NDArray[np.intp], num_partitions: int) -> bool:
143
+ """
144
+ Check if the dataset can be stratified by class label over the given number of partitions
145
+
146
+ Parameters
147
+ ----------
148
+ labels : NDArray of ints
149
+ All class labels of the input dataset
150
+ num_partitions : int
151
+ Total number of [train, val, test] splits requested
152
+
153
+ Returns
154
+ -------
155
+ bool
156
+ True if dataset can be stratified else False
157
+
158
+ Warns
159
+ -----
160
+ UserWarning
161
+ Warns user if the dataset cannot be stratified due to the total number of [train, val, test]
162
+ partitions exceeding the number of instances of the rarest class label.
163
+ """
164
+
165
+ # Get the minimum count of all labels
166
+ lowest_label_count = np.unique(labels, return_counts=True)[1].min()
167
+ if lowest_label_count < num_partitions:
168
+ warnings.warn(
169
+ f"Unable to stratify due to label frequency. The lowest label count ({lowest_label_count}) is fewer "
170
+ f"than the total number of partitions ({num_partitions}) requested.",
171
+ UserWarning,
172
+ )
173
+ return False
174
+ return True
175
+
176
+
177
+ def is_groupable(group_ids: NDArray[np.intp], num_partitions: int) -> bool:
178
+ """
179
+ Warns user if the number of unique group_ids is incompatible with a grouped partition containing
180
+ num_folds folds. If this is the case, returns groups=None, which tells the partitioner not to
181
+ group the input data.
182
+
183
+ Parameters
184
+ ----------
185
+ group_ids : NDArray of ints
186
+ The id of the group each sample at the corresponding index belongs to
187
+ num_partitions : int
188
+ Total number of train, val, and test splits requested
189
+
190
+ Returns
191
+ -------
192
+ bool
193
+ True if the dataset can be grouped by the given group ids else False
194
+
195
+ Warns
196
+ -----
197
+ UserWarning
198
+ Warns if there are fewer groups than the requested number of partitions plus one
199
+ """
200
+
201
+ num_unique_groups = len(np.unique(group_ids))
202
+ # Cannot separate if only one group exists
203
+ if num_unique_groups == 1:
204
+ return False
205
+
206
+ if num_unique_groups < num_partitions:
207
+ warnings.warn(
208
+ f"Groups must be greater than num partitions. Got {num_unique_groups} and {num_partitions}. "
209
+ "Reverting to ungrouped partitioning",
210
+ UserWarning,
211
+ )
212
+ return False
213
+ return True
214
+
215
+
216
+ def bin_kmeans(array: NDArray[Any]) -> NDArray[np.intp]:
217
+ """
218
+ Find bins of continuous data by iteratively applying k-means clustering, and keeping the
219
+ clustering with the highest silhouette score.
220
+
221
+ Parameters
222
+ ----------
223
+ array : NDArray
224
+ continuous data to bin
225
+
226
+ Returns
227
+ -------
228
+ NDArray[int]:
229
+ bin numbers assigned by the kmeans best clusterer.
230
+ """
231
+
232
+ if array.ndim == 1:
233
+ array = array.reshape([-1, 1])
234
+ best_score = 0.60
235
+ else:
236
+ best_score = 0.50
237
+ bin_index = np.zeros(len(array), dtype=np.intp)
238
+ for k in range(2, 20):
239
+ clusterer = KMeans(n_clusters=k)
240
+ cluster_labels = clusterer.fit_predict(array)
241
+ score = silhouette_score(array, cluster_labels, sample_size=25_000)
242
+ if score > best_score:
243
+ best_score = score
244
+ bin_index = cluster_labels.astype(np.intp)
245
+ return bin_index
246
+
247
+
248
+ def get_group_ids(metadata: dict[str, Any], group_names: list[str], num_samples: int) -> NDArray[np.intp]:
249
+ """
250
+ Returns individual group numbers based on a subset of metadata defined by groupnames
251
+
252
+ Parameters
253
+ ----------
254
+ metadata : dict
255
+ dictionary containing all metadata
256
+ groupnames : list
257
+ which groups from the metadata dictionary to consider for dataset grouping
258
+ num_samples : int
259
+ number of labels. Used to ensure agreement between input data/labels and metadata entries.
260
+
261
+ Raises
262
+ ------
263
+ IndexError
264
+ raised if an entry in the metadata dictionary doesn't have the same length as num_samples
265
+
266
+ Returns
267
+ -------
268
+ np.ndarray
269
+ group identifiers from metadata
270
+ """
271
+ features2group = {k: np.array(v) for k, v in metadata.items() if k in group_names}
272
+ if not features2group:
273
+ return np.zeros(num_samples, dtype=np.intp)
274
+ for name, feature in features2group.items():
275
+ if len(feature) != num_samples:
276
+ raise ValueError(
277
+ f"Feature length does not match number of labels. "
278
+ f"Got {len(feature)} features and {num_samples} samples"
279
+ )
280
+
281
+ if type_of_target(feature) == "continuous":
282
+ features2group[name] = bin_kmeans(feature)
283
+ binned_features = np.stack(list(features2group.values()), axis=1)
284
+ _, group_ids = np.unique(binned_features, axis=0, return_inverse=True)
285
+ return group_ids
286
+
287
+
288
+ def make_splits(
289
+ index: NDArray[np.intp],
290
+ labels: NDArray[np.intp],
291
+ n_folds: int,
292
+ groups: NDArray[np.intp] | None,
293
+ stratified: bool,
294
+ ) -> list[TrainValSplit]:
295
+ """
296
+ Split data into n_folds partitions of training and validation data.
297
+
298
+ Parameters
299
+ ----------
300
+ index : NDArray of ints
301
+ index corresponding to each label
302
+ labels : NDArray of ints
303
+ classification labels
304
+ n_folds : int
305
+ number of [train, val] folds
306
+ groups : NDArray of ints or None
307
+ group index for grouped partitions. Grouped partitions are split such that no group id is
308
+ present in both a training and validation split.
309
+ stratified : bool
310
+ If True, maintain dataset class balance within each [train, val] split
311
+
312
+ Returns
313
+ -------
314
+ split_defs : list[TrainValSplit]
315
+ List of TrainValSplits, which specify train index, validation index, and the ratio of
316
+ validation to all data.
317
+ """
318
+ split_defs: list[TrainValSplit] = []
319
+ n_labels = len(np.unique(labels))
320
+ splitter = KFOLD_GROUP_STRATIFIED_MAP[(groups is not None, stratified)](n_folds)
321
+ good = False
322
+ attempts = 0
323
+ while not good and attempts < 3:
324
+ attempts += 1
325
+ splits = splitter.split(index, labels, groups)
326
+ split_defs.clear()
327
+ for train_idx, eval_idx in splits:
328
+ # test_ratio = len(eval_idx) / len(index)
329
+ t = np.atleast_1d(train_idx).astype(np.intp)
330
+ v = np.atleast_1d(eval_idx).astype(np.intp)
331
+ good = good or (len(np.unique(labels[t])) == n_labels and len(np.unique(labels[v])) == n_labels)
332
+ split_defs.append(TrainValSplit(t, v))
333
+ if not good and attempts == 3:
334
+ warnings.warn("Unable to create a good split definition, not all classes are represented in each split.")
335
+ return split_defs
336
+
337
+
338
+ def find_best_split(
339
+ labels: NDArray[np.intp], split_defs: list[TrainValSplit], stratified: bool, split_frac: float
340
+ ) -> TrainValSplit:
341
+ """
342
+ Finds the split that most closely satisfies a criterion determined by the arguments passed.
343
+ If stratified is True, returns the split whose class balance most closely resembles the overall
344
+ class balance. If false, returns the split with the size closest to the desired split_frac
345
+
346
+ Parameters
347
+ ----------
348
+ labels : np.ndarray
349
+ Labels upon which splits are (optionally) stratified
350
+ split_defs : list of TrainValSplits
351
+ Specifies the train index, validation index
352
+ stratified : bool
353
+ If True, maintain dataset class balance within each [train, val] split
354
+ split_frac : float
355
+ Desired fraction of the dataset sequestered for evaluation
356
+
357
+ Returns
358
+ -------
359
+ TrainValSplit
360
+ Indices of data partitioned for training and evaluation
361
+ """
362
+
363
+ # Minimization functions and helpers
364
+ def freq(arr: NDArray[Any], minlength: int = 0) -> NDArray[np.floating[Any]]:
365
+ counts = np.bincount(arr, minlength=minlength)
366
+ return counts / np.sum(counts)
367
+
368
+ def weight(arr: NDArray, class_freq: NDArray) -> np.float64:
369
+ return np.sum(np.abs(freq(arr, len(class_freq)) - class_freq))
370
+
371
+ def class_freq_diff(split: TrainValSplit) -> np.float64:
372
+ class_freq = freq(labels)
373
+ return weight(labels[split.train], class_freq) + weight(labels[split.val], class_freq)
374
+
375
+ def split_ratio(split: TrainValSplit) -> np.float64:
376
+ return np.float64(len(split.val) / (len(split.val) + len(split.train)))
377
+
378
+ def split_diff(split: TrainValSplit) -> np.float64:
379
+ return abs(split_frac - split_ratio(split))
380
+
381
+ def split_inv_diff(split: TrainValSplit) -> np.float64:
382
+ return abs(1 - split_frac - split_ratio(split))
383
+
384
+ # Selects minimization function based on inputs
385
+ if stratified:
386
+ key_func = class_freq_diff
387
+ elif split_frac <= 2 / 3:
388
+ key_func = split_diff
389
+ else:
390
+ key_func = split_inv_diff
391
+
392
+ return min(split_defs, key=key_func)
393
+
394
+
395
+ def single_split(
396
+ index: NDArray[np.intp],
397
+ labels: NDArray[np.intp],
398
+ split_frac: float,
399
+ groups: NDArray[np.intp] | None = None,
400
+ stratified: bool = False,
401
+ ) -> TrainValSplit:
402
+ """
403
+ Handles the special case where only 1 partition of the data is desired (such as when
404
+ generating the test holdout split). In this case, the desired fraction of the data to be
405
+ partitioned into the test data must be specified, and a single [train, val] pair is returned.
406
+
407
+ Parameters
408
+ ----------
409
+ index : NDArray of ints
410
+ Input Dataset index corresponding to each label
411
+ labels : NDArray of ints
412
+ Labels upon which splits are (optionally) stratified
413
+ split_frac : float
414
+ Fraction of incoming data to be set aside for evaluation
415
+ groups : NDArray of ints, Optional
416
+ Group_ids (same shape as labels) for optional group partitioning
417
+ stratified : bool, default False
418
+ Generates stratified splits if true (recommended)
419
+
420
+ Returns
421
+ -------
422
+ TrainValSplit
423
+ Indices of data partitioned for training and evaluation
424
+ """
425
+
426
+ _, label_counts = np.unique(labels, return_counts=True)
427
+ max_folds = label_counts.min()
428
+ min_folds = np.unique(groups).shape[0] if groups is not None else 2
429
+ divisor = split_frac + 1e-06 if split_frac <= 2 / 3 else 1 - split_frac - 1e-06
430
+ n_folds = round(min(max(1 / divisor, min_folds), max_folds)) # Clips value between min_folds and max_folds
431
+
432
+ split_candidates = make_splits(index, labels, n_folds, groups, stratified)
433
+ return find_best_split(labels, split_candidates, stratified, split_frac)
434
+
435
+
436
+ @set_metadata
437
+ def split_dataset(
438
+ labels: list[int] | NDArray[np.intp],
439
+ num_folds: int = 1,
440
+ stratify: bool = False,
441
+ split_on: list[str] | None = None,
442
+ metadata: dict[str, Any] | None = None,
443
+ test_frac: float = 0.0,
444
+ val_frac: float = 0.0,
445
+ ) -> SplitDatasetOutput:
446
+ """
447
+ Top level splitting function. Returns a dataclass containing a list of train and validation indices.
448
+ Indices for a test holdout may also be optionally included
449
+
450
+ Parameters
451
+ ----------
452
+ labels : list or NDArray of ints
453
+ Classification Labels used to generate splits. Determines the size of the dataset
454
+ num_folds : int, default 1
455
+ Number of [train, val] folds. If equal to 1, val_frac must be greater than 0.0
456
+ stratify : bool, default False
457
+ If true, dataset is split such that the class distribution of the entire dataset is
458
+ preserved within each [train, val] partition, which is generally recommended.
459
+ split_on : list or None, default None
460
+ Keys of the metadata dictionary upon which to group the dataset.
461
+ A grouped partition is divided such that no group is present within both the training and
462
+ validation set. Split_on groups should be selected to mitigate validation bias
463
+ metadata : dict or None, default None
464
+ Dict containing data for potential dataset grouping. See split_on above
465
+ test_frac : float, default 0.0
466
+ Fraction of data to be optionally held out for test set
467
+ val_frac : float, default 0.0
468
+ Fraction of training data to be set aside for validation in the case where a single
469
+ [train, val] split is desired
470
+
471
+ Returns
472
+ -------
473
+ split_defs : SplitDatasetOutput
474
+ Output class containing a list of indices of training
475
+ and validation data for each fold and optional test indices
476
+
477
+ Raises
478
+ ------
479
+ TypeError
480
+ Raised if split_on is passed, but metadata is None or empty
481
+
482
+ Note
483
+ ----
484
+ When specifying groups and/or stratification, ratios for test and validation splits can vary
485
+ as the stratification and grouping take higher priority than the percentages
486
+ """
487
+
488
+ val_frac = calculate_validation_fraction(num_folds, test_frac, val_frac)
489
+ total_partitions = num_folds + 1 if test_frac else num_folds
490
+
491
+ if isinstance(labels, list):
492
+ labels = np.array(labels, dtype=np.intp)
493
+
494
+ label_length: int = len(labels)
495
+
496
+ _validate_labels(labels, total_partitions)
497
+ stratify &= is_stratifiable(labels, total_partitions)
498
+ groups = None
499
+ if split_on:
500
+ if metadata is None or metadata == {}:
501
+ raise TypeError("If split_on is specified, metadata must also be provided, got None")
502
+ possible_groups = get_group_ids(metadata, split_on, label_length)
503
+ # Accounts for a test set that is 100 % of the data
504
+ group_partitions = total_partitions + 1 if val_frac else total_partitions
505
+ if is_groupable(possible_groups, group_partitions):
506
+ groups = possible_groups
507
+
508
+ test_indices: NDArray[np.intp]
509
+ index = np.arange(label_length)
510
+
511
+ tv_indices, test_indices = (
512
+ single_split(index=index, labels=labels, split_frac=test_frac, groups=groups, stratified=stratify)
513
+ if test_frac
514
+ else (index, np.array([], dtype=np.intp))
515
+ )
516
+
517
+ tv_labels = labels[tv_indices]
518
+ tv_groups = groups[tv_indices] if groups is not None else None
519
+
520
+ if num_folds == 1:
521
+ tv_splits = [single_split(tv_indices, tv_labels, val_frac, tv_groups, stratify)]
522
+ else:
523
+ tv_splits = make_splits(tv_indices, tv_labels, num_folds, tv_groups, stratify)
524
+
525
+ folds: list[TrainValSplit] = [TrainValSplit(tv_indices[split.train], tv_indices[split.val]) for split in tv_splits]
526
+
527
+ return SplitDatasetOutput(test_indices, folds)
dataeval/utils/image.py CHANGED
@@ -63,8 +63,8 @@ def edge_filter(image: ArrayLike, offset: float = 0.5) -> NDArray[np.uint8]:
63
63
  """
64
64
  Returns the image filtered using a 3x3 edge detection kernel:
65
65
  [[ -1, -1, -1 ],
66
- [ -1, 8, -1 ],
67
- [ -1, -1, -1 ]]
66
+ [ -1, 8, -1 ],
67
+ [ -1, -1, -1 ]]
68
68
  """
69
69
  edges = convolve2d(image, EDGE_KERNEL, mode="same", boundary="symm") + offset
70
70
  np.clip(edges, 0, 255, edges)