dataeval 0.70.0__py3-none-any.whl → 0.71.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +6 -6
- dataeval/_internal/datasets.py +235 -131
- dataeval/_internal/detectors/clusterer.py +2 -0
- dataeval/_internal/detectors/drift/base.py +2 -2
- dataeval/_internal/detectors/drift/mmd.py +1 -1
- dataeval/_internal/detectors/duplicates.py +2 -0
- dataeval/_internal/detectors/ood/ae.py +5 -3
- dataeval/_internal/detectors/ood/aegmm.py +6 -4
- dataeval/_internal/detectors/ood/base.py +12 -7
- dataeval/_internal/detectors/ood/llr.py +6 -4
- dataeval/_internal/detectors/ood/vae.py +5 -3
- dataeval/_internal/detectors/ood/vaegmm.py +6 -4
- dataeval/_internal/detectors/outliers.py +6 -9
- dataeval/_internal/metrics/balance.py +4 -2
- dataeval/_internal/metrics/ber.py +2 -0
- dataeval/_internal/metrics/coverage.py +4 -0
- dataeval/_internal/metrics/divergence.py +6 -2
- dataeval/_internal/metrics/diversity.py +8 -6
- dataeval/_internal/metrics/parity.py +8 -6
- dataeval/_internal/metrics/stats/base.py +105 -46
- dataeval/_internal/metrics/stats/datasetstats.py +96 -22
- dataeval/_internal/metrics/stats/dimensionstats.py +22 -20
- dataeval/_internal/metrics/stats/hashstats.py +11 -9
- dataeval/_internal/metrics/stats/labelstats.py +1 -1
- dataeval/_internal/metrics/stats/pixelstats.py +28 -26
- dataeval/_internal/metrics/stats/visualstats.py +37 -35
- dataeval/_internal/metrics/uap.py +6 -2
- dataeval/_internal/metrics/utils.py +2 -2
- dataeval/_internal/models/pytorch/autoencoder.py +5 -5
- dataeval/_internal/models/tensorflow/pixelcnn.py +1 -4
- dataeval/_internal/utils.py +11 -16
- dataeval/_internal/workflows/sufficiency.py +44 -33
- dataeval/detectors/__init__.py +4 -0
- dataeval/detectors/drift/__init__.py +8 -3
- dataeval/detectors/drift/kernels/__init__.py +4 -0
- dataeval/detectors/drift/updates/__init__.py +4 -0
- dataeval/detectors/linters/__init__.py +15 -4
- dataeval/detectors/ood/__init__.py +14 -2
- dataeval/metrics/__init__.py +5 -0
- dataeval/metrics/bias/__init__.py +13 -4
- dataeval/metrics/estimators/__init__.py +8 -8
- dataeval/metrics/stats/__init__.py +24 -6
- dataeval/utils/__init__.py +16 -3
- dataeval/utils/tensorflow/__init__.py +11 -0
- dataeval/utils/torch/__init__.py +12 -0
- dataeval/utils/torch/datasets/__init__.py +7 -0
- dataeval/workflows/__init__.py +4 -0
- {dataeval-0.70.0.dist-info → dataeval-0.71.0.dist-info}/METADATA +11 -2
- dataeval-0.71.0.dist-info/RECORD +80 -0
- dataeval/tensorflow/__init__.py +0 -3
- dataeval/torch/__init__.py +0 -3
- dataeval-0.70.0.dist-info/RECORD +0 -79
- /dataeval/{tensorflow → utils/tensorflow}/loss/__init__.py +0 -0
- /dataeval/{tensorflow → utils/tensorflow}/models/__init__.py +0 -0
- /dataeval/{tensorflow → utils/tensorflow}/recon/__init__.py +0 -0
- /dataeval/{torch → utils/torch}/models/__init__.py +0 -0
- /dataeval/{torch → utils/torch}/trainer/__init__.py +0 -0
- {dataeval-0.70.0.dist-info → dataeval-0.71.0.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.70.0.dist-info → dataeval-0.71.0.dist-info}/WHEEL +0 -0
@@ -10,7 +10,7 @@ from __future__ import annotations
|
|
10
10
|
|
11
11
|
from abc import ABC, abstractmethod
|
12
12
|
from dataclasses import dataclass
|
13
|
-
from typing import Callable, Literal,
|
13
|
+
from typing import Callable, Literal, cast
|
14
14
|
|
15
15
|
import keras
|
16
16
|
import numpy as np
|
@@ -26,6 +26,9 @@ from dataeval._internal.output import OutputMetadata, set_metadata
|
|
26
26
|
@dataclass(frozen=True)
|
27
27
|
class OODOutput(OutputMetadata):
|
28
28
|
"""
|
29
|
+
Output class for predictions from :class:`OOD_AE`, :class:`OOD_AEGMM`, :class:`OOD_LLR`,
|
30
|
+
:class:`OOD_VAE`, and :class:`OOD_VAEGMM` out-of-distribution detectors
|
31
|
+
|
29
32
|
Attributes
|
30
33
|
----------
|
31
34
|
is_ood : NDArray
|
@@ -41,9 +44,11 @@ class OODOutput(OutputMetadata):
|
|
41
44
|
feature_score: NDArray[np.float32] | None
|
42
45
|
|
43
46
|
|
44
|
-
|
47
|
+
@dataclass(frozen=True)
|
48
|
+
class OODScoreOutput(OutputMetadata):
|
45
49
|
"""
|
46
|
-
|
50
|
+
Output class for instance and feature scores from :class:`OOD_AE`, :class:`OOD_AEGMM`,
|
51
|
+
:class:`OOD_LLR`, :class:`OOD_VAE`, and :class:`OOD_VAEGMM` out-of-distribution detectors
|
47
52
|
|
48
53
|
Parameters
|
49
54
|
----------
|
@@ -76,7 +81,7 @@ class OODBase(ABC):
|
|
76
81
|
def __init__(self, model: keras.Model) -> None:
|
77
82
|
self.model = model
|
78
83
|
|
79
|
-
self._ref_score:
|
84
|
+
self._ref_score: OODScoreOutput
|
80
85
|
self._threshold_perc: float
|
81
86
|
self._data_info: tuple[tuple, type] | None = None
|
82
87
|
|
@@ -102,7 +107,7 @@ class OODBase(ABC):
|
|
102
107
|
self._validate(X)
|
103
108
|
|
104
109
|
@abstractmethod
|
105
|
-
def score(self, X: ArrayLike, batch_size: int = int(1e10)) ->
|
110
|
+
def score(self, X: ArrayLike, batch_size: int = int(1e10)) -> OODScoreOutput:
|
106
111
|
"""
|
107
112
|
Compute the out-of-distribution (OOD) scores for a given dataset.
|
108
113
|
|
@@ -116,7 +121,7 @@ class OODBase(ABC):
|
|
116
121
|
|
117
122
|
Returns
|
118
123
|
-------
|
119
|
-
|
124
|
+
OODScoreOutput
|
120
125
|
An object containing the instance-level and feature-level OOD scores.
|
121
126
|
"""
|
122
127
|
|
@@ -197,7 +202,7 @@ class OODBase(ABC):
|
|
197
202
|
# compute outlier scores
|
198
203
|
score = self.score(X, batch_size=batch_size)
|
199
204
|
ood_pred = score.get(ood_type) > self._threshold_score(ood_type)
|
200
|
-
return OODOutput(is_ood=ood_pred, **score.
|
205
|
+
return OODOutput(is_ood=ood_pred, **score.dict())
|
201
206
|
|
202
207
|
|
203
208
|
class OODGMMBase(OODBase):
|
@@ -18,11 +18,12 @@ from keras.layers import Input
|
|
18
18
|
from keras.models import Model
|
19
19
|
from numpy.typing import ArrayLike, NDArray
|
20
20
|
|
21
|
-
from dataeval._internal.detectors.ood.base import OODBase,
|
21
|
+
from dataeval._internal.detectors.ood.base import OODBase, OODScoreOutput
|
22
22
|
from dataeval._internal.interop import to_numpy
|
23
23
|
from dataeval._internal.models.tensorflow.pixelcnn import PixelCNN
|
24
24
|
from dataeval._internal.models.tensorflow.trainer import trainer
|
25
25
|
from dataeval._internal.models.tensorflow.utils import predict_batch
|
26
|
+
from dataeval._internal.output import set_metadata
|
26
27
|
|
27
28
|
|
28
29
|
def build_model(
|
@@ -124,7 +125,7 @@ class OOD_LLR(OODBase):
|
|
124
125
|
self.sequential = sequential
|
125
126
|
self.log_prob = log_prob
|
126
127
|
|
127
|
-
self._ref_score:
|
128
|
+
self._ref_score: OODScoreOutput
|
128
129
|
self._threshold_perc: float
|
129
130
|
self._data_info: tuple[tuple, type] | None = None
|
130
131
|
|
@@ -279,12 +280,13 @@ class OOD_LLR(OODBase):
|
|
279
280
|
logp_b = logp_fn(self.dist_b, X, return_per_feature=return_per_feature, batch_size=batch_size)
|
280
281
|
return logp_s - logp_b
|
281
282
|
|
283
|
+
@set_metadata("dataeval.detectors")
|
282
284
|
def score(
|
283
285
|
self,
|
284
286
|
X: ArrayLike,
|
285
287
|
batch_size: int = int(1e10),
|
286
|
-
) ->
|
288
|
+
) -> OODScoreOutput:
|
287
289
|
self._validate(X := to_numpy(X))
|
288
290
|
fscore = -self._llr(X, True, batch_size=batch_size)
|
289
291
|
iscore = -self._llr(X, False, batch_size=batch_size)
|
290
|
-
return
|
292
|
+
return OODScoreOutput(iscore, fscore)
|
@@ -15,11 +15,12 @@ import numpy as np
|
|
15
15
|
import tensorflow as tf
|
16
16
|
from numpy.typing import ArrayLike
|
17
17
|
|
18
|
-
from dataeval._internal.detectors.ood.base import OODBase,
|
18
|
+
from dataeval._internal.detectors.ood.base import OODBase, OODScoreOutput
|
19
19
|
from dataeval._internal.interop import to_numpy
|
20
20
|
from dataeval._internal.models.tensorflow.autoencoder import VAE
|
21
21
|
from dataeval._internal.models.tensorflow.losses import Elbo
|
22
22
|
from dataeval._internal.models.tensorflow.utils import predict_batch
|
23
|
+
from dataeval._internal.output import set_metadata
|
23
24
|
|
24
25
|
|
25
26
|
class OOD_VAE(OODBase):
|
@@ -67,7 +68,8 @@ class OOD_VAE(OODBase):
|
|
67
68
|
loss_fn = Elbo(0.05)
|
68
69
|
super().fit(x_ref, threshold_perc, loss_fn, optimizer, epochs, batch_size, verbose)
|
69
70
|
|
70
|
-
|
71
|
+
@set_metadata("dataeval.detectors")
|
72
|
+
def score(self, X: ArrayLike, batch_size: int = int(1e10)) -> OODScoreOutput:
|
71
73
|
self._validate(X := to_numpy(X))
|
72
74
|
|
73
75
|
# sample reconstructed instances
|
@@ -86,4 +88,4 @@ class OOD_VAE(OODBase):
|
|
86
88
|
sorted_fscore_perc = sorted_fscore[:, -n_score_features:]
|
87
89
|
iscore = np.mean(sorted_fscore_perc, axis=1)
|
88
90
|
|
89
|
-
return
|
91
|
+
return OODScoreOutput(iscore, fscore)
|
@@ -15,12 +15,13 @@ import numpy as np
|
|
15
15
|
import tensorflow as tf
|
16
16
|
from numpy.typing import ArrayLike
|
17
17
|
|
18
|
-
from dataeval._internal.detectors.ood.base import OODGMMBase,
|
18
|
+
from dataeval._internal.detectors.ood.base import OODGMMBase, OODScoreOutput
|
19
19
|
from dataeval._internal.interop import to_numpy
|
20
20
|
from dataeval._internal.models.tensorflow.autoencoder import VAEGMM
|
21
21
|
from dataeval._internal.models.tensorflow.gmm import gmm_energy
|
22
22
|
from dataeval._internal.models.tensorflow.losses import Elbo, LossGMM
|
23
23
|
from dataeval._internal.models.tensorflow.utils import predict_batch
|
24
|
+
from dataeval._internal.output import set_metadata
|
24
25
|
|
25
26
|
|
26
27
|
class OOD_VAEGMM(OODGMMBase):
|
@@ -53,7 +54,8 @@ class OOD_VAEGMM(OODGMMBase):
|
|
53
54
|
loss_fn = LossGMM(elbo=Elbo(0.05))
|
54
55
|
super().fit(x_ref, threshold_perc, loss_fn, optimizer, epochs, batch_size, verbose)
|
55
56
|
|
56
|
-
|
57
|
+
@set_metadata("dataeval.detectors")
|
58
|
+
def score(self, X: ArrayLike, batch_size: int = int(1e10)) -> OODScoreOutput:
|
57
59
|
"""
|
58
60
|
Compute the out-of-distribution (OOD) score for a given dataset.
|
59
61
|
|
@@ -67,7 +69,7 @@ class OOD_VAEGMM(OODGMMBase):
|
|
67
69
|
|
68
70
|
Returns
|
69
71
|
-------
|
70
|
-
|
72
|
+
OODScoreOutput
|
71
73
|
An object containing the instance-level OOD score.
|
72
74
|
|
73
75
|
Note
|
@@ -84,4 +86,4 @@ class OOD_VAEGMM(OODGMMBase):
|
|
84
86
|
energy, _ = gmm_energy(z, self.gmm_params, return_mean=False)
|
85
87
|
energy_samples = energy.numpy().reshape((-1, self.samples)) # type: ignore
|
86
88
|
iscore = np.mean(energy_samples, axis=-1)
|
87
|
-
return
|
89
|
+
return OODScoreOutput(iscore)
|
@@ -22,6 +22,8 @@ TIndexIssueMap = TypeVar("TIndexIssueMap", IndexIssueMap, list[IndexIssueMap])
|
|
22
22
|
@dataclass(frozen=True)
|
23
23
|
class OutliersOutput(Generic[TIndexIssueMap], OutputMetadata):
|
24
24
|
"""
|
25
|
+
Output class for :class:`Outliers` lint detector
|
26
|
+
|
25
27
|
Attributes
|
26
28
|
----------
|
27
29
|
issues : dict[int, dict[str, float]] | list[dict[int, dict[str, float]]]
|
@@ -86,8 +88,8 @@ class Outliers:
|
|
86
88
|
--------
|
87
89
|
Duplicates
|
88
90
|
|
89
|
-
|
90
|
-
|
91
|
+
Note
|
92
|
+
----
|
91
93
|
There are 3 different statistical methods:
|
92
94
|
|
93
95
|
- zscore
|
@@ -259,11 +261,6 @@ class Outliers:
|
|
259
261
|
>>> results.issues[10]
|
260
262
|
{'skew': -3.906, 'kurtosis': 13.266, 'entropy': 0.2128, 'contrast': 1.25, 'zeros': 0.05493}
|
261
263
|
"""
|
262
|
-
self.stats = datasetstats(
|
263
|
-
|
264
|
-
use_dimension=self.use_dimension,
|
265
|
-
use_pixel=self.use_pixel,
|
266
|
-
use_visual=self.use_visual,
|
267
|
-
)
|
268
|
-
outliers = self._get_outliers({k: v for o in self.stats.outputs() for k, v in o.dict().items()})
|
264
|
+
self.stats = datasetstats(images=data)
|
265
|
+
outliers = self._get_outliers(self.stats.dict())
|
269
266
|
return OutliersOutput(outliers)
|
@@ -15,6 +15,8 @@ from dataeval._internal.output import OutputMetadata, set_metadata
|
|
15
15
|
@dataclass(frozen=True)
|
16
16
|
class BalanceOutput(OutputMetadata):
|
17
17
|
"""
|
18
|
+
Output class for :func:`balance` bias metric
|
19
|
+
|
18
20
|
Attributes
|
19
21
|
----------
|
20
22
|
balance : NDArray[np.float64]
|
@@ -71,8 +73,8 @@ def balance(class_labels: ArrayLike, metadata: Mapping[str, ArrayLike], num_neig
|
|
71
73
|
(num_factors+1) x (num_factors+1) estimate of mutual information
|
72
74
|
between num_factors metadata factors and class label. Symmetry is enforced.
|
73
75
|
|
74
|
-
|
75
|
-
|
76
|
+
Note
|
77
|
+
----
|
76
78
|
We use `mutual_info_classif` from sklearn since class label is categorical.
|
77
79
|
`mutual_info_classif` outputs are consistent up to O(1e-4) and depend on a random
|
78
80
|
seed. MI is computed differently for categorical and continuous variables, and
|
@@ -1,3 +1,5 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
import math
|
2
4
|
from dataclasses import dataclass
|
3
5
|
from typing import Literal
|
@@ -14,6 +16,8 @@ from dataeval._internal.output import OutputMetadata, set_metadata
|
|
14
16
|
@dataclass(frozen=True)
|
15
17
|
class CoverageOutput(OutputMetadata):
|
16
18
|
"""
|
19
|
+
Output class for :func:`coverage` bias metric
|
20
|
+
|
17
21
|
Attributes
|
18
22
|
----------
|
19
23
|
indices : NDArray
|
@@ -3,6 +3,8 @@ This module contains the implementation of HP Divergence
|
|
3
3
|
using the Fast Nearest Neighbor and Minimum Spanning Tree algorithms
|
4
4
|
"""
|
5
5
|
|
6
|
+
from __future__ import annotations
|
7
|
+
|
6
8
|
from dataclasses import dataclass
|
7
9
|
from typing import Literal
|
8
10
|
|
@@ -17,6 +19,8 @@ from dataeval._internal.output import OutputMetadata, set_metadata
|
|
17
19
|
@dataclass(frozen=True)
|
18
20
|
class DivergenceOutput(OutputMetadata):
|
19
21
|
"""
|
22
|
+
Output class for :func:`divergence` estimator metric
|
23
|
+
|
20
24
|
Attributes
|
21
25
|
----------
|
22
26
|
divergence : float
|
@@ -96,8 +100,8 @@ def divergence(data_a: ArrayLike, data_b: ArrayLike, method: Literal["FNN", "MST
|
|
96
100
|
DivergenceOutput
|
97
101
|
The divergence value (0.0..1.0) and the number of differing edges between the datasets
|
98
102
|
|
99
|
-
|
100
|
-
|
103
|
+
Note
|
104
|
+
----
|
101
105
|
The divergence value indicates how similar the 2 datasets are
|
102
106
|
with 0 indicating approximately identical data distributions.
|
103
107
|
|
@@ -13,6 +13,8 @@ from dataeval._internal.output import OutputMetadata, set_metadata
|
|
13
13
|
@dataclass(frozen=True)
|
14
14
|
class DiversityOutput(OutputMetadata):
|
15
15
|
"""
|
16
|
+
Output class for :func:`diversity` bias metric
|
17
|
+
|
16
18
|
Attributes
|
17
19
|
----------
|
18
20
|
diversity_index : NDArray[np.float64]
|
@@ -52,8 +54,8 @@ def diversity_shannon(
|
|
52
54
|
subset_mask: NDArray[np.bool_] | None
|
53
55
|
Boolean mask of samples to bin (e.g. when computing per class). True -> include in histogram counts
|
54
56
|
|
55
|
-
|
56
|
-
|
57
|
+
Note
|
58
|
+
----
|
57
59
|
For continuous variables, histogram bins are chosen automatically. See `numpy.histogram` for details.
|
58
60
|
|
59
61
|
Returns
|
@@ -103,8 +105,8 @@ def diversity_simpson(
|
|
103
105
|
subset_mask: NDArray[np.bool_] | None
|
104
106
|
Boolean mask of samples to bin (e.g. when computing per class). True -> include in histogram counts
|
105
107
|
|
106
|
-
|
107
|
-
|
108
|
+
Note
|
109
|
+
----
|
108
110
|
For continuous variables, histogram bins are chosen automatically. See
|
109
111
|
numpy.histogram for details.
|
110
112
|
If there is only one category, the diversity index takes a value of 0.
|
@@ -162,8 +164,8 @@ def diversity(
|
|
162
164
|
method: Literal["shannon", "simpson"], default "simpson"
|
163
165
|
Indicates which diversity index should be computed
|
164
166
|
|
165
|
-
|
166
|
-
|
167
|
+
Note
|
168
|
+
----
|
167
169
|
- For continuous variables, histogram bins are chosen automatically. See numpy.histogram for details.
|
168
170
|
- The expression is undefined for q=1, but it approaches the Shannon entropy in the limit.
|
169
171
|
- If there is only one category, the diversity index takes a value of 1 = 1/N = 1/1. Entropy will take a value of 0.
|
@@ -17,6 +17,8 @@ TData = TypeVar("TData", np.float64, NDArray[np.float64])
|
|
17
17
|
@dataclass(frozen=True)
|
18
18
|
class ParityOutput(Generic[TData], OutputMetadata):
|
19
19
|
"""
|
20
|
+
Output class for :func:`parity` and :func:`label_parity` bias metrics
|
21
|
+
|
20
22
|
Attributes
|
21
23
|
----------
|
22
24
|
score : np.float64 | NDArray[np.float64]
|
@@ -137,8 +139,8 @@ def normalize_expected_dist(expected_dist: NDArray, observed_dist: NDArray) -> N
|
|
137
139
|
ValueError
|
138
140
|
If the expected distribution is all zeros.
|
139
141
|
|
140
|
-
|
141
|
-
|
142
|
+
Note
|
143
|
+
----
|
142
144
|
The function ensures that the total number of labels in the expected distribution matches the total
|
143
145
|
number of labels in the observed distribution by scaling the expected distribution.
|
144
146
|
"""
|
@@ -224,8 +226,8 @@ def label_parity(
|
|
224
226
|
of unique classes between the observed and expected distributions.
|
225
227
|
|
226
228
|
|
227
|
-
|
228
|
-
|
229
|
+
Note
|
230
|
+
----
|
229
231
|
- Providing ``num_classes`` can be helpful if there are classes with zero instances in one of the distributions.
|
230
232
|
- The function first validates the observed distribution and normalizes the expected distribution so that it
|
231
233
|
has the same total number of labels as the observed distribution.
|
@@ -317,8 +319,8 @@ def parity(
|
|
317
319
|
factor values either 0 times or at least 5 times. Alternatively, continuous-valued factors can be digitized
|
318
320
|
into fewer bins.
|
319
321
|
|
320
|
-
|
321
|
-
|
322
|
+
Note
|
323
|
+
----
|
322
324
|
- Each key of the ``continuous_factor_bincounts`` dictionary must occur as a key in data_factors.
|
323
325
|
- A high score with a low p-value suggests that a metadata factor is strongly correlated with a class label.
|
324
326
|
- The function creates a contingency matrix for each factor, where each entry represents the frequency of a
|
@@ -3,9 +3,13 @@ from __future__ import annotations
|
|
3
3
|
import re
|
4
4
|
import warnings
|
5
5
|
from dataclasses import dataclass
|
6
|
-
from
|
6
|
+
from functools import partial
|
7
|
+
from itertools import repeat
|
8
|
+
from multiprocessing import Pool
|
9
|
+
from typing import Any, Callable, Generic, Iterable, NamedTuple, Optional, TypeVar, Union
|
7
10
|
|
8
11
|
import numpy as np
|
12
|
+
import tqdm
|
9
13
|
from numpy.typing import ArrayLike, NDArray
|
10
14
|
|
11
15
|
from dataeval._internal.interop import to_numpy_iter
|
@@ -91,7 +95,11 @@ class BaseStatsOutput(OutputMetadata):
|
|
91
95
|
return len(self.source_index)
|
92
96
|
|
93
97
|
|
94
|
-
|
98
|
+
TStatsOutput = TypeVar("TStatsOutput", bound=BaseStatsOutput, covariant=True)
|
99
|
+
|
100
|
+
|
101
|
+
class StatsProcessor(Generic[TStatsOutput]):
|
102
|
+
output_class: type[TStatsOutput]
|
95
103
|
cache_keys: list[str] = []
|
96
104
|
image_function_map: dict[str, Callable[[StatsProcessor], Any]] = {}
|
97
105
|
channel_function_map: dict[str, Callable[[StatsProcessor], Any]] = {}
|
@@ -119,6 +127,9 @@ class StatsProcessor:
|
|
119
127
|
else:
|
120
128
|
return self.fn_map[fn_key](self)
|
121
129
|
|
130
|
+
def process(self) -> dict:
|
131
|
+
return {k: self.fn_map[k](self) for k in self.fn_map}
|
132
|
+
|
122
133
|
@property
|
123
134
|
def image(self) -> NDArray:
|
124
135
|
if self._image is None:
|
@@ -143,14 +154,66 @@ class StatsProcessor:
|
|
143
154
|
self._scaled = self._scaled.reshape(self.image.shape[0], -1)
|
144
155
|
return self._scaled
|
145
156
|
|
157
|
+
@classmethod
|
158
|
+
def convert_output(
|
159
|
+
cls, source: dict[str, Any], source_index: list[SourceIndex], box_count: list[int]
|
160
|
+
) -> TStatsOutput:
|
161
|
+
output = {}
|
162
|
+
for key in source:
|
163
|
+
if key not in cls.output_class.__annotations__:
|
164
|
+
continue
|
165
|
+
stat_type: str = cls.output_class.__annotations__[key]
|
166
|
+
dtype_match = re.match(DTYPE_REGEX, stat_type)
|
167
|
+
if dtype_match is not None:
|
168
|
+
output[key] = np.asarray(source[key], dtype=np.dtype(dtype_match.group(1)))
|
169
|
+
else:
|
170
|
+
output[key] = source[key]
|
171
|
+
return cls.output_class(**output, source_index=source_index, box_count=np.asarray(box_count, dtype=np.uint16))
|
172
|
+
|
173
|
+
|
174
|
+
class StatsProcessorOutput(NamedTuple):
|
175
|
+
results: list[dict[str, Any]]
|
176
|
+
source_indices: list[SourceIndex]
|
177
|
+
box_counts: list[int]
|
178
|
+
warnings_list: list[tuple[int, int, NDArray, tuple[int, ...]]]
|
179
|
+
|
180
|
+
|
181
|
+
def process_stats(
|
182
|
+
i: int,
|
183
|
+
image_boxes: tuple[NDArray, NDArray | None],
|
184
|
+
per_channel: bool,
|
185
|
+
stats_processor_cls: Iterable[type[StatsProcessor]],
|
186
|
+
) -> StatsProcessorOutput:
|
187
|
+
image, boxes = image_boxes
|
188
|
+
results_list: list[dict[str, Any]] = []
|
189
|
+
source_indices: list[SourceIndex] = []
|
190
|
+
box_counts: list[int] = []
|
191
|
+
warnings_list: list[tuple[int, int, NDArray, tuple[int, ...]]] = []
|
192
|
+
nboxes = [None] if boxes is None else normalize_box_shape(boxes)
|
193
|
+
for i_b, box in enumerate(nboxes):
|
194
|
+
i_b = None if box is None else i_b
|
195
|
+
processor_list = [p(image, box, per_channel) for p in stats_processor_cls]
|
196
|
+
if any(not p.is_valid_slice for p in processor_list) and i_b is not None and box is not None:
|
197
|
+
warnings_list.append((i, i_b, box, image.shape))
|
198
|
+
results_list.append({k: v for p in processor_list for k, v in p.process().items()})
|
199
|
+
if per_channel:
|
200
|
+
source_indices.extend([SourceIndex(i, i_b, c) for c in range(image_boxes[0].shape[-3])])
|
201
|
+
else:
|
202
|
+
source_indices.append(SourceIndex(i, i_b, None))
|
203
|
+
box_counts.append(0 if boxes is None else len(boxes))
|
204
|
+
return StatsProcessorOutput(results_list, source_indices, box_counts, warnings_list)
|
205
|
+
|
206
|
+
|
207
|
+
def process_stats_unpack(args, per_channel: bool, stats_processor_cls: Iterable[type[StatsProcessor]]):
|
208
|
+
return process_stats(*args, per_channel=per_channel, stats_processor_cls=stats_processor_cls)
|
209
|
+
|
146
210
|
|
147
211
|
def run_stats(
|
148
212
|
images: Iterable[ArrayLike],
|
149
213
|
bboxes: Iterable[ArrayLike] | None,
|
150
214
|
per_channel: bool,
|
151
|
-
stats_processor_cls: type,
|
152
|
-
|
153
|
-
) -> dict:
|
215
|
+
stats_processor_cls: Iterable[type[StatsProcessor[TStatsOutput]]],
|
216
|
+
) -> list[TStatsOutput]:
|
154
217
|
"""
|
155
218
|
Compute specified statistics on a set of images.
|
156
219
|
|
@@ -169,18 +232,16 @@ def run_stats(
|
|
169
232
|
iterable should match the length of the input images.
|
170
233
|
per_channel : bool
|
171
234
|
A flag which determines if the states should be evaluated on a per-channel basis or not.
|
172
|
-
|
173
|
-
|
235
|
+
stats_processor_cls : Iterable[type[StatsProcessor]]
|
236
|
+
An iterable of stats processor classes that calculate stats and return output classes.
|
174
237
|
|
175
238
|
Returns
|
176
239
|
-------
|
177
|
-
|
178
|
-
A
|
179
|
-
The dictionary keys correspond to the names of the statistics, and the values are NumPy arrays
|
180
|
-
with the results of the computations.
|
240
|
+
list[TStatsOutput]
|
241
|
+
A list of output classes corresponding to the input processor types.
|
181
242
|
|
182
|
-
|
183
|
-
|
243
|
+
Note
|
244
|
+
----
|
184
245
|
- The function performs image normalization (rescaling the image values)
|
185
246
|
before applying some of the statistics.
|
186
247
|
- Pixel-level statistics (e.g., brightness, entropy) are computed after
|
@@ -189,43 +250,41 @@ def run_stats(
|
|
189
250
|
be reused to avoid redundant computation.
|
190
251
|
"""
|
191
252
|
results_list: list[dict[str, NDArray]] = []
|
192
|
-
output_list = list(output_cls.__annotations__)
|
193
253
|
source_index = []
|
194
254
|
box_count = []
|
195
|
-
bbox_iter = (None
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
255
|
+
bbox_iter = repeat(None) if bboxes is None else to_numpy_iter(bboxes)
|
256
|
+
|
257
|
+
warning_list = []
|
258
|
+
total_for_status = getattr(images, "__len__")() if hasattr(images, "__len__") else None
|
259
|
+
stats_processor_cls = stats_processor_cls if isinstance(stats_processor_cls, Iterable) else [stats_processor_cls]
|
260
|
+
|
261
|
+
# TODO: Introduce global controls for CPU job parallelism and GPU configurations
|
262
|
+
with Pool(16) as p:
|
263
|
+
for r in tqdm.tqdm(
|
264
|
+
p.imap(
|
265
|
+
partial(process_stats_unpack, per_channel=per_channel, stats_processor_cls=stats_processor_cls),
|
266
|
+
enumerate(zip(to_numpy_iter(images), bbox_iter)),
|
267
|
+
),
|
268
|
+
total=total_for_status,
|
269
|
+
):
|
270
|
+
results_list.extend(r.results)
|
271
|
+
source_index.extend(r.source_indices)
|
272
|
+
box_count.extend(r.box_counts)
|
273
|
+
warning_list.extend(r.warnings_list)
|
274
|
+
p.close()
|
275
|
+
p.join()
|
276
|
+
|
277
|
+
# warnings are not emitted while in multiprocessing pools so we emit after gathering all warnings
|
278
|
+
for w in warning_list:
|
279
|
+
warnings.warn(f"Bounding box [{w[0]}][{w[1]}]: {w[2]} is out of bounds of {w[3]}.", UserWarning)
|
210
280
|
|
211
281
|
output = {}
|
212
|
-
|
213
|
-
for
|
214
|
-
|
282
|
+
for results in results_list:
|
283
|
+
for stat, result in results.items():
|
284
|
+
if per_channel:
|
215
285
|
output.setdefault(stat, []).extend(result.tolist())
|
216
|
-
|
217
|
-
for results in results_list:
|
218
|
-
for stat, result in results.items():
|
286
|
+
else:
|
219
287
|
output.setdefault(stat, []).append(result.tolist() if isinstance(result, np.ndarray) else result)
|
220
288
|
|
221
|
-
for
|
222
|
-
|
223
|
-
|
224
|
-
dtype_match = re.match(DTYPE_REGEX, stat_type)
|
225
|
-
if dtype_match is not None:
|
226
|
-
output[stat] = np.asarray(output[stat], dtype=np.dtype(dtype_match.group(1)))
|
227
|
-
|
228
|
-
output[SOURCE_INDEX] = source_index
|
229
|
-
output[BOX_COUNT] = np.asarray(box_count, dtype=np.uint16)
|
230
|
-
|
231
|
-
return output
|
289
|
+
outputs = [s.convert_output(output, source_index, box_count) for s in stats_processor_cls]
|
290
|
+
return outputs
|