dataeval 0.74.1__tar.gz → 0.74.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. {dataeval-0.74.1 → dataeval-0.74.2}/PKG-INFO +1 -1
  2. {dataeval-0.74.1 → dataeval-0.74.2}/pyproject.toml +1 -1
  3. dataeval-0.74.2/src/dataeval/__init__.py +36 -0
  4. {dataeval-0.74.1 → dataeval-0.74.2}/src/dataeval/interop.py +14 -2
  5. dataeval-0.74.2/src/dataeval/logging.py +16 -0
  6. {dataeval-0.74.1 → dataeval-0.74.2}/src/dataeval/output.py +1 -1
  7. dataeval-0.74.2/src/dataeval/utils/split_dataset.py +519 -0
  8. dataeval-0.74.1/src/dataeval/__init__.py +0 -17
  9. dataeval-0.74.1/src/dataeval/utils/split_dataset.py +0 -492
  10. {dataeval-0.74.1 → dataeval-0.74.2}/LICENSE.txt +0 -0
  11. {dataeval-0.74.1 → dataeval-0.74.2}/README.md +0 -0
  12. {dataeval-0.74.1 → dataeval-0.74.2}/src/dataeval/detectors/__init__.py +0 -0
  13. {dataeval-0.74.1 → dataeval-0.74.2}/src/dataeval/detectors/drift/__init__.py +0 -0
  14. {dataeval-0.74.1 → dataeval-0.74.2}/src/dataeval/detectors/drift/base.py +0 -0
  15. {dataeval-0.74.1 → dataeval-0.74.2}/src/dataeval/detectors/drift/cvm.py +0 -0
  16. {dataeval-0.74.1 → dataeval-0.74.2}/src/dataeval/detectors/drift/ks.py +0 -0
  17. {dataeval-0.74.1 → dataeval-0.74.2}/src/dataeval/detectors/drift/mmd.py +0 -0
  18. {dataeval-0.74.1 → dataeval-0.74.2}/src/dataeval/detectors/drift/torch.py +0 -0
  19. {dataeval-0.74.1 → dataeval-0.74.2}/src/dataeval/detectors/drift/uncertainty.py +0 -0
  20. {dataeval-0.74.1 → dataeval-0.74.2}/src/dataeval/detectors/drift/updates.py +0 -0
  21. {dataeval-0.74.1 → dataeval-0.74.2}/src/dataeval/detectors/linters/__init__.py +0 -0
  22. {dataeval-0.74.1 → dataeval-0.74.2}/src/dataeval/detectors/linters/clusterer.py +0 -0
  23. {dataeval-0.74.1 → dataeval-0.74.2}/src/dataeval/detectors/linters/duplicates.py +0 -0
  24. {dataeval-0.74.1 → dataeval-0.74.2}/src/dataeval/detectors/linters/merged_stats.py +0 -0
  25. {dataeval-0.74.1 → dataeval-0.74.2}/src/dataeval/detectors/linters/outliers.py +0 -0
  26. {dataeval-0.74.1 → dataeval-0.74.2}/src/dataeval/detectors/ood/__init__.py +0 -0
  27. {dataeval-0.74.1 → dataeval-0.74.2}/src/dataeval/detectors/ood/ae_torch.py +0 -0
  28. {dataeval-0.74.1 → dataeval-0.74.2}/src/dataeval/detectors/ood/base.py +0 -0
  29. {dataeval-0.74.1 → dataeval-0.74.2}/src/dataeval/detectors/ood/base_torch.py +0 -0
  30. {dataeval-0.74.1 → dataeval-0.74.2}/src/dataeval/detectors/ood/metadata_ks_compare.py +0 -0
  31. {dataeval-0.74.1 → dataeval-0.74.2}/src/dataeval/detectors/ood/metadata_least_likely.py +0 -0
  32. {dataeval-0.74.1 → dataeval-0.74.2}/src/dataeval/detectors/ood/metadata_ood_mi.py +0 -0
  33. {dataeval-0.74.1 → dataeval-0.74.2}/src/dataeval/metrics/__init__.py +0 -0
  34. {dataeval-0.74.1 → dataeval-0.74.2}/src/dataeval/metrics/bias/__init__.py +0 -0
  35. {dataeval-0.74.1 → dataeval-0.74.2}/src/dataeval/metrics/bias/balance.py +0 -0
  36. {dataeval-0.74.1 → dataeval-0.74.2}/src/dataeval/metrics/bias/coverage.py +0 -0
  37. {dataeval-0.74.1 → dataeval-0.74.2}/src/dataeval/metrics/bias/diversity.py +0 -0
  38. {dataeval-0.74.1 → dataeval-0.74.2}/src/dataeval/metrics/bias/metadata_preprocessing.py +0 -0
  39. {dataeval-0.74.1 → dataeval-0.74.2}/src/dataeval/metrics/bias/metadata_utils.py +0 -0
  40. {dataeval-0.74.1 → dataeval-0.74.2}/src/dataeval/metrics/bias/parity.py +0 -0
  41. {dataeval-0.74.1 → dataeval-0.74.2}/src/dataeval/metrics/estimators/__init__.py +0 -0
  42. {dataeval-0.74.1 → dataeval-0.74.2}/src/dataeval/metrics/estimators/ber.py +0 -0
  43. {dataeval-0.74.1 → dataeval-0.74.2}/src/dataeval/metrics/estimators/divergence.py +0 -0
  44. {dataeval-0.74.1 → dataeval-0.74.2}/src/dataeval/metrics/estimators/uap.py +0 -0
  45. {dataeval-0.74.1 → dataeval-0.74.2}/src/dataeval/metrics/stats/__init__.py +0 -0
  46. {dataeval-0.74.1 → dataeval-0.74.2}/src/dataeval/metrics/stats/base.py +0 -0
  47. {dataeval-0.74.1 → dataeval-0.74.2}/src/dataeval/metrics/stats/boxratiostats.py +0 -0
  48. {dataeval-0.74.1 → dataeval-0.74.2}/src/dataeval/metrics/stats/datasetstats.py +0 -0
  49. {dataeval-0.74.1 → dataeval-0.74.2}/src/dataeval/metrics/stats/dimensionstats.py +0 -0
  50. {dataeval-0.74.1 → dataeval-0.74.2}/src/dataeval/metrics/stats/hashstats.py +0 -0
  51. {dataeval-0.74.1 → dataeval-0.74.2}/src/dataeval/metrics/stats/labelstats.py +0 -0
  52. {dataeval-0.74.1 → dataeval-0.74.2}/src/dataeval/metrics/stats/pixelstats.py +0 -0
  53. {dataeval-0.74.1 → dataeval-0.74.2}/src/dataeval/metrics/stats/visualstats.py +0 -0
  54. {dataeval-0.74.1 → dataeval-0.74.2}/src/dataeval/py.typed +0 -0
  55. {dataeval-0.74.1 → dataeval-0.74.2}/src/dataeval/utils/__init__.py +0 -0
  56. {dataeval-0.74.1 → dataeval-0.74.2}/src/dataeval/utils/gmm.py +0 -0
  57. {dataeval-0.74.1 → dataeval-0.74.2}/src/dataeval/utils/image.py +0 -0
  58. {dataeval-0.74.1 → dataeval-0.74.2}/src/dataeval/utils/metadata.py +0 -0
  59. {dataeval-0.74.1 → dataeval-0.74.2}/src/dataeval/utils/shared.py +0 -0
  60. {dataeval-0.74.1 → dataeval-0.74.2}/src/dataeval/utils/torch/__init__.py +0 -0
  61. {dataeval-0.74.1 → dataeval-0.74.2}/src/dataeval/utils/torch/blocks.py +0 -0
  62. {dataeval-0.74.1 → dataeval-0.74.2}/src/dataeval/utils/torch/datasets.py +0 -0
  63. {dataeval-0.74.1 → dataeval-0.74.2}/src/dataeval/utils/torch/gmm.py +0 -0
  64. {dataeval-0.74.1 → dataeval-0.74.2}/src/dataeval/utils/torch/models.py +0 -0
  65. {dataeval-0.74.1 → dataeval-0.74.2}/src/dataeval/utils/torch/trainer.py +0 -0
  66. {dataeval-0.74.1 → dataeval-0.74.2}/src/dataeval/utils/torch/utils.py +0 -0
  67. {dataeval-0.74.1 → dataeval-0.74.2}/src/dataeval/workflows/__init__.py +0 -0
  68. {dataeval-0.74.1 → dataeval-0.74.2}/src/dataeval/workflows/sufficiency.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: dataeval
3
- Version: 0.74.1
3
+ Version: 0.74.2
4
4
  Summary: DataEval provides a simple interface to characterize image data and its impact on model performance across classification and object-detection tasks
5
5
  Home-page: https://dataeval.ai/
6
6
  License: MIT
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "dataeval"
3
- version = "0.74.1" # dynamic
3
+ version = "0.74.2" # dynamic
4
4
  description = "DataEval provides a simple interface to characterize image data and its impact on model performance across classification and object-detection tasks"
5
5
  license = "MIT"
6
6
  readme = "README.md"
@@ -0,0 +1,36 @@
1
+ __version__ = "0.74.2"
2
+
3
+ import logging
4
+ from importlib.util import find_spec
5
+
6
+ logging.getLogger(__name__).addHandler(logging.NullHandler())
7
+
8
+
9
+ def log_stderr(level: int = logging.DEBUG) -> None:
10
+ """
11
+ Helper for quickly adding a StreamHandler to the logger. Useful for
12
+ debugging.
13
+ """
14
+ import logging
15
+
16
+ logger = logging.getLogger(__name__)
17
+ handler = logging.StreamHandler()
18
+ handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(message)s"))
19
+ logger.addHandler(handler)
20
+ logger.setLevel(level)
21
+ logger.debug("Added a stderr logging handler to logger: %s", __name__)
22
+
23
+
24
+ _IS_TORCH_AVAILABLE = find_spec("torch") is not None
25
+ _IS_TORCHVISION_AVAILABLE = find_spec("torchvision") is not None
26
+
27
+ del find_spec
28
+
29
+ from dataeval import detectors, metrics # noqa: E402
30
+
31
+ __all__ = ["log_stderr", "detectors", "metrics"]
32
+
33
+ if _IS_TORCH_AVAILABLE:
34
+ from dataeval import utils, workflows
35
+
36
+ __all__ += ["utils", "workflows"]
@@ -1,23 +1,31 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from types import ModuleType
4
+
5
+ from dataeval.logging import LogMessage
6
+
3
7
  __all__ = ["as_numpy", "to_numpy", "to_numpy_iter"]
4
8
 
9
+ import logging
5
10
  from importlib import import_module
6
11
  from typing import Any, Iterable, Iterator
7
12
 
8
13
  import numpy as np
9
14
  from numpy.typing import ArrayLike, NDArray
10
15
 
16
+ _logger = logging.getLogger(__name__)
17
+
11
18
  _MODULE_CACHE = {}
12
19
 
13
20
 
14
- def _try_import(module_name):
21
+ def _try_import(module_name) -> ModuleType | None:
15
22
  if module_name in _MODULE_CACHE:
16
23
  return _MODULE_CACHE[module_name]
17
24
 
18
25
  try:
19
26
  module = import_module(module_name)
20
27
  except ImportError: # pragma: no cover - covered by test_mindeps.py
28
+ _logger.log(logging.INFO, f"Unable to import {module_name}.")
21
29
  module = None
22
30
 
23
31
  _MODULE_CACHE[module_name] = module
@@ -40,12 +48,16 @@ def to_numpy(array: ArrayLike | None, copy: bool = True) -> NDArray[Any]:
40
48
  if array.__class__.__module__.startswith("tensorflow"):
41
49
  tf = _try_import("tensorflow")
42
50
  if tf and tf.is_tensor(array):
51
+ _logger.log(logging.INFO, "Converting Tensorflow array to NumPy array.")
43
52
  return array.numpy().copy() if copy else array.numpy() # type: ignore
44
53
 
45
54
  if array.__class__.__module__.startswith("torch"):
46
55
  torch = _try_import("torch")
47
56
  if torch and isinstance(array, torch.Tensor):
48
- return array.detach().cpu().numpy().copy() if copy else array.detach().cpu().numpy() # type: ignore
57
+ _logger.log(logging.INFO, "Converting PyTorch array to NumPy array.")
58
+ numpy = array.detach().cpu().numpy().copy() if copy else array.detach().cpu().numpy() # type: ignore
59
+ _logger.log(logging.DEBUG, LogMessage(lambda: f"{str(array)} -> {str(numpy)}"))
60
+ return numpy
49
61
 
50
62
  return np.array(array) if copy else np.asarray(array)
51
63
 
@@ -0,0 +1,16 @@
1
+ from typing import Callable
2
+
3
+
4
+ class LogMessage:
5
+ """
6
+ Deferred message callback for logging expensive messages.
7
+ """
8
+
9
+ def __init__(self, fn: Callable[..., str]):
10
+ self._fn = fn
11
+ self._str = None
12
+
13
+ def __str__(self) -> str:
14
+ if self._str is None:
15
+ self._str = self._fn()
16
+ return self._str
@@ -65,7 +65,7 @@ R = TypeVar("R", bound=Output)
65
65
 
66
66
 
67
67
  def set_metadata(fn: Callable[P, R] | None = None, *, state: list[str] | None = None) -> Callable[P, R]:
68
- """Decorator to stamp OutputMetadata classes with runtime metadata"""
68
+ """Decorator to stamp Output classes with runtime metadata"""
69
69
 
70
70
  if fn is None:
71
71
  return partial(set_metadata, state=state) # type: ignore
@@ -0,0 +1,519 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+ from dataeval.output import Output, set_metadata
6
+
7
+ __all__ = ["split_dataset", "SplitDatasetOutput"]
8
+
9
+ import warnings
10
+ from typing import Any, Iterator, NamedTuple, Protocol
11
+
12
+ import numpy as np
13
+ from numpy.typing import NDArray
14
+ from sklearn.cluster import KMeans
15
+ from sklearn.metrics import silhouette_score
16
+ from sklearn.model_selection import GroupKFold, KFold, StratifiedGroupKFold, StratifiedKFold
17
+ from sklearn.utils.multiclass import type_of_target
18
+
19
+
20
+ class TrainValSplit(NamedTuple):
21
+ """Tuple containing train and validation indices"""
22
+
23
+ train: NDArray[np.int_]
24
+ val: NDArray[np.int_]
25
+
26
+
27
+ @dataclass(frozen=True)
28
+ class SplitDatasetOutput(Output):
29
+ """Output class containing test indices and a list of TrainValSplits"""
30
+
31
+ test: NDArray[np.int_]
32
+ folds: list[TrainValSplit]
33
+
34
+
35
+ class KFoldSplitter(Protocol):
36
+ """Protocol covering sklearn KFold variant splitters"""
37
+
38
+ def __init__(self, n_splits: int): ...
39
+ def split(self, X: Any, y: Any, groups: Any) -> Iterator[tuple[NDArray[Any], NDArray[Any]]]: ...
40
+
41
+
42
+ KFOLD_GROUP_STRATIFIED_MAP: dict[tuple[bool, bool], type[KFoldSplitter]] = {
43
+ (False, False): KFold,
44
+ (False, True): StratifiedKFold,
45
+ (True, False): GroupKFold,
46
+ (True, True): StratifiedGroupKFold,
47
+ }
48
+
49
+
50
+ def calculate_validation_fraction(num_folds: int, test_frac: float, val_frac: float) -> float:
51
+ """
52
+ Calculate possible validation fraction based on the number of folds and test fraction.
53
+
54
+ Parameters
55
+ ----------
56
+ num_folds : int
57
+ number of train and validation cross-validation folds to generate
58
+ test_frac : float
59
+ The fraction of the data to extract for testing before folds are created
60
+ val_frac : float
61
+ The validation split will contain (val_frac * 100)% of any data not already allocated to the test set.
62
+ Only required if requesting a single [train, val] split.
63
+
64
+ Raises
65
+ ------
66
+ ValueError
67
+ When number of folds requested is less than 1
68
+ ValueError
69
+ When the test fraction is not within 0.0 and 1.0 inclusively
70
+ ValueError
71
+ When more than one fold and the validation fraction are both requested
72
+ ValueError
73
+ When number of folds equals one but the validation fraction is 0.0
74
+ ValueError
75
+ When the validation fraction is not within 0.0 and 1.0 inclusively
76
+
77
+ Returns
78
+ -------
79
+ float
80
+ The updated validation fraction of the remaining data after the testing fraction is removed
81
+ """
82
+ if num_folds < 1:
83
+ raise ValueError(f"Number of folds must be greater than or equal to 1, got {num_folds}")
84
+ if test_frac < 0.0 or test_frac > 1.0:
85
+ raise ValueError(f"test_frac out of bounds. Must be between 0.0 and 1.0, got {test_frac}")
86
+
87
+ # val base is a variable placeholder so val_frac can be ignored if num_folds != 1
88
+ val_base: float = 1.0
89
+ if num_folds == 1:
90
+ if val_frac == 0.0:
91
+ raise ValueError("If num_folds equals 1, must assign a value to val_frac")
92
+ if val_frac < 0.0 or val_frac > 1.0:
93
+ raise ValueError(f"val_frac out of bounds. Must be between 0.0 and 1.0, got {val_frac}")
94
+ val_base = val_frac
95
+ # num folds must be >1 in this case
96
+ elif val_frac != 0.0:
97
+ raise ValueError("Can only specify val_frac when num_folds equals 1")
98
+
99
+ # This value is mathematically bound between 0-1 inclusive
100
+ return val_base * (1.0 / num_folds) * (1.0 - test_frac)
101
+
102
+
103
+ def _validate_labels(labels: NDArray[np.int_], total_partitions: int) -> None:
104
+ """
105
+ Check to make sure there is more input data than the total number of partitions requested
106
+
107
+ Parameters
108
+ ----------
109
+ labels : np.ndarray of ints
110
+ All class labels from the input dataset
111
+ total_partitions : int
112
+ Number of [train, val, test] splits requested
113
+
114
+ Raises
115
+ ------
116
+ ValueError
117
+ When more partitions are requested than number of labels.
118
+ ValueError
119
+ When the labels are considered continuous by Scikit-Learn. This does not necessarily
120
+ mean that floats are not accepted as a label format. Rather, this implies that
121
+ there are too many unique values in the set relative to its cardinality.
122
+ """
123
+
124
+ if len(labels) <= total_partitions:
125
+ raise ValueError(
126
+ "Total number of labels must be greater than the total number of partitions. "
127
+ f"Got {len(labels)} labels and {total_partitions} total [train, val, test] partitions."
128
+ )
129
+
130
+ if type_of_target(labels) == "continuous":
131
+ raise ValueError("Detected continuous labels. Labels must be discrete for proper stratification")
132
+
133
+
134
+ def is_stratifiable(labels: NDArray[np.int_], num_partitions: int) -> bool:
135
+ """
136
+ Check if the dataset can be stratified by class label over the given number of partitions
137
+
138
+ Parameters
139
+ ----------
140
+ labels : NDArray of ints
141
+ All class labels of the input dataset
142
+ num_partitions : int
143
+ Total number of [train, val, test] splits requested
144
+
145
+ Returns
146
+ -------
147
+ bool
148
+ True if dataset can be stratified else False
149
+
150
+ Warns
151
+ -----
152
+ UserWarning
153
+ Warns user if the dataset cannot be stratified due to the total number of [train, val, test]
154
+ partitions exceeding the number of instances of the rarest class label.
155
+ """
156
+
157
+ # Get the minimum count of all labels
158
+ lowest_label_count = np.unique(labels, return_counts=True)[1].min()
159
+ if lowest_label_count < num_partitions:
160
+ warnings.warn(
161
+ f"Unable to stratify due to label frequency. The lowest label count ({lowest_label_count}) is fewer "
162
+ f"than the total number of partitions ({num_partitions}) requested.",
163
+ UserWarning,
164
+ )
165
+ return False
166
+ return True
167
+
168
+
169
+ def is_groupable(group_ids: NDArray[np.int_], num_partitions: int) -> bool:
170
+ """
171
+ Warns user if the number of unique group_ids is incompatible with a grouped partition containing
172
+ num_folds folds. If this is the case, returns groups=None, which tells the partitioner not to
173
+ group the input data.
174
+
175
+ Parameters
176
+ ----------
177
+ group_ids : NDArray of ints
178
+ The id of the group each sample at the corresponding index belongs to
179
+ num_partitions : int
180
+ Total number of train, val, and test splits requested
181
+
182
+ Returns
183
+ -------
184
+ bool
185
+ True if the dataset can be grouped by the given group ids else False
186
+
187
+ Warns
188
+ -----
189
+ UserWarning
190
+ Warns if there are fewer groups than the requested number of partitions plus one
191
+ """
192
+
193
+ num_unique_groups = len(np.unique(group_ids))
194
+ # Cannot separate if only one group exists
195
+ if num_unique_groups == 1:
196
+ return False
197
+
198
+ if num_unique_groups < num_partitions:
199
+ warnings.warn(
200
+ f"Groups must be greater than num partitions. Got {num_unique_groups} and {num_partitions}. "
201
+ "Reverting to ungrouped partitioning",
202
+ UserWarning,
203
+ )
204
+ return False
205
+ return True
206
+
207
+
208
+ def bin_kmeans(array: NDArray[Any]) -> NDArray[np.int_]:
209
+ """
210
+ Find bins of continuous data by iteratively applying k-means clustering, and keeping the
211
+ clustering with the highest silhouette score.
212
+
213
+ Parameters
214
+ ----------
215
+ array : NDArray
216
+ continuous data to bin
217
+
218
+ Returns
219
+ -------
220
+ NDArray[int]:
221
+ bin numbers assigned by the kmeans best clusterer.
222
+ """
223
+
224
+ if array.ndim == 1:
225
+ array = array.reshape([-1, 1])
226
+ best_score = 0.60
227
+ else:
228
+ best_score = 0.50
229
+ bin_index = np.zeros(len(array), dtype=np.int_)
230
+ for k in range(2, 20):
231
+ clusterer = KMeans(n_clusters=k)
232
+ cluster_labels = clusterer.fit_predict(array)
233
+ score = silhouette_score(array, cluster_labels, sample_size=25_000)
234
+ if score > best_score:
235
+ best_score = score
236
+ bin_index = cluster_labels.astype(np.int_)
237
+ return bin_index
238
+
239
+
240
+ def get_group_ids(metadata: dict[str, Any], group_names: list[str], num_samples: int) -> NDArray[np.int_]:
241
+ """
242
+ Returns individual group numbers based on a subset of metadata defined by groupnames
243
+
244
+ Parameters
245
+ ----------
246
+ metadata : dict
247
+ dictionary containing all metadata
248
+ groupnames : list
249
+ which groups from the metadata dictionary to consider for dataset grouping
250
+ num_samples : int
251
+ number of labels. Used to ensure agreement between input data/labels and metadata entries.
252
+
253
+ Raises
254
+ ------
255
+ IndexError
256
+ raised if an entry in the metadata dictionary doesn't have the same length as num_samples
257
+
258
+ Returns
259
+ -------
260
+ np.ndarray
261
+ group identifiers from metadata
262
+ """
263
+ features2group = {k: np.array(v) for k, v in metadata.items() if k in group_names}
264
+ if not features2group:
265
+ return np.zeros(num_samples, dtype=np.int_)
266
+ for name, feature in features2group.items():
267
+ if len(feature) != num_samples:
268
+ raise ValueError(
269
+ f"Feature length does not match number of labels. "
270
+ f"Got {len(feature)} features and {num_samples} samples"
271
+ )
272
+
273
+ if type_of_target(feature) == "continuous":
274
+ features2group[name] = bin_kmeans(feature)
275
+ binned_features = np.stack(list(features2group.values()), axis=1)
276
+ _, group_ids = np.unique(binned_features, axis=0, return_inverse=True)
277
+ return group_ids
278
+
279
+
280
+ def make_splits(
281
+ index: NDArray[np.int_],
282
+ labels: NDArray[np.int_],
283
+ n_folds: int,
284
+ groups: NDArray[np.int_] | None,
285
+ stratified: bool,
286
+ ) -> list[TrainValSplit]:
287
+ """
288
+ Split data into n_folds partitions of training and validation data.
289
+
290
+ Parameters
291
+ ----------
292
+ index : NDArray of ints
293
+ index corresponding to each label
294
+ labels : NDArray of ints
295
+ classification labels
296
+ n_folds : int
297
+ number of [train, val] folds
298
+ groups : NDArray of ints or None
299
+ group index for grouped partitions. Grouped partitions are split such that no group id is
300
+ present in both a training and validation split.
301
+ stratified : bool
302
+ If True, maintain dataset class balance within each [train, val] split
303
+
304
+ Returns
305
+ -------
306
+ split_defs : list[TrainValSplit]
307
+ List of TrainValSplits, which specify train index, validation index, and the ratio of
308
+ validation to all data.
309
+ """
310
+ split_defs: list[TrainValSplit] = []
311
+ n_labels = len(np.unique(labels))
312
+ splitter = KFOLD_GROUP_STRATIFIED_MAP[(groups is not None, stratified)](n_folds)
313
+ good = False
314
+ attempts = 0
315
+ while not good and attempts < 3:
316
+ attempts += 1
317
+ splits = splitter.split(index, labels, groups)
318
+ split_defs.clear()
319
+ for train_idx, eval_idx in splits:
320
+ # test_ratio = len(eval_idx) / len(index)
321
+ t = np.atleast_1d(train_idx).astype(np.int_)
322
+ v = np.atleast_1d(eval_idx).astype(np.int_)
323
+ good = good or (len(np.unique(labels[t])) == n_labels and len(np.unique(labels[v])) == n_labels)
324
+ split_defs.append(TrainValSplit(t, v))
325
+ if not good and attempts == 3:
326
+ warnings.warn("Unable to create a good split definition, not all classes are represented in each split.")
327
+ return split_defs
328
+
329
+
330
+ def find_best_split(
331
+ labels: NDArray[np.int_], split_defs: list[TrainValSplit], stratified: bool, split_frac: float
332
+ ) -> TrainValSplit:
333
+ """
334
+ Finds the split that most closely satisfies a criterion determined by the arguments passed.
335
+ If stratified is True, returns the split whose class balance most closely resembles the overall
336
+ class balance. If false, returns the split with the size closest to the desired split_frac
337
+
338
+ Parameters
339
+ ----------
340
+ labels : np.ndarray
341
+ Labels upon which splits are (optionally) stratified
342
+ split_defs : list of TrainValSplits
343
+ Specifies the train index, validation index
344
+ stratified : bool
345
+ If True, maintain dataset class balance within each [train, val] split
346
+ split_frac : float
347
+ Desired fraction of the dataset sequestered for evaluation
348
+
349
+ Returns
350
+ -------
351
+ TrainValSplit
352
+ Indices of data partitioned for training and evaluation
353
+ """
354
+
355
+ # Minimization functions and helpers
356
+ def freq(arr: NDArray[Any], minlength: int = 0) -> NDArray[np.floating[Any]]:
357
+ counts = np.bincount(arr, minlength=minlength)
358
+ return counts / np.sum(counts)
359
+
360
+ def weight(arr: NDArray, class_freq: NDArray) -> np.float64:
361
+ return np.sum(np.abs(freq(arr, len(class_freq)) - class_freq))
362
+
363
+ def class_freq_diff(split: TrainValSplit) -> np.float64:
364
+ class_freq = freq(labels)
365
+ return weight(labels[split.train], class_freq) + weight(labels[split.val], class_freq)
366
+
367
+ def split_ratio(split: TrainValSplit) -> np.float64:
368
+ return np.float64(len(split.val) / (len(split.val) + len(split.train)))
369
+
370
+ def split_diff(split: TrainValSplit) -> np.float64:
371
+ return abs(split_frac - split_ratio(split))
372
+
373
+ def split_inv_diff(split: TrainValSplit) -> np.float64:
374
+ return abs(1 - split_frac - split_ratio(split))
375
+
376
+ # Selects minimization function based on inputs
377
+ if stratified:
378
+ key_func = class_freq_diff
379
+ elif split_frac <= 2 / 3:
380
+ key_func = split_diff
381
+ else:
382
+ key_func = split_inv_diff
383
+
384
+ return min(split_defs, key=key_func)
385
+
386
+
387
+ def single_split(
388
+ index: NDArray[np.int_],
389
+ labels: NDArray[np.int_],
390
+ split_frac: float,
391
+ groups: NDArray[np.int_] | None = None,
392
+ stratified: bool = False,
393
+ ) -> TrainValSplit:
394
+ """
395
+ Handles the special case where only 1 partition of the data is desired (such as when
396
+ generating the test holdout split). In this case, the desired fraction of the data to be
397
+ partitioned into the test data must be specified, and a single [train, val] pair is returned.
398
+
399
+ Parameters
400
+ ----------
401
+ index : NDArray of ints
402
+ Input Dataset index corresponding to each label
403
+ labels : NDArray of ints
404
+ Labels upon which splits are (optionally) stratified
405
+ split_frac : float
406
+ Fraction of incoming data to be set aside for evaluation
407
+ groups : NDArray of ints, Optional
408
+ Group_ids (same shape as labels) for optional group partitioning
409
+ stratified : bool, default False
410
+ Generates stratified splits if true (recommended)
411
+
412
+ Returns
413
+ -------
414
+ TrainValSplit
415
+ Indices of data partitioned for training and evaluation
416
+ """
417
+
418
+ _, label_counts = np.unique(labels, return_counts=True)
419
+ max_folds = label_counts.min()
420
+ min_folds = np.unique(groups).shape[0] if groups is not None else 2
421
+ divisor = split_frac + 1e-06 if split_frac <= 2 / 3 else 1 - split_frac - 1e-06
422
+ n_folds = round(min(max(1 / divisor, min_folds), max_folds)) # Clips value between min_folds and max_folds
423
+
424
+ split_candidates = make_splits(index, labels, n_folds, groups, stratified)
425
+ return find_best_split(labels, split_candidates, stratified, split_frac)
426
+
427
+
428
+ @set_metadata
429
+ def split_dataset(
430
+ labels: list[int] | NDArray[np.int_],
431
+ num_folds: int = 1,
432
+ stratify: bool = False,
433
+ split_on: list[str] | None = None,
434
+ metadata: dict[str, Any] | None = None,
435
+ test_frac: float = 0.0,
436
+ val_frac: float = 0.0,
437
+ ) -> SplitDatasetOutput:
438
+ """
439
+ Top level splitting function. Returns a dataclass containing a list of train and validation indices.
440
+ Indices for a test holdout may also be optionally included
441
+
442
+ Parameters
443
+ ----------
444
+ labels : list or NDArray of ints
445
+ Classification Labels used to generate splits. Determines the size of the dataset
446
+ num_folds : int, default 1
447
+ Number of [train, val] folds. If equal to 1, val_frac must be greater than 0.0
448
+ stratify : bool, default False
449
+ If true, dataset is split such that the class distribution of the entire dataset is
450
+ preserved within each [train, val] partition, which is generally recommended.
451
+ split_on : list or None, default None
452
+ Keys of the metadata dictionary upon which to group the dataset.
453
+ A grouped partition is divided such that no group is present within both the training and
454
+ validation set. Split_on groups should be selected to mitigate validation bias
455
+ metadata : dict or None, default None
456
+ Dict containing data for potential dataset grouping. See split_on above
457
+ test_frac : float, default 0.0
458
+ Fraction of data to be optionally held out for test set
459
+ val_frac : float, default 0.0
460
+ Fraction of training data to be set aside for validation in the case where a single
461
+ [train, val] split is desired
462
+
463
+ Returns
464
+ -------
465
+ split_defs : SplitDatasetOutput
466
+ Output class containing a list of indices of training
467
+ and validation data for each fold and optional test indices
468
+
469
+ Raises
470
+ ------
471
+ TypeError
472
+ Raised if split_on is passed, but metadata is None or empty
473
+
474
+ Note
475
+ ----
476
+ When specifying groups and/or stratification, ratios for test and validation splits can vary
477
+ as the stratification and grouping take higher priority than the percentages
478
+ """
479
+
480
+ val_frac = calculate_validation_fraction(num_folds, test_frac, val_frac)
481
+ total_partitions = num_folds + 1 if test_frac else num_folds
482
+
483
+ if isinstance(labels, list):
484
+ labels = np.array(labels, dtype=np.int_)
485
+
486
+ label_length: int = len(labels)
487
+
488
+ _validate_labels(labels, total_partitions)
489
+ stratify &= is_stratifiable(labels, total_partitions)
490
+ groups = None
491
+ if split_on:
492
+ if metadata is None or metadata == {}:
493
+ raise TypeError("If split_on is specified, metadata must also be provided, got None")
494
+ possible_groups = get_group_ids(metadata, split_on, label_length)
495
+ # Accounts for a test set that is 100 % of the data
496
+ group_partitions = total_partitions + 1 if val_frac else total_partitions
497
+ if is_groupable(possible_groups, group_partitions):
498
+ groups = possible_groups
499
+
500
+ test_indices: NDArray[np.int_]
501
+ index = np.arange(label_length)
502
+
503
+ tv_indices, test_indices = (
504
+ single_split(index=index, labels=labels, split_frac=test_frac, groups=groups, stratified=stratify)
505
+ if test_frac
506
+ else (index, np.array([], dtype=np.int_))
507
+ )
508
+
509
+ tv_labels = labels[tv_indices]
510
+ tv_groups = groups[tv_indices] if groups is not None else None
511
+
512
+ if num_folds == 1:
513
+ tv_splits = [single_split(tv_indices, tv_labels, val_frac, tv_groups, stratify)]
514
+ else:
515
+ tv_splits = make_splits(tv_indices, tv_labels, num_folds, tv_groups, stratify)
516
+
517
+ folds: list[TrainValSplit] = [TrainValSplit(tv_indices[split.train], tv_indices[split.val]) for split in tv_splits]
518
+
519
+ return SplitDatasetOutput(test_indices, folds)
@@ -1,17 +0,0 @@
1
- __version__ = "0.74.1"
2
-
3
- from importlib.util import find_spec
4
-
5
- _IS_TORCH_AVAILABLE = find_spec("torch") is not None
6
- _IS_TORCHVISION_AVAILABLE = find_spec("torchvision") is not None
7
-
8
- del find_spec
9
-
10
- from dataeval import detectors, metrics # noqa: E402
11
-
12
- __all__ = ["detectors", "metrics"]
13
-
14
- if _IS_TORCH_AVAILABLE:
15
- from dataeval import utils, workflows
16
-
17
- __all__ += ["utils", "workflows"]