dataeval 0.82.1__py3-none-any.whl → 0.83.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +7 -2
- dataeval/config.py +10 -0
- dataeval/metadata/__init__.py +2 -2
- dataeval/metadata/_ood.py +144 -27
- dataeval/metrics/bias/_balance.py +3 -3
- dataeval/metrics/estimators/_ber.py +2 -1
- dataeval/metrics/stats/_base.py +17 -18
- dataeval/metrics/stats/_dimensionstats.py +2 -2
- dataeval/metrics/stats/_hashstats.py +2 -2
- dataeval/metrics/stats/_imagestats.py +4 -4
- dataeval/metrics/stats/_pixelstats.py +2 -2
- dataeval/metrics/stats/_visualstats.py +2 -2
- dataeval/outputs/__init__.py +2 -1
- dataeval/outputs/_metadata.py +7 -0
- dataeval/typing.py +40 -9
- dataeval/utils/_mst.py +1 -2
- dataeval/utils/data/_embeddings.py +15 -10
- dataeval/utils/data/_selection.py +22 -11
- dataeval/utils/data/datasets/_base.py +4 -2
- dataeval/utils/data/datasets/_cifar10.py +17 -9
- dataeval/utils/data/datasets/_milco.py +18 -12
- dataeval/utils/data/datasets/_mnist.py +24 -8
- dataeval/utils/data/datasets/_ships.py +18 -8
- dataeval/utils/data/datasets/_types.py +1 -5
- dataeval/utils/data/datasets/_voc.py +47 -24
- dataeval/utils/data/selections/__init__.py +2 -0
- dataeval/utils/data/selections/_classfilter.py +1 -1
- dataeval/utils/data/selections/_prioritize.py +296 -0
- dataeval/utils/data/selections/_shuffle.py +13 -4
- dataeval/utils/torch/_gmm.py +3 -2
- {dataeval-0.82.1.dist-info → dataeval-0.83.0.dist-info}/METADATA +4 -4
- {dataeval-0.82.1.dist-info → dataeval-0.83.0.dist-info}/RECORD +34 -34
- dataeval/detectors/ood/metadata_ood_mi.py +0 -91
- {dataeval-0.82.1.dist-info → dataeval-0.83.0.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.82.1.dist-info → dataeval-0.83.0.dist-info}/WHEEL +0 -0
dataeval/__init__.py
CHANGED
@@ -8,7 +8,7 @@ shifts that impact performance of deployed models.
|
|
8
8
|
from __future__ import annotations
|
9
9
|
|
10
10
|
__all__ = ["config", "detectors", "log", "metrics", "typing", "utils", "workflows"]
|
11
|
-
__version__ = "0.
|
11
|
+
__version__ = "0.83.0"
|
12
12
|
|
13
13
|
import logging
|
14
14
|
|
@@ -34,7 +34,12 @@ def log(level: int = logging.DEBUG, handler: logging.Handler | None = None) -> N
|
|
34
34
|
logger = logging.getLogger(__name__)
|
35
35
|
if handler is None:
|
36
36
|
handler = logging.StreamHandler() if handler is None else handler
|
37
|
-
handler.setFormatter(
|
37
|
+
handler.setFormatter(
|
38
|
+
logging.Formatter(
|
39
|
+
"%(asctime)s %(levelname)-8s %(name)s.%(filename)s:%(lineno)s - %(funcName)10s() | %(message)s"
|
40
|
+
)
|
41
|
+
)
|
38
42
|
logger.addHandler(handler)
|
39
43
|
logger.setLevel(level)
|
44
|
+
logging.DEBUG
|
40
45
|
logger.debug(f"Added logging handler {handler} to logger: {__name__}")
|
dataeval/config.py
CHANGED
@@ -17,10 +17,18 @@ else:
|
|
17
17
|
import numpy as np
|
18
18
|
import torch
|
19
19
|
|
20
|
+
### GLOBALS ###
|
21
|
+
|
20
22
|
_device: torch.device | None = None
|
21
23
|
_processes: int | None = None
|
22
24
|
_seed: int | None = None
|
23
25
|
|
26
|
+
### CONSTS ###
|
27
|
+
|
28
|
+
EPSILON = 1e-10
|
29
|
+
|
30
|
+
### TYPES ###
|
31
|
+
|
24
32
|
DeviceLike: TypeAlias = Union[int, str, tuple[str, int], torch.device]
|
25
33
|
"""
|
26
34
|
Type alias for types that are acceptable for specifying a torch.device.
|
@@ -30,6 +38,8 @@ See Also
|
|
30
38
|
`torch.device <https://pytorch.org/docs/stable/tensor_attributes.html#torch.device>`_
|
31
39
|
"""
|
32
40
|
|
41
|
+
### FUNCS ###
|
42
|
+
|
33
43
|
|
34
44
|
def _todevice(device: DeviceLike) -> torch.device:
|
35
45
|
return torch.device(*device) if isinstance(device, tuple) else torch.device(device)
|
dataeval/metadata/__init__.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
"""Explanatory functions using metadata and additional features such as ood or drift"""
|
2
2
|
|
3
|
-
__all__ = ["
|
3
|
+
__all__ = ["find_ood_predictors", "metadata_distance", "find_most_deviated_factors"]
|
4
4
|
|
5
5
|
from dataeval.metadata._distance import metadata_distance
|
6
|
-
from dataeval.metadata._ood import
|
6
|
+
from dataeval.metadata._ood import find_most_deviated_factors, find_ood_predictors
|
dataeval/metadata/_ood.py
CHANGED
@@ -6,14 +6,44 @@ import warnings
|
|
6
6
|
|
7
7
|
import numpy as np
|
8
8
|
from numpy.typing import NDArray
|
9
|
+
from sklearn.feature_selection import mutual_info_classif
|
9
10
|
|
11
|
+
from dataeval.config import get_seed
|
10
12
|
from dataeval.metadata._utils import _compare_keys, _validate_factors_and_data
|
11
|
-
from dataeval.outputs import MostDeviatedFactorsOutput, OODOutput
|
13
|
+
from dataeval.outputs import MostDeviatedFactorsOutput, OODOutput, OODPredictorOutput
|
12
14
|
from dataeval.outputs._base import set_metadata
|
13
15
|
from dataeval.utils.data import Metadata
|
14
16
|
|
15
17
|
|
16
|
-
def
|
18
|
+
def _combine_discrete_continuous(metadata: Metadata) -> tuple[list[str], NDArray[np.float64]]:
|
19
|
+
"""Combines the discrete and continuous data of a :class:`Metadata` object
|
20
|
+
|
21
|
+
Returns
|
22
|
+
-------
|
23
|
+
Tuple[list[str], NDArray]
|
24
|
+
The combined list of factors names and the combined discrete and continuous data
|
25
|
+
|
26
|
+
Note
|
27
|
+
----
|
28
|
+
Discrete and continuous data must have the same number of samples
|
29
|
+
"""
|
30
|
+
names = []
|
31
|
+
data = []
|
32
|
+
|
33
|
+
if metadata.discrete_factor_names and metadata.discrete_data.size != 0:
|
34
|
+
names.extend(metadata.discrete_factor_names)
|
35
|
+
data.append(metadata.discrete_data)
|
36
|
+
|
37
|
+
if metadata.continuous_factor_names and metadata.continuous_data.size != 0:
|
38
|
+
names.extend(metadata.continuous_factor_names)
|
39
|
+
data.append(metadata.continuous_data)
|
40
|
+
|
41
|
+
return names, np.hstack(data, dtype=np.float64) if data else np.array([], dtype=np.float64)
|
42
|
+
|
43
|
+
|
44
|
+
def _combine_metadata(
|
45
|
+
metadata_1: Metadata, metadata_2: Metadata
|
46
|
+
) -> tuple[list[str], list[NDArray[np.float64 | np.int64]], list[NDArray[np.int64 | np.float64]]]:
|
17
47
|
"""
|
18
48
|
Combines the factor names and data arrays of metadata_1 and metadata_2 when the names
|
19
49
|
match exactly and data has the same number of columns (factors).
|
@@ -42,8 +72,8 @@ def _combine_metadata(metadata_1: Metadata, metadata_2: Metadata) -> tuple[list[
|
|
42
72
|
If the length of keys do not match the length of the data
|
43
73
|
"""
|
44
74
|
factor_names: list[str] = []
|
45
|
-
m1_data: list[NDArray] = []
|
46
|
-
m2_data: list[NDArray] = []
|
75
|
+
m1_data: list[NDArray[np.int64 | np.float64]] = []
|
76
|
+
m2_data: list[NDArray[np.int64 | np.float64]] = []
|
47
77
|
|
48
78
|
# Both metadata must have the same number of factors (cols), but not necessarily samples (row)
|
49
79
|
if metadata_1.total_num_factors != metadata_2.total_num_factors:
|
@@ -121,36 +151,37 @@ def _calc_median_deviations(reference: NDArray, test: NDArray) -> NDArray:
|
|
121
151
|
|
122
152
|
|
123
153
|
@set_metadata
|
124
|
-
def
|
125
|
-
|
126
|
-
|
154
|
+
def find_most_deviated_factors(
|
155
|
+
metadata_ref: Metadata,
|
156
|
+
metadata_tst: Metadata,
|
127
157
|
ood: OODOutput,
|
128
158
|
) -> MostDeviatedFactorsOutput:
|
129
159
|
"""
|
130
|
-
|
160
|
+
Determine greatest deviation in metadata features per out of distribution sample in test metadata.
|
131
161
|
|
132
162
|
Parameters
|
133
163
|
----------
|
134
|
-
|
164
|
+
metadata_ref : Metadata
|
135
165
|
A reference set of Metadata containing factor names and samples
|
136
166
|
with discrete and/or continuous values per factor
|
137
|
-
|
167
|
+
metadata_tst : Metadata
|
138
168
|
The set of Metadata that is tested against the reference metadata.
|
139
169
|
This set must have the same number of features but does not require the same number of samples.
|
140
170
|
ood : OODOutput
|
141
|
-
A class output by
|
171
|
+
A class output by DataEval's OOD functions that contains which examples are OOD.
|
142
172
|
|
143
173
|
Returns
|
144
174
|
-------
|
145
|
-
|
146
|
-
An
|
175
|
+
MostDeviatedFactorsOutput
|
176
|
+
An output class containing the factor name and deviation of the highest metadata deviations for each
|
177
|
+
OOD example in the test metadata.
|
147
178
|
|
148
179
|
Notes
|
149
180
|
-----
|
150
181
|
1. Both :class:`.Metadata` inputs must have discrete and continuous data in the shape (samples, factors)
|
151
182
|
and have equivalent factor names and lengths
|
152
183
|
2. The flag at index `i` in :attr:`.OODOutput.is_ood` must correspond
|
153
|
-
directly to sample `i` of `
|
184
|
+
directly to sample `i` of `metadata_tst` being out-of-distribution from `metadata_ref`
|
154
185
|
|
155
186
|
Examples
|
156
187
|
--------
|
@@ -160,13 +191,13 @@ def most_deviated_factors(
|
|
160
191
|
All samples are out-of-distribution
|
161
192
|
|
162
193
|
>>> is_ood = OODOutput(np.array([True, True, True]), np.array([]), np.array([]))
|
163
|
-
>>>
|
194
|
+
>>> find_most_deviated_factors(metadata1, metadata2, is_ood)
|
164
195
|
MostDeviatedFactorsOutput([('time', 2.0), ('time', 2.592), ('time', 3.51)])
|
165
196
|
|
166
|
-
|
197
|
+
No samples are out-of-distribution
|
167
198
|
|
168
199
|
>>> is_ood = OODOutput(np.array([False, False, False]), np.array([]), np.array([]))
|
169
|
-
>>>
|
200
|
+
>>> find_most_deviated_factors(metadata1, metadata2, is_ood)
|
170
201
|
MostDeviatedFactorsOutput([])
|
171
202
|
"""
|
172
203
|
|
@@ -179,31 +210,30 @@ def most_deviated_factors(
|
|
179
210
|
# Combines reference and test factor names and data if exists and match exactly
|
180
211
|
# shape -> (samples, factors)
|
181
212
|
factor_names, md_1, md_2 = _combine_metadata(
|
182
|
-
metadata_1=
|
183
|
-
metadata_2=
|
213
|
+
metadata_1=metadata_ref,
|
214
|
+
metadata_2=metadata_tst,
|
184
215
|
)
|
185
216
|
|
186
217
|
# Stack discrete and continuous factors as separate factors. Must have equal sample counts
|
187
|
-
|
188
|
-
|
218
|
+
ref_data = np.hstack(md_1) if md_1 else np.array([]) # (S, Fd + Fc)
|
219
|
+
tst_data = np.hstack(md_2) if md_2 else np.array([]) # (S, Fd + Fc)
|
189
220
|
|
190
|
-
if len(
|
221
|
+
if len(ref_data) < 3:
|
191
222
|
warnings.warn(
|
192
|
-
f"At least 3 reference metadata samples are needed, got {len(
|
223
|
+
f"At least 3 reference metadata samples are needed, got {len(ref_data)}",
|
193
224
|
UserWarning,
|
194
225
|
)
|
195
226
|
return MostDeviatedFactorsOutput([])
|
196
227
|
|
197
|
-
if len(
|
228
|
+
if len(tst_data) != len(ood_mask):
|
198
229
|
raise ValueError(
|
199
|
-
f"ood and test metadata must have the same length, "
|
200
|
-
f"got {len(ood_mask)} and {len(metadata_tst)} respectively."
|
230
|
+
f"ood and test metadata must have the same length, got {len(ood_mask)} and {len(tst_data)} respectively."
|
201
231
|
)
|
202
232
|
|
203
233
|
# Calculates deviations of all samples in m2_data
|
204
234
|
# from the median values of the corresponding index in m1_data
|
205
235
|
# Guaranteed for inputs to not be empty
|
206
|
-
deviations = _calc_median_deviations(
|
236
|
+
deviations = _calc_median_deviations(ref_data, tst_data)
|
207
237
|
|
208
238
|
# Get most impactful factor deviation of each sample for ood samples only
|
209
239
|
deviation = np.max(deviations, axis=1)[ood_mask].astype(np.float16)
|
@@ -217,3 +247,90 @@ def most_deviated_factors(
|
|
217
247
|
# List of tuples matching the factor name with its deviation
|
218
248
|
|
219
249
|
return MostDeviatedFactorsOutput([(factor, dev) for factor, dev in zip(most_ood_factors, deviation)])
|
250
|
+
|
251
|
+
|
252
|
+
_NATS2BITS = 1.442695
|
253
|
+
"""
|
254
|
+
_NATS2BITS is the reciprocal of natural log of 2. If you have an information/entropy-type quantity measured in nats,
|
255
|
+
which is what many library functions return, multiply it by _NATS2BITS to get it in bits.
|
256
|
+
"""
|
257
|
+
|
258
|
+
|
259
|
+
def find_ood_predictors(
|
260
|
+
metadata: Metadata,
|
261
|
+
ood: OODOutput,
|
262
|
+
) -> OODPredictorOutput:
|
263
|
+
"""Computes mutual information between a set of metadata features and per sample out-of-distribution flags.
|
264
|
+
|
265
|
+
Given a set of metadata features per sample and a corresponding OODOutput that indicates whether a sample was
|
266
|
+
determined to be out of distribution, this function calculates the mutual information between each factor and being
|
267
|
+
out of distribution. In other words, it finds which metadata factors most likely correlate to an
|
268
|
+
out of distribution sample.
|
269
|
+
|
270
|
+
Note
|
271
|
+
----
|
272
|
+
A high mutual information between a factor and ood samples is an indication of correlation, but not causation.
|
273
|
+
Additional analysis should be done to determine how to handle factors with a high mutual information.
|
274
|
+
|
275
|
+
|
276
|
+
Parameters
|
277
|
+
----------
|
278
|
+
metadata : Metadata
|
279
|
+
A set of arrays of values, indexed by metadata feature names, with one value per data example per feature.
|
280
|
+
ood : OODOutput
|
281
|
+
A class output by DataEval's OOD functions that contains which examples are OOD.
|
282
|
+
|
283
|
+
Returns
|
284
|
+
-------
|
285
|
+
OODPredictorOutput
|
286
|
+
A dictionary with keys corresponding to metadata feature names, and values indicating the strength of
|
287
|
+
association between each named feature and the OOD flag, as mutual information measured in bits.
|
288
|
+
|
289
|
+
Examples
|
290
|
+
--------
|
291
|
+
>>> from dataeval.outputs import OODOutput
|
292
|
+
|
293
|
+
All samples are out-of-distribution
|
294
|
+
|
295
|
+
>>> is_ood = OODOutput(np.array([True, True, True]), np.array([]), np.array([]))
|
296
|
+
>>> find_ood_predictors(metadata1, is_ood)
|
297
|
+
OODPredictorOutput({'time': 8.008566032557951e-17, 'altitude': 8.008566032557951e-17})
|
298
|
+
|
299
|
+
No out-of-distribution samples
|
300
|
+
|
301
|
+
>> is_ood = OODOutput(np.array([False, False, False]), np.array([]), np.array([]))
|
302
|
+
>> find_ood_predictors(metadata1, is_ood)
|
303
|
+
OODPredictorOutput({})
|
304
|
+
"""
|
305
|
+
|
306
|
+
ood_mask: NDArray[np.bool] = ood.is_ood
|
307
|
+
|
308
|
+
discrete_features_count = len(metadata.discrete_factor_names)
|
309
|
+
factors, data = _combine_discrete_continuous(metadata) # (F, ), (S, F) => F = Fd + Fc
|
310
|
+
|
311
|
+
# No metadata correlated with out of distribution data, return 0.0 for all factors
|
312
|
+
if not any(ood_mask):
|
313
|
+
return OODPredictorOutput(dict.fromkeys(factors, 0.0))
|
314
|
+
|
315
|
+
if len(data) != len(ood_mask):
|
316
|
+
raise ValueError(
|
317
|
+
f"ood and metadata must have the same length, got {len(ood_mask)} and {len(data)} respectively."
|
318
|
+
)
|
319
|
+
|
320
|
+
# Calculate mean, std of each factor over all samples
|
321
|
+
scaled_data = (data - np.mean(data, axis=0)) / np.std(data, axis=0, ddof=1) # (S, F)
|
322
|
+
|
323
|
+
discrete_features = np.zeros_like(factors, dtype=np.bool)
|
324
|
+
discrete_features[:discrete_features_count] = True
|
325
|
+
|
326
|
+
mutual_info_values = (
|
327
|
+
mutual_info_classif(
|
328
|
+
X=scaled_data,
|
329
|
+
y=ood_mask,
|
330
|
+
discrete_features=discrete_features, # type: ignore -> sklearn issue - NDArray[bool] not of accepted type Union[ArrayLike, 'auto']
|
331
|
+
random_state=get_seed(),
|
332
|
+
)
|
333
|
+
* _NATS2BITS
|
334
|
+
)
|
335
|
+
|
336
|
+
return OODPredictorOutput({k: mutual_info_values[i] for i, k in enumerate(factors)})
|
@@ -8,7 +8,7 @@ import numpy as np
|
|
8
8
|
import scipy as sp
|
9
9
|
from sklearn.feature_selection import mutual_info_classif, mutual_info_regression
|
10
10
|
|
11
|
-
from dataeval.config import get_seed
|
11
|
+
from dataeval.config import EPSILON, get_seed
|
12
12
|
from dataeval.outputs import BalanceOutput
|
13
13
|
from dataeval.outputs._base import set_metadata
|
14
14
|
from dataeval.utils._bin import get_counts
|
@@ -128,7 +128,7 @@ def balance(
|
|
128
128
|
# Normalization via entropy
|
129
129
|
bin_cnts = get_counts(discretized_data)
|
130
130
|
ent_factor = sp.stats.entropy(bin_cnts, axis=0)
|
131
|
-
norm_factor = 0.5 * np.add.outer(ent_factor, ent_factor) +
|
131
|
+
norm_factor = 0.5 * np.add.outer(ent_factor, ent_factor) + EPSILON
|
132
132
|
|
133
133
|
# in principle MI should be symmetric, but it is not in practice.
|
134
134
|
nmi = 0.5 * (mi + mi.T) / norm_factor
|
@@ -157,7 +157,7 @@ def balance(
|
|
157
157
|
# Classwise normalization via entropy
|
158
158
|
classwise_bin_cnts = get_counts(tgt_bin)
|
159
159
|
ent_tgt_bin = sp.stats.entropy(classwise_bin_cnts, axis=0)
|
160
|
-
norm_factor = 0.5 * np.add.outer(ent_tgt_bin, ent_factor) +
|
160
|
+
norm_factor = 0.5 * np.add.outer(ent_tgt_bin, ent_factor) + EPSILON
|
161
161
|
classwise = classwise_mi / norm_factor
|
162
162
|
|
163
163
|
# Grabbing factor names for plotting function
|
@@ -19,6 +19,7 @@ from numpy.typing import NDArray
|
|
19
19
|
from scipy.sparse import coo_matrix
|
20
20
|
from scipy.stats import mode
|
21
21
|
|
22
|
+
from dataeval.config import EPSILON
|
22
23
|
from dataeval.outputs import BEROutput
|
23
24
|
from dataeval.outputs._base import set_metadata
|
24
25
|
from dataeval.typing import ArrayLike
|
@@ -82,7 +83,7 @@ def ber_knn(images: NDArray[np.float64], labels: NDArray[np.int_], k: int) -> tu
|
|
82
83
|
|
83
84
|
def knn_lowerbound(value: float, classes: int, k: int) -> float:
|
84
85
|
"""Several cases for computing the BER lower bound"""
|
85
|
-
if value <=
|
86
|
+
if value <= EPSILON:
|
86
87
|
return 0.0
|
87
88
|
|
88
89
|
if classes == 2 and k != 1:
|
dataeval/metrics/stats/_base.py
CHANGED
@@ -9,7 +9,7 @@ from copy import deepcopy
|
|
9
9
|
from dataclasses import dataclass
|
10
10
|
from functools import partial
|
11
11
|
from multiprocessing import Pool
|
12
|
-
from typing import Any, Callable, Generic, Iterable, Sequence, TypeVar
|
12
|
+
from typing import Any, Callable, Generic, Iterable, Sequence, TypeVar
|
13
13
|
|
14
14
|
import numpy as np
|
15
15
|
import tqdm
|
@@ -17,7 +17,7 @@ from numpy.typing import NDArray
|
|
17
17
|
|
18
18
|
from dataeval.config import get_max_processes
|
19
19
|
from dataeval.outputs._stats import BaseStatsOutput, SourceIndex
|
20
|
-
from dataeval.typing import ArrayLike, Dataset, ObjectDetectionTarget
|
20
|
+
from dataeval.typing import Array, ArrayLike, Dataset, ObjectDetectionTarget
|
21
21
|
from dataeval.utils._array import to_numpy
|
22
22
|
from dataeval.utils._image import normalize_image_shape, rescale
|
23
23
|
|
@@ -122,22 +122,19 @@ class StatsProcessorOutput:
|
|
122
122
|
|
123
123
|
def process_stats(
|
124
124
|
i: int,
|
125
|
-
|
125
|
+
image: ArrayLike,
|
126
|
+
target: Any,
|
126
127
|
per_box: bool,
|
127
128
|
per_channel: bool,
|
128
129
|
stats_processor_cls: Iterable[type[StatsProcessor[TStatsOutput]]],
|
129
130
|
) -> StatsProcessorOutput:
|
130
|
-
|
131
|
-
|
132
|
-
target = None if not isinstance(target, ObjectDetectionTarget) else target
|
133
|
-
boxes = to_numpy(target.boxes) if target is not None else None
|
131
|
+
image = to_numpy(image)
|
132
|
+
boxes = to_numpy(target.boxes) if isinstance(target, ObjectDetectionTarget) else None
|
134
133
|
results_list: list[dict[str, Any]] = []
|
135
134
|
source_indices: list[SourceIndex] = []
|
136
135
|
box_counts: list[int] = []
|
137
136
|
warnings_list: list[str] = []
|
138
|
-
|
139
|
-
for i_b, box in enumerate(nboxes):
|
140
|
-
i_b = None if box is None else i_b
|
137
|
+
for i_b, box in [(None, None)] if boxes is None else enumerate(normalize_box_shape(boxes)):
|
141
138
|
processor_list = [p(image, box, per_channel) for p in stats_processor_cls]
|
142
139
|
if any(not p._is_valid_slice for p in processor_list) and i_b is not None and box is not None:
|
143
140
|
warnings_list.append(f"Bounding box [{i}][{i_b}]: {box} is out of bounds of {image.shape}.")
|
@@ -151,17 +148,16 @@ def process_stats(
|
|
151
148
|
|
152
149
|
|
153
150
|
def process_stats_unpack(
|
154
|
-
|
155
|
-
dataset: Dataset[ArrayLike] | Dataset[tuple[ArrayLike, Any, Any]],
|
151
|
+
args: tuple[int, ArrayLike, Any],
|
156
152
|
per_box: bool,
|
157
153
|
per_channel: bool,
|
158
154
|
stats_processor_cls: Iterable[type[StatsProcessor[TStatsOutput]]],
|
159
155
|
) -> StatsProcessorOutput:
|
160
|
-
return process_stats(
|
156
|
+
return process_stats(*args, per_box=per_box, per_channel=per_channel, stats_processor_cls=stats_processor_cls)
|
161
157
|
|
162
158
|
|
163
159
|
def run_stats(
|
164
|
-
dataset: Dataset[
|
160
|
+
dataset: Dataset[Array] | Dataset[tuple[Array, Any, Any]],
|
165
161
|
per_box: bool,
|
166
162
|
per_channel: bool,
|
167
163
|
stats_processor_cls: Iterable[type[StatsProcessor[TStatsOutput]]],
|
@@ -175,7 +171,7 @@ def run_stats(
|
|
175
171
|
|
176
172
|
Parameters
|
177
173
|
----------
|
178
|
-
data : Dataset[
|
174
|
+
data : Dataset[Array] | Dataset[tuple[Array, Any, Any]]
|
179
175
|
A dataset of images and targets to compute statistics on.
|
180
176
|
per_box : bool
|
181
177
|
A flag which determines if the statistics should be evaluated on a per-box basis or not.
|
@@ -206,18 +202,21 @@ def run_stats(
|
|
206
202
|
warning_list = []
|
207
203
|
stats_processor_cls = stats_processor_cls if isinstance(stats_processor_cls, Iterable) else [stats_processor_cls]
|
208
204
|
|
209
|
-
|
205
|
+
def _enumerate(dataset: Dataset[Array] | Dataset[tuple[Array, Any, Any]], per_box: bool):
|
206
|
+
for i in range(len(dataset)):
|
207
|
+
d = dataset[i]
|
208
|
+
yield i, d[0] if isinstance(d, tuple) else d, d[1] if isinstance(d, tuple) and per_box else None
|
209
|
+
|
210
210
|
with Pool(processes=get_max_processes()) as p:
|
211
211
|
for r in tqdm.tqdm(
|
212
212
|
p.imap(
|
213
213
|
partial(
|
214
214
|
process_stats_unpack,
|
215
|
-
dataset=dataset,
|
216
215
|
per_box=per_box,
|
217
216
|
per_channel=per_channel,
|
218
217
|
stats_processor_cls=stats_processor_cls,
|
219
218
|
),
|
220
|
-
|
219
|
+
_enumerate(dataset, per_box),
|
221
220
|
),
|
222
221
|
total=len(dataset),
|
223
222
|
):
|
@@ -9,7 +9,7 @@ import numpy as np
|
|
9
9
|
from dataeval.metrics.stats._base import StatsProcessor, run_stats
|
10
10
|
from dataeval.outputs import DimensionStatsOutput
|
11
11
|
from dataeval.outputs._base import set_metadata
|
12
|
-
from dataeval.typing import
|
12
|
+
from dataeval.typing import Array, Dataset
|
13
13
|
from dataeval.utils._image import get_bitdepth
|
14
14
|
|
15
15
|
|
@@ -34,7 +34,7 @@ class DimensionStatsProcessor(StatsProcessor[DimensionStatsOutput]):
|
|
34
34
|
|
35
35
|
@set_metadata
|
36
36
|
def dimensionstats(
|
37
|
-
dataset: Dataset[
|
37
|
+
dataset: Dataset[Array] | Dataset[tuple[Array, Any, Any]],
|
38
38
|
*,
|
39
39
|
per_box: bool = False,
|
40
40
|
) -> DimensionStatsOutput:
|
@@ -14,7 +14,7 @@ from scipy.fftpack import dct
|
|
14
14
|
from dataeval.metrics.stats._base import StatsProcessor, run_stats
|
15
15
|
from dataeval.outputs import HashStatsOutput
|
16
16
|
from dataeval.outputs._base import set_metadata
|
17
|
-
from dataeval.typing import ArrayLike, Dataset
|
17
|
+
from dataeval.typing import Array, ArrayLike, Dataset
|
18
18
|
from dataeval.utils._array import as_numpy
|
19
19
|
from dataeval.utils._image import normalize_image_shape, rescale
|
20
20
|
|
@@ -105,7 +105,7 @@ class HashStatsProcessor(StatsProcessor[HashStatsOutput]):
|
|
105
105
|
|
106
106
|
@set_metadata
|
107
107
|
def hashstats(
|
108
|
-
dataset: Dataset[
|
108
|
+
dataset: Dataset[Array] | Dataset[tuple[Array, Any, Any]],
|
109
109
|
*,
|
110
110
|
per_box: bool = False,
|
111
111
|
) -> HashStatsOutput:
|
@@ -10,12 +10,12 @@ from dataeval.metrics.stats._pixelstats import PixelStatsProcessor
|
|
10
10
|
from dataeval.metrics.stats._visualstats import VisualStatsProcessor
|
11
11
|
from dataeval.outputs import ChannelStatsOutput, ImageStatsOutput
|
12
12
|
from dataeval.outputs._base import set_metadata
|
13
|
-
from dataeval.typing import
|
13
|
+
from dataeval.typing import Array, Dataset
|
14
14
|
|
15
15
|
|
16
16
|
@overload
|
17
17
|
def imagestats(
|
18
|
-
dataset: Dataset[
|
18
|
+
dataset: Dataset[Array] | Dataset[tuple[Array, Any, Any]],
|
19
19
|
*,
|
20
20
|
per_box: bool = False,
|
21
21
|
per_channel: Literal[True],
|
@@ -24,7 +24,7 @@ def imagestats(
|
|
24
24
|
|
25
25
|
@overload
|
26
26
|
def imagestats(
|
27
|
-
dataset: Dataset[
|
27
|
+
dataset: Dataset[Array] | Dataset[tuple[Array, Any, Any]],
|
28
28
|
*,
|
29
29
|
per_box: bool = False,
|
30
30
|
per_channel: Literal[False] = False,
|
@@ -33,7 +33,7 @@ def imagestats(
|
|
33
33
|
|
34
34
|
@set_metadata
|
35
35
|
def imagestats(
|
36
|
-
dataset: Dataset[
|
36
|
+
dataset: Dataset[Array] | Dataset[tuple[Array, Any, Any]],
|
37
37
|
*,
|
38
38
|
per_box: bool = False,
|
39
39
|
per_channel: bool = False,
|
@@ -10,7 +10,7 @@ from scipy.stats import entropy, kurtosis, skew
|
|
10
10
|
from dataeval.metrics.stats._base import StatsProcessor, run_stats
|
11
11
|
from dataeval.outputs import PixelStatsOutput
|
12
12
|
from dataeval.outputs._base import set_metadata
|
13
|
-
from dataeval.typing import
|
13
|
+
from dataeval.typing import Array, Dataset
|
14
14
|
|
15
15
|
|
16
16
|
class PixelStatsProcessor(StatsProcessor[PixelStatsOutput]):
|
@@ -37,7 +37,7 @@ class PixelStatsProcessor(StatsProcessor[PixelStatsOutput]):
|
|
37
37
|
|
38
38
|
@set_metadata
|
39
39
|
def pixelstats(
|
40
|
-
dataset: Dataset[
|
40
|
+
dataset: Dataset[Array] | Dataset[tuple[Array, Any, Any]],
|
41
41
|
*,
|
42
42
|
per_box: bool = False,
|
43
43
|
per_channel: bool = False,
|
@@ -9,7 +9,7 @@ import numpy as np
|
|
9
9
|
from dataeval.metrics.stats._base import StatsProcessor, run_stats
|
10
10
|
from dataeval.outputs import VisualStatsOutput
|
11
11
|
from dataeval.outputs._base import set_metadata
|
12
|
-
from dataeval.typing import
|
12
|
+
from dataeval.typing import Array, Dataset
|
13
13
|
from dataeval.utils._image import edge_filter
|
14
14
|
|
15
15
|
QUARTILES = (0, 25, 50, 75, 100)
|
@@ -44,7 +44,7 @@ class VisualStatsProcessor(StatsProcessor[VisualStatsOutput]):
|
|
44
44
|
|
45
45
|
@set_metadata
|
46
46
|
def visualstats(
|
47
|
-
dataset: Dataset[
|
47
|
+
dataset: Dataset[Array] | Dataset[tuple[Array, Any, Any]],
|
48
48
|
*,
|
49
49
|
per_box: bool = False,
|
50
50
|
per_channel: bool = False,
|
dataeval/outputs/__init__.py
CHANGED
@@ -8,7 +8,7 @@ from ._bias import BalanceOutput, CoverageOutput, DiversityOutput, LabelParityOu
|
|
8
8
|
from ._drift import DriftMMDOutput, DriftOutput
|
9
9
|
from ._estimators import BEROutput, ClustererOutput, DivergenceOutput, UAPOutput
|
10
10
|
from ._linters import DuplicatesOutput, OutliersOutput
|
11
|
-
from ._metadata import MetadataDistanceOutput, MetadataDistanceValues, MostDeviatedFactorsOutput
|
11
|
+
from ._metadata import MetadataDistanceOutput, MetadataDistanceValues, MostDeviatedFactorsOutput, OODPredictorOutput
|
12
12
|
from ._ood import OODOutput, OODScoreOutput
|
13
13
|
from ._stats import (
|
14
14
|
ChannelStatsOutput,
|
@@ -44,6 +44,7 @@ __all__ = [
|
|
44
44
|
"MetadataDistanceValues",
|
45
45
|
"MostDeviatedFactorsOutput",
|
46
46
|
"OODOutput",
|
47
|
+
"OODPredictorOutput",
|
47
48
|
"OODScoreOutput",
|
48
49
|
"OutliersOutput",
|
49
50
|
"ParityOutput",
|
dataeval/outputs/_metadata.py
CHANGED
@@ -52,3 +52,10 @@ class MetadataDistanceOutput(MappingOutput[str, MetadataDistanceValues]):
|
|
52
52
|
value : :class:`.MetadataDistanceValues`
|
53
53
|
Output per feature name containing the statistic, statistic location, distance, and pvalue.
|
54
54
|
"""
|
55
|
+
|
56
|
+
|
57
|
+
class OODPredictorOutput(MappingOutput[str, float]):
|
58
|
+
"""
|
59
|
+
Output class for results of :func:`find_ood_predictors` for the
|
60
|
+
mutual information between factors and being out of distribution
|
61
|
+
"""
|