dataeval 0.86.0__py3-none-any.whl → 0.86.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +1 -1
- dataeval/_log.py +1 -1
- dataeval/config.py +21 -4
- dataeval/data/_embeddings.py +2 -2
- dataeval/data/_images.py +2 -3
- dataeval/data/_metadata.py +48 -37
- dataeval/data/_selection.py +1 -2
- dataeval/data/_split.py +2 -3
- dataeval/data/_targets.py +17 -13
- dataeval/data/selections/_classfilter.py +2 -5
- dataeval/data/selections/_prioritize.py +6 -9
- dataeval/data/selections/_shuffle.py +3 -1
- dataeval/detectors/drift/_base.py +4 -5
- dataeval/detectors/drift/_mmd.py +3 -6
- dataeval/detectors/drift/_nml/_base.py +4 -2
- dataeval/detectors/drift/_nml/_chunk.py +11 -19
- dataeval/detectors/drift/_nml/_domainclassifier.py +8 -19
- dataeval/detectors/drift/_nml/_result.py +8 -9
- dataeval/detectors/drift/_nml/_thresholds.py +66 -77
- dataeval/detectors/linters/outliers.py +7 -7
- dataeval/metrics/bias/_parity.py +10 -13
- dataeval/metrics/estimators/_divergence.py +2 -4
- dataeval/metrics/stats/_base.py +103 -42
- dataeval/metrics/stats/_boxratiostats.py +21 -19
- dataeval/metrics/stats/_dimensionstats.py +14 -10
- dataeval/metrics/stats/_hashstats.py +1 -1
- dataeval/metrics/stats/_pixelstats.py +6 -6
- dataeval/metrics/stats/_visualstats.py +3 -3
- dataeval/outputs/_base.py +22 -7
- dataeval/outputs/_bias.py +26 -28
- dataeval/outputs/_drift.py +1 -9
- dataeval/outputs/_linters.py +11 -11
- dataeval/outputs/_stats.py +82 -23
- dataeval/outputs/_workflows.py +2 -2
- dataeval/utils/_array.py +6 -9
- dataeval/utils/_bin.py +1 -2
- dataeval/utils/_clusterer.py +7 -4
- dataeval/utils/_fast_mst.py +27 -13
- dataeval/utils/_image.py +65 -11
- dataeval/utils/_mst.py +1 -3
- dataeval/utils/_plot.py +15 -10
- dataeval/utils/data/_dataset.py +32 -20
- dataeval/utils/data/metadata.py +104 -82
- dataeval/utils/datasets/__init__.py +2 -0
- dataeval/utils/datasets/_antiuav.py +189 -0
- dataeval/utils/datasets/_base.py +11 -8
- dataeval/utils/datasets/_cifar10.py +104 -45
- dataeval/utils/datasets/_fileio.py +21 -47
- dataeval/utils/datasets/_milco.py +19 -11
- dataeval/utils/datasets/_mixin.py +2 -4
- dataeval/utils/datasets/_mnist.py +3 -4
- dataeval/utils/datasets/_ships.py +14 -7
- dataeval/utils/datasets/_voc.py +229 -42
- dataeval/utils/torch/models.py +5 -10
- dataeval/utils/torch/trainer.py +3 -3
- dataeval/workflows/sufficiency.py +2 -2
- {dataeval-0.86.0.dist-info → dataeval-0.86.1.dist-info}/METADATA +1 -1
- dataeval-0.86.1.dist-info/RECORD +114 -0
- dataeval/detectors/ood/vae.py +0 -74
- dataeval-0.86.0.dist-info/RECORD +0 -114
- {dataeval-0.86.0.dist-info → dataeval-0.86.1.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.86.0.dist-info → dataeval-0.86.1.dist-info}/WHEEL +0 -0
@@ -42,14 +42,13 @@ class AbstractResult(GenericOutput[pd.DataFrame]):
|
|
42
42
|
"""Export results to pandas dataframe."""
|
43
43
|
if multilevel:
|
44
44
|
return self._data
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
return single_level_data
|
45
|
+
column_names = [
|
46
|
+
"_".join(col).replace("chunk_chunk_chunk", "chunk").replace("chunk_chunk", "chunk")
|
47
|
+
for col in self._data.columns.values
|
48
|
+
]
|
49
|
+
single_level_data = self._data.copy(deep=True)
|
50
|
+
single_level_data.columns = column_names
|
51
|
+
return single_level_data
|
53
52
|
|
54
53
|
def filter(self, period: str = "all", metrics: str | Sequence[str] | None = None) -> Self:
|
55
54
|
"""Returns filtered result metric data."""
|
@@ -67,7 +66,7 @@ class Abstract1DResult(AbstractResult, ABC):
|
|
67
66
|
def __init__(self, results_data: pd.DataFrame) -> None:
|
68
67
|
super().__init__(results_data)
|
69
68
|
|
70
|
-
def _filter(self, period: str, metrics=None) -> Self:
|
69
|
+
def _filter(self, period: str, metrics: Sequence[str] | None = None) -> Self:
|
71
70
|
data = self._data
|
72
71
|
if period != "all":
|
73
72
|
data = self._data.loc[self._data.loc[:, ("chunk", "period")] == period, :] # type: ignore | dataframe loc
|
@@ -29,10 +29,10 @@ class Threshold(ABC):
|
|
29
29
|
"""Class registry lookup to get threshold subclass from threshold_type string"""
|
30
30
|
|
31
31
|
def __str__(self) -> str:
|
32
|
-
return self.
|
32
|
+
return f"{self.__class__.__name__}({str(vars(self))})"
|
33
33
|
|
34
34
|
def __repr__(self) -> str:
|
35
|
-
return
|
35
|
+
return str(self)
|
36
36
|
|
37
37
|
def __eq__(self, other: object) -> bool:
|
38
38
|
return isinstance(other, self.__class__) and other.__dict__ == self.__dict__
|
@@ -41,7 +41,7 @@ class Threshold(ABC):
|
|
41
41
|
Threshold._registry[threshold_type] = cls
|
42
42
|
|
43
43
|
@abstractmethod
|
44
|
-
def
|
44
|
+
def _thresholds(self, data: np.ndarray) -> tuple[float | None, float | None]:
|
45
45
|
"""Returns lower and upper threshold values when given one or more np.ndarray instances.
|
46
46
|
|
47
47
|
Parameters:
|
@@ -69,6 +69,61 @@ class Threshold(ABC):
|
|
69
69
|
|
70
70
|
return threshold_cls(**obj)
|
71
71
|
|
72
|
+
def calculate(
|
73
|
+
self,
|
74
|
+
data: np.ndarray,
|
75
|
+
lower_limit: float | None = None,
|
76
|
+
upper_limit: float | None = None,
|
77
|
+
override_using_none: bool = False,
|
78
|
+
logger: logging.Logger | None = None,
|
79
|
+
) -> tuple[float | None, float | None]:
|
80
|
+
"""
|
81
|
+
Calculate lower and upper threshold values with respect to the provided Threshold and value limits.
|
82
|
+
|
83
|
+
Parameters
|
84
|
+
----------
|
85
|
+
data : np.ndarray
|
86
|
+
The data used by the Threshold instance to calculate the lower and upper threshold values.
|
87
|
+
This will often be the values of a drift detection method or performance metric on chunks of reference
|
88
|
+
data.
|
89
|
+
lower_limit : float or None, default None
|
90
|
+
An optional value that serves as a limit for the lower threshold value. Any calculated lower threshold
|
91
|
+
values that end up below this limit will be replaced by this limit value.
|
92
|
+
The limit is often a theoretical constraint enforced by a specific drift detection method or performance
|
93
|
+
metric.
|
94
|
+
upper_threshold_value_limit : float or None, default None
|
95
|
+
An optional value that serves as a limit for the lower threshold value. Any calculated lower threshold
|
96
|
+
values that end up below this limit will be replaced by this limit value.
|
97
|
+
The limit is often a theoretical constraint enforced by a specific drift detection method or performance
|
98
|
+
metric.
|
99
|
+
override_using_none: bool, default False
|
100
|
+
When set to True use None to override threshold values that exceed value limits.
|
101
|
+
This will prevent them from being rendered on plots.
|
102
|
+
logger: Optional[logging.Logger], default=None
|
103
|
+
An optional Logger instance. When provided a warning will be logged when a calculated threshold value
|
104
|
+
gets overridden by a threshold value limit.
|
105
|
+
"""
|
106
|
+
|
107
|
+
lower_value, upper_value = self._thresholds(data)
|
108
|
+
|
109
|
+
if lower_limit is not None and lower_value is not None and lower_value <= lower_limit:
|
110
|
+
override_value = None if override_using_none else lower_limit
|
111
|
+
if logger:
|
112
|
+
logger.warning(
|
113
|
+
f"lower threshold value {lower_value} overridden by lower threshold value limit {override_value}"
|
114
|
+
)
|
115
|
+
lower_value = override_value
|
116
|
+
|
117
|
+
if upper_limit is not None and upper_value is not None and upper_value >= upper_limit:
|
118
|
+
override_value = None if override_using_none else upper_limit
|
119
|
+
if logger:
|
120
|
+
logger.warning(
|
121
|
+
f"upper threshold value {upper_value} overridden by upper threshold value limit {override_value}"
|
122
|
+
)
|
123
|
+
upper_value = override_value
|
124
|
+
|
125
|
+
return lower_value, upper_value
|
126
|
+
|
72
127
|
|
73
128
|
class ConstantThreshold(Threshold, threshold_type="constant"):
|
74
129
|
"""A `Thresholder` implementation that returns a constant lower and or upper threshold value.
|
@@ -91,7 +146,7 @@ class ConstantThreshold(Threshold, threshold_type="constant"):
|
|
91
146
|
None 0.1
|
92
147
|
"""
|
93
148
|
|
94
|
-
def __init__(self, lower: float | int | None = None, upper: float | int | None = None):
|
149
|
+
def __init__(self, lower: float | int | None = None, upper: float | int | None = None) -> None:
|
95
150
|
"""Creates a new ConstantThreshold instance.
|
96
151
|
|
97
152
|
Args:
|
@@ -109,11 +164,11 @@ class ConstantThreshold(Threshold, threshold_type="constant"):
|
|
109
164
|
self.lower = lower
|
110
165
|
self.upper = upper
|
111
166
|
|
112
|
-
def
|
167
|
+
def _thresholds(self, data: np.ndarray) -> tuple[float | None, float | None]:
|
113
168
|
return self.lower, self.upper
|
114
169
|
|
115
170
|
@staticmethod
|
116
|
-
def _validate_inputs(lower: float | int | None = None, upper: float | int | None = None):
|
171
|
+
def _validate_inputs(lower: float | int | None = None, upper: float | int | None = None) -> None:
|
117
172
|
if lower is not None and not isinstance(lower, (float, int)) or isinstance(lower, bool):
|
118
173
|
raise ValueError(f"expected type of 'lower' to be 'float', 'int' or None but got '{type(lower).__name__}'")
|
119
174
|
|
@@ -149,7 +204,7 @@ class StandardDeviationThreshold(Threshold, threshold_type="standard_deviation")
|
|
149
204
|
std_lower_multiplier: float | int | None = 3,
|
150
205
|
std_upper_multiplier: float | int | None = 3,
|
151
206
|
offset_from: Callable[[np.ndarray], Any] = np.nanmean,
|
152
|
-
):
|
207
|
+
) -> None:
|
153
208
|
"""Creates a new StandardDeviationThreshold instance.
|
154
209
|
|
155
210
|
Args:
|
@@ -173,7 +228,7 @@ class StandardDeviationThreshold(Threshold, threshold_type="standard_deviation")
|
|
173
228
|
self.std_upper_multiplier = std_upper_multiplier
|
174
229
|
self.offset_from = offset_from
|
175
230
|
|
176
|
-
def
|
231
|
+
def _thresholds(self, data: np.ndarray) -> tuple[float | None, float | None]:
|
177
232
|
aggregate = self.offset_from(data)
|
178
233
|
std = np.nanstd(data)
|
179
234
|
|
@@ -184,7 +239,9 @@ class StandardDeviationThreshold(Threshold, threshold_type="standard_deviation")
|
|
184
239
|
return lower_threshold, upper_threshold
|
185
240
|
|
186
241
|
@staticmethod
|
187
|
-
def _validate_inputs(
|
242
|
+
def _validate_inputs(
|
243
|
+
std_lower_multiplier: float | int | None = 3, std_upper_multiplier: float | int | None = 3
|
244
|
+
) -> None:
|
188
245
|
if (
|
189
246
|
std_lower_multiplier is not None
|
190
247
|
and not isinstance(std_lower_multiplier, (float, int))
|
@@ -210,71 +267,3 @@ class StandardDeviationThreshold(Threshold, threshold_type="standard_deviation")
|
|
210
267
|
|
211
268
|
if std_upper_multiplier and std_upper_multiplier < 0:
|
212
269
|
raise ValueError(f"'std_upper_multiplier' should be greater than 0 but got value {std_upper_multiplier}")
|
213
|
-
|
214
|
-
|
215
|
-
def calculate_threshold_values(
|
216
|
-
threshold: Threshold,
|
217
|
-
data: np.ndarray,
|
218
|
-
lower_threshold_value_limit: float | None = None,
|
219
|
-
upper_threshold_value_limit: float | None = None,
|
220
|
-
override_using_none: bool = False,
|
221
|
-
logger: logging.Logger | None = None,
|
222
|
-
metric_name: str | None = None,
|
223
|
-
) -> tuple[float | None, float | None]:
|
224
|
-
"""Calculate lower and upper threshold values with respect to the provided Threshold and value limits.
|
225
|
-
|
226
|
-
Parameters:
|
227
|
-
threshold: Threshold
|
228
|
-
The Threshold instance that determines how the lower and upper threshold values will be calculated.
|
229
|
-
data: np.ndarray
|
230
|
-
The data used by the Threshold instance to calculate the lower and upper threshold values.
|
231
|
-
This will often be the values of a drift detection method or performance metric on chunks of reference data.
|
232
|
-
lower_threshold_value_limit: Optional[float], default=None
|
233
|
-
An optional value that serves as a limit for the lower threshold value. Any calculated lower threshold
|
234
|
-
values that end up below this limit will be replaced by this limit value.
|
235
|
-
The limit is often a theoretical constraint enforced by a specific drift detection method or performance
|
236
|
-
metric.
|
237
|
-
upper_threshold_value_limit: Optional[float], default=None
|
238
|
-
An optional value that serves as a limit for the lower threshold value. Any calculated lower threshold
|
239
|
-
values that end up below this limit will be replaced by this limit value.
|
240
|
-
The limit is often a theoretical constraint enforced by a specific drift detection method or performance
|
241
|
-
metric.
|
242
|
-
override_using_none: bool, default=False
|
243
|
-
When set to True use None to override threshold values that exceed value limits.
|
244
|
-
This will prevent them from being rendered on plots.
|
245
|
-
logger: Optional[logging.Logger], default=None
|
246
|
-
An optional Logger instance. When provided a warning will be logged when a calculated threshold value
|
247
|
-
gets overridden by a threshold value limit.
|
248
|
-
metric_name: Optional[str], default=None
|
249
|
-
When provided the metric name will be included within any log messages for additional clarity.
|
250
|
-
"""
|
251
|
-
|
252
|
-
lower_threshold_value, upper_threshold_value = threshold.thresholds(data)
|
253
|
-
|
254
|
-
if (
|
255
|
-
lower_threshold_value_limit is not None
|
256
|
-
and lower_threshold_value is not None
|
257
|
-
and lower_threshold_value <= lower_threshold_value_limit
|
258
|
-
):
|
259
|
-
override_value = None if override_using_none else lower_threshold_value_limit
|
260
|
-
if logger:
|
261
|
-
logger.warning(
|
262
|
-
f"{metric_name + ' ' if metric_name else ''}lower threshold value {lower_threshold_value} "
|
263
|
-
f"overridden by lower threshold value limit {override_value}"
|
264
|
-
)
|
265
|
-
lower_threshold_value = override_value
|
266
|
-
|
267
|
-
if (
|
268
|
-
upper_threshold_value_limit is not None
|
269
|
-
and upper_threshold_value is not None
|
270
|
-
and upper_threshold_value >= upper_threshold_value_limit
|
271
|
-
):
|
272
|
-
override_value = None if override_using_none else upper_threshold_value_limit
|
273
|
-
if logger:
|
274
|
-
logger.warning(
|
275
|
-
f"{metric_name + ' ' if metric_name else ''}upper threshold value {upper_threshold_value} "
|
276
|
-
f"overridden by upper threshold value limit {override_value}"
|
277
|
-
)
|
278
|
-
upper_threshold_value = override_value
|
279
|
-
|
280
|
-
return lower_threshold_value, upper_threshold_value
|
@@ -13,31 +13,31 @@ from dataeval.metrics.stats._imagestats import imagestats
|
|
13
13
|
from dataeval.outputs import DimensionStatsOutput, ImageStatsOutput, OutliersOutput, PixelStatsOutput, VisualStatsOutput
|
14
14
|
from dataeval.outputs._base import set_metadata
|
15
15
|
from dataeval.outputs._linters import IndexIssueMap, OutlierStatsOutput
|
16
|
-
from dataeval.outputs._stats import
|
16
|
+
from dataeval.outputs._stats import BASE_ATTRS
|
17
17
|
from dataeval.typing import ArrayLike, Dataset
|
18
18
|
|
19
19
|
|
20
20
|
def _get_outlier_mask(
|
21
21
|
values: NDArray, method: Literal["zscore", "modzscore", "iqr"], threshold: float | None
|
22
22
|
) -> NDArray:
|
23
|
+
values = values.astype(np.float64)
|
23
24
|
if method == "zscore":
|
24
25
|
threshold = threshold if threshold else 3.0
|
25
26
|
std = np.std(values)
|
26
27
|
abs_diff = np.abs(values - np.mean(values))
|
27
28
|
return std != 0 and (abs_diff / std) > threshold
|
28
|
-
|
29
|
+
if method == "modzscore":
|
29
30
|
threshold = threshold if threshold else 3.5
|
30
31
|
abs_diff = np.abs(values - np.median(values))
|
31
32
|
med_abs_diff = np.median(abs_diff) if np.median(abs_diff) != 0 else np.mean(abs_diff)
|
32
33
|
mod_z_score = 0.6745 * abs_diff / med_abs_diff
|
33
34
|
return mod_z_score > threshold
|
34
|
-
|
35
|
+
if method == "iqr":
|
35
36
|
threshold = threshold if threshold else 1.5
|
36
37
|
qrt = np.percentile(values, q=(25, 75), method="midpoint")
|
37
38
|
iqr = (qrt[1] - qrt[0]) * threshold
|
38
39
|
return (values < (qrt[0] - iqr)) | (values > (qrt[1] + iqr))
|
39
|
-
|
40
|
-
raise ValueError("Outlier method must be 'zscore' 'modzscore' or 'iqr'.")
|
40
|
+
raise ValueError("Outlier method must be 'zscore' 'modzscore' or 'iqr'.")
|
41
41
|
|
42
42
|
|
43
43
|
class Outliers:
|
@@ -103,7 +103,7 @@ class Outliers:
|
|
103
103
|
use_visual: bool = True,
|
104
104
|
outlier_method: Literal["zscore", "modzscore", "iqr"] = "modzscore",
|
105
105
|
outlier_threshold: float | None = None,
|
106
|
-
):
|
106
|
+
) -> None:
|
107
107
|
self.stats: ImageStatsOutput
|
108
108
|
self.use_dimension = use_dimension
|
109
109
|
self.use_pixel = use_pixel
|
@@ -114,7 +114,7 @@ class Outliers:
|
|
114
114
|
def _get_outliers(self, stats: dict) -> dict[int, dict[str, float]]:
|
115
115
|
flagged_images: dict[int, dict[str, float]] = {}
|
116
116
|
for stat, values in stats.items():
|
117
|
-
if stat in
|
117
|
+
if stat in BASE_ATTRS:
|
118
118
|
continue
|
119
119
|
if values.ndim == 1:
|
120
120
|
mask = _get_outlier_mask(values.astype(np.float64), self.outlier_method, self.outlier_threshold)
|
dataeval/metrics/bias/_parity.py
CHANGED
@@ -3,6 +3,7 @@ from __future__ import annotations
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
5
|
import warnings
|
6
|
+
from collections import defaultdict
|
6
7
|
from typing import Any
|
7
8
|
|
8
9
|
import numpy as np
|
@@ -246,7 +247,7 @@ def parity(metadata: Metadata) -> ParityOutput:
|
|
246
247
|
|
247
248
|
chi_scores = np.zeros(metadata.discrete_data.shape[1])
|
248
249
|
p_values = np.zeros_like(chi_scores)
|
249
|
-
insufficient_data =
|
250
|
+
insufficient_data: defaultdict[str, defaultdict[int, dict[str, int]]] = defaultdict(lambda: defaultdict(dict))
|
250
251
|
for i, col_data in enumerate(metadata.discrete_data.T):
|
251
252
|
# Builds a contingency matrix where entry at index (r,c) represents
|
252
253
|
# the frequency of current_factor_name achieving value unique_factor_values[r]
|
@@ -261,26 +262,22 @@ def parity(metadata: Metadata) -> ParityOutput:
|
|
261
262
|
for int_factor, int_class in zip(counts[0], counts[1]):
|
262
263
|
if contingency_matrix[int_factor, int_class] > 0:
|
263
264
|
factor_category = unique_factor_values[int_factor].item()
|
264
|
-
if current_factor_name not in insufficient_data:
|
265
|
-
insufficient_data[current_factor_name] = {}
|
266
|
-
if factor_category not in insufficient_data[current_factor_name]:
|
267
|
-
insufficient_data[current_factor_name][factor_category] = {}
|
268
265
|
class_name = metadata.class_names[int_class]
|
269
266
|
class_count = contingency_matrix[int_factor, int_class].item()
|
270
267
|
insufficient_data[current_factor_name][factor_category][class_name] = class_count
|
271
268
|
|
272
269
|
# This deletes rows containing only zeros,
|
273
270
|
# because scipy.stats.chi2_contingency fails when there are rows containing only zeros.
|
274
|
-
|
275
|
-
rowmask = np.nonzero(rowsums)[0]
|
276
|
-
contingency_matrix = contingency_matrix[rowmask]
|
271
|
+
contingency_matrix = contingency_matrix[np.any(contingency_matrix, axis=1)]
|
277
272
|
|
278
|
-
|
279
|
-
|
280
|
-
chi_scores[i] = chi2
|
281
|
-
p_values[i] = p
|
273
|
+
chi_scores[i], p_values[i] = chi2_contingency(contingency_matrix)[:2]
|
282
274
|
|
283
275
|
if insufficient_data:
|
284
276
|
warnings.warn("Some factors did not meet the recommended 5 occurrences for each value-label combination.")
|
285
277
|
|
286
|
-
return ParityOutput(
|
278
|
+
return ParityOutput(
|
279
|
+
score=chi_scores,
|
280
|
+
p_value=p_values,
|
281
|
+
factor_names=metadata.discrete_factor_names,
|
282
|
+
insufficient_data={k: dict(v) for k, v in insufficient_data.items()},
|
283
|
+
)
|
@@ -38,8 +38,7 @@ def divergence_mst(data: NDArray[np.float64], labels: NDArray[np.int_]) -> int:
|
|
38
38
|
"""
|
39
39
|
mst = minimum_spanning_tree(data).toarray()
|
40
40
|
edgelist = np.transpose(np.nonzero(mst))
|
41
|
-
|
42
|
-
return errors
|
41
|
+
return np.sum(labels[edgelist[:, 0]] != labels[edgelist[:, 1]])
|
43
42
|
|
44
43
|
|
45
44
|
def divergence_fnn(data: NDArray[np.float64], labels: NDArray[np.int_]) -> int:
|
@@ -59,8 +58,7 @@ def divergence_fnn(data: NDArray[np.float64], labels: NDArray[np.int_]) -> int:
|
|
59
58
|
Number of label errors when finding nearest neighbors
|
60
59
|
"""
|
61
60
|
nn_indices = compute_neighbors(data, data)
|
62
|
-
|
63
|
-
return errors
|
61
|
+
return np.sum(np.abs(labels[nn_indices] - labels))
|
64
62
|
|
65
63
|
|
66
64
|
_DIVERGENCE_FN_MAP = {"FNN": divergence_fnn, "MST": divergence_mst}
|
dataeval/metrics/stats/_base.py
CHANGED
@@ -10,23 +10,86 @@ from copy import deepcopy
|
|
10
10
|
from dataclasses import dataclass
|
11
11
|
from functools import partial
|
12
12
|
from multiprocessing import Pool
|
13
|
-
from typing import Any, Callable, Generic, Iterable, Sequence, TypeVar
|
13
|
+
from typing import Any, Callable, Generic, Iterable, Iterator, Sequence, TypeVar
|
14
14
|
|
15
15
|
import numpy as np
|
16
16
|
import tqdm
|
17
17
|
from numpy.typing import NDArray
|
18
18
|
|
19
19
|
from dataeval.config import get_max_processes
|
20
|
-
from dataeval.outputs._stats import BaseStatsOutput, SourceIndex
|
20
|
+
from dataeval.outputs._stats import BASE_ATTRS, BaseStatsOutput, SourceIndex
|
21
21
|
from dataeval.typing import Array, ArrayLike, Dataset, ObjectDetectionTarget
|
22
22
|
from dataeval.utils._array import as_numpy, to_numpy
|
23
|
-
from dataeval.utils._image import normalize_image_shape, rescale
|
23
|
+
from dataeval.utils._image import clip_and_pad, clip_box, is_valid_box, normalize_image_shape, rescale
|
24
24
|
|
25
25
|
DTYPE_REGEX = re.compile(r"NDArray\[np\.(.*?)\]")
|
26
26
|
|
27
|
-
BoundingBox = tuple[float, float, float, float]
|
28
27
|
TStatsOutput = TypeVar("TStatsOutput", bound=BaseStatsOutput, covariant=True)
|
29
28
|
|
29
|
+
_S = TypeVar("_S")
|
30
|
+
_T = TypeVar("_T")
|
31
|
+
|
32
|
+
|
33
|
+
@dataclass
|
34
|
+
class BoundingBox:
|
35
|
+
x0: float
|
36
|
+
y0: float
|
37
|
+
x1: float
|
38
|
+
y1: float
|
39
|
+
|
40
|
+
def __post_init__(self) -> None:
|
41
|
+
# Test for invalid coordinates
|
42
|
+
x_swap = self.x0 > self.x1
|
43
|
+
y_swap = self.y0 > self.y1
|
44
|
+
if x_swap or y_swap:
|
45
|
+
warnings.warn(f"Invalid bounding box coordinates: {self} - swapping invalid coordinates.")
|
46
|
+
if x_swap:
|
47
|
+
self.x0, self.x1 = self.x1, self.x0
|
48
|
+
if y_swap:
|
49
|
+
self.y0, self.y1 = self.y1, self.y0
|
50
|
+
|
51
|
+
@property
|
52
|
+
def width(self) -> float:
|
53
|
+
return self.x1 - self.x0
|
54
|
+
|
55
|
+
@property
|
56
|
+
def height(self) -> float:
|
57
|
+
return self.y1 - self.y0
|
58
|
+
|
59
|
+
def to_int(self) -> tuple[int, int, int, int]:
|
60
|
+
"""
|
61
|
+
Returns the bounding box as a tuple of integers.
|
62
|
+
"""
|
63
|
+
x0_int = math.floor(self.x0)
|
64
|
+
y0_int = math.floor(self.y0)
|
65
|
+
x1_int = math.ceil(self.x1)
|
66
|
+
y1_int = math.ceil(self.y1)
|
67
|
+
return x0_int, y0_int, x1_int, y1_int
|
68
|
+
|
69
|
+
|
70
|
+
class PoolWrapper:
|
71
|
+
"""
|
72
|
+
Wraps `multiprocessing.Pool` to allow for easy switching between
|
73
|
+
multiprocessing and single-threaded execution.
|
74
|
+
|
75
|
+
This helps with debugging and profiling, as well as usage with Jupyter notebooks
|
76
|
+
in VS Code, which does not support subprocess debugging.
|
77
|
+
"""
|
78
|
+
|
79
|
+
def __init__(self, processes: int | None) -> None:
|
80
|
+
self.pool = Pool(processes) if processes is not None and processes > 1 else None
|
81
|
+
|
82
|
+
def imap(self, func: Callable[[_S], _T], iterable: Iterable[_S]) -> Iterator[_T]:
|
83
|
+
return map(func, iterable) if self.pool is None else self.pool.imap(func, iterable)
|
84
|
+
|
85
|
+
def __enter__(self, *args: Any, **kwargs: Any) -> PoolWrapper:
|
86
|
+
return self
|
87
|
+
|
88
|
+
def __exit__(self, *args: Any) -> None:
|
89
|
+
if self.pool is not None:
|
90
|
+
self.pool.close()
|
91
|
+
self.pool.join()
|
92
|
+
|
30
93
|
|
31
94
|
class StatsProcessor(Generic[TStatsOutput]):
|
32
95
|
output_class: type[TStatsOutput]
|
@@ -34,32 +97,26 @@ class StatsProcessor(Generic[TStatsOutput]):
|
|
34
97
|
image_function_map: dict[str, Callable[[StatsProcessor[TStatsOutput]], Any]] = {}
|
35
98
|
channel_function_map: dict[str, Callable[[StatsProcessor[TStatsOutput]], Any]] = {}
|
36
99
|
|
37
|
-
def __init__(self, image: NDArray[Any], box: BoundingBox | None, per_channel: bool) -> None:
|
100
|
+
def __init__(self, image: NDArray[Any], box: BoundingBox | Iterable[Any] | None, per_channel: bool) -> None:
|
38
101
|
self.raw = image
|
39
102
|
self.width: int = image.shape[-1]
|
40
103
|
self.height: int = image.shape[-2]
|
41
|
-
box =
|
42
|
-
|
43
|
-
x0, y0 = (min(j, max(0, math.floor(box[i]))) for i, j in zip((0, 1), (self.width - 1, self.height - 1)))
|
44
|
-
x1, y1 = (min(j, max(1, math.ceil(box[i]))) for i, j in zip((2, 3), (self.width, self.height)))
|
45
|
-
self.box: NDArray[np.int64] = np.array([x0, y0, x1, y1], dtype=np.int64)
|
104
|
+
box = (0, 0, self.width, self.height) if box is None else box
|
105
|
+
self.box = box if isinstance(box, BoundingBox) else BoundingBox(*box)
|
46
106
|
self._per_channel = per_channel
|
47
107
|
self._image = None
|
48
108
|
self._shape = None
|
49
109
|
self._scaled = None
|
50
110
|
self._cache = {}
|
51
111
|
self._fn_map = self.channel_function_map if per_channel else self.image_function_map
|
52
|
-
self.
|
53
|
-
box[0] >= 0 and box[1] >= 0 and box[2] <= image.shape[-1] and box[3] <= image.shape[-2]
|
54
|
-
)
|
112
|
+
self._is_valid_box = is_valid_box(clip_box(image, self.box.to_int()))
|
55
113
|
|
56
114
|
def get(self, fn_key: str) -> NDArray[Any]:
|
57
115
|
if fn_key in self.cache_keys:
|
58
116
|
if fn_key not in self._cache:
|
59
117
|
self._cache[fn_key] = self._fn_map[fn_key](self)
|
60
118
|
return self._cache[fn_key]
|
61
|
-
|
62
|
-
return self._fn_map[fn_key](self)
|
119
|
+
return self._fn_map[fn_key](self)
|
63
120
|
|
64
121
|
def process(self) -> dict[str, Any]:
|
65
122
|
return {k: self._fn_map[k](self) for k in self._fn_map}
|
@@ -67,11 +124,7 @@ class StatsProcessor(Generic[TStatsOutput]):
|
|
67
124
|
@property
|
68
125
|
def image(self) -> NDArray[Any]:
|
69
126
|
if self._image is None:
|
70
|
-
|
71
|
-
norm = normalize_image_shape(self.raw)
|
72
|
-
self._image = norm[:, self.box[1] : self.box[3], self.box[0] : self.box[2]]
|
73
|
-
else:
|
74
|
-
self._image = np.zeros((self.raw.shape[0], self.box[3] - self.box[1], self.box[2] - self.box[0]))
|
127
|
+
self._image = clip_and_pad(normalize_image_shape(self.raw), self.box.to_int())
|
75
128
|
return self._image
|
76
129
|
|
77
130
|
@property
|
@@ -90,9 +143,9 @@ class StatsProcessor(Generic[TStatsOutput]):
|
|
90
143
|
|
91
144
|
@classmethod
|
92
145
|
def convert_output(
|
93
|
-
cls, source: dict[str, Any], source_index: list[SourceIndex],
|
146
|
+
cls, source: dict[str, Any], source_index: list[SourceIndex], object_count: list[int], image_count: int
|
94
147
|
) -> TStatsOutput:
|
95
|
-
output = {}
|
148
|
+
output: dict[str, Any] = {}
|
96
149
|
attrs = dict(ChainMap(*(getattr(c, "__annotations__", {}) for c in cls.output_class.__mro__)))
|
97
150
|
for key in (key for key in source if key in attrs):
|
98
151
|
stat_type: str = attrs[key]
|
@@ -101,14 +154,17 @@ class StatsProcessor(Generic[TStatsOutput]):
|
|
101
154
|
output[key] = np.asarray(source[key], dtype=np.dtype(dtype_match.group(1)))
|
102
155
|
else:
|
103
156
|
output[key] = source[key]
|
104
|
-
|
157
|
+
base_attrs: dict[str, Any] = dict(
|
158
|
+
zip(BASE_ATTRS, (source_index, np.asarray(object_count, dtype=np.uint16), image_count))
|
159
|
+
)
|
160
|
+
return cls.output_class(**output, **base_attrs)
|
105
161
|
|
106
162
|
|
107
163
|
@dataclass
|
108
164
|
class StatsProcessorOutput:
|
109
165
|
results: list[dict[str, Any]]
|
110
166
|
source_indices: list[SourceIndex]
|
111
|
-
|
167
|
+
object_counts: list[int]
|
112
168
|
warnings_list: list[str]
|
113
169
|
|
114
170
|
|
@@ -119,18 +175,18 @@ def process_stats(
|
|
119
175
|
per_channel: bool,
|
120
176
|
stats_processor_cls: Iterable[type[StatsProcessor[TStatsOutput]]],
|
121
177
|
) -> StatsProcessorOutput:
|
122
|
-
|
178
|
+
np_image = to_numpy(image)
|
123
179
|
results_list: list[dict[str, Any]] = []
|
124
180
|
source_indices: list[SourceIndex] = []
|
125
181
|
box_counts: list[int] = []
|
126
182
|
warnings_list: list[str] = []
|
127
183
|
for i_b, box in [(None, None)] if boxes is None else enumerate(boxes):
|
128
|
-
processor_list = [p(
|
129
|
-
if any(not p.
|
130
|
-
warnings_list.append(f"Bounding box [{i}][{i_b}]: {box}
|
184
|
+
processor_list = [p(np_image, box, per_channel) for p in stats_processor_cls]
|
185
|
+
if any(not p._is_valid_box for p in processor_list) and i_b is not None and box is not None:
|
186
|
+
warnings_list.append(f"Bounding box [{i}][{i_b}]: {box} for image shape {np_image.shape} is invalid.")
|
131
187
|
results_list.append({k: v for p in processor_list for k, v in p.process().items()})
|
132
188
|
if per_channel:
|
133
|
-
source_indices.extend([SourceIndex(i, i_b, c) for c in range(
|
189
|
+
source_indices.extend([SourceIndex(i, i_b, c) for c in range(np_image.shape[-3])])
|
134
190
|
else:
|
135
191
|
source_indices.append(SourceIndex(i, i_b, None))
|
136
192
|
box_counts.append(0 if boxes is None else len(boxes))
|
@@ -145,13 +201,18 @@ def process_stats_unpack(
|
|
145
201
|
return process_stats(*args, per_channel=per_channel, stats_processor_cls=stats_processor_cls)
|
146
202
|
|
147
203
|
|
148
|
-
def _enumerate(
|
204
|
+
def _enumerate(
|
205
|
+
dataset: Dataset[ArrayLike] | Dataset[tuple[ArrayLike, Any, Any]], per_box: bool
|
206
|
+
) -> Iterator[tuple[int, ArrayLike, Any]]:
|
149
207
|
for i in range(len(dataset)):
|
150
208
|
d = dataset[i]
|
151
209
|
image = d[0] if isinstance(d, tuple) else d
|
152
210
|
if per_box and isinstance(d, tuple) and isinstance(d[1], ObjectDetectionTarget):
|
153
|
-
|
154
|
-
|
211
|
+
try:
|
212
|
+
boxes = d[1].boxes if isinstance(d[1].boxes, Array) else as_numpy(d[1].boxes)
|
213
|
+
target = [BoundingBox(*(float(box[i]) for i in range(4))) for box in boxes]
|
214
|
+
except (ValueError, IndexError):
|
215
|
+
raise ValueError(f"Invalid bounding box format for image {i}: {d[1].boxes}")
|
155
216
|
else:
|
156
217
|
target = None
|
157
218
|
|
@@ -199,12 +260,13 @@ def run_stats(
|
|
199
260
|
"""
|
200
261
|
results_list: list[dict[str, NDArray[np.float64]]] = []
|
201
262
|
source_index: list[SourceIndex] = []
|
202
|
-
|
263
|
+
object_count: list[int] = []
|
264
|
+
image_count: int = len(dataset)
|
203
265
|
|
204
266
|
warning_list = []
|
205
267
|
stats_processor_cls = stats_processor_cls if isinstance(stats_processor_cls, Iterable) else [stats_processor_cls]
|
206
268
|
|
207
|
-
with
|
269
|
+
with PoolWrapper(processes=get_max_processes()) as p:
|
208
270
|
for r in tqdm.tqdm(
|
209
271
|
p.imap(
|
210
272
|
partial(
|
@@ -214,14 +276,12 @@ def run_stats(
|
|
214
276
|
),
|
215
277
|
_enumerate(dataset, per_box),
|
216
278
|
),
|
217
|
-
total=
|
279
|
+
total=image_count,
|
218
280
|
):
|
219
281
|
results_list.extend(r.results)
|
220
282
|
source_index.extend(r.source_indices)
|
221
|
-
|
283
|
+
object_count.extend(r.object_counts)
|
222
284
|
warning_list.extend(r.warnings_list)
|
223
|
-
p.close()
|
224
|
-
p.join()
|
225
285
|
|
226
286
|
# warnings are not emitted while in multiprocessing pools so we emit after gathering all warnings
|
227
287
|
for w in warning_list:
|
@@ -235,8 +295,7 @@ def run_stats(
|
|
235
295
|
else:
|
236
296
|
output.setdefault(stat, []).append(result.tolist() if isinstance(result, np.ndarray) else result)
|
237
297
|
|
238
|
-
|
239
|
-
return outputs
|
298
|
+
return [s.convert_output(output, source_index, object_count, image_count) for s in stats_processor_cls]
|
240
299
|
|
241
300
|
|
242
301
|
def add_stats(a: TStatsOutput, b: TStatsOutput) -> TStatsOutput:
|
@@ -246,10 +305,12 @@ def add_stats(a: TStatsOutput, b: TStatsOutput) -> TStatsOutput:
|
|
246
305
|
sum_dict = deepcopy(a.data())
|
247
306
|
|
248
307
|
for k in sum_dict:
|
249
|
-
if isinstance(sum_dict[k],
|
308
|
+
if isinstance(sum_dict[k], Sequence):
|
250
309
|
sum_dict[k].extend(b.data()[k])
|
251
|
-
|
310
|
+
elif isinstance(sum_dict[k], Array):
|
252
311
|
sum_dict[k] = np.concatenate((sum_dict[k], b.data()[k]))
|
312
|
+
else:
|
313
|
+
sum_dict[k] += b.data()[k]
|
253
314
|
|
254
315
|
return type(a)(**sum_dict)
|
255
316
|
|