dataeval 0.70.0__py3-none-any.whl → 0.71.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. dataeval/__init__.py +6 -6
  2. dataeval/_internal/datasets.py +235 -131
  3. dataeval/_internal/detectors/clusterer.py +2 -0
  4. dataeval/_internal/detectors/drift/base.py +2 -2
  5. dataeval/_internal/detectors/drift/mmd.py +1 -1
  6. dataeval/_internal/detectors/duplicates.py +2 -0
  7. dataeval/_internal/detectors/ood/ae.py +5 -3
  8. dataeval/_internal/detectors/ood/aegmm.py +6 -4
  9. dataeval/_internal/detectors/ood/base.py +12 -7
  10. dataeval/_internal/detectors/ood/llr.py +6 -4
  11. dataeval/_internal/detectors/ood/vae.py +5 -3
  12. dataeval/_internal/detectors/ood/vaegmm.py +6 -4
  13. dataeval/_internal/detectors/outliers.py +6 -9
  14. dataeval/_internal/metrics/balance.py +4 -2
  15. dataeval/_internal/metrics/ber.py +2 -0
  16. dataeval/_internal/metrics/coverage.py +4 -0
  17. dataeval/_internal/metrics/divergence.py +6 -2
  18. dataeval/_internal/metrics/diversity.py +8 -6
  19. dataeval/_internal/metrics/parity.py +8 -6
  20. dataeval/_internal/metrics/stats/base.py +105 -46
  21. dataeval/_internal/metrics/stats/datasetstats.py +96 -22
  22. dataeval/_internal/metrics/stats/dimensionstats.py +22 -20
  23. dataeval/_internal/metrics/stats/hashstats.py +11 -9
  24. dataeval/_internal/metrics/stats/labelstats.py +1 -1
  25. dataeval/_internal/metrics/stats/pixelstats.py +28 -26
  26. dataeval/_internal/metrics/stats/visualstats.py +37 -35
  27. dataeval/_internal/metrics/uap.py +6 -2
  28. dataeval/_internal/metrics/utils.py +2 -2
  29. dataeval/_internal/models/pytorch/autoencoder.py +5 -5
  30. dataeval/_internal/models/tensorflow/pixelcnn.py +1 -4
  31. dataeval/_internal/utils.py +11 -16
  32. dataeval/_internal/workflows/sufficiency.py +44 -33
  33. dataeval/detectors/__init__.py +4 -0
  34. dataeval/detectors/drift/__init__.py +8 -3
  35. dataeval/detectors/drift/kernels/__init__.py +4 -0
  36. dataeval/detectors/drift/updates/__init__.py +4 -0
  37. dataeval/detectors/linters/__init__.py +15 -4
  38. dataeval/detectors/ood/__init__.py +14 -2
  39. dataeval/metrics/__init__.py +5 -0
  40. dataeval/metrics/bias/__init__.py +13 -4
  41. dataeval/metrics/estimators/__init__.py +8 -8
  42. dataeval/metrics/stats/__init__.py +24 -6
  43. dataeval/utils/__init__.py +16 -3
  44. dataeval/utils/tensorflow/__init__.py +11 -0
  45. dataeval/utils/torch/__init__.py +12 -0
  46. dataeval/utils/torch/datasets/__init__.py +7 -0
  47. dataeval/workflows/__init__.py +4 -0
  48. {dataeval-0.70.0.dist-info → dataeval-0.71.0.dist-info}/METADATA +11 -2
  49. dataeval-0.71.0.dist-info/RECORD +80 -0
  50. dataeval/tensorflow/__init__.py +0 -3
  51. dataeval/torch/__init__.py +0 -3
  52. dataeval-0.70.0.dist-info/RECORD +0 -79
  53. /dataeval/{tensorflow → utils/tensorflow}/loss/__init__.py +0 -0
  54. /dataeval/{tensorflow → utils/tensorflow}/models/__init__.py +0 -0
  55. /dataeval/{tensorflow → utils/tensorflow}/recon/__init__.py +0 -0
  56. /dataeval/{torch → utils/torch}/models/__init__.py +0 -0
  57. /dataeval/{torch → utils/torch}/trainer/__init__.py +0 -0
  58. {dataeval-0.70.0.dist-info → dataeval-0.71.0.dist-info}/LICENSE.txt +0 -0
  59. {dataeval-0.70.0.dist-info → dataeval-0.71.0.dist-info}/WHEEL +0 -0
@@ -10,7 +10,7 @@ from __future__ import annotations
10
10
 
11
11
  from abc import ABC, abstractmethod
12
12
  from dataclasses import dataclass
13
- from typing import Callable, Literal, NamedTuple, cast
13
+ from typing import Callable, Literal, cast
14
14
 
15
15
  import keras
16
16
  import numpy as np
@@ -26,6 +26,9 @@ from dataeval._internal.output import OutputMetadata, set_metadata
26
26
  @dataclass(frozen=True)
27
27
  class OODOutput(OutputMetadata):
28
28
  """
29
+ Output class for predictions from :class:`OOD_AE`, :class:`OOD_AEGMM`, :class:`OOD_LLR`,
30
+ :class:`OOD_VAE`, and :class:`OOD_VAEGMM` out-of-distribution detectors
31
+
29
32
  Attributes
30
33
  ----------
31
34
  is_ood : NDArray
@@ -41,9 +44,11 @@ class OODOutput(OutputMetadata):
41
44
  feature_score: NDArray[np.float32] | None
42
45
 
43
46
 
44
- class OODScore(NamedTuple):
47
+ @dataclass(frozen=True)
48
+ class OODScoreOutput(OutputMetadata):
45
49
  """
46
- NamedTuple containing the instance and (optionally) feature score.
50
+ Output class for instance and feature scores from :class:`OOD_AE`, :class:`OOD_AEGMM`,
51
+ :class:`OOD_LLR`, :class:`OOD_VAE`, and :class:`OOD_VAEGMM` out-of-distribution detectors
47
52
 
48
53
  Parameters
49
54
  ----------
@@ -76,7 +81,7 @@ class OODBase(ABC):
76
81
  def __init__(self, model: keras.Model) -> None:
77
82
  self.model = model
78
83
 
79
- self._ref_score: OODScore
84
+ self._ref_score: OODScoreOutput
80
85
  self._threshold_perc: float
81
86
  self._data_info: tuple[tuple, type] | None = None
82
87
 
@@ -102,7 +107,7 @@ class OODBase(ABC):
102
107
  self._validate(X)
103
108
 
104
109
  @abstractmethod
105
- def score(self, X: ArrayLike, batch_size: int = int(1e10)) -> OODScore:
110
+ def score(self, X: ArrayLike, batch_size: int = int(1e10)) -> OODScoreOutput:
106
111
  """
107
112
  Compute the out-of-distribution (OOD) scores for a given dataset.
108
113
 
@@ -116,7 +121,7 @@ class OODBase(ABC):
116
121
 
117
122
  Returns
118
123
  -------
119
- OODScore
124
+ OODScoreOutput
120
125
  An object containing the instance-level and feature-level OOD scores.
121
126
  """
122
127
 
@@ -197,7 +202,7 @@ class OODBase(ABC):
197
202
  # compute outlier scores
198
203
  score = self.score(X, batch_size=batch_size)
199
204
  ood_pred = score.get(ood_type) > self._threshold_score(ood_type)
200
- return OODOutput(is_ood=ood_pred, **score._asdict())
205
+ return OODOutput(is_ood=ood_pred, **score.dict())
201
206
 
202
207
 
203
208
  class OODGMMBase(OODBase):
@@ -18,11 +18,12 @@ from keras.layers import Input
18
18
  from keras.models import Model
19
19
  from numpy.typing import ArrayLike, NDArray
20
20
 
21
- from dataeval._internal.detectors.ood.base import OODBase, OODScore
21
+ from dataeval._internal.detectors.ood.base import OODBase, OODScoreOutput
22
22
  from dataeval._internal.interop import to_numpy
23
23
  from dataeval._internal.models.tensorflow.pixelcnn import PixelCNN
24
24
  from dataeval._internal.models.tensorflow.trainer import trainer
25
25
  from dataeval._internal.models.tensorflow.utils import predict_batch
26
+ from dataeval._internal.output import set_metadata
26
27
 
27
28
 
28
29
  def build_model(
@@ -124,7 +125,7 @@ class OOD_LLR(OODBase):
124
125
  self.sequential = sequential
125
126
  self.log_prob = log_prob
126
127
 
127
- self._ref_score: OODScore
128
+ self._ref_score: OODScoreOutput
128
129
  self._threshold_perc: float
129
130
  self._data_info: tuple[tuple, type] | None = None
130
131
 
@@ -279,12 +280,13 @@ class OOD_LLR(OODBase):
279
280
  logp_b = logp_fn(self.dist_b, X, return_per_feature=return_per_feature, batch_size=batch_size)
280
281
  return logp_s - logp_b
281
282
 
283
+ @set_metadata("dataeval.detectors")
282
284
  def score(
283
285
  self,
284
286
  X: ArrayLike,
285
287
  batch_size: int = int(1e10),
286
- ) -> OODScore:
288
+ ) -> OODScoreOutput:
287
289
  self._validate(X := to_numpy(X))
288
290
  fscore = -self._llr(X, True, batch_size=batch_size)
289
291
  iscore = -self._llr(X, False, batch_size=batch_size)
290
- return OODScore(iscore, fscore)
292
+ return OODScoreOutput(iscore, fscore)
@@ -15,11 +15,12 @@ import numpy as np
15
15
  import tensorflow as tf
16
16
  from numpy.typing import ArrayLike
17
17
 
18
- from dataeval._internal.detectors.ood.base import OODBase, OODScore
18
+ from dataeval._internal.detectors.ood.base import OODBase, OODScoreOutput
19
19
  from dataeval._internal.interop import to_numpy
20
20
  from dataeval._internal.models.tensorflow.autoencoder import VAE
21
21
  from dataeval._internal.models.tensorflow.losses import Elbo
22
22
  from dataeval._internal.models.tensorflow.utils import predict_batch
23
+ from dataeval._internal.output import set_metadata
23
24
 
24
25
 
25
26
  class OOD_VAE(OODBase):
@@ -67,7 +68,8 @@ class OOD_VAE(OODBase):
67
68
  loss_fn = Elbo(0.05)
68
69
  super().fit(x_ref, threshold_perc, loss_fn, optimizer, epochs, batch_size, verbose)
69
70
 
70
- def score(self, X: ArrayLike, batch_size: int = int(1e10)) -> OODScore:
71
+ @set_metadata("dataeval.detectors")
72
+ def score(self, X: ArrayLike, batch_size: int = int(1e10)) -> OODScoreOutput:
71
73
  self._validate(X := to_numpy(X))
72
74
 
73
75
  # sample reconstructed instances
@@ -86,4 +88,4 @@ class OOD_VAE(OODBase):
86
88
  sorted_fscore_perc = sorted_fscore[:, -n_score_features:]
87
89
  iscore = np.mean(sorted_fscore_perc, axis=1)
88
90
 
89
- return OODScore(iscore, fscore)
91
+ return OODScoreOutput(iscore, fscore)
@@ -15,12 +15,13 @@ import numpy as np
15
15
  import tensorflow as tf
16
16
  from numpy.typing import ArrayLike
17
17
 
18
- from dataeval._internal.detectors.ood.base import OODGMMBase, OODScore
18
+ from dataeval._internal.detectors.ood.base import OODGMMBase, OODScoreOutput
19
19
  from dataeval._internal.interop import to_numpy
20
20
  from dataeval._internal.models.tensorflow.autoencoder import VAEGMM
21
21
  from dataeval._internal.models.tensorflow.gmm import gmm_energy
22
22
  from dataeval._internal.models.tensorflow.losses import Elbo, LossGMM
23
23
  from dataeval._internal.models.tensorflow.utils import predict_batch
24
+ from dataeval._internal.output import set_metadata
24
25
 
25
26
 
26
27
  class OOD_VAEGMM(OODGMMBase):
@@ -53,7 +54,8 @@ class OOD_VAEGMM(OODGMMBase):
53
54
  loss_fn = LossGMM(elbo=Elbo(0.05))
54
55
  super().fit(x_ref, threshold_perc, loss_fn, optimizer, epochs, batch_size, verbose)
55
56
 
56
- def score(self, X: ArrayLike, batch_size: int = int(1e10)) -> OODScore:
57
+ @set_metadata("dataeval.detectors")
58
+ def score(self, X: ArrayLike, batch_size: int = int(1e10)) -> OODScoreOutput:
57
59
  """
58
60
  Compute the out-of-distribution (OOD) score for a given dataset.
59
61
 
@@ -67,7 +69,7 @@ class OOD_VAEGMM(OODGMMBase):
67
69
 
68
70
  Returns
69
71
  -------
70
- OODScore
72
+ OODScoreOutput
71
73
  An object containing the instance-level OOD score.
72
74
 
73
75
  Note
@@ -84,4 +86,4 @@ class OOD_VAEGMM(OODGMMBase):
84
86
  energy, _ = gmm_energy(z, self.gmm_params, return_mean=False)
85
87
  energy_samples = energy.numpy().reshape((-1, self.samples)) # type: ignore
86
88
  iscore = np.mean(energy_samples, axis=-1)
87
- return OODScore(iscore)
89
+ return OODScoreOutput(iscore)
@@ -22,6 +22,8 @@ TIndexIssueMap = TypeVar("TIndexIssueMap", IndexIssueMap, list[IndexIssueMap])
22
22
  @dataclass(frozen=True)
23
23
  class OutliersOutput(Generic[TIndexIssueMap], OutputMetadata):
24
24
  """
25
+ Output class for :class:`Outliers` lint detector
26
+
25
27
  Attributes
26
28
  ----------
27
29
  issues : dict[int, dict[str, float]] | list[dict[int, dict[str, float]]]
@@ -86,8 +88,8 @@ class Outliers:
86
88
  --------
87
89
  Duplicates
88
90
 
89
- Notes
90
- ------
91
+ Note
92
+ ----
91
93
  There are 3 different statistical methods:
92
94
 
93
95
  - zscore
@@ -259,11 +261,6 @@ class Outliers:
259
261
  >>> results.issues[10]
260
262
  {'skew': -3.906, 'kurtosis': 13.266, 'entropy': 0.2128, 'contrast': 1.25, 'zeros': 0.05493}
261
263
  """
262
- self.stats = datasetstats(
263
- images=data,
264
- use_dimension=self.use_dimension,
265
- use_pixel=self.use_pixel,
266
- use_visual=self.use_visual,
267
- )
268
- outliers = self._get_outliers({k: v for o in self.stats.outputs() for k, v in o.dict().items()})
264
+ self.stats = datasetstats(images=data)
265
+ outliers = self._get_outliers(self.stats.dict())
269
266
  return OutliersOutput(outliers)
@@ -15,6 +15,8 @@ from dataeval._internal.output import OutputMetadata, set_metadata
15
15
  @dataclass(frozen=True)
16
16
  class BalanceOutput(OutputMetadata):
17
17
  """
18
+ Output class for :func:`balance` bias metric
19
+
18
20
  Attributes
19
21
  ----------
20
22
  balance : NDArray[np.float64]
@@ -71,8 +73,8 @@ def balance(class_labels: ArrayLike, metadata: Mapping[str, ArrayLike], num_neig
71
73
  (num_factors+1) x (num_factors+1) estimate of mutual information
72
74
  between num_factors metadata factors and class label. Symmetry is enforced.
73
75
 
74
- Notes
75
- -----
76
+ Note
77
+ ----
76
78
  We use `mutual_info_classif` from sklearn since class label is categorical.
77
79
  `mutual_info_classif` outputs are consistent up to O(1e-4) and depend on a random
78
80
  seed. MI is computed differently for categorical and continuous variables, and
@@ -25,6 +25,8 @@ from dataeval._internal.output import OutputMetadata, set_metadata
25
25
  @dataclass(frozen=True)
26
26
  class BEROutput(OutputMetadata):
27
27
  """
28
+ Output class for :func:`ber` estimator metric
29
+
28
30
  Attributes
29
31
  ----------
30
32
  ber : float
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import math
2
4
  from dataclasses import dataclass
3
5
  from typing import Literal
@@ -14,6 +16,8 @@ from dataeval._internal.output import OutputMetadata, set_metadata
14
16
  @dataclass(frozen=True)
15
17
  class CoverageOutput(OutputMetadata):
16
18
  """
19
+ Output class for :func:`coverage` bias metric
20
+
17
21
  Attributes
18
22
  ----------
19
23
  indices : NDArray
@@ -3,6 +3,8 @@ This module contains the implementation of HP Divergence
3
3
  using the Fast Nearest Neighbor and Minimum Spanning Tree algorithms
4
4
  """
5
5
 
6
+ from __future__ import annotations
7
+
6
8
  from dataclasses import dataclass
7
9
  from typing import Literal
8
10
 
@@ -17,6 +19,8 @@ from dataeval._internal.output import OutputMetadata, set_metadata
17
19
  @dataclass(frozen=True)
18
20
  class DivergenceOutput(OutputMetadata):
19
21
  """
22
+ Output class for :func:`divergence` estimator metric
23
+
20
24
  Attributes
21
25
  ----------
22
26
  divergence : float
@@ -96,8 +100,8 @@ def divergence(data_a: ArrayLike, data_b: ArrayLike, method: Literal["FNN", "MST
96
100
  DivergenceOutput
97
101
  The divergence value (0.0..1.0) and the number of differing edges between the datasets
98
102
 
99
- Notes
100
- -----
103
+ Note
104
+ ----
101
105
  The divergence value indicates how similar the 2 datasets are
102
106
  with 0 indicating approximately identical data distributions.
103
107
 
@@ -13,6 +13,8 @@ from dataeval._internal.output import OutputMetadata, set_metadata
13
13
  @dataclass(frozen=True)
14
14
  class DiversityOutput(OutputMetadata):
15
15
  """
16
+ Output class for :func:`diversity` bias metric
17
+
16
18
  Attributes
17
19
  ----------
18
20
  diversity_index : NDArray[np.float64]
@@ -52,8 +54,8 @@ def diversity_shannon(
52
54
  subset_mask: NDArray[np.bool_] | None
53
55
  Boolean mask of samples to bin (e.g. when computing per class). True -> include in histogram counts
54
56
 
55
- Notes
56
- -----
57
+ Note
58
+ ----
57
59
  For continuous variables, histogram bins are chosen automatically. See `numpy.histogram` for details.
58
60
 
59
61
  Returns
@@ -103,8 +105,8 @@ def diversity_simpson(
103
105
  subset_mask: NDArray[np.bool_] | None
104
106
  Boolean mask of samples to bin (e.g. when computing per class). True -> include in histogram counts
105
107
 
106
- Notes
107
- -----
108
+ Note
109
+ ----
108
110
  For continuous variables, histogram bins are chosen automatically. See
109
111
  numpy.histogram for details.
110
112
  If there is only one category, the diversity index takes a value of 0.
@@ -162,8 +164,8 @@ def diversity(
162
164
  method: Literal["shannon", "simpson"], default "simpson"
163
165
  Indicates which diversity index should be computed
164
166
 
165
- Notes
166
- -----
167
+ Note
168
+ ----
167
169
  - For continuous variables, histogram bins are chosen automatically. See numpy.histogram for details.
168
170
  - The expression is undefined for q=1, but it approaches the Shannon entropy in the limit.
169
171
  - If there is only one category, the diversity index takes a value of 1 = 1/N = 1/1. Entropy will take a value of 0.
@@ -17,6 +17,8 @@ TData = TypeVar("TData", np.float64, NDArray[np.float64])
17
17
  @dataclass(frozen=True)
18
18
  class ParityOutput(Generic[TData], OutputMetadata):
19
19
  """
20
+ Output class for :func:`parity` and :func:`label_parity` bias metrics
21
+
20
22
  Attributes
21
23
  ----------
22
24
  score : np.float64 | NDArray[np.float64]
@@ -137,8 +139,8 @@ def normalize_expected_dist(expected_dist: NDArray, observed_dist: NDArray) -> N
137
139
  ValueError
138
140
  If the expected distribution is all zeros.
139
141
 
140
- Notes
141
- -----
142
+ Note
143
+ ----
142
144
  The function ensures that the total number of labels in the expected distribution matches the total
143
145
  number of labels in the observed distribution by scaling the expected distribution.
144
146
  """
@@ -224,8 +226,8 @@ def label_parity(
224
226
  of unique classes between the observed and expected distributions.
225
227
 
226
228
 
227
- Notes
228
- -----
229
+ Note
230
+ ----
229
231
  - Providing ``num_classes`` can be helpful if there are classes with zero instances in one of the distributions.
230
232
  - The function first validates the observed distribution and normalizes the expected distribution so that it
231
233
  has the same total number of labels as the observed distribution.
@@ -317,8 +319,8 @@ def parity(
317
319
  factor values either 0 times or at least 5 times. Alternatively, continuous-valued factors can be digitized
318
320
  into fewer bins.
319
321
 
320
- Notes
321
- -----
322
+ Note
323
+ ----
322
324
  - Each key of the ``continuous_factor_bincounts`` dictionary must occur as a key in data_factors.
323
325
  - A high score with a low p-value suggests that a metadata factor is strongly correlated with a class label.
324
326
  - The function creates a contingency matrix for each factor, where each entry represents the frequency of a
@@ -3,9 +3,13 @@ from __future__ import annotations
3
3
  import re
4
4
  import warnings
5
5
  from dataclasses import dataclass
6
- from typing import Any, Callable, Iterable, NamedTuple, Optional, Union
6
+ from functools import partial
7
+ from itertools import repeat
8
+ from multiprocessing import Pool
9
+ from typing import Any, Callable, Generic, Iterable, NamedTuple, Optional, TypeVar, Union
7
10
 
8
11
  import numpy as np
12
+ import tqdm
9
13
  from numpy.typing import ArrayLike, NDArray
10
14
 
11
15
  from dataeval._internal.interop import to_numpy_iter
@@ -91,7 +95,11 @@ class BaseStatsOutput(OutputMetadata):
91
95
  return len(self.source_index)
92
96
 
93
97
 
94
- class StatsProcessor:
98
+ TStatsOutput = TypeVar("TStatsOutput", bound=BaseStatsOutput, covariant=True)
99
+
100
+
101
+ class StatsProcessor(Generic[TStatsOutput]):
102
+ output_class: type[TStatsOutput]
95
103
  cache_keys: list[str] = []
96
104
  image_function_map: dict[str, Callable[[StatsProcessor], Any]] = {}
97
105
  channel_function_map: dict[str, Callable[[StatsProcessor], Any]] = {}
@@ -119,6 +127,9 @@ class StatsProcessor:
119
127
  else:
120
128
  return self.fn_map[fn_key](self)
121
129
 
130
+ def process(self) -> dict:
131
+ return {k: self.fn_map[k](self) for k in self.fn_map}
132
+
122
133
  @property
123
134
  def image(self) -> NDArray:
124
135
  if self._image is None:
@@ -143,14 +154,66 @@ class StatsProcessor:
143
154
  self._scaled = self._scaled.reshape(self.image.shape[0], -1)
144
155
  return self._scaled
145
156
 
157
+ @classmethod
158
+ def convert_output(
159
+ cls, source: dict[str, Any], source_index: list[SourceIndex], box_count: list[int]
160
+ ) -> TStatsOutput:
161
+ output = {}
162
+ for key in source:
163
+ if key not in cls.output_class.__annotations__:
164
+ continue
165
+ stat_type: str = cls.output_class.__annotations__[key]
166
+ dtype_match = re.match(DTYPE_REGEX, stat_type)
167
+ if dtype_match is not None:
168
+ output[key] = np.asarray(source[key], dtype=np.dtype(dtype_match.group(1)))
169
+ else:
170
+ output[key] = source[key]
171
+ return cls.output_class(**output, source_index=source_index, box_count=np.asarray(box_count, dtype=np.uint16))
172
+
173
+
174
+ class StatsProcessorOutput(NamedTuple):
175
+ results: list[dict[str, Any]]
176
+ source_indices: list[SourceIndex]
177
+ box_counts: list[int]
178
+ warnings_list: list[tuple[int, int, NDArray, tuple[int, ...]]]
179
+
180
+
181
+ def process_stats(
182
+ i: int,
183
+ image_boxes: tuple[NDArray, NDArray | None],
184
+ per_channel: bool,
185
+ stats_processor_cls: Iterable[type[StatsProcessor]],
186
+ ) -> StatsProcessorOutput:
187
+ image, boxes = image_boxes
188
+ results_list: list[dict[str, Any]] = []
189
+ source_indices: list[SourceIndex] = []
190
+ box_counts: list[int] = []
191
+ warnings_list: list[tuple[int, int, NDArray, tuple[int, ...]]] = []
192
+ nboxes = [None] if boxes is None else normalize_box_shape(boxes)
193
+ for i_b, box in enumerate(nboxes):
194
+ i_b = None if box is None else i_b
195
+ processor_list = [p(image, box, per_channel) for p in stats_processor_cls]
196
+ if any(not p.is_valid_slice for p in processor_list) and i_b is not None and box is not None:
197
+ warnings_list.append((i, i_b, box, image.shape))
198
+ results_list.append({k: v for p in processor_list for k, v in p.process().items()})
199
+ if per_channel:
200
+ source_indices.extend([SourceIndex(i, i_b, c) for c in range(image_boxes[0].shape[-3])])
201
+ else:
202
+ source_indices.append(SourceIndex(i, i_b, None))
203
+ box_counts.append(0 if boxes is None else len(boxes))
204
+ return StatsProcessorOutput(results_list, source_indices, box_counts, warnings_list)
205
+
206
+
207
+ def process_stats_unpack(args, per_channel: bool, stats_processor_cls: Iterable[type[StatsProcessor]]):
208
+ return process_stats(*args, per_channel=per_channel, stats_processor_cls=stats_processor_cls)
209
+
146
210
 
147
211
  def run_stats(
148
212
  images: Iterable[ArrayLike],
149
213
  bboxes: Iterable[ArrayLike] | None,
150
214
  per_channel: bool,
151
- stats_processor_cls: type,
152
- output_cls: type,
153
- ) -> dict:
215
+ stats_processor_cls: Iterable[type[StatsProcessor[TStatsOutput]]],
216
+ ) -> list[TStatsOutput]:
154
217
  """
155
218
  Compute specified statistics on a set of images.
156
219
 
@@ -169,18 +232,16 @@ def run_stats(
169
232
  iterable should match the length of the input images.
170
233
  per_channel : bool
171
234
  A flag which determines if the states should be evaluated on a per-channel basis or not.
172
- output_cls : type
173
- The output class for which stats values will be calculated.
235
+ stats_processor_cls : Iterable[type[StatsProcessor]]
236
+ An iterable of stats processor classes that calculate stats and return output classes.
174
237
 
175
238
  Returns
176
239
  -------
177
- dict[str, NDArray]]
178
- A dictionary containing the computed statistics for each image.
179
- The dictionary keys correspond to the names of the statistics, and the values are NumPy arrays
180
- with the results of the computations.
240
+ list[TStatsOutput]
241
+ A list of output classes corresponding to the input processor types.
181
242
 
182
- Notes
183
- -----
243
+ Note
244
+ ----
184
245
  - The function performs image normalization (rescaling the image values)
185
246
  before applying some of the statistics.
186
247
  - Pixel-level statistics (e.g., brightness, entropy) are computed after
@@ -189,43 +250,41 @@ def run_stats(
189
250
  be reused to avoid redundant computation.
190
251
  """
191
252
  results_list: list[dict[str, NDArray]] = []
192
- output_list = list(output_cls.__annotations__)
193
253
  source_index = []
194
254
  box_count = []
195
- bbox_iter = (None for _ in images) if bboxes is None else to_numpy_iter(bboxes)
196
-
197
- for i, (boxes, image) in enumerate(zip(bbox_iter, to_numpy_iter(images))):
198
- nboxes = [None] if boxes is None else normalize_box_shape(boxes)
199
- for i_b, box in enumerate(nboxes):
200
- i_b = None if box is None else i_b
201
- processor: StatsProcessor = stats_processor_cls(image, box, per_channel)
202
- if not processor.is_valid_slice:
203
- warnings.warn(f"Bounding box {i_b}: {box} is out of bounds of image {i}: {image.shape}.")
204
- results_list.append({stat: processor.get(stat) for stat in output_list})
205
- if per_channel:
206
- source_index.extend([SourceIndex(i, i_b, c) for c in range(image.shape[-3])])
207
- else:
208
- source_index.append(SourceIndex(i, i_b, None))
209
- box_count.append(0 if boxes is None else len(boxes))
255
+ bbox_iter = repeat(None) if bboxes is None else to_numpy_iter(bboxes)
256
+
257
+ warning_list = []
258
+ total_for_status = getattr(images, "__len__")() if hasattr(images, "__len__") else None
259
+ stats_processor_cls = stats_processor_cls if isinstance(stats_processor_cls, Iterable) else [stats_processor_cls]
260
+
261
+ # TODO: Introduce global controls for CPU job parallelism and GPU configurations
262
+ with Pool(16) as p:
263
+ for r in tqdm.tqdm(
264
+ p.imap(
265
+ partial(process_stats_unpack, per_channel=per_channel, stats_processor_cls=stats_processor_cls),
266
+ enumerate(zip(to_numpy_iter(images), bbox_iter)),
267
+ ),
268
+ total=total_for_status,
269
+ ):
270
+ results_list.extend(r.results)
271
+ source_index.extend(r.source_indices)
272
+ box_count.extend(r.box_counts)
273
+ warning_list.extend(r.warnings_list)
274
+ p.close()
275
+ p.join()
276
+
277
+ # warnings are not emitted while in multiprocessing pools so we emit after gathering all warnings
278
+ for w in warning_list:
279
+ warnings.warn(f"Bounding box [{w[0]}][{w[1]}]: {w[2]} is out of bounds of {w[3]}.", UserWarning)
210
280
 
211
281
  output = {}
212
- if per_channel:
213
- for i, results in enumerate(results_list):
214
- for stat, result in results.items():
282
+ for results in results_list:
283
+ for stat, result in results.items():
284
+ if per_channel:
215
285
  output.setdefault(stat, []).extend(result.tolist())
216
- else:
217
- for results in results_list:
218
- for stat, result in results.items():
286
+ else:
219
287
  output.setdefault(stat, []).append(result.tolist() if isinstance(result, np.ndarray) else result)
220
288
 
221
- for stat in output:
222
- stat_type: str = output_cls.__annotations__[stat]
223
-
224
- dtype_match = re.match(DTYPE_REGEX, stat_type)
225
- if dtype_match is not None:
226
- output[stat] = np.asarray(output[stat], dtype=np.dtype(dtype_match.group(1)))
227
-
228
- output[SOURCE_INDEX] = source_index
229
- output[BOX_COUNT] = np.asarray(box_count, dtype=np.uint16)
230
-
231
- return output
289
+ outputs = [s.convert_output(output, source_index, box_count) for s in stats_processor_cls]
290
+ return outputs