dataeval 0.83.0__py3-none-any.whl → 0.84.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +1 -1
- dataeval/config.py +3 -3
- dataeval/metrics/bias/__init__.py +11 -1
- dataeval/metrics/bias/_completeness.py +130 -0
- dataeval/metrics/stats/_base.py +26 -30
- dataeval/metrics/stats/_labelstats.py +4 -45
- dataeval/outputs/__init__.py +2 -1
- dataeval/outputs/_bias.py +31 -22
- dataeval/outputs/_stats.py +2 -3
- dataeval/typing.py +3 -3
- dataeval/utils/_array.py +26 -1
- dataeval/utils/data/_dataset.py +2 -0
- dataeval/utils/data/_embeddings.py +99 -21
- dataeval/utils/data/_images.py +38 -15
- dataeval/utils/data/_selection.py +3 -15
- dataeval/utils/data/_split.py +76 -129
- dataeval/utils/metadata.py +1 -1
- {dataeval-0.83.0.dist-info → dataeval-0.84.0.dist-info}/METADATA +1 -1
- {dataeval-0.83.0.dist-info → dataeval-0.84.0.dist-info}/RECORD +21 -20
- {dataeval-0.83.0.dist-info → dataeval-0.84.0.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.83.0.dist-info → dataeval-0.84.0.dist-info}/WHEEL +0 -0
@@ -3,14 +3,14 @@ from __future__ import annotations
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
5
|
import math
|
6
|
-
from typing import Any, Iterator, Sequence
|
6
|
+
from typing import Any, Iterator, Sequence, cast
|
7
7
|
|
8
8
|
import torch
|
9
9
|
from torch.utils.data import DataLoader, Subset
|
10
10
|
from tqdm import tqdm
|
11
11
|
|
12
12
|
from dataeval.config import DeviceLike, get_device
|
13
|
-
from dataeval.typing import Array, Dataset
|
13
|
+
from dataeval.typing import Array, Dataset, Transform
|
14
14
|
from dataeval.utils.torch.models import SupportsEncode
|
15
15
|
|
16
16
|
|
@@ -26,11 +26,15 @@ class Embeddings:
|
|
26
26
|
Dataset to access original images from.
|
27
27
|
batch_size : int
|
28
28
|
Batch size to use when encoding images.
|
29
|
+
transforms : Transform or Sequence[Transform] or None, default None
|
30
|
+
Transforms to apply to images before encoding.
|
29
31
|
model : torch.nn.Module or None, default None
|
30
32
|
Model to use for encoding images.
|
31
33
|
device : DeviceLike or None, default None
|
32
34
|
The hardware device to use if specified, otherwise uses the DataEval
|
33
35
|
default or torch default.
|
36
|
+
cache : bool, default False
|
37
|
+
Whether to cache the embeddings in memory.
|
34
38
|
verbose : bool, default False
|
35
39
|
Whether to print progress bar when encoding images.
|
36
40
|
"""
|
@@ -43,19 +47,27 @@ class Embeddings:
|
|
43
47
|
self,
|
44
48
|
dataset: Dataset[tuple[Array, Any, Any]],
|
45
49
|
batch_size: int,
|
50
|
+
transforms: Transform[torch.Tensor] | Sequence[Transform[torch.Tensor]] | None = None,
|
46
51
|
model: torch.nn.Module | None = None,
|
47
52
|
device: DeviceLike | None = None,
|
53
|
+
cache: bool = False,
|
48
54
|
verbose: bool = False,
|
49
55
|
) -> None:
|
50
56
|
self.device = get_device(device)
|
51
|
-
self.
|
57
|
+
self.cache = cache
|
58
|
+
self.batch_size = batch_size if batch_size > 0 else 1
|
52
59
|
self.verbose = verbose
|
53
60
|
|
54
61
|
self._dataset = dataset
|
62
|
+
self._length = len(dataset)
|
55
63
|
model = torch.nn.Flatten() if model is None else model
|
64
|
+
self._transforms = [transforms] if isinstance(transforms, Transform) else transforms
|
56
65
|
self._model = model.to(self.device).eval()
|
57
66
|
self._encoder = model.encode if isinstance(model, SupportsEncode) else model
|
58
67
|
self._collate_fn = lambda datum: [torch.as_tensor(i) for i, _, _ in datum]
|
68
|
+
self._cached_idx = set()
|
69
|
+
self._embeddings: torch.Tensor = torch.empty(())
|
70
|
+
self._shallow: bool = False
|
59
71
|
|
60
72
|
def to_tensor(self, indices: Sequence[int] | None = None) -> torch.Tensor:
|
61
73
|
"""
|
@@ -79,30 +91,96 @@ class Embeddings:
|
|
79
91
|
else:
|
80
92
|
return self[:]
|
81
93
|
|
82
|
-
|
83
|
-
|
94
|
+
@classmethod
|
95
|
+
def from_array(cls, array: Array, device: DeviceLike | None = None) -> Embeddings:
|
96
|
+
"""
|
97
|
+
Instantiates a shallow Embeddings object using an array.
|
98
|
+
|
99
|
+
Parameters
|
100
|
+
----------
|
101
|
+
array : Array
|
102
|
+
The array to convert to embeddings.
|
103
|
+
device : DeviceLike or None, default None
|
104
|
+
The hardware device to use if specified, otherwise uses the DataEval
|
105
|
+
default or torch default.
|
106
|
+
|
107
|
+
Returns
|
108
|
+
-------
|
109
|
+
Embeddings
|
110
|
+
|
111
|
+
Example
|
112
|
+
-------
|
113
|
+
>>> import numpy as np
|
114
|
+
>>> from dataeval.utils.data._embeddings import Embeddings
|
115
|
+
>>> array = np.random.randn(100, 3, 224, 224)
|
116
|
+
>>> embeddings = Embeddings.from_array(array)
|
117
|
+
>>> print(embeddings.to_tensor().shape)
|
118
|
+
torch.Size([100, 3, 224, 224])
|
119
|
+
"""
|
120
|
+
embeddings = Embeddings([], 0, None, None, device, True, False)
|
121
|
+
embeddings._length = len(array)
|
122
|
+
embeddings._cached_idx = set(range(len(array)))
|
123
|
+
embeddings._embeddings = torch.as_tensor(array).to(get_device(device))
|
124
|
+
embeddings._shallow = True
|
125
|
+
return embeddings
|
126
|
+
|
127
|
+
def _encode(self, images: list[torch.Tensor]) -> torch.Tensor:
|
128
|
+
if self._transforms:
|
129
|
+
images = [transform(image) for transform in self._transforms for image in images]
|
130
|
+
return self._encoder(torch.stack(images).to(self.device))
|
131
|
+
|
132
|
+
@torch.no_grad() # Reduce overhead cost by not tracking tensor gradients
|
84
133
|
def _batch(self, indices: Sequence[int]) -> Iterator[torch.Tensor]:
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
134
|
+
dataset = cast(torch.utils.data.Dataset[tuple[Array, Any, Any]], self._dataset)
|
135
|
+
total_batches = math.ceil(len(indices) / self.batch_size)
|
136
|
+
|
137
|
+
# If not caching, process all indices normally
|
138
|
+
if not self.cache:
|
139
|
+
for images in tqdm(
|
140
|
+
DataLoader(Subset(dataset, indices), self.batch_size, collate_fn=self._collate_fn),
|
141
|
+
total=total_batches,
|
142
|
+
desc="Batch embedding",
|
143
|
+
disable=not self.verbose,
|
144
|
+
):
|
145
|
+
yield self._encode(images)
|
146
|
+
return
|
147
|
+
|
148
|
+
# If caching, process each batch of indices at a time, preserving original order
|
149
|
+
for i in tqdm(range(0, len(indices), self.batch_size), desc="Batch embedding", disable=not self.verbose):
|
150
|
+
batch = indices[i : i + self.batch_size]
|
151
|
+
uncached = [idx for idx in batch if idx not in self._cached_idx]
|
152
|
+
|
153
|
+
if uncached:
|
154
|
+
# Process uncached indices as as single batch
|
155
|
+
for images in DataLoader(Subset(dataset, uncached), len(uncached), collate_fn=self._collate_fn):
|
156
|
+
embeddings = self._encode(images)
|
157
|
+
|
158
|
+
if not self._embeddings.shape:
|
159
|
+
full_shape = (len(self._dataset), *embeddings.shape[1:])
|
160
|
+
self._embeddings = torch.empty(full_shape, dtype=embeddings.dtype, device=self.device)
|
161
|
+
|
162
|
+
self._embeddings[uncached] = embeddings
|
163
|
+
self._cached_idx.update(uncached)
|
164
|
+
|
165
|
+
yield self._embeddings[batch]
|
94
166
|
|
95
167
|
def __getitem__(self, key: int | slice, /) -> torch.Tensor:
|
96
|
-
if isinstance(key, slice):
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
168
|
+
if not isinstance(key, slice) and not hasattr(key, "__int__"):
|
169
|
+
raise TypeError("Invalid argument type.")
|
170
|
+
|
171
|
+
if self._shallow:
|
172
|
+
if not self._embeddings.shape:
|
173
|
+
raise ValueError("Embeddings not initialized.")
|
174
|
+
return self._embeddings[key]
|
175
|
+
|
176
|
+
indices = list(range(len(self._dataset))[key]) if isinstance(key, slice) else [int(key)]
|
177
|
+
result = torch.vstack(list(self._batch(indices))).to(self.device)
|
178
|
+
return result.squeeze(0) if len(indices) == 1 else result
|
101
179
|
|
102
180
|
def __iter__(self) -> Iterator[torch.Tensor]:
|
103
181
|
# process in batches while yielding individual embeddings
|
104
|
-
for batch in self._batch(range(
|
182
|
+
for batch in self._batch(range(self._length)):
|
105
183
|
yield from batch
|
106
184
|
|
107
185
|
def __len__(self) -> int:
|
108
|
-
return
|
186
|
+
return self._length
|
dataeval/utils/data/_images.py
CHANGED
@@ -2,11 +2,15 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
|
-
from typing import Any, Generic, Iterator, Sequence, TypeVar, cast, overload
|
5
|
+
from typing import TYPE_CHECKING, Any, Generic, Iterator, Sequence, TypeVar, cast, overload
|
6
6
|
|
7
|
-
from dataeval.typing import Dataset
|
7
|
+
from dataeval.typing import Array, Dataset
|
8
|
+
from dataeval.utils._array import as_numpy, channels_first_to_last
|
8
9
|
|
9
|
-
|
10
|
+
if TYPE_CHECKING:
|
11
|
+
from matplotlib.figure import Figure
|
12
|
+
|
13
|
+
T = TypeVar("T", bound=Array)
|
10
14
|
|
11
15
|
|
12
16
|
class Images(Generic[T]):
|
@@ -21,7 +25,10 @@ class Images(Generic[T]):
|
|
21
25
|
Dataset to access images from.
|
22
26
|
"""
|
23
27
|
|
24
|
-
def __init__(
|
28
|
+
def __init__(
|
29
|
+
self,
|
30
|
+
dataset: Dataset[tuple[T, Any, Any] | T],
|
31
|
+
) -> None:
|
25
32
|
self._is_tuple_datum = isinstance(dataset[0], tuple)
|
26
33
|
self._dataset = dataset
|
27
34
|
|
@@ -40,25 +47,41 @@ class Images(Generic[T]):
|
|
40
47
|
"""
|
41
48
|
return self[:]
|
42
49
|
|
50
|
+
def plot(
|
51
|
+
self,
|
52
|
+
indices: Sequence[int],
|
53
|
+
images_per_row: int = 3,
|
54
|
+
figsize: tuple[int, int] = (10, 10),
|
55
|
+
) -> Figure:
|
56
|
+
import matplotlib.pyplot as plt
|
57
|
+
|
58
|
+
num_images = len(indices)
|
59
|
+
num_rows = (num_images + images_per_row - 1) // images_per_row
|
60
|
+
fig, axes = plt.subplots(num_rows, images_per_row, figsize=figsize)
|
61
|
+
for i, ax in enumerate(axes.flatten()):
|
62
|
+
image = channels_first_to_last(as_numpy(self[i]))
|
63
|
+
ax.imshow(image)
|
64
|
+
ax.axis("off")
|
65
|
+
plt.tight_layout()
|
66
|
+
return fig
|
67
|
+
|
43
68
|
@overload
|
44
69
|
def __getitem__(self, key: int, /) -> T: ...
|
45
70
|
@overload
|
46
71
|
def __getitem__(self, key: slice, /) -> Sequence[T]: ...
|
47
72
|
|
48
73
|
def __getitem__(self, key: int | slice, /) -> Sequence[T] | T:
|
74
|
+
if isinstance(key, slice):
|
75
|
+
return [self._get_image(k) for k in range(len(self._dataset))[key]]
|
76
|
+
elif hasattr(key, "__int__"):
|
77
|
+
return self._get_image(int(key))
|
78
|
+
raise TypeError(f"Key must be integers or slices, not {type(key)}")
|
79
|
+
|
80
|
+
def _get_image(self, index: int) -> T:
|
49
81
|
if self._is_tuple_datum:
|
50
|
-
|
51
|
-
if isinstance(key, slice):
|
52
|
-
return [dataset[k][0] for k in range(len(self._dataset))[key]]
|
53
|
-
elif isinstance(key, int):
|
54
|
-
return dataset[key][0]
|
82
|
+
return cast(Dataset[tuple[T, Any, Any]], self._dataset)[index][0]
|
55
83
|
else:
|
56
|
-
|
57
|
-
if isinstance(key, slice):
|
58
|
-
return [dataset[k] for k in range(len(self._dataset))[key]]
|
59
|
-
elif isinstance(key, int):
|
60
|
-
return dataset[key]
|
61
|
-
raise TypeError(f"Key must be integers or slices, not {type(key)}")
|
84
|
+
return cast(Dataset[T], self._dataset)[index]
|
62
85
|
|
63
86
|
def __iter__(self) -> Iterator[T]:
|
64
87
|
for i in range(len(self._dataset)):
|
@@ -5,7 +5,7 @@ __all__ = []
|
|
5
5
|
from enum import IntEnum
|
6
6
|
from typing import Generic, Iterator, Sequence, TypeVar
|
7
7
|
|
8
|
-
from dataeval.typing import AnnotatedDataset, DatasetMetadata
|
8
|
+
from dataeval.typing import AnnotatedDataset, DatasetMetadata
|
9
9
|
|
10
10
|
_TDatum = TypeVar("_TDatum")
|
11
11
|
|
@@ -35,8 +35,6 @@ class Select(AnnotatedDataset[_TDatum]):
|
|
35
35
|
The dataset to wrap.
|
36
36
|
selections : Selection or list[Selection], optional
|
37
37
|
The selection criteria to apply to the dataset.
|
38
|
-
transforms : Transform or list[Transform], optional
|
39
|
-
The transforms to apply to the dataset.
|
40
38
|
|
41
39
|
Examples
|
42
40
|
--------
|
@@ -70,16 +68,12 @@ class Select(AnnotatedDataset[_TDatum]):
|
|
70
68
|
self,
|
71
69
|
dataset: AnnotatedDataset[_TDatum],
|
72
70
|
selections: Selection[_TDatum] | Sequence[Selection[_TDatum]] | None = None,
|
73
|
-
transforms: Transform[_TDatum] | Sequence[Transform[_TDatum]] | None = None,
|
74
71
|
) -> None:
|
75
72
|
self.__dict__.update(dataset.__dict__)
|
76
73
|
self._dataset = dataset
|
77
74
|
self._size_limit = len(dataset)
|
78
75
|
self._selection = list(range(self._size_limit))
|
79
76
|
self._selections = self._sort(selections)
|
80
|
-
self._transforms = (
|
81
|
-
[] if transforms is None else [transforms] if isinstance(transforms, Transform) else transforms
|
82
|
-
)
|
83
77
|
|
84
78
|
# Ensure metadata is populated correctly as DatasetMetadata TypedDict
|
85
79
|
_metadata = getattr(dataset, "metadata", {})
|
@@ -98,8 +92,7 @@ class Select(AnnotatedDataset[_TDatum]):
|
|
98
92
|
title = f"{self.__class__.__name__} Dataset"
|
99
93
|
sep = "-" * len(title)
|
100
94
|
selections = f"Selections: [{', '.join([str(s) for s in self._selections])}]"
|
101
|
-
|
102
|
-
return f"{title}\n{sep}{nt}{selections}{nt}{transforms}{nt}Selected Size: {len(self)}\n\n{self._dataset}"
|
95
|
+
return f"{title}\n{sep}{nt}{selections}{nt}Selected Size: {len(self)}\n\n{self._dataset}"
|
103
96
|
|
104
97
|
def _sort(self, selections: Selection[_TDatum] | Sequence[Selection[_TDatum]] | None) -> list[Selection]:
|
105
98
|
if not selections:
|
@@ -117,13 +110,8 @@ class Select(AnnotatedDataset[_TDatum]):
|
|
117
110
|
selection(self)
|
118
111
|
self._selection = self._selection[: self._size_limit]
|
119
112
|
|
120
|
-
def _transform(self, datum: _TDatum) -> _TDatum:
|
121
|
-
for t in self._transforms:
|
122
|
-
datum = t(datum)
|
123
|
-
return datum
|
124
|
-
|
125
113
|
def __getitem__(self, index: int) -> _TDatum:
|
126
|
-
return self.
|
114
|
+
return self._dataset[self._selection[index]]
|
127
115
|
|
128
116
|
def __iter__(self) -> Iterator[_TDatum]:
|
129
117
|
for i in range(len(self)):
|
dataeval/utils/data/_split.py
CHANGED
@@ -2,19 +2,22 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
|
+
import logging
|
5
6
|
import warnings
|
6
|
-
from typing import Any, Iterator, Protocol
|
7
|
+
from typing import Any, Iterator, Protocol, Sequence
|
7
8
|
|
8
9
|
import numpy as np
|
9
10
|
from numpy.typing import NDArray
|
10
|
-
from sklearn.cluster import KMeans
|
11
|
-
from sklearn.metrics import silhouette_score
|
12
11
|
from sklearn.model_selection import GroupKFold, KFold, StratifiedGroupKFold, StratifiedKFold
|
13
12
|
from sklearn.utils.multiclass import type_of_target
|
14
13
|
|
15
|
-
from dataeval.config import
|
14
|
+
from dataeval.config import EPSILON
|
16
15
|
from dataeval.outputs._base import set_metadata
|
17
16
|
from dataeval.outputs._utils import SplitDatasetOutput, TrainValSplit
|
17
|
+
from dataeval.typing import AnnotatedDataset
|
18
|
+
from dataeval.utils.data._metadata import Metadata
|
19
|
+
|
20
|
+
_logger = logging.getLogger(__name__)
|
18
21
|
|
19
22
|
|
20
23
|
class KFoldSplitter(Protocol):
|
@@ -85,7 +88,7 @@ def calculate_validation_fraction(num_folds: int, test_frac: float, val_frac: fl
|
|
85
88
|
return val_base * (1.0 / num_folds) * (1.0 - test_frac)
|
86
89
|
|
87
90
|
|
88
|
-
def
|
91
|
+
def validate_labels(labels: NDArray[np.intp], total_partitions: int) -> None:
|
89
92
|
"""
|
90
93
|
Check to make sure there is more input data than the total number of partitions requested
|
91
94
|
|
@@ -116,7 +119,7 @@ def _validate_labels(labels: NDArray[np.intp], total_partitions: int) -> None:
|
|
116
119
|
raise ValueError("Detected continuous labels. Labels must be discrete for proper stratification")
|
117
120
|
|
118
121
|
|
119
|
-
def
|
122
|
+
def validate_stratifiable(labels: NDArray[np.intp], num_partitions: int) -> None:
|
120
123
|
"""
|
121
124
|
Check if the dataset can be stratified by class label over the given number of partitions
|
122
125
|
|
@@ -132,26 +135,23 @@ def is_stratifiable(labels: NDArray[np.intp], num_partitions: int) -> bool:
|
|
132
135
|
bool
|
133
136
|
True if dataset can be stratified else False
|
134
137
|
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
138
|
+
Raises
|
139
|
+
------
|
140
|
+
ValueError
|
141
|
+
If the dataset cannot be stratified due to the total number of [train, val, test]
|
139
142
|
partitions exceeding the number of instances of the rarest class label.
|
140
143
|
"""
|
141
144
|
|
142
145
|
# Get the minimum count of all labels
|
143
146
|
lowest_label_count = np.unique(labels, return_counts=True)[1].min()
|
144
147
|
if lowest_label_count < num_partitions:
|
145
|
-
|
148
|
+
raise ValueError(
|
146
149
|
f"Unable to stratify due to label frequency. The lowest label count ({lowest_label_count}) is fewer "
|
147
|
-
f"than the total number of partitions ({num_partitions}) requested."
|
148
|
-
UserWarning,
|
150
|
+
f"than the total number of partitions ({num_partitions}) requested."
|
149
151
|
)
|
150
|
-
return False
|
151
|
-
return True
|
152
152
|
|
153
153
|
|
154
|
-
def
|
154
|
+
def validate_groupable(groups: NDArray[np.intp], num_partitions: int) -> None:
|
155
155
|
"""
|
156
156
|
Warns user if the number of unique group_ids is incompatible with a grouped partition containing
|
157
157
|
num_folds folds. If this is the case, returns groups=None, which tells the partitioner not to
|
@@ -159,7 +159,7 @@ def is_groupable(group_ids: NDArray[np.intp], num_partitions: int) -> bool:
|
|
159
159
|
|
160
160
|
Parameters
|
161
161
|
----------
|
162
|
-
|
162
|
+
groups : NDArray of ints
|
163
163
|
The id of the group each sample at the corresponding index belongs to
|
164
164
|
num_partitions : int
|
165
165
|
Total number of train, val, and test splits requested
|
@@ -169,60 +169,24 @@ def is_groupable(group_ids: NDArray[np.intp], num_partitions: int) -> bool:
|
|
169
169
|
bool
|
170
170
|
True if the dataset can be grouped by the given group ids else False
|
171
171
|
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
172
|
+
Raises
|
173
|
+
------
|
174
|
+
ValueError
|
175
|
+
If there are is only one unique group.
|
176
|
+
ValueError
|
177
|
+
If there are fewer groups than the requested number of partitions plus one
|
176
178
|
"""
|
177
179
|
|
178
|
-
num_unique_groups = len(np.unique(
|
180
|
+
num_unique_groups = len(np.unique(groups))
|
179
181
|
# Cannot separate if only one group exists
|
180
182
|
if num_unique_groups == 1:
|
181
|
-
|
183
|
+
raise ValueError(f"Unique groups ({num_unique_groups}) must be greater than 1.")
|
182
184
|
|
183
185
|
if num_unique_groups < num_partitions:
|
184
|
-
|
185
|
-
f"Groups must be greater than num partitions. Got {num_unique_groups} and {num_partitions}. "
|
186
|
-
"Reverting to ungrouped partitioning",
|
187
|
-
UserWarning,
|
188
|
-
)
|
189
|
-
return False
|
190
|
-
return True
|
191
|
-
|
192
|
-
|
193
|
-
def bin_kmeans(array: NDArray[Any]) -> NDArray[np.intp]:
|
194
|
-
"""
|
195
|
-
Find bins of continuous data by iteratively applying k-means clustering, and keeping the
|
196
|
-
clustering with the highest silhouette score.
|
197
|
-
|
198
|
-
Parameters
|
199
|
-
----------
|
200
|
-
array : NDArray
|
201
|
-
continuous data to bin
|
186
|
+
raise ValueError(f"Unique groups ({num_unique_groups}) must be greater than num partitions ({num_partitions}).")
|
202
187
|
|
203
|
-
Returns
|
204
|
-
-------
|
205
|
-
NDArray[int]:
|
206
|
-
bin numbers assigned by the kmeans best clusterer.
|
207
|
-
"""
|
208
188
|
|
209
|
-
|
210
|
-
array = array.reshape([-1, 1])
|
211
|
-
best_score = 0.60
|
212
|
-
else:
|
213
|
-
best_score = 0.50
|
214
|
-
bin_index = np.zeros(len(array), dtype=np.intp)
|
215
|
-
for k in range(2, 20):
|
216
|
-
clusterer = KMeans(n_clusters=k, random_state=get_seed())
|
217
|
-
cluster_labels = clusterer.fit_predict(array)
|
218
|
-
score = silhouette_score(array, cluster_labels, sample_size=25_000, random_state=get_seed())
|
219
|
-
if score > best_score:
|
220
|
-
best_score = score
|
221
|
-
bin_index = cluster_labels.astype(np.intp)
|
222
|
-
return bin_index
|
223
|
-
|
224
|
-
|
225
|
-
def get_group_ids(metadata: dict[str, Any], group_names: list[str], num_samples: int) -> NDArray[np.intp]:
|
189
|
+
def get_groups(metadata: Metadata, split_on: Sequence[str] | None) -> NDArray[np.intp] | None:
|
226
190
|
"""
|
227
191
|
Returns individual group numbers based on a subset of metadata defined by groupnames
|
228
192
|
|
@@ -232,32 +196,20 @@ def get_group_ids(metadata: dict[str, Any], group_names: list[str], num_samples:
|
|
232
196
|
dictionary containing all metadata
|
233
197
|
groupnames : list
|
234
198
|
which groups from the metadata dictionary to consider for dataset grouping
|
235
|
-
num_samples : int
|
236
|
-
number of labels. Used to ensure agreement between input data/labels and metadata entries.
|
237
|
-
|
238
|
-
Raises
|
239
|
-
------
|
240
|
-
IndexError
|
241
|
-
raised if an entry in the metadata dictionary doesn't have the same length as num_samples
|
242
199
|
|
243
200
|
Returns
|
244
201
|
-------
|
245
202
|
np.ndarray
|
246
203
|
group identifiers from metadata
|
247
204
|
"""
|
248
|
-
|
249
|
-
if
|
250
|
-
return
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
if type_of_target(feature) == "continuous":
|
258
|
-
features2group[name] = bin_kmeans(feature)
|
259
|
-
binned_features = np.stack(list(features2group.values()), axis=1)
|
260
|
-
_, group_ids = np.unique(binned_features, axis=0, return_inverse=True)
|
205
|
+
# get only the factors that are present in the metadata
|
206
|
+
if split_on is None:
|
207
|
+
return None
|
208
|
+
|
209
|
+
split_set = set(split_on)
|
210
|
+
indices = [i for i, name in enumerate(metadata.discrete_factor_names) if name in split_set]
|
211
|
+
binned_features = metadata.discrete_data[:, indices]
|
212
|
+
group_ids = np.unique(binned_features, axis=0, return_inverse=True)[1]
|
261
213
|
return group_ids
|
262
214
|
|
263
215
|
|
@@ -294,10 +246,18 @@ def make_splits(
|
|
294
246
|
split_defs: list[TrainValSplit] = []
|
295
247
|
n_labels = len(np.unique(labels))
|
296
248
|
splitter = KFOLD_GROUP_STRATIFIED_MAP[(groups is not None, stratified)](n_folds)
|
249
|
+
_logger.log(logging.DEBUG, f"splitter={splitter.__class__.__name__}(n_splits={n_folds})")
|
297
250
|
good = False
|
298
251
|
attempts = 0
|
299
252
|
while not good and attempts < 3:
|
300
253
|
attempts += 1
|
254
|
+
_logger.log(
|
255
|
+
logging.DEBUG,
|
256
|
+
f"attempt={attempts}: splitter.split("
|
257
|
+
+ f"index=arr(len={len(index)}, unique={np.unique(index)}), "
|
258
|
+
+ f"labels=arr(len={len(index)}, unique={np.unique(index)}), "
|
259
|
+
+ ("groups=None" if groups is None else f"groups=arr(len={len(groups)}, unique={np.unique(groups)}))"),
|
260
|
+
)
|
301
261
|
splits = splitter.split(index, labels, groups)
|
302
262
|
split_defs.clear()
|
303
263
|
for train_idx, eval_idx in splits:
|
@@ -341,20 +301,20 @@ def find_best_split(
|
|
341
301
|
counts = np.bincount(arr, minlength=minlength)
|
342
302
|
return counts / np.sum(counts)
|
343
303
|
|
344
|
-
def weight(arr: NDArray, class_freq: NDArray) ->
|
345
|
-
return np.sum(np.abs(freq(arr, len(class_freq)) - class_freq))
|
304
|
+
def weight(arr: NDArray, class_freq: NDArray) -> float:
|
305
|
+
return float(np.sum(np.abs(freq(arr, len(class_freq)) - class_freq)))
|
346
306
|
|
347
|
-
def class_freq_diff(split: TrainValSplit) ->
|
307
|
+
def class_freq_diff(split: TrainValSplit) -> float:
|
348
308
|
class_freq = freq(labels)
|
349
309
|
return weight(labels[split.train], class_freq) + weight(labels[split.val], class_freq)
|
350
310
|
|
351
|
-
def split_ratio(split: TrainValSplit) ->
|
352
|
-
return
|
311
|
+
def split_ratio(split: TrainValSplit) -> float:
|
312
|
+
return len(split.val) / (len(split.val) + len(split.train))
|
353
313
|
|
354
|
-
def split_diff(split: TrainValSplit) ->
|
314
|
+
def split_diff(split: TrainValSplit) -> float:
|
355
315
|
return abs(split_frac - split_ratio(split))
|
356
316
|
|
357
|
-
def split_inv_diff(split: TrainValSplit) ->
|
317
|
+
def split_inv_diff(split: TrainValSplit) -> float:
|
358
318
|
return abs(1 - split_frac - split_ratio(split))
|
359
319
|
|
360
320
|
# Selects minimization function based on inputs
|
@@ -399,11 +359,12 @@ def single_split(
|
|
399
359
|
Indices of data partitioned for training and evaluation
|
400
360
|
"""
|
401
361
|
|
402
|
-
|
403
|
-
max_folds =
|
404
|
-
|
405
|
-
divisor = split_frac
|
406
|
-
n_folds =
|
362
|
+
unique_groups = 2 if groups is None else len(np.unique(groups))
|
363
|
+
max_folds = min(min(np.unique(labels, return_counts=True)[1]), unique_groups) if stratified else unique_groups
|
364
|
+
|
365
|
+
divisor = split_frac if split_frac <= 2 / 3 else 1 - split_frac
|
366
|
+
n_folds = min(max(round(1 / (divisor + EPSILON)), 2), max_folds) # Clips value between 2 and max_folds
|
367
|
+
_logger.log(logging.DEBUG, f"n_folds={n_folds} clipped between[2, {max_folds}]")
|
407
368
|
|
408
369
|
split_candidates = make_splits(index, labels, n_folds, groups, stratified)
|
409
370
|
return find_best_split(labels, split_candidates, stratified, split_frac)
|
@@ -411,22 +372,20 @@ def single_split(
|
|
411
372
|
|
412
373
|
@set_metadata
|
413
374
|
def split_dataset(
|
414
|
-
|
375
|
+
dataset: AnnotatedDataset[Any] | Metadata,
|
415
376
|
num_folds: int = 1,
|
416
377
|
stratify: bool = False,
|
417
|
-
split_on:
|
418
|
-
metadata: dict[str, Any] | None = None,
|
378
|
+
split_on: Sequence[str] | None = None,
|
419
379
|
test_frac: float = 0.0,
|
420
380
|
val_frac: float = 0.0,
|
421
381
|
) -> SplitDatasetOutput:
|
422
382
|
"""
|
423
|
-
|
424
|
-
Indices for a test holdout may also be optionally included
|
383
|
+
Dataset splitting function. Returns a dataclass containing a list of train and validation indices.
|
425
384
|
|
426
385
|
Parameters
|
427
386
|
----------
|
428
|
-
|
429
|
-
|
387
|
+
dataset : AnnotatedDataset or Metadata
|
388
|
+
Dataset to split.
|
430
389
|
num_folds : int, default 1
|
431
390
|
Number of [train, val] folds. If equal to 1, val_frac must be greater than 0.0
|
432
391
|
stratify : bool, default False
|
@@ -436,8 +395,6 @@ def split_dataset(
|
|
436
395
|
Keys of the metadata dictionary upon which to group the dataset.
|
437
396
|
A grouped partition is divided such that no group is present within both the training and
|
438
397
|
validation set. Split_on groups should be selected to mitigate validation bias
|
439
|
-
metadata : dict or None, default None
|
440
|
-
Dict containing data for potential dataset grouping. See split_on above
|
441
398
|
test_frac : float, default 0.0
|
442
399
|
Fraction of data to be optionally held out for test set
|
443
400
|
val_frac : float, default 0.0
|
@@ -450,13 +407,8 @@ def split_dataset(
|
|
450
407
|
Output class containing a list of indices of training
|
451
408
|
and validation data for each fold and optional test indices
|
452
409
|
|
453
|
-
|
454
|
-
|
455
|
-
TypeError
|
456
|
-
Raised if split_on is passed, but metadata is None or empty
|
457
|
-
|
458
|
-
Note
|
459
|
-
----
|
410
|
+
Notes
|
411
|
+
-----
|
460
412
|
When specifying groups and/or stratification, ratios for test and validation splits can vary
|
461
413
|
as the stratification and grouping take higher priority than the percentages
|
462
414
|
"""
|
@@ -464,30 +416,25 @@ def split_dataset(
|
|
464
416
|
val_frac = calculate_validation_fraction(num_folds, test_frac, val_frac)
|
465
417
|
total_partitions = num_folds + 1 if test_frac else num_folds
|
466
418
|
|
467
|
-
if isinstance(
|
468
|
-
|
419
|
+
metadata = dataset if isinstance(dataset, Metadata) else Metadata(dataset)
|
420
|
+
labels = metadata.class_labels
|
469
421
|
|
470
|
-
|
422
|
+
validate_labels(labels, total_partitions)
|
423
|
+
if stratify:
|
424
|
+
validate_stratifiable(labels, total_partitions)
|
471
425
|
|
472
|
-
|
473
|
-
|
474
|
-
groups = None
|
475
|
-
if split_on:
|
476
|
-
if metadata is None or metadata == {}:
|
477
|
-
raise TypeError("If split_on is specified, metadata must also be provided, got None")
|
478
|
-
possible_groups = get_group_ids(metadata, split_on, label_length)
|
426
|
+
groups = get_groups(metadata, split_on)
|
427
|
+
if groups is not None:
|
479
428
|
# Accounts for a test set that is 100 % of the data
|
480
429
|
group_partitions = total_partitions + 1 if val_frac else total_partitions
|
481
|
-
|
482
|
-
groups = possible_groups
|
430
|
+
validate_groupable(groups, group_partitions)
|
483
431
|
|
484
|
-
index = np.arange(
|
432
|
+
index = np.arange(len(labels))
|
485
433
|
|
486
|
-
|
487
|
-
single_split(index
|
488
|
-
|
489
|
-
|
490
|
-
)
|
434
|
+
if test_frac:
|
435
|
+
tvs = single_split(index, labels, test_frac, groups, stratify)
|
436
|
+
else:
|
437
|
+
tvs = TrainValSplit(index, np.array([], dtype=np.intp))
|
491
438
|
|
492
439
|
tv_labels = labels[tvs.train]
|
493
440
|
tv_groups = groups[tvs.train] if groups is not None else None
|