dataeval 0.86.0__py3-none-any.whl → 0.86.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. dataeval/__init__.py +1 -1
  2. dataeval/_log.py +1 -1
  3. dataeval/config.py +21 -4
  4. dataeval/data/_embeddings.py +2 -2
  5. dataeval/data/_images.py +2 -3
  6. dataeval/data/_metadata.py +48 -37
  7. dataeval/data/_selection.py +1 -2
  8. dataeval/data/_split.py +2 -3
  9. dataeval/data/_targets.py +17 -13
  10. dataeval/data/selections/_classfilter.py +2 -5
  11. dataeval/data/selections/_prioritize.py +6 -9
  12. dataeval/data/selections/_shuffle.py +3 -1
  13. dataeval/detectors/drift/_base.py +4 -5
  14. dataeval/detectors/drift/_mmd.py +3 -6
  15. dataeval/detectors/drift/_nml/_base.py +4 -2
  16. dataeval/detectors/drift/_nml/_chunk.py +11 -19
  17. dataeval/detectors/drift/_nml/_domainclassifier.py +8 -19
  18. dataeval/detectors/drift/_nml/_result.py +8 -9
  19. dataeval/detectors/drift/_nml/_thresholds.py +66 -77
  20. dataeval/detectors/linters/outliers.py +7 -7
  21. dataeval/metrics/bias/_parity.py +10 -13
  22. dataeval/metrics/estimators/_divergence.py +2 -4
  23. dataeval/metrics/stats/_base.py +103 -42
  24. dataeval/metrics/stats/_boxratiostats.py +21 -19
  25. dataeval/metrics/stats/_dimensionstats.py +14 -10
  26. dataeval/metrics/stats/_hashstats.py +1 -1
  27. dataeval/metrics/stats/_pixelstats.py +6 -6
  28. dataeval/metrics/stats/_visualstats.py +3 -3
  29. dataeval/outputs/_base.py +22 -7
  30. dataeval/outputs/_bias.py +26 -28
  31. dataeval/outputs/_drift.py +1 -9
  32. dataeval/outputs/_linters.py +11 -11
  33. dataeval/outputs/_stats.py +82 -23
  34. dataeval/outputs/_workflows.py +2 -2
  35. dataeval/utils/_array.py +6 -9
  36. dataeval/utils/_bin.py +1 -2
  37. dataeval/utils/_clusterer.py +7 -4
  38. dataeval/utils/_fast_mst.py +27 -13
  39. dataeval/utils/_image.py +65 -11
  40. dataeval/utils/_mst.py +1 -3
  41. dataeval/utils/_plot.py +15 -10
  42. dataeval/utils/data/_dataset.py +32 -20
  43. dataeval/utils/data/metadata.py +104 -82
  44. dataeval/utils/datasets/__init__.py +2 -0
  45. dataeval/utils/datasets/_antiuav.py +189 -0
  46. dataeval/utils/datasets/_base.py +11 -8
  47. dataeval/utils/datasets/_cifar10.py +104 -45
  48. dataeval/utils/datasets/_fileio.py +21 -47
  49. dataeval/utils/datasets/_milco.py +19 -11
  50. dataeval/utils/datasets/_mixin.py +2 -4
  51. dataeval/utils/datasets/_mnist.py +3 -4
  52. dataeval/utils/datasets/_ships.py +14 -7
  53. dataeval/utils/datasets/_voc.py +229 -42
  54. dataeval/utils/torch/models.py +5 -10
  55. dataeval/utils/torch/trainer.py +3 -3
  56. dataeval/workflows/sufficiency.py +2 -2
  57. {dataeval-0.86.0.dist-info → dataeval-0.86.1.dist-info}/METADATA +1 -1
  58. dataeval-0.86.1.dist-info/RECORD +114 -0
  59. dataeval/detectors/ood/vae.py +0 -74
  60. dataeval-0.86.0.dist-info/RECORD +0 -114
  61. {dataeval-0.86.0.dist-info → dataeval-0.86.1.dist-info}/LICENSE.txt +0 -0
  62. {dataeval-0.86.0.dist-info → dataeval-0.86.1.dist-info}/WHEEL +0 -0
@@ -42,14 +42,13 @@ class AbstractResult(GenericOutput[pd.DataFrame]):
42
42
  """Export results to pandas dataframe."""
43
43
  if multilevel:
44
44
  return self._data
45
- else:
46
- column_names = [
47
- "_".join(col).replace("chunk_chunk_chunk", "chunk").replace("chunk_chunk", "chunk")
48
- for col in self._data.columns.values
49
- ]
50
- single_level_data = self._data.copy(deep=True)
51
- single_level_data.columns = column_names
52
- return single_level_data
45
+ column_names = [
46
+ "_".join(col).replace("chunk_chunk_chunk", "chunk").replace("chunk_chunk", "chunk")
47
+ for col in self._data.columns.values
48
+ ]
49
+ single_level_data = self._data.copy(deep=True)
50
+ single_level_data.columns = column_names
51
+ return single_level_data
53
52
 
54
53
  def filter(self, period: str = "all", metrics: str | Sequence[str] | None = None) -> Self:
55
54
  """Returns filtered result metric data."""
@@ -67,7 +66,7 @@ class Abstract1DResult(AbstractResult, ABC):
67
66
  def __init__(self, results_data: pd.DataFrame) -> None:
68
67
  super().__init__(results_data)
69
68
 
70
- def _filter(self, period: str, metrics=None) -> Self:
69
+ def _filter(self, period: str, metrics: Sequence[str] | None = None) -> Self:
71
70
  data = self._data
72
71
  if period != "all":
73
72
  data = self._data.loc[self._data.loc[:, ("chunk", "period")] == period, :] # type: ignore | dataframe loc
@@ -29,10 +29,10 @@ class Threshold(ABC):
29
29
  """Class registry lookup to get threshold subclass from threshold_type string"""
30
30
 
31
31
  def __str__(self) -> str:
32
- return self.__str__()
32
+ return f"{self.__class__.__name__}({str(vars(self))})"
33
33
 
34
34
  def __repr__(self) -> str:
35
- return self.__class__.__name__ + str(vars(self))
35
+ return str(self)
36
36
 
37
37
  def __eq__(self, other: object) -> bool:
38
38
  return isinstance(other, self.__class__) and other.__dict__ == self.__dict__
@@ -41,7 +41,7 @@ class Threshold(ABC):
41
41
  Threshold._registry[threshold_type] = cls
42
42
 
43
43
  @abstractmethod
44
- def thresholds(self, data: np.ndarray) -> tuple[float | None, float | None]:
44
+ def _thresholds(self, data: np.ndarray) -> tuple[float | None, float | None]:
45
45
  """Returns lower and upper threshold values when given one or more np.ndarray instances.
46
46
 
47
47
  Parameters:
@@ -69,6 +69,61 @@ class Threshold(ABC):
69
69
 
70
70
  return threshold_cls(**obj)
71
71
 
72
+ def calculate(
73
+ self,
74
+ data: np.ndarray,
75
+ lower_limit: float | None = None,
76
+ upper_limit: float | None = None,
77
+ override_using_none: bool = False,
78
+ logger: logging.Logger | None = None,
79
+ ) -> tuple[float | None, float | None]:
80
+ """
81
+ Calculate lower and upper threshold values with respect to the provided Threshold and value limits.
82
+
83
+ Parameters
84
+ ----------
85
+ data : np.ndarray
86
+ The data used by the Threshold instance to calculate the lower and upper threshold values.
87
+ This will often be the values of a drift detection method or performance metric on chunks of reference
88
+ data.
89
+ lower_limit : float or None, default None
90
+ An optional value that serves as a limit for the lower threshold value. Any calculated lower threshold
91
+ values that end up below this limit will be replaced by this limit value.
92
+ The limit is often a theoretical constraint enforced by a specific drift detection method or performance
93
+ metric.
94
+ upper_threshold_value_limit : float or None, default None
95
+ An optional value that serves as a limit for the lower threshold value. Any calculated lower threshold
96
+ values that end up below this limit will be replaced by this limit value.
97
+ The limit is often a theoretical constraint enforced by a specific drift detection method or performance
98
+ metric.
99
+ override_using_none: bool, default False
100
+ When set to True use None to override threshold values that exceed value limits.
101
+ This will prevent them from being rendered on plots.
102
+ logger: Optional[logging.Logger], default=None
103
+ An optional Logger instance. When provided a warning will be logged when a calculated threshold value
104
+ gets overridden by a threshold value limit.
105
+ """
106
+
107
+ lower_value, upper_value = self._thresholds(data)
108
+
109
+ if lower_limit is not None and lower_value is not None and lower_value <= lower_limit:
110
+ override_value = None if override_using_none else lower_limit
111
+ if logger:
112
+ logger.warning(
113
+ f"lower threshold value {lower_value} overridden by lower threshold value limit {override_value}"
114
+ )
115
+ lower_value = override_value
116
+
117
+ if upper_limit is not None and upper_value is not None and upper_value >= upper_limit:
118
+ override_value = None if override_using_none else upper_limit
119
+ if logger:
120
+ logger.warning(
121
+ f"upper threshold value {upper_value} overridden by upper threshold value limit {override_value}"
122
+ )
123
+ upper_value = override_value
124
+
125
+ return lower_value, upper_value
126
+
72
127
 
73
128
  class ConstantThreshold(Threshold, threshold_type="constant"):
74
129
  """A `Thresholder` implementation that returns a constant lower and or upper threshold value.
@@ -91,7 +146,7 @@ class ConstantThreshold(Threshold, threshold_type="constant"):
91
146
  None 0.1
92
147
  """
93
148
 
94
- def __init__(self, lower: float | int | None = None, upper: float | int | None = None):
149
+ def __init__(self, lower: float | int | None = None, upper: float | int | None = None) -> None:
95
150
  """Creates a new ConstantThreshold instance.
96
151
 
97
152
  Args:
@@ -109,11 +164,11 @@ class ConstantThreshold(Threshold, threshold_type="constant"):
109
164
  self.lower = lower
110
165
  self.upper = upper
111
166
 
112
- def thresholds(self, data: np.ndarray) -> tuple[float | None, float | None]:
167
+ def _thresholds(self, data: np.ndarray) -> tuple[float | None, float | None]:
113
168
  return self.lower, self.upper
114
169
 
115
170
  @staticmethod
116
- def _validate_inputs(lower: float | int | None = None, upper: float | int | None = None):
171
+ def _validate_inputs(lower: float | int | None = None, upper: float | int | None = None) -> None:
117
172
  if lower is not None and not isinstance(lower, (float, int)) or isinstance(lower, bool):
118
173
  raise ValueError(f"expected type of 'lower' to be 'float', 'int' or None but got '{type(lower).__name__}'")
119
174
 
@@ -149,7 +204,7 @@ class StandardDeviationThreshold(Threshold, threshold_type="standard_deviation")
149
204
  std_lower_multiplier: float | int | None = 3,
150
205
  std_upper_multiplier: float | int | None = 3,
151
206
  offset_from: Callable[[np.ndarray], Any] = np.nanmean,
152
- ):
207
+ ) -> None:
153
208
  """Creates a new StandardDeviationThreshold instance.
154
209
 
155
210
  Args:
@@ -173,7 +228,7 @@ class StandardDeviationThreshold(Threshold, threshold_type="standard_deviation")
173
228
  self.std_upper_multiplier = std_upper_multiplier
174
229
  self.offset_from = offset_from
175
230
 
176
- def thresholds(self, data: np.ndarray) -> tuple[float | None, float | None]:
231
+ def _thresholds(self, data: np.ndarray) -> tuple[float | None, float | None]:
177
232
  aggregate = self.offset_from(data)
178
233
  std = np.nanstd(data)
179
234
 
@@ -184,7 +239,9 @@ class StandardDeviationThreshold(Threshold, threshold_type="standard_deviation")
184
239
  return lower_threshold, upper_threshold
185
240
 
186
241
  @staticmethod
187
- def _validate_inputs(std_lower_multiplier: float | int | None = 3, std_upper_multiplier: float | int | None = 3):
242
+ def _validate_inputs(
243
+ std_lower_multiplier: float | int | None = 3, std_upper_multiplier: float | int | None = 3
244
+ ) -> None:
188
245
  if (
189
246
  std_lower_multiplier is not None
190
247
  and not isinstance(std_lower_multiplier, (float, int))
@@ -210,71 +267,3 @@ class StandardDeviationThreshold(Threshold, threshold_type="standard_deviation")
210
267
 
211
268
  if std_upper_multiplier and std_upper_multiplier < 0:
212
269
  raise ValueError(f"'std_upper_multiplier' should be greater than 0 but got value {std_upper_multiplier}")
213
-
214
-
215
- def calculate_threshold_values(
216
- threshold: Threshold,
217
- data: np.ndarray,
218
- lower_threshold_value_limit: float | None = None,
219
- upper_threshold_value_limit: float | None = None,
220
- override_using_none: bool = False,
221
- logger: logging.Logger | None = None,
222
- metric_name: str | None = None,
223
- ) -> tuple[float | None, float | None]:
224
- """Calculate lower and upper threshold values with respect to the provided Threshold and value limits.
225
-
226
- Parameters:
227
- threshold: Threshold
228
- The Threshold instance that determines how the lower and upper threshold values will be calculated.
229
- data: np.ndarray
230
- The data used by the Threshold instance to calculate the lower and upper threshold values.
231
- This will often be the values of a drift detection method or performance metric on chunks of reference data.
232
- lower_threshold_value_limit: Optional[float], default=None
233
- An optional value that serves as a limit for the lower threshold value. Any calculated lower threshold
234
- values that end up below this limit will be replaced by this limit value.
235
- The limit is often a theoretical constraint enforced by a specific drift detection method or performance
236
- metric.
237
- upper_threshold_value_limit: Optional[float], default=None
238
- An optional value that serves as a limit for the lower threshold value. Any calculated lower threshold
239
- values that end up below this limit will be replaced by this limit value.
240
- The limit is often a theoretical constraint enforced by a specific drift detection method or performance
241
- metric.
242
- override_using_none: bool, default=False
243
- When set to True use None to override threshold values that exceed value limits.
244
- This will prevent them from being rendered on plots.
245
- logger: Optional[logging.Logger], default=None
246
- An optional Logger instance. When provided a warning will be logged when a calculated threshold value
247
- gets overridden by a threshold value limit.
248
- metric_name: Optional[str], default=None
249
- When provided the metric name will be included within any log messages for additional clarity.
250
- """
251
-
252
- lower_threshold_value, upper_threshold_value = threshold.thresholds(data)
253
-
254
- if (
255
- lower_threshold_value_limit is not None
256
- and lower_threshold_value is not None
257
- and lower_threshold_value <= lower_threshold_value_limit
258
- ):
259
- override_value = None if override_using_none else lower_threshold_value_limit
260
- if logger:
261
- logger.warning(
262
- f"{metric_name + ' ' if metric_name else ''}lower threshold value {lower_threshold_value} "
263
- f"overridden by lower threshold value limit {override_value}"
264
- )
265
- lower_threshold_value = override_value
266
-
267
- if (
268
- upper_threshold_value_limit is not None
269
- and upper_threshold_value is not None
270
- and upper_threshold_value >= upper_threshold_value_limit
271
- ):
272
- override_value = None if override_using_none else upper_threshold_value_limit
273
- if logger:
274
- logger.warning(
275
- f"{metric_name + ' ' if metric_name else ''}upper threshold value {upper_threshold_value} "
276
- f"overridden by upper threshold value limit {override_value}"
277
- )
278
- upper_threshold_value = override_value
279
-
280
- return lower_threshold_value, upper_threshold_value
@@ -13,31 +13,31 @@ from dataeval.metrics.stats._imagestats import imagestats
13
13
  from dataeval.outputs import DimensionStatsOutput, ImageStatsOutput, OutliersOutput, PixelStatsOutput, VisualStatsOutput
14
14
  from dataeval.outputs._base import set_metadata
15
15
  from dataeval.outputs._linters import IndexIssueMap, OutlierStatsOutput
16
- from dataeval.outputs._stats import BOX_COUNT, SOURCE_INDEX
16
+ from dataeval.outputs._stats import BASE_ATTRS
17
17
  from dataeval.typing import ArrayLike, Dataset
18
18
 
19
19
 
20
20
  def _get_outlier_mask(
21
21
  values: NDArray, method: Literal["zscore", "modzscore", "iqr"], threshold: float | None
22
22
  ) -> NDArray:
23
+ values = values.astype(np.float64)
23
24
  if method == "zscore":
24
25
  threshold = threshold if threshold else 3.0
25
26
  std = np.std(values)
26
27
  abs_diff = np.abs(values - np.mean(values))
27
28
  return std != 0 and (abs_diff / std) > threshold
28
- elif method == "modzscore":
29
+ if method == "modzscore":
29
30
  threshold = threshold if threshold else 3.5
30
31
  abs_diff = np.abs(values - np.median(values))
31
32
  med_abs_diff = np.median(abs_diff) if np.median(abs_diff) != 0 else np.mean(abs_diff)
32
33
  mod_z_score = 0.6745 * abs_diff / med_abs_diff
33
34
  return mod_z_score > threshold
34
- elif method == "iqr":
35
+ if method == "iqr":
35
36
  threshold = threshold if threshold else 1.5
36
37
  qrt = np.percentile(values, q=(25, 75), method="midpoint")
37
38
  iqr = (qrt[1] - qrt[0]) * threshold
38
39
  return (values < (qrt[0] - iqr)) | (values > (qrt[1] + iqr))
39
- else:
40
- raise ValueError("Outlier method must be 'zscore' 'modzscore' or 'iqr'.")
40
+ raise ValueError("Outlier method must be 'zscore' 'modzscore' or 'iqr'.")
41
41
 
42
42
 
43
43
  class Outliers:
@@ -103,7 +103,7 @@ class Outliers:
103
103
  use_visual: bool = True,
104
104
  outlier_method: Literal["zscore", "modzscore", "iqr"] = "modzscore",
105
105
  outlier_threshold: float | None = None,
106
- ):
106
+ ) -> None:
107
107
  self.stats: ImageStatsOutput
108
108
  self.use_dimension = use_dimension
109
109
  self.use_pixel = use_pixel
@@ -114,7 +114,7 @@ class Outliers:
114
114
  def _get_outliers(self, stats: dict) -> dict[int, dict[str, float]]:
115
115
  flagged_images: dict[int, dict[str, float]] = {}
116
116
  for stat, values in stats.items():
117
- if stat in (SOURCE_INDEX, BOX_COUNT):
117
+ if stat in BASE_ATTRS:
118
118
  continue
119
119
  if values.ndim == 1:
120
120
  mask = _get_outlier_mask(values.astype(np.float64), self.outlier_method, self.outlier_threshold)
@@ -3,6 +3,7 @@ from __future__ import annotations
3
3
  __all__ = []
4
4
 
5
5
  import warnings
6
+ from collections import defaultdict
6
7
  from typing import Any
7
8
 
8
9
  import numpy as np
@@ -246,7 +247,7 @@ def parity(metadata: Metadata) -> ParityOutput:
246
247
 
247
248
  chi_scores = np.zeros(metadata.discrete_data.shape[1])
248
249
  p_values = np.zeros_like(chi_scores)
249
- insufficient_data = {}
250
+ insufficient_data: defaultdict[str, defaultdict[int, dict[str, int]]] = defaultdict(lambda: defaultdict(dict))
250
251
  for i, col_data in enumerate(metadata.discrete_data.T):
251
252
  # Builds a contingency matrix where entry at index (r,c) represents
252
253
  # the frequency of current_factor_name achieving value unique_factor_values[r]
@@ -261,26 +262,22 @@ def parity(metadata: Metadata) -> ParityOutput:
261
262
  for int_factor, int_class in zip(counts[0], counts[1]):
262
263
  if contingency_matrix[int_factor, int_class] > 0:
263
264
  factor_category = unique_factor_values[int_factor].item()
264
- if current_factor_name not in insufficient_data:
265
- insufficient_data[current_factor_name] = {}
266
- if factor_category not in insufficient_data[current_factor_name]:
267
- insufficient_data[current_factor_name][factor_category] = {}
268
265
  class_name = metadata.class_names[int_class]
269
266
  class_count = contingency_matrix[int_factor, int_class].item()
270
267
  insufficient_data[current_factor_name][factor_category][class_name] = class_count
271
268
 
272
269
  # This deletes rows containing only zeros,
273
270
  # because scipy.stats.chi2_contingency fails when there are rows containing only zeros.
274
- rowsums = np.sum(contingency_matrix, axis=1)
275
- rowmask = np.nonzero(rowsums)[0]
276
- contingency_matrix = contingency_matrix[rowmask]
271
+ contingency_matrix = contingency_matrix[np.any(contingency_matrix, axis=1)]
277
272
 
278
- chi2, p, _, _ = chi2_contingency(contingency_matrix)
279
-
280
- chi_scores[i] = chi2
281
- p_values[i] = p
273
+ chi_scores[i], p_values[i] = chi2_contingency(contingency_matrix)[:2]
282
274
 
283
275
  if insufficient_data:
284
276
  warnings.warn("Some factors did not meet the recommended 5 occurrences for each value-label combination.")
285
277
 
286
- return ParityOutput(chi_scores, p_values, metadata.discrete_factor_names, insufficient_data)
278
+ return ParityOutput(
279
+ score=chi_scores,
280
+ p_value=p_values,
281
+ factor_names=metadata.discrete_factor_names,
282
+ insufficient_data={k: dict(v) for k, v in insufficient_data.items()},
283
+ )
@@ -38,8 +38,7 @@ def divergence_mst(data: NDArray[np.float64], labels: NDArray[np.int_]) -> int:
38
38
  """
39
39
  mst = minimum_spanning_tree(data).toarray()
40
40
  edgelist = np.transpose(np.nonzero(mst))
41
- errors = np.sum(labels[edgelist[:, 0]] != labels[edgelist[:, 1]])
42
- return errors
41
+ return np.sum(labels[edgelist[:, 0]] != labels[edgelist[:, 1]])
43
42
 
44
43
 
45
44
  def divergence_fnn(data: NDArray[np.float64], labels: NDArray[np.int_]) -> int:
@@ -59,8 +58,7 @@ def divergence_fnn(data: NDArray[np.float64], labels: NDArray[np.int_]) -> int:
59
58
  Number of label errors when finding nearest neighbors
60
59
  """
61
60
  nn_indices = compute_neighbors(data, data)
62
- errors = np.sum(np.abs(labels[nn_indices] - labels))
63
- return errors
61
+ return np.sum(np.abs(labels[nn_indices] - labels))
64
62
 
65
63
 
66
64
  _DIVERGENCE_FN_MAP = {"FNN": divergence_fnn, "MST": divergence_mst}
@@ -10,23 +10,86 @@ from copy import deepcopy
10
10
  from dataclasses import dataclass
11
11
  from functools import partial
12
12
  from multiprocessing import Pool
13
- from typing import Any, Callable, Generic, Iterable, Sequence, TypeVar
13
+ from typing import Any, Callable, Generic, Iterable, Iterator, Sequence, TypeVar
14
14
 
15
15
  import numpy as np
16
16
  import tqdm
17
17
  from numpy.typing import NDArray
18
18
 
19
19
  from dataeval.config import get_max_processes
20
- from dataeval.outputs._stats import BaseStatsOutput, SourceIndex
20
+ from dataeval.outputs._stats import BASE_ATTRS, BaseStatsOutput, SourceIndex
21
21
  from dataeval.typing import Array, ArrayLike, Dataset, ObjectDetectionTarget
22
22
  from dataeval.utils._array import as_numpy, to_numpy
23
- from dataeval.utils._image import normalize_image_shape, rescale
23
+ from dataeval.utils._image import clip_and_pad, clip_box, is_valid_box, normalize_image_shape, rescale
24
24
 
25
25
  DTYPE_REGEX = re.compile(r"NDArray\[np\.(.*?)\]")
26
26
 
27
- BoundingBox = tuple[float, float, float, float]
28
27
  TStatsOutput = TypeVar("TStatsOutput", bound=BaseStatsOutput, covariant=True)
29
28
 
29
+ _S = TypeVar("_S")
30
+ _T = TypeVar("_T")
31
+
32
+
33
+ @dataclass
34
+ class BoundingBox:
35
+ x0: float
36
+ y0: float
37
+ x1: float
38
+ y1: float
39
+
40
+ def __post_init__(self) -> None:
41
+ # Test for invalid coordinates
42
+ x_swap = self.x0 > self.x1
43
+ y_swap = self.y0 > self.y1
44
+ if x_swap or y_swap:
45
+ warnings.warn(f"Invalid bounding box coordinates: {self} - swapping invalid coordinates.")
46
+ if x_swap:
47
+ self.x0, self.x1 = self.x1, self.x0
48
+ if y_swap:
49
+ self.y0, self.y1 = self.y1, self.y0
50
+
51
+ @property
52
+ def width(self) -> float:
53
+ return self.x1 - self.x0
54
+
55
+ @property
56
+ def height(self) -> float:
57
+ return self.y1 - self.y0
58
+
59
+ def to_int(self) -> tuple[int, int, int, int]:
60
+ """
61
+ Returns the bounding box as a tuple of integers.
62
+ """
63
+ x0_int = math.floor(self.x0)
64
+ y0_int = math.floor(self.y0)
65
+ x1_int = math.ceil(self.x1)
66
+ y1_int = math.ceil(self.y1)
67
+ return x0_int, y0_int, x1_int, y1_int
68
+
69
+
70
+ class PoolWrapper:
71
+ """
72
+ Wraps `multiprocessing.Pool` to allow for easy switching between
73
+ multiprocessing and single-threaded execution.
74
+
75
+ This helps with debugging and profiling, as well as usage with Jupyter notebooks
76
+ in VS Code, which does not support subprocess debugging.
77
+ """
78
+
79
+ def __init__(self, processes: int | None) -> None:
80
+ self.pool = Pool(processes) if processes is not None and processes > 1 else None
81
+
82
+ def imap(self, func: Callable[[_S], _T], iterable: Iterable[_S]) -> Iterator[_T]:
83
+ return map(func, iterable) if self.pool is None else self.pool.imap(func, iterable)
84
+
85
+ def __enter__(self, *args: Any, **kwargs: Any) -> PoolWrapper:
86
+ return self
87
+
88
+ def __exit__(self, *args: Any) -> None:
89
+ if self.pool is not None:
90
+ self.pool.close()
91
+ self.pool.join()
92
+
30
93
 
31
94
  class StatsProcessor(Generic[TStatsOutput]):
32
95
  output_class: type[TStatsOutput]
@@ -34,32 +97,26 @@ class StatsProcessor(Generic[TStatsOutput]):
34
97
  image_function_map: dict[str, Callable[[StatsProcessor[TStatsOutput]], Any]] = {}
35
98
  channel_function_map: dict[str, Callable[[StatsProcessor[TStatsOutput]], Any]] = {}
36
99
 
37
- def __init__(self, image: NDArray[Any], box: BoundingBox | None, per_channel: bool) -> None:
100
+ def __init__(self, image: NDArray[Any], box: BoundingBox | Iterable[Any] | None, per_channel: bool) -> None:
38
101
  self.raw = image
39
102
  self.width: int = image.shape[-1]
40
103
  self.height: int = image.shape[-2]
41
- box = BoundingBox((0, 0, self.width, self.height)) if box is None else box
42
- # Clip the bounding box to image
43
- x0, y0 = (min(j, max(0, math.floor(box[i]))) for i, j in zip((0, 1), (self.width - 1, self.height - 1)))
44
- x1, y1 = (min(j, max(1, math.ceil(box[i]))) for i, j in zip((2, 3), (self.width, self.height)))
45
- self.box: NDArray[np.int64] = np.array([x0, y0, x1, y1], dtype=np.int64)
104
+ box = (0, 0, self.width, self.height) if box is None else box
105
+ self.box = box if isinstance(box, BoundingBox) else BoundingBox(*box)
46
106
  self._per_channel = per_channel
47
107
  self._image = None
48
108
  self._shape = None
49
109
  self._scaled = None
50
110
  self._cache = {}
51
111
  self._fn_map = self.channel_function_map if per_channel else self.image_function_map
52
- self._is_valid_slice = box is None or bool(
53
- box[0] >= 0 and box[1] >= 0 and box[2] <= image.shape[-1] and box[3] <= image.shape[-2]
54
- )
112
+ self._is_valid_box = is_valid_box(clip_box(image, self.box.to_int()))
55
113
 
56
114
  def get(self, fn_key: str) -> NDArray[Any]:
57
115
  if fn_key in self.cache_keys:
58
116
  if fn_key not in self._cache:
59
117
  self._cache[fn_key] = self._fn_map[fn_key](self)
60
118
  return self._cache[fn_key]
61
- else:
62
- return self._fn_map[fn_key](self)
119
+ return self._fn_map[fn_key](self)
63
120
 
64
121
  def process(self) -> dict[str, Any]:
65
122
  return {k: self._fn_map[k](self) for k in self._fn_map}
@@ -67,11 +124,7 @@ class StatsProcessor(Generic[TStatsOutput]):
67
124
  @property
68
125
  def image(self) -> NDArray[Any]:
69
126
  if self._image is None:
70
- if self._is_valid_slice:
71
- norm = normalize_image_shape(self.raw)
72
- self._image = norm[:, self.box[1] : self.box[3], self.box[0] : self.box[2]]
73
- else:
74
- self._image = np.zeros((self.raw.shape[0], self.box[3] - self.box[1], self.box[2] - self.box[0]))
127
+ self._image = clip_and_pad(normalize_image_shape(self.raw), self.box.to_int())
75
128
  return self._image
76
129
 
77
130
  @property
@@ -90,9 +143,9 @@ class StatsProcessor(Generic[TStatsOutput]):
90
143
 
91
144
  @classmethod
92
145
  def convert_output(
93
- cls, source: dict[str, Any], source_index: list[SourceIndex], box_count: list[int]
146
+ cls, source: dict[str, Any], source_index: list[SourceIndex], object_count: list[int], image_count: int
94
147
  ) -> TStatsOutput:
95
- output = {}
148
+ output: dict[str, Any] = {}
96
149
  attrs = dict(ChainMap(*(getattr(c, "__annotations__", {}) for c in cls.output_class.__mro__)))
97
150
  for key in (key for key in source if key in attrs):
98
151
  stat_type: str = attrs[key]
@@ -101,14 +154,17 @@ class StatsProcessor(Generic[TStatsOutput]):
101
154
  output[key] = np.asarray(source[key], dtype=np.dtype(dtype_match.group(1)))
102
155
  else:
103
156
  output[key] = source[key]
104
- return cls.output_class(**output, source_index=source_index, box_count=np.asarray(box_count, dtype=np.uint16))
157
+ base_attrs: dict[str, Any] = dict(
158
+ zip(BASE_ATTRS, (source_index, np.asarray(object_count, dtype=np.uint16), image_count))
159
+ )
160
+ return cls.output_class(**output, **base_attrs)
105
161
 
106
162
 
107
163
  @dataclass
108
164
  class StatsProcessorOutput:
109
165
  results: list[dict[str, Any]]
110
166
  source_indices: list[SourceIndex]
111
- box_counts: list[int]
167
+ object_counts: list[int]
112
168
  warnings_list: list[str]
113
169
 
114
170
 
@@ -119,18 +175,18 @@ def process_stats(
119
175
  per_channel: bool,
120
176
  stats_processor_cls: Iterable[type[StatsProcessor[TStatsOutput]]],
121
177
  ) -> StatsProcessorOutput:
122
- image = to_numpy(image)
178
+ np_image = to_numpy(image)
123
179
  results_list: list[dict[str, Any]] = []
124
180
  source_indices: list[SourceIndex] = []
125
181
  box_counts: list[int] = []
126
182
  warnings_list: list[str] = []
127
183
  for i_b, box in [(None, None)] if boxes is None else enumerate(boxes):
128
- processor_list = [p(image, box, per_channel) for p in stats_processor_cls]
129
- if any(not p._is_valid_slice for p in processor_list) and i_b is not None and box is not None:
130
- warnings_list.append(f"Bounding box [{i}][{i_b}]: {box} is out of bounds of {image.shape}.")
184
+ processor_list = [p(np_image, box, per_channel) for p in stats_processor_cls]
185
+ if any(not p._is_valid_box for p in processor_list) and i_b is not None and box is not None:
186
+ warnings_list.append(f"Bounding box [{i}][{i_b}]: {box} for image shape {np_image.shape} is invalid.")
131
187
  results_list.append({k: v for p in processor_list for k, v in p.process().items()})
132
188
  if per_channel:
133
- source_indices.extend([SourceIndex(i, i_b, c) for c in range(image.shape[-3])])
189
+ source_indices.extend([SourceIndex(i, i_b, c) for c in range(np_image.shape[-3])])
134
190
  else:
135
191
  source_indices.append(SourceIndex(i, i_b, None))
136
192
  box_counts.append(0 if boxes is None else len(boxes))
@@ -145,13 +201,18 @@ def process_stats_unpack(
145
201
  return process_stats(*args, per_channel=per_channel, stats_processor_cls=stats_processor_cls)
146
202
 
147
203
 
148
- def _enumerate(dataset: Dataset[ArrayLike] | Dataset[tuple[ArrayLike, Any, Any]], per_box: bool):
204
+ def _enumerate(
205
+ dataset: Dataset[ArrayLike] | Dataset[tuple[ArrayLike, Any, Any]], per_box: bool
206
+ ) -> Iterator[tuple[int, ArrayLike, Any]]:
149
207
  for i in range(len(dataset)):
150
208
  d = dataset[i]
151
209
  image = d[0] if isinstance(d, tuple) else d
152
210
  if per_box and isinstance(d, tuple) and isinstance(d[1], ObjectDetectionTarget):
153
- boxes = d[1].boxes if isinstance(d[1].boxes, Array) else as_numpy(d[1].boxes)
154
- target = [BoundingBox(float(box[i]) for i in range(4)) for box in boxes]
211
+ try:
212
+ boxes = d[1].boxes if isinstance(d[1].boxes, Array) else as_numpy(d[1].boxes)
213
+ target = [BoundingBox(*(float(box[i]) for i in range(4))) for box in boxes]
214
+ except (ValueError, IndexError):
215
+ raise ValueError(f"Invalid bounding box format for image {i}: {d[1].boxes}")
155
216
  else:
156
217
  target = None
157
218
 
@@ -199,12 +260,13 @@ def run_stats(
199
260
  """
200
261
  results_list: list[dict[str, NDArray[np.float64]]] = []
201
262
  source_index: list[SourceIndex] = []
202
- box_count: list[int] = []
263
+ object_count: list[int] = []
264
+ image_count: int = len(dataset)
203
265
 
204
266
  warning_list = []
205
267
  stats_processor_cls = stats_processor_cls if isinstance(stats_processor_cls, Iterable) else [stats_processor_cls]
206
268
 
207
- with Pool(processes=get_max_processes()) as p:
269
+ with PoolWrapper(processes=get_max_processes()) as p:
208
270
  for r in tqdm.tqdm(
209
271
  p.imap(
210
272
  partial(
@@ -214,14 +276,12 @@ def run_stats(
214
276
  ),
215
277
  _enumerate(dataset, per_box),
216
278
  ),
217
- total=len(dataset),
279
+ total=image_count,
218
280
  ):
219
281
  results_list.extend(r.results)
220
282
  source_index.extend(r.source_indices)
221
- box_count.extend(r.box_counts)
283
+ object_count.extend(r.object_counts)
222
284
  warning_list.extend(r.warnings_list)
223
- p.close()
224
- p.join()
225
285
 
226
286
  # warnings are not emitted while in multiprocessing pools so we emit after gathering all warnings
227
287
  for w in warning_list:
@@ -235,8 +295,7 @@ def run_stats(
235
295
  else:
236
296
  output.setdefault(stat, []).append(result.tolist() if isinstance(result, np.ndarray) else result)
237
297
 
238
- outputs = [s.convert_output(output, source_index, box_count) for s in stats_processor_cls]
239
- return outputs
298
+ return [s.convert_output(output, source_index, object_count, image_count) for s in stats_processor_cls]
240
299
 
241
300
 
242
301
  def add_stats(a: TStatsOutput, b: TStatsOutput) -> TStatsOutput:
@@ -246,10 +305,12 @@ def add_stats(a: TStatsOutput, b: TStatsOutput) -> TStatsOutput:
246
305
  sum_dict = deepcopy(a.data())
247
306
 
248
307
  for k in sum_dict:
249
- if isinstance(sum_dict[k], list):
308
+ if isinstance(sum_dict[k], Sequence):
250
309
  sum_dict[k].extend(b.data()[k])
251
- else:
310
+ elif isinstance(sum_dict[k], Array):
252
311
  sum_dict[k] = np.concatenate((sum_dict[k], b.data()[k]))
312
+ else:
313
+ sum_dict[k] += b.data()[k]
253
314
 
254
315
  return type(a)(**sum_dict)
255
316