dataeval 0.86.8__py3-none-any.whl → 0.87.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +1 -1
- dataeval/_version.py +2 -2
- dataeval/config.py +4 -19
- dataeval/data/_metadata.py +56 -27
- dataeval/data/_split.py +1 -1
- dataeval/data/selections/_classbalance.py +4 -3
- dataeval/data/selections/_classfilter.py +5 -5
- dataeval/data/selections/_indices.py +2 -2
- dataeval/data/selections/_prioritize.py +249 -29
- dataeval/data/selections/_reverse.py +1 -1
- dataeval/data/selections/_shuffle.py +2 -2
- dataeval/detectors/ood/__init__.py +2 -1
- dataeval/detectors/ood/base.py +38 -1
- dataeval/detectors/ood/knn.py +95 -0
- dataeval/metrics/bias/_balance.py +28 -21
- dataeval/metrics/bias/_diversity.py +4 -4
- dataeval/metrics/bias/_parity.py +2 -2
- dataeval/metrics/stats/_hashstats.py +19 -2
- dataeval/outputs/_workflows.py +20 -7
- dataeval/typing.py +14 -2
- dataeval/utils/__init__.py +2 -2
- dataeval/utils/_bin.py +7 -6
- dataeval/utils/data/__init__.py +2 -0
- dataeval/utils/data/_dataset.py +13 -6
- dataeval/utils/data/_validate.py +169 -0
- dataeval/workflows/sufficiency.py +53 -10
- {dataeval-0.86.8.dist-info → dataeval-0.87.0.dist-info}/METADATA +5 -17
- {dataeval-0.86.8.dist-info → dataeval-0.87.0.dist-info}/RECORD +30 -39
- dataeval/utils/datasets/__init__.py +0 -19
- dataeval/utils/datasets/_antiuav.py +0 -189
- dataeval/utils/datasets/_base.py +0 -262
- dataeval/utils/datasets/_cifar10.py +0 -201
- dataeval/utils/datasets/_fileio.py +0 -142
- dataeval/utils/datasets/_milco.py +0 -197
- dataeval/utils/datasets/_mixin.py +0 -54
- dataeval/utils/datasets/_mnist.py +0 -202
- dataeval/utils/datasets/_ships.py +0 -144
- dataeval/utils/datasets/_types.py +0 -48
- dataeval/utils/datasets/_voc.py +0 -583
- {dataeval-0.86.8.dist-info → dataeval-0.87.0.dist-info}/WHEEL +0 -0
- /dataeval-0.86.8.dist-info/licenses/LICENSE.txt → /dataeval-0.87.0.dist-info/licenses/LICENSE +0 -0
@@ -73,9 +73,9 @@ def balance(
|
|
73
73
|
Return intra/interfactor balance (mutual information)
|
74
74
|
|
75
75
|
>>> bal.factors
|
76
|
-
array([[1. , 0.
|
77
|
-
[0.
|
78
|
-
[0.015, 0.
|
76
|
+
array([[1. , 0. , 0.015],
|
77
|
+
[0. , 0.08 , 0.011],
|
78
|
+
[0.015, 0.011, 1.063]])
|
79
79
|
|
80
80
|
Return classwise balance (mutual information) of factors with individual class_labels
|
81
81
|
|
@@ -95,32 +95,39 @@ def balance(
|
|
95
95
|
|
96
96
|
num_neighbors = _validate_num_neighbors(num_neighbors)
|
97
97
|
|
98
|
-
data = metadata.discretized_data
|
99
98
|
factor_types = {"class_label": "categorical"} | {k: v.factor_type for k, v in metadata.factor_info.items()}
|
100
99
|
is_discrete = [factor_type != "continuous" for factor_type in factor_types.values()]
|
101
100
|
num_factors = len(factor_types)
|
102
101
|
class_labels = metadata.class_labels
|
103
102
|
|
104
103
|
mi = np.full((num_factors, num_factors), np.nan, dtype=np.float32)
|
105
|
-
|
104
|
+
|
105
|
+
# Use numeric data for MI
|
106
|
+
data = np.hstack((class_labels[:, np.newaxis], metadata.digitized_data))
|
107
|
+
|
108
|
+
# Present discrete features composed of distinct values as continuous for `mutual_info_classif`
|
109
|
+
for i, factor_type in enumerate(factor_types):
|
110
|
+
if len(data) == len(np.unique(data[:, i])):
|
111
|
+
is_discrete[i] = False
|
112
|
+
factor_types[factor_type] = "continuous"
|
113
|
+
|
114
|
+
mutual_info_fn_map = {
|
115
|
+
"categorical": mutual_info_classif,
|
116
|
+
"discrete": mutual_info_classif,
|
117
|
+
"continuous": mutual_info_regression,
|
118
|
+
}
|
106
119
|
|
107
120
|
for idx, factor_type in enumerate(factor_types.values()):
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
data,
|
119
|
-
data[:, idx],
|
120
|
-
discrete_features=is_discrete, # type: ignore - sklearn function not typed
|
121
|
-
n_neighbors=num_neighbors,
|
122
|
-
random_state=get_seed(),
|
123
|
-
)
|
121
|
+
mi[idx, :] = mutual_info_fn_map[factor_type](
|
122
|
+
data,
|
123
|
+
data[:, idx],
|
124
|
+
discrete_features=is_discrete,
|
125
|
+
n_neighbors=num_neighbors,
|
126
|
+
random_state=get_seed(),
|
127
|
+
)
|
128
|
+
|
129
|
+
# Use binned data for classwise MI
|
130
|
+
data = np.hstack((class_labels[:, np.newaxis], metadata.binned_data))
|
124
131
|
|
125
132
|
# Normalization via entropy
|
126
133
|
bin_cnts = get_counts(data)
|
@@ -162,12 +162,12 @@ def diversity(
|
|
162
162
|
raise ValueError("No factors found in provided metadata.")
|
163
163
|
|
164
164
|
diversity_fn = get_method(_DIVERSITY_FN_MAP, method)
|
165
|
-
|
165
|
+
binned_data = metadata.binned_data
|
166
166
|
factor_names = metadata.factor_names
|
167
167
|
class_lbl = metadata.class_labels
|
168
168
|
|
169
|
-
|
170
|
-
cnts = get_counts(
|
169
|
+
class_labels_with_binned_data = np.hstack((class_lbl[:, np.newaxis], binned_data))
|
170
|
+
cnts = get_counts(class_labels_with_binned_data)
|
171
171
|
num_bins = np.bincount(np.nonzero(cnts)[1])
|
172
172
|
diversity_index = diversity_fn(cnts, num_bins)
|
173
173
|
|
@@ -176,7 +176,7 @@ def diversity(
|
|
176
176
|
classwise_div = np.full((len(u_classes), num_factors), np.nan)
|
177
177
|
for idx, cls in enumerate(u_classes):
|
178
178
|
subset_mask = class_lbl == cls
|
179
|
-
cls_cnts = get_counts(
|
179
|
+
cls_cnts = get_counts(binned_data[subset_mask], min_num_bins=cnts.shape[0])
|
180
180
|
classwise_div[idx, :] = diversity_fn(cls_cnts, num_bins[1:])
|
181
181
|
|
182
182
|
return DiversityOutput(diversity_index, classwise_div, factor_names, metadata.class_names)
|
dataeval/metrics/bias/_parity.py
CHANGED
@@ -245,10 +245,10 @@ def parity(metadata: Metadata) -> ParityOutput:
|
|
245
245
|
if not metadata.factor_names:
|
246
246
|
raise ValueError("No factors found in provided metadata.")
|
247
247
|
|
248
|
-
chi_scores = np.zeros(metadata.
|
248
|
+
chi_scores = np.zeros(metadata.binned_data.shape[1])
|
249
249
|
p_values = np.zeros_like(chi_scores)
|
250
250
|
insufficient_data: defaultdict[str, defaultdict[int, dict[str, int]]] = defaultdict(lambda: defaultdict(dict))
|
251
|
-
for i, col_data in enumerate(metadata.
|
251
|
+
for i, col_data in enumerate(metadata.binned_data.T):
|
252
252
|
# Builds a contingency matrix where entry at index (r,c) represents
|
253
253
|
# the frequency of current_factor_name achieving value unique_factor_values[r]
|
254
254
|
# at a data point with class c.
|
@@ -8,8 +8,9 @@ from typing import Any, Callable
|
|
8
8
|
|
9
9
|
import numpy as np
|
10
10
|
import xxhash as xxh
|
11
|
-
from
|
11
|
+
from numpy.typing import NDArray
|
12
12
|
from scipy.fftpack import dct
|
13
|
+
from scipy.ndimage import zoom
|
13
14
|
|
14
15
|
from dataeval.metrics.stats._base import StatsProcessor, run_stats
|
15
16
|
from dataeval.outputs import HashStatsOutput
|
@@ -18,10 +19,26 @@ from dataeval.typing import ArrayLike, Dataset
|
|
18
19
|
from dataeval.utils._array import as_numpy
|
19
20
|
from dataeval.utils._image import normalize_image_shape, rescale
|
20
21
|
|
22
|
+
try:
|
23
|
+
from PIL import Image
|
24
|
+
except ImportError:
|
25
|
+
Image = None
|
26
|
+
|
21
27
|
HASH_SIZE = 8
|
22
28
|
MAX_FACTOR = 4
|
23
29
|
|
24
30
|
|
31
|
+
def _resize(image: NDArray[np.uint8], resize_dim: int, use_pil: bool = True) -> NDArray[np.uint8]:
|
32
|
+
"""Resizes a grayscale (HxW) 8-bit image using PIL or scipy.ndimage.zoom."""
|
33
|
+
|
34
|
+
# Use PIL if available, otherwise resize and resample with scipy.ndimage.zoom
|
35
|
+
if use_pil and Image is not None:
|
36
|
+
return np.array(Image.fromarray(image).resize((resize_dim, resize_dim), Image.Resampling.LANCZOS))
|
37
|
+
|
38
|
+
zoom_factors = (resize_dim / image.shape[0], resize_dim / image.shape[1])
|
39
|
+
return np.clip(zoom(image, zoom_factors, order=5, mode="reflect"), 0, 255, dtype=np.uint8)
|
40
|
+
|
41
|
+
|
25
42
|
def pchash(image: ArrayLike) -> str:
|
26
43
|
"""
|
27
44
|
Performs a perceptual hash on an image by resizing to a square NxN image
|
@@ -59,7 +76,7 @@ def pchash(image: ArrayLike) -> str:
|
|
59
76
|
rescaled = rescale(normalized, 8).astype(np.uint8)
|
60
77
|
|
61
78
|
# Resizes the image using the Lanczos algorithm to a square image
|
62
|
-
im =
|
79
|
+
im = _resize(rescaled, resize_dim)
|
63
80
|
|
64
81
|
# Performs discrete cosine transforms to compress the image information and takes the lowest frequency component
|
65
82
|
transform = dct(dct(im.T).T)[:HASH_SIZE, :HASH_SIZE]
|
dataeval/outputs/_workflows.py
CHANGED
@@ -92,7 +92,7 @@ def plot_measure(
|
|
92
92
|
return fig
|
93
93
|
|
94
94
|
|
95
|
-
def f_inv_out(y_i: NDArray[Any], x: NDArray[Any]) -> NDArray[np.
|
95
|
+
def f_inv_out(y_i: NDArray[Any], x: NDArray[Any]) -> NDArray[np.int64]:
|
96
96
|
"""
|
97
97
|
Inverse function for f_out()
|
98
98
|
|
@@ -106,13 +106,27 @@ def f_inv_out(y_i: NDArray[Any], x: NDArray[Any]) -> NDArray[np.uint64]:
|
|
106
106
|
Returns
|
107
107
|
-------
|
108
108
|
NDArray
|
109
|
-
|
109
|
+
Sample size or -1 if unachievable for each data point
|
110
110
|
"""
|
111
|
-
|
112
|
-
|
111
|
+
with np.errstate(invalid="ignore"):
|
112
|
+
n_i = ((y_i - x[2]) / x[0]) ** (-1 / x[1])
|
113
|
+
unachievable_targets = np.isnan(n_i) | np.any(n_i > np.iinfo(np.int64).max)
|
114
|
+
if any(unachievable_targets):
|
115
|
+
with np.printoptions(suppress=True):
|
116
|
+
warnings.warn(
|
117
|
+
"Number of samples could not be determined for target(s): "
|
118
|
+
f"""{
|
119
|
+
np.array2string(
|
120
|
+
1 - y_i[unachievable_targets], separator=", ", formatter={"float": lambda x: f"{x}"}
|
121
|
+
)
|
122
|
+
}""",
|
123
|
+
UserWarning,
|
124
|
+
)
|
125
|
+
n_i[unachievable_targets] = -1
|
126
|
+
return np.asarray(n_i, dtype=np.int64)
|
113
127
|
|
114
128
|
|
115
|
-
def inv_project_steps(params: NDArray[Any], targets: NDArray[Any]) -> NDArray[np.
|
129
|
+
def inv_project_steps(params: NDArray[Any], targets: NDArray[Any]) -> NDArray[np.int64]:
|
116
130
|
"""Inverse function for project_steps()
|
117
131
|
|
118
132
|
Parameters
|
@@ -125,10 +139,9 @@ def inv_project_steps(params: NDArray[Any], targets: NDArray[Any]) -> NDArray[np
|
|
125
139
|
Returns
|
126
140
|
-------
|
127
141
|
NDArray
|
128
|
-
|
142
|
+
Samples required or -1 if unachievable for each target value
|
129
143
|
"""
|
130
144
|
steps = f_inv_out(1 - np.array(targets), params)
|
131
|
-
steps[np.isnan(steps)] = 0
|
132
145
|
return np.ceil(steps)
|
133
146
|
|
134
147
|
|
dataeval/typing.py
CHANGED
@@ -3,11 +3,12 @@ Common type protocols used for interoperability with DataEval.
|
|
3
3
|
"""
|
4
4
|
|
5
5
|
__all__ = [
|
6
|
+
"AnnotatedDataset",
|
6
7
|
"Array",
|
7
8
|
"ArrayLike",
|
8
9
|
"Dataset",
|
9
|
-
"AnnotatedDataset",
|
10
10
|
"DatasetMetadata",
|
11
|
+
"DeviceLike",
|
11
12
|
"ImageClassificationDatum",
|
12
13
|
"ImageClassificationDataset",
|
13
14
|
"ObjectDetectionTarget",
|
@@ -21,9 +22,10 @@ __all__ = [
|
|
21
22
|
|
22
23
|
|
23
24
|
import sys
|
24
|
-
from typing import Any, Generic, Iterator, Mapping, Protocol, TypedDict, TypeVar, runtime_checkable
|
25
|
+
from typing import Any, Generic, Iterator, Mapping, Protocol, TypedDict, TypeVar, Union, runtime_checkable
|
25
26
|
|
26
27
|
import numpy.typing
|
28
|
+
import torch
|
27
29
|
from typing_extensions import NotRequired, ReadOnly, Required
|
28
30
|
|
29
31
|
if sys.version_info >= (3, 10):
|
@@ -42,6 +44,16 @@ See Also
|
|
42
44
|
"""
|
43
45
|
|
44
46
|
|
47
|
+
DeviceLike: TypeAlias = Union[int, str, tuple[str, int], torch.device]
|
48
|
+
"""
|
49
|
+
Type alias for a `Union` representing types that specify a torch.device.
|
50
|
+
|
51
|
+
See Also
|
52
|
+
--------
|
53
|
+
`torch.device <https://pytorch.org/docs/stable/tensor_attributes.html#torch.device>`_
|
54
|
+
"""
|
55
|
+
|
56
|
+
|
45
57
|
@runtime_checkable
|
46
58
|
class Array(Protocol):
|
47
59
|
"""
|
dataeval/utils/__init__.py
CHANGED
dataeval/utils/_bin.py
CHANGED
@@ -94,7 +94,7 @@ def bin_data(data: NDArray[Any], bin_method: str) -> NDArray[np.int_]:
|
|
94
94
|
return np.digitize(data, bin_edges)
|
95
95
|
|
96
96
|
|
97
|
-
def is_continuous(data: NDArray[np.number[Any]], image_indices: NDArray[np.number[Any]]) -> bool:
|
97
|
+
def is_continuous(data: NDArray[np.number[Any]], image_indices: NDArray[np.number[Any]] | None = None) -> bool:
|
98
98
|
"""
|
99
99
|
Determines whether the data is continuous or discrete using the Wasserstein distance.
|
100
100
|
|
@@ -113,11 +113,12 @@ def is_continuous(data: NDArray[np.number[Any]], image_indices: NDArray[np.numbe
|
|
113
113
|
measured from a uniform distribution is greater or less than 0.054, respectively.
|
114
114
|
"""
|
115
115
|
# Check if the metadata is image specific
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
116
|
+
if image_indices is not None:
|
117
|
+
_, data_indices_unsorted = np.unique(data, return_index=True)
|
118
|
+
if data_indices_unsorted.size == image_indices.size:
|
119
|
+
data_indices = np.sort(data_indices_unsorted)
|
120
|
+
if (data_indices == image_indices).all():
|
121
|
+
data = data[data_indices]
|
121
122
|
|
122
123
|
n_examples = len(data)
|
123
124
|
|
dataeval/utils/data/__init__.py
CHANGED
@@ -2,10 +2,12 @@
|
|
2
2
|
|
3
3
|
from dataeval.utils.data import collate, metadata
|
4
4
|
from dataeval.utils.data._dataset import to_image_classification_dataset, to_object_detection_dataset
|
5
|
+
from dataeval.utils.data._validate import validate_dataset
|
5
6
|
|
6
7
|
__all__ = [
|
7
8
|
"collate",
|
8
9
|
"metadata",
|
9
10
|
"to_image_classification_dataset",
|
10
11
|
"to_object_detection_dataset",
|
12
|
+
"validate_dataset",
|
11
13
|
]
|
dataeval/utils/data/_dataset.py
CHANGED
@@ -14,6 +14,10 @@ from dataeval.typing import (
|
|
14
14
|
from dataeval.utils._array import as_numpy
|
15
15
|
|
16
16
|
|
17
|
+
def _ensure_id(index: int, metadata: dict[str, Any]) -> dict[str, Any]:
|
18
|
+
return {"id": index, **metadata} if "id" not in metadata else metadata
|
19
|
+
|
20
|
+
|
17
21
|
def _validate_data(
|
18
22
|
datum_type: Literal["ic", "od"],
|
19
23
|
images: Array | Sequence[Array],
|
@@ -128,16 +132,19 @@ class CustomImageClassificationDataset(BaseAnnotatedDataset[Sequence[int]], Imag
|
|
128
132
|
return (
|
129
133
|
self._images[idx],
|
130
134
|
as_numpy(one_hot),
|
131
|
-
self._metadata[idx] if self._metadata is not None else {},
|
135
|
+
_ensure_id(idx, self._metadata[idx] if self._metadata is not None else {}),
|
132
136
|
)
|
133
137
|
|
134
138
|
|
135
139
|
class CustomObjectDetectionDataset(BaseAnnotatedDataset[Sequence[Sequence[int]]], ObjectDetectionDataset):
|
136
140
|
class ObjectDetectionTarget:
|
137
|
-
def __init__(self, labels: Sequence[int], bboxes: Sequence[Sequence[float]]) -> None:
|
141
|
+
def __init__(self, labels: Sequence[int], bboxes: Sequence[Sequence[float]], class_count: int) -> None:
|
138
142
|
self._labels = labels
|
139
143
|
self._bboxes = bboxes
|
140
|
-
|
144
|
+
one_hot = [[0.0] * class_count] * len(labels)
|
145
|
+
for i, label in enumerate(labels):
|
146
|
+
one_hot[i][label] = 1.0
|
147
|
+
self._scores = one_hot
|
141
148
|
|
142
149
|
@property
|
143
150
|
def labels(self) -> Sequence[int]:
|
@@ -148,7 +155,7 @@ class CustomObjectDetectionDataset(BaseAnnotatedDataset[Sequence[Sequence[int]]]
|
|
148
155
|
return self._bboxes
|
149
156
|
|
150
157
|
@property
|
151
|
-
def scores(self) -> Sequence[float]:
|
158
|
+
def scores(self) -> Sequence[Sequence[float]]:
|
152
159
|
return self._scores
|
153
160
|
|
154
161
|
def __init__(
|
@@ -180,8 +187,8 @@ class CustomObjectDetectionDataset(BaseAnnotatedDataset[Sequence[Sequence[int]]]
|
|
180
187
|
def __getitem__(self, idx: int, /) -> tuple[Array, ObjectDetectionTarget, dict[str, Any]]:
|
181
188
|
return (
|
182
189
|
self._images[idx],
|
183
|
-
self.ObjectDetectionTarget(self._labels[idx], self._bboxes[idx]),
|
184
|
-
self._metadata[idx] if self._metadata is not None else {},
|
190
|
+
self.ObjectDetectionTarget(self._labels[idx], self._bboxes[idx], len(self._classes)),
|
191
|
+
_ensure_id(idx, self._metadata[idx] if self._metadata is not None else {}),
|
185
192
|
)
|
186
193
|
|
187
194
|
|
@@ -0,0 +1,169 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__all__ = []
|
4
|
+
|
5
|
+
from typing import Any, Literal, Sequence, Sized
|
6
|
+
|
7
|
+
from dataeval.config import EPSILON
|
8
|
+
from dataeval.typing import Array, ObjectDetectionTarget
|
9
|
+
from dataeval.utils._array import as_numpy
|
10
|
+
|
11
|
+
|
12
|
+
class ValidationMessages:
|
13
|
+
DATASET_SIZED = "Dataset must be sized."
|
14
|
+
DATASET_INDEXABLE = "Dataset must be indexable."
|
15
|
+
DATASET_NONEMPTY = "Dataset must be non-empty."
|
16
|
+
DATASET_METADATA = "Dataset must have a 'metadata' attribute."
|
17
|
+
DATASET_METADATA_TYPE = "Dataset metadata must be a dictionary."
|
18
|
+
DATASET_METADATA_FORMAT = "Dataset metadata must contain an 'id' key."
|
19
|
+
DATUM_TYPE = "Dataset datum must be a tuple."
|
20
|
+
DATUM_FORMAT = "Dataset datum must contain 3 elements: image, target, metadata."
|
21
|
+
DATUM_IMAGE_TYPE = "Images must be 3-dimensional arrays."
|
22
|
+
DATUM_IMAGE_FORMAT = "Images must be in CHW format."
|
23
|
+
DATUM_TARGET_IC_TYPE = "ImageClassificationDataset targets must be one-dimensional arrays."
|
24
|
+
DATUM_TARGET_IC_FORMAT = "ImageClassificationDataset targets must be one-hot encoded or pseudo-probabilities."
|
25
|
+
DATUM_TARGET_OD_TYPE = "ObjectDetectionDataset targets must be have 'boxes', 'labels' and 'scores'."
|
26
|
+
DATUM_TARGET_OD_LABELS_TYPE = "ObjectDetectionTarget labels must be one-dimensional (N,) arrays."
|
27
|
+
DATUM_TARGET_OD_BOXES_TYPE = "ObjectDetectionTarget boxes must be two-dimensional (N, 4) arrays in xxyy format."
|
28
|
+
DATUM_TARGET_OD_SCORES_TYPE = "ObjectDetectionTarget scores must be one (N,) or two-dimensional (N, M) arrays."
|
29
|
+
DATUM_TARGET_TYPE = "Target is not a valid ImageClassification or ObjectDetection target type."
|
30
|
+
DATUM_METADATA_TYPE = "Datum metadata must be a dictionary."
|
31
|
+
DATUM_METADATA_FORMAT = "Datum metadata must contain an 'id' key."
|
32
|
+
|
33
|
+
|
34
|
+
def _validate_dataset_type(dataset: Any) -> list[str]:
|
35
|
+
issues = []
|
36
|
+
is_sized = isinstance(dataset, Sized)
|
37
|
+
is_indexable = hasattr(dataset, "__getitem__")
|
38
|
+
if not is_sized:
|
39
|
+
issues.append(ValidationMessages.DATASET_SIZED)
|
40
|
+
if not is_indexable:
|
41
|
+
issues.append(ValidationMessages.DATASET_INDEXABLE)
|
42
|
+
if is_sized and len(dataset) == 0:
|
43
|
+
issues.append(ValidationMessages.DATASET_NONEMPTY)
|
44
|
+
return issues
|
45
|
+
|
46
|
+
|
47
|
+
def _validate_dataset_metadata(dataset: Any) -> list[str]:
|
48
|
+
issues = []
|
49
|
+
if not hasattr(dataset, "metadata"):
|
50
|
+
issues.append(ValidationMessages.DATASET_METADATA)
|
51
|
+
metadata = getattr(dataset, "metadata", None)
|
52
|
+
if not isinstance(metadata, dict):
|
53
|
+
issues.append(ValidationMessages.DATASET_METADATA_TYPE)
|
54
|
+
if not isinstance(metadata, dict) or "id" not in metadata:
|
55
|
+
issues.append(ValidationMessages.DATASET_METADATA_FORMAT)
|
56
|
+
return issues
|
57
|
+
|
58
|
+
|
59
|
+
def _validate_datum_type(datum: Any) -> list[str]:
|
60
|
+
issues = []
|
61
|
+
if not isinstance(datum, tuple):
|
62
|
+
issues.append(ValidationMessages.DATUM_TYPE)
|
63
|
+
if datum is None or isinstance(datum, Sized) and len(datum) != 3:
|
64
|
+
issues.append(ValidationMessages.DATUM_FORMAT)
|
65
|
+
return issues
|
66
|
+
|
67
|
+
|
68
|
+
def _validate_datum_image(image: Any) -> list[str]:
|
69
|
+
issues = []
|
70
|
+
if not isinstance(image, Array) or len(image.shape) != 3:
|
71
|
+
issues.append(ValidationMessages.DATUM_IMAGE_TYPE)
|
72
|
+
if (
|
73
|
+
not isinstance(image, Array)
|
74
|
+
or len(image.shape) == 3
|
75
|
+
and (image.shape[0] > image.shape[1] or image.shape[0] > image.shape[2])
|
76
|
+
):
|
77
|
+
issues.append(ValidationMessages.DATUM_IMAGE_FORMAT)
|
78
|
+
return issues
|
79
|
+
|
80
|
+
|
81
|
+
def _validate_datum_target_ic(target: Any) -> list[str]:
|
82
|
+
issues = []
|
83
|
+
if not isinstance(target, Array) or len(target.shape) != 1:
|
84
|
+
issues.append(ValidationMessages.DATUM_TARGET_IC_TYPE)
|
85
|
+
if target is None or sum(target) > 1 + EPSILON or sum(target) < 1 - EPSILON:
|
86
|
+
issues.append(ValidationMessages.DATUM_TARGET_IC_FORMAT)
|
87
|
+
return issues
|
88
|
+
|
89
|
+
|
90
|
+
def _validate_datum_target_od(target: Any) -> list[str]:
|
91
|
+
issues = []
|
92
|
+
if not isinstance(target, ObjectDetectionTarget):
|
93
|
+
issues.append(ValidationMessages.DATUM_TARGET_OD_TYPE)
|
94
|
+
od_target: ObjectDetectionTarget | None = target if isinstance(target, ObjectDetectionTarget) else None
|
95
|
+
if od_target is None or len(as_numpy(od_target.labels).shape) != 1:
|
96
|
+
issues.append(ValidationMessages.DATUM_TARGET_OD_LABELS_TYPE)
|
97
|
+
if (
|
98
|
+
od_target is None
|
99
|
+
or len(as_numpy(od_target.boxes).shape) != 2
|
100
|
+
or (len(as_numpy(od_target.boxes).shape) == 2 and as_numpy(od_target.boxes).shape[1] != 4)
|
101
|
+
):
|
102
|
+
issues.append(ValidationMessages.DATUM_TARGET_OD_BOXES_TYPE)
|
103
|
+
if od_target is None or len(as_numpy(od_target.scores).shape) not in (1, 2):
|
104
|
+
issues.append(ValidationMessages.DATUM_TARGET_OD_SCORES_TYPE)
|
105
|
+
return issues
|
106
|
+
|
107
|
+
|
108
|
+
def _detect_target_type(target: Any) -> Literal["ic", "od", "auto"]:
|
109
|
+
if isinstance(target, Array):
|
110
|
+
return "ic"
|
111
|
+
if isinstance(target, ObjectDetectionTarget):
|
112
|
+
return "od"
|
113
|
+
return "auto"
|
114
|
+
|
115
|
+
|
116
|
+
def _validate_datum_target(target: Any, target_type: Literal["ic", "od", "auto"]) -> list[str]:
|
117
|
+
issues = []
|
118
|
+
target_type = _detect_target_type(target) if target_type == "auto" else target_type
|
119
|
+
if target_type == "ic":
|
120
|
+
issues.extend(_validate_datum_target_ic(target))
|
121
|
+
elif target_type == "od":
|
122
|
+
issues.extend(_validate_datum_target_od(target))
|
123
|
+
else:
|
124
|
+
issues.append(ValidationMessages.DATUM_TARGET_TYPE)
|
125
|
+
return issues
|
126
|
+
|
127
|
+
|
128
|
+
def _validate_datum_metadata(metadata: Any) -> list[str]:
|
129
|
+
issues = []
|
130
|
+
if metadata is None or not isinstance(metadata, dict):
|
131
|
+
issues.append(ValidationMessages.DATUM_METADATA_TYPE)
|
132
|
+
if metadata is None or isinstance(metadata, dict) and "id" not in metadata:
|
133
|
+
issues.append(ValidationMessages.DATUM_METADATA_FORMAT)
|
134
|
+
return issues
|
135
|
+
|
136
|
+
|
137
|
+
def validate_dataset(dataset: Any, dataset_type: Literal["ic", "od", "auto"] = "auto") -> None:
|
138
|
+
"""
|
139
|
+
Validate a dataset for compliance with MAITE protocol.
|
140
|
+
|
141
|
+
Parameters
|
142
|
+
----------
|
143
|
+
dataset: Any
|
144
|
+
Dataset to validate.
|
145
|
+
dataset_type: "ic", "od", or "auto", default "auto"
|
146
|
+
Dataset type, if known.
|
147
|
+
|
148
|
+
Raises
|
149
|
+
------
|
150
|
+
ValueError
|
151
|
+
Raises exception if dataset is invalid with a list of validation issues.
|
152
|
+
"""
|
153
|
+
issues = []
|
154
|
+
issues.extend(_validate_dataset_type(dataset))
|
155
|
+
datum = None if issues else dataset[0] # type: ignore
|
156
|
+
issues.extend(_validate_dataset_metadata(dataset))
|
157
|
+
issues.extend(_validate_datum_type(datum))
|
158
|
+
|
159
|
+
is_seq = isinstance(datum, Sequence)
|
160
|
+
datum_len = len(datum) if is_seq else 0
|
161
|
+
image = datum[0] if is_seq and datum_len > 0 else None
|
162
|
+
target = datum[1] if is_seq and datum_len > 1 else None
|
163
|
+
metadata = datum[2] if is_seq and datum_len > 2 else None
|
164
|
+
issues.extend(_validate_datum_image(image))
|
165
|
+
issues.extend(_validate_datum_target(target, dataset_type))
|
166
|
+
issues.extend(_validate_datum_metadata(metadata))
|
167
|
+
|
168
|
+
if issues:
|
169
|
+
raise ValueError("Dataset validation issues found:\n - " + "\n - ".join(issues))
|
@@ -57,22 +57,29 @@ class Sufficiency(Generic[T]):
|
|
57
57
|
test_ds : torch.Dataset
|
58
58
|
Data that will be used for every run's evaluation
|
59
59
|
train_fn : Callable[[nn.Module, Dataset, Sequence[int]], None]
|
60
|
-
Function which takes a model
|
61
|
-
(torch.utils.data.Dataset), indices to train on and executes model
|
60
|
+
Function which takes a model, a dataset, and indices to train on and then executes model
|
62
61
|
training against the data.
|
63
62
|
eval_fn : Callable[[nn.Module, Dataset], Mapping[str, float | ArrayLike]]
|
64
|
-
Function which takes a model
|
65
|
-
|
66
|
-
values (Mapping[str, float]) which is used to assess model performance
|
63
|
+
Function which takes a model, a dataset and returns a dictionary of metric
|
64
|
+
values which is used to assess model performance
|
67
65
|
given the model and data.
|
68
66
|
runs : int, default 1
|
69
|
-
Number of models to
|
67
|
+
Number of models to train over the entire dataset.
|
70
68
|
substeps : int, default 5
|
71
|
-
|
69
|
+
The number of steps that each model will be trained and evaluated on.
|
72
70
|
train_kwargs : Mapping | None, default None
|
73
71
|
Additional arguments required for custom training function
|
74
72
|
eval_kwargs : Mapping | None, default None
|
75
73
|
Additional arguments required for custom evaluation function
|
74
|
+
|
75
|
+
Warning
|
76
|
+
-------
|
77
|
+
Since each run is trained sequentially, increasing the parameter `runs` can significantly increase runtime.
|
78
|
+
|
79
|
+
Note
|
80
|
+
----
|
81
|
+
Substeps is overridden by the parameter `eval_at` in :meth:`.Sufficiency.evaluate`
|
82
|
+
|
76
83
|
"""
|
77
84
|
|
78
85
|
def __init__(
|
@@ -159,13 +166,22 @@ class Sufficiency(Generic[T]):
|
|
159
166
|
@set_metadata(state=["runs", "substeps"])
|
160
167
|
def evaluate(self, eval_at: int | Iterable[int] | None = None) -> SufficiencyOutput:
|
161
168
|
"""
|
162
|
-
|
169
|
+
Train and evaluate a model over multiple substeps
|
170
|
+
|
171
|
+
This function trains a model up to each step calculated from substeps. The model is then evaluated
|
172
|
+
at that step and trained from 0 to the next step. This repeats for all substeps. Once a model has been
|
173
|
+
trained and evaluated at all substeps, if runs is greater than one, the model weights are reset and
|
174
|
+
the process is repeated.
|
175
|
+
|
176
|
+
During each evaluation, the metrics returned as a dictionary by the given evaluation function are stored
|
177
|
+
and then averaged over when all runs are complete.
|
163
178
|
|
164
179
|
Parameters
|
165
180
|
----------
|
166
181
|
eval_at : int | Iterable[int] | None, default None
|
167
|
-
Specify this to collect
|
168
|
-
|
182
|
+
Specify this to collect metrics over a specific set of dataset lengths.
|
183
|
+
If `None`, evaluates at each step is calculated by
|
184
|
+
`np.geomspace` over the length of the dataset for self.substeps
|
169
185
|
|
170
186
|
Returns
|
171
187
|
-------
|
@@ -179,6 +195,8 @@ class Sufficiency(Generic[T]):
|
|
179
195
|
|
180
196
|
Examples
|
181
197
|
--------
|
198
|
+
Default runs and substeps
|
199
|
+
|
182
200
|
>>> suff = Sufficiency(
|
183
201
|
... model=model,
|
184
202
|
... train_ds=train_ds,
|
@@ -190,6 +208,31 @@ class Sufficiency(Generic[T]):
|
|
190
208
|
... )
|
191
209
|
>>> suff.evaluate()
|
192
210
|
SufficiencyOutput(steps=array([ 1, 3, 10, 31, 100], dtype=uint32), measures={'test': array([1., 1., 1., 1., 1.])}, n_iter=1000)
|
211
|
+
|
212
|
+
Evaluate at a single value
|
213
|
+
|
214
|
+
>>> suff = Sufficiency(
|
215
|
+
... model=model,
|
216
|
+
... train_ds=train_ds,
|
217
|
+
... test_ds=test_ds,
|
218
|
+
... train_fn=train_fn,
|
219
|
+
... eval_fn=eval_fn,
|
220
|
+
... )
|
221
|
+
>>> suff.evaluate(eval_at=50)
|
222
|
+
SufficiencyOutput(steps=array([50]), measures={'test': array([1.])}, n_iter=1000)
|
223
|
+
|
224
|
+
Evaluating at linear steps from 0-100 inclusive
|
225
|
+
|
226
|
+
>>> suff = Sufficiency(
|
227
|
+
... model=model,
|
228
|
+
... train_ds=train_ds,
|
229
|
+
... test_ds=test_ds,
|
230
|
+
... train_fn=train_fn,
|
231
|
+
... eval_fn=eval_fn,
|
232
|
+
... )
|
233
|
+
>>> suff.evaluate(eval_at=np.arange(0, 101, 20))
|
234
|
+
SufficiencyOutput(steps=array([ 0, 20, 40, 60, 80, 100]), measures={'test': array([1., 1., 1., 1., 1., 1.])}, n_iter=1000)
|
235
|
+
|
193
236
|
""" # noqa: E501
|
194
237
|
if eval_at is not None:
|
195
238
|
ranges = np.asarray(list(eval_at) if isinstance(eval_at, Iterable) else [eval_at])
|