dataeval 0.86.0__py3-none-any.whl → 0.86.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. dataeval/__init__.py +1 -1
  2. dataeval/_log.py +1 -1
  3. dataeval/config.py +21 -4
  4. dataeval/data/_embeddings.py +2 -2
  5. dataeval/data/_images.py +2 -3
  6. dataeval/data/_metadata.py +48 -37
  7. dataeval/data/_selection.py +1 -2
  8. dataeval/data/_split.py +2 -3
  9. dataeval/data/_targets.py +17 -13
  10. dataeval/data/selections/_classfilter.py +2 -5
  11. dataeval/data/selections/_prioritize.py +6 -9
  12. dataeval/data/selections/_shuffle.py +3 -1
  13. dataeval/detectors/drift/_base.py +4 -5
  14. dataeval/detectors/drift/_mmd.py +3 -6
  15. dataeval/detectors/drift/_nml/_base.py +4 -2
  16. dataeval/detectors/drift/_nml/_chunk.py +11 -19
  17. dataeval/detectors/drift/_nml/_domainclassifier.py +8 -19
  18. dataeval/detectors/drift/_nml/_result.py +8 -9
  19. dataeval/detectors/drift/_nml/_thresholds.py +66 -77
  20. dataeval/detectors/linters/outliers.py +7 -7
  21. dataeval/metrics/bias/_parity.py +10 -13
  22. dataeval/metrics/estimators/_divergence.py +2 -4
  23. dataeval/metrics/stats/_base.py +103 -42
  24. dataeval/metrics/stats/_boxratiostats.py +21 -19
  25. dataeval/metrics/stats/_dimensionstats.py +14 -10
  26. dataeval/metrics/stats/_hashstats.py +1 -1
  27. dataeval/metrics/stats/_pixelstats.py +6 -6
  28. dataeval/metrics/stats/_visualstats.py +3 -3
  29. dataeval/outputs/_base.py +22 -7
  30. dataeval/outputs/_bias.py +26 -28
  31. dataeval/outputs/_drift.py +1 -9
  32. dataeval/outputs/_linters.py +11 -11
  33. dataeval/outputs/_stats.py +82 -23
  34. dataeval/outputs/_workflows.py +2 -2
  35. dataeval/utils/_array.py +6 -9
  36. dataeval/utils/_bin.py +1 -2
  37. dataeval/utils/_clusterer.py +7 -4
  38. dataeval/utils/_fast_mst.py +27 -13
  39. dataeval/utils/_image.py +65 -11
  40. dataeval/utils/_mst.py +1 -3
  41. dataeval/utils/_plot.py +15 -10
  42. dataeval/utils/data/_dataset.py +32 -20
  43. dataeval/utils/data/metadata.py +104 -82
  44. dataeval/utils/datasets/__init__.py +2 -0
  45. dataeval/utils/datasets/_antiuav.py +189 -0
  46. dataeval/utils/datasets/_base.py +11 -8
  47. dataeval/utils/datasets/_cifar10.py +104 -45
  48. dataeval/utils/datasets/_fileio.py +21 -47
  49. dataeval/utils/datasets/_milco.py +19 -11
  50. dataeval/utils/datasets/_mixin.py +2 -4
  51. dataeval/utils/datasets/_mnist.py +3 -4
  52. dataeval/utils/datasets/_ships.py +14 -7
  53. dataeval/utils/datasets/_voc.py +229 -42
  54. dataeval/utils/torch/models.py +5 -10
  55. dataeval/utils/torch/trainer.py +3 -3
  56. dataeval/workflows/sufficiency.py +2 -2
  57. {dataeval-0.86.0.dist-info → dataeval-0.86.1.dist-info}/METADATA +1 -1
  58. dataeval-0.86.1.dist-info/RECORD +114 -0
  59. dataeval/detectors/ood/vae.py +0 -74
  60. dataeval-0.86.0.dist-info/RECORD +0 -114
  61. {dataeval-0.86.0.dist-info → dataeval-0.86.1.dist-info}/LICENSE.txt +0 -0
  62. {dataeval-0.86.0.dist-info → dataeval-0.86.1.dist-info}/WHEEL +0 -0
@@ -8,8 +8,9 @@ from typing import Any, Callable, Generic, TypeVar, cast
8
8
  import numpy as np
9
9
  from numpy.typing import NDArray
10
10
 
11
+ from dataeval.config import EPSILON
11
12
  from dataeval.outputs._base import set_metadata
12
- from dataeval.outputs._stats import BOX_COUNT, SOURCE_INDEX, BaseStatsOutput, DimensionStatsOutput
13
+ from dataeval.outputs._stats import BASE_ATTRS, BaseStatsOutput
13
14
 
14
15
  TStatOutput = TypeVar("TStatOutput", bound=BaseStatsOutput, contravariant=True)
15
16
  ArraySlice = tuple[int, int]
@@ -40,15 +41,19 @@ class BoxImageStatsOutputSlice(Generic[TStatOutput]):
40
41
  self.img = self.StatSlicer(img_stats, img_slice)
41
42
 
42
43
 
43
- RATIOSTATS_OVERRIDE_MAP: dict[type, dict[str, Callable[..., NDArray[Any]]]] = {
44
- DimensionStatsOutput: dict[str, Callable[[BoxImageStatsOutputSlice[DimensionStatsOutput]], NDArray[Any]]](
45
- {
46
- "left": lambda x: x.box["left"] / x.img["width"],
47
- "top": lambda x: x.box["top"] / x.img["height"],
48
- "channels": lambda x: x.box["channels"],
49
- "depth": lambda x: x.box["depth"],
50
- "distance": lambda x: x.box["distance"],
51
- }
44
+ RATIOSTATS_OVERRIDE_MAP: dict[str, Callable[[BoxImageStatsOutputSlice[Any]], NDArray[Any]]] = {
45
+ "offset_x": lambda x: x.box["offset_x"] / x.img["width"],
46
+ "offset_y": lambda x: x.box["offset_y"] / x.img["height"],
47
+ "channels": lambda x: x.box["channels"],
48
+ "depth": lambda x: x.box["depth"],
49
+ "distance_center": lambda x: x.box["distance_center"]
50
+ / (np.sqrt(np.square(x.img["width"]) + np.square(x.img["height"])) / 2),
51
+ "distance_edge": lambda x: x.box["distance_edge"]
52
+ / (
53
+ x.img["width"]
54
+ if np.min([np.abs(x.box["offset_x"]), np.abs((x.box["width"] + x.box["offset_x"]) - x.img["width"])])
55
+ < np.min([np.abs(x.box["offset_y"]), np.abs((x.box["height"] + x.box["offset_y"]) - x.img["height"])])
56
+ else x.img["height"]
52
57
  ),
53
58
  }
54
59
 
@@ -69,11 +74,9 @@ def calculate_ratios(key: str, box_stats: BaseStatsOutput, img_stats: BaseStatsO
69
74
 
70
75
  stats = getattr(box_stats, key)
71
76
 
72
- # Copy over stats index maps and box counts
73
- if key in (SOURCE_INDEX):
77
+ # Copy over base attributes
78
+ if key in BASE_ATTRS:
74
79
  return copy.deepcopy(stats)
75
- elif key == BOX_COUNT:
76
- return np.copy(stats)
77
80
 
78
81
  # Calculate ratios for each stat
79
82
  out_stats: np.ndarray = np.copy(stats).astype(np.float64)
@@ -84,10 +87,9 @@ def calculate_ratios(key: str, box_stats: BaseStatsOutput, img_stats: BaseStatsO
84
87
  box_j = len(box_stats) if i == len(box_map) - 1 else box_map[i + 1]
85
88
  img_j = len(img_stats) if i == len(img_map) - 1 else img_map[i + 1]
86
89
  stats = BoxImageStatsOutputSlice(box_stats, (box_i, box_j), img_stats, (img_i, img_j))
87
- out_type = type(box_stats)
88
- use_override = out_type in RATIOSTATS_OVERRIDE_MAP and key in RATIOSTATS_OVERRIDE_MAP[out_type]
90
+ use_override = key in RATIOSTATS_OVERRIDE_MAP
89
91
  with np.errstate(divide="ignore", invalid="ignore"):
90
- ratio = RATIOSTATS_OVERRIDE_MAP[out_type][key](stats) if use_override else stats.box[key] / stats.img[key]
92
+ ratio = RATIOSTATS_OVERRIDE_MAP[key](stats) if use_override else stats.box[key] / (stats.img[key] + EPSILON)
91
93
  out_stats[box_i:box_j] = ratio.reshape(-1, *out_stats[box_i].shape)
92
94
  return out_stats
93
95
 
@@ -141,8 +143,8 @@ def boxratiostats(
141
143
  output_cls = type(boxstats)
142
144
  if type(boxstats) is not type(imgstats):
143
145
  raise TypeError("Must provide stats outputs of the same type.")
144
- if boxstats.source_index[-1].image != imgstats.source_index[-1].image:
145
- raise ValueError("Stats index_map length mismatch. Check if the correct box and image stats were provided.")
146
+ if boxstats.image_count != imgstats.image_count:
147
+ raise ValueError("Stats image count length mismatch. Check if the correct box and image stats were provided.")
146
148
  if any(src_idx.box is None for src_idx in boxstats.source_index):
147
149
  raise ValueError("Input for boxstats must contain box information.")
148
150
  if any(src_idx.box is not None for src_idx in imgstats.source_index):
@@ -6,6 +6,7 @@ from typing import Any, Callable
6
6
 
7
7
  import numpy as np
8
8
 
9
+ from dataeval.config import EPSILON
9
10
  from dataeval.metrics.stats._base import StatsProcessor, run_stats
10
11
  from dataeval.outputs import DimensionStatsOutput
11
12
  from dataeval.outputs._base import set_metadata
@@ -16,18 +17,21 @@ from dataeval.utils._image import get_bitdepth
16
17
  class DimensionStatsProcessor(StatsProcessor[DimensionStatsOutput]):
17
18
  output_class: type = DimensionStatsOutput
18
19
  image_function_map: dict[str, Callable[[StatsProcessor[DimensionStatsOutput]], Any]] = {
19
- "left": lambda x: x.box[0],
20
- "top": lambda x: x.box[1],
21
- "width": lambda x: x.box[2] - x.box[0],
22
- "height": lambda x: x.box[3] - x.box[1],
20
+ "offset_x": lambda x: x.box.x0,
21
+ "offset_y": lambda x: x.box.y0,
22
+ "width": lambda x: x.box.width,
23
+ "height": lambda x: x.box.height,
23
24
  "channels": lambda x: x.shape[-3],
24
- "size": lambda x: (x.box[2] - x.box[0]) * (x.box[3] - x.box[1]),
25
- "aspect_ratio": lambda x: (x.box[2] - x.box[0]) / (x.box[3] - x.box[1]),
25
+ "size": lambda x: x.box.width * x.box.height,
26
+ "aspect_ratio": lambda x: x.box.width / (x.box.height + EPSILON),
26
27
  "depth": lambda x: get_bitdepth(x.image).depth,
27
- "center": lambda x: np.asarray([(x.box[0] + x.box[2]) / 2, (x.box[1] + x.box[3]) / 2]),
28
- "distance": lambda x: np.sqrt(
29
- np.square(((x.box[0] + x.box[2]) / 2) - (x.shape[-1] / 2))
30
- + np.square(((x.box[1] + x.box[3]) / 2) - (x.shape[-2] / 2))
28
+ "center": lambda x: np.asarray([(x.box.x0 + x.box.x1) / 2, (x.box.y0 + x.box.y1) / 2]),
29
+ "distance_center": lambda x: np.sqrt(
30
+ np.square(((x.box.x0 + x.box.x1) / 2) - (x.raw.shape[-1] / 2))
31
+ + np.square(((x.box.y0 + x.box.y1) / 2) - (x.raw.shape[-2] / 2))
32
+ ),
33
+ "distance_edge": lambda x: np.min(
34
+ [np.abs(x.box.x0), np.abs(x.box.y0), np.abs(x.box.x1 - x.raw.shape[-1]), np.abs(x.box.y1 - x.raw.shape[-2])]
31
35
  ),
32
36
  }
33
37
 
@@ -137,7 +137,7 @@ def hashstats(
137
137
 
138
138
  >>> results = hashstats(dataset)
139
139
  >>> print(results.xxhash[:5])
140
- ['66a93f556577c086', 'd8b686fb405c4105', '7ffdb4990ad44ac6', '42cd4c34c80f6006', 'c5519e36ac1f8839']
140
+ ['69b50a5f06af238c', '5a861d7a23d1afe7', '7ffdb4990ad44ac6', '4f0c366a3298ceac', 'c5519e36ac1f8839']
141
141
  >>> print(results.pchash[:5])
142
142
  ['e666999999266666', 'e666999999266666', 'e666999966666299', 'e666999999266666', '96e91656e91616e9']
143
143
  """
@@ -16,18 +16,18 @@ from dataeval.typing import ArrayLike, Dataset
16
16
  class PixelStatsProcessor(StatsProcessor[PixelStatsOutput]):
17
17
  output_class: type = PixelStatsOutput
18
18
  image_function_map: dict[str, Callable[[StatsProcessor[PixelStatsOutput]], Any]] = {
19
- "mean": lambda x: np.mean(x.scaled),
20
- "std": lambda x: np.std(x.scaled),
21
- "var": lambda x: np.var(x.scaled),
19
+ "mean": lambda x: np.nanmean(x.scaled),
20
+ "std": lambda x: np.nanstd(x.scaled),
21
+ "var": lambda x: np.nanvar(x.scaled),
22
22
  "skew": lambda x: np.nan_to_num(skew(x.scaled.ravel())),
23
23
  "kurtosis": lambda x: np.nan_to_num(kurtosis(x.scaled.ravel())),
24
24
  "histogram": lambda x: np.histogram(x.scaled, 256, (0, 1))[0],
25
25
  "entropy": lambda x: entropy(x.get("histogram")),
26
26
  }
27
27
  channel_function_map: dict[str, Callable[[StatsProcessor[PixelStatsOutput]], Any]] = {
28
- "mean": lambda x: np.mean(x.scaled, axis=1),
29
- "std": lambda x: np.std(x.scaled, axis=1),
30
- "var": lambda x: np.var(x.scaled, axis=1),
28
+ "mean": lambda x: np.nanmean(x.scaled, axis=1),
29
+ "std": lambda x: np.nanstd(x.scaled, axis=1),
30
+ "var": lambda x: np.nanvar(x.scaled, axis=1),
31
31
  "skew": lambda x: np.nan_to_num(skew(x.scaled, axis=1)),
32
32
  "kurtosis": lambda x: np.nan_to_num(kurtosis(x.scaled, axis=1)),
33
33
  "histogram": lambda x: np.apply_along_axis(lambda y: np.histogram(y, 256, (0, 1))[0], 1, x.scaled),
@@ -24,8 +24,8 @@ class VisualStatsProcessor(StatsProcessor[VisualStatsOutput]):
24
24
  else (np.max(x.get("percentiles")) - np.min(x.get("percentiles"))) / np.mean(x.get("percentiles")),
25
25
  "darkness": lambda x: x.get("percentiles")[-2],
26
26
  "missing": lambda x: np.count_nonzero(np.isnan(np.sum(x.image, axis=0))) / np.prod(x.shape[-2:]),
27
- "sharpness": lambda x: np.std(edge_filter(np.mean(x.image, axis=0))),
28
- "zeros": lambda x: np.count_nonzero(np.sum(x.image, axis=0) == 0) / np.prod(x.shape[-2:]),
27
+ "sharpness": lambda x: np.nanstd(edge_filter(np.mean(x.image, axis=0))),
28
+ "zeros": lambda x: np.count_nonzero(np.nansum(x.image, axis=0) == 0) / np.prod(x.shape[-2:]),
29
29
  "percentiles": lambda x: np.nanpercentile(x.scaled, q=QUARTILES),
30
30
  }
31
31
  channel_function_map: dict[str, Callable[[StatsProcessor[VisualStatsOutput]], Any]] = {
@@ -36,7 +36,7 @@ class VisualStatsProcessor(StatsProcessor[VisualStatsOutput]):
36
36
  ),
37
37
  "darkness": lambda x: x.get("percentiles")[:, -2],
38
38
  "missing": lambda x: np.count_nonzero(np.isnan(x.image), axis=(1, 2)) / np.prod(x.shape[-2:]),
39
- "sharpness": lambda x: np.std(np.vectorize(edge_filter, signature="(m,n)->(m,n)")(x.image), axis=(1, 2)),
39
+ "sharpness": lambda x: np.nanstd(np.vectorize(edge_filter, signature="(m,n)->(m,n)")(x.image), axis=(1, 2)),
40
40
  "zeros": lambda x: np.count_nonzero(x.image == 0, axis=(1, 2)) / np.prod(x.shape[-2:]),
41
41
  "percentiles": lambda x: np.nanpercentile(x.scaled, q=QUARTILES, axis=1).T,
42
42
  }
dataeval/outputs/_base.py CHANGED
@@ -66,25 +66,40 @@ class GenericOutput(Generic[T]):
66
66
  def meta(self) -> ExecutionMetadata:
67
67
  """
68
68
  Metadata about the execution of the function or method for the Output class.
69
+
70
+ Returns
71
+ -------
72
+ ExecutionMetadata
69
73
  """
70
74
  return self._meta or ExecutionMetadata.empty()
71
75
 
72
76
 
73
77
  class Output(GenericOutput[dict[str, Any]]):
74
78
  def data(self) -> dict[str, Any]:
75
- return {k: v for k, v in self.__dict__.items() if k != "_meta"}
79
+ """
80
+ The output data as a dictionary.
76
81
 
77
- def __repr__(self) -> str:
78
- return str(self)
82
+ Returns
83
+ -------
84
+ dict[str, Any]
85
+ """
86
+ return {k: v for k, v in self.__dict__.items() if k != "_meta"}
79
87
 
80
88
  def __str__(self) -> str:
81
- return f"{self.__class__.__name__}({', '.join([f'{k}={v}' for k, v in self.data().items()])})"
89
+ return str(self.data())
82
90
 
83
91
 
84
92
  class BaseCollectionMixin(Collection[Any]):
85
93
  __slots__ = ["_data"]
86
94
 
87
95
  def data(self) -> Any:
96
+ """
97
+ The output data as a collection.
98
+
99
+ Returns
100
+ -------
101
+ Collection
102
+ """
88
103
  return self._data
89
104
 
90
105
  def __len__(self) -> int:
@@ -102,7 +117,7 @@ TValue = TypeVar("TValue")
102
117
 
103
118
 
104
119
  class MappingOutput(Mapping[TKey, TValue], BaseCollectionMixin, GenericOutput[Mapping[TKey, TValue]]):
105
- def __init__(self, data: Mapping[TKey, TValue]):
120
+ def __init__(self, data: Mapping[TKey, TValue]) -> None:
106
121
  self._data = data
107
122
 
108
123
  def __getitem__(self, key: TKey) -> TValue:
@@ -113,7 +128,7 @@ class MappingOutput(Mapping[TKey, TValue], BaseCollectionMixin, GenericOutput[Ma
113
128
 
114
129
 
115
130
  class SequenceOutput(Sequence[TValue], BaseCollectionMixin, GenericOutput[Sequence[TValue]]):
116
- def __init__(self, data: Sequence[TValue]):
131
+ def __init__(self, data: Sequence[TValue]) -> None:
117
132
  self._data = data
118
133
 
119
134
  @overload
@@ -140,7 +155,7 @@ def set_metadata(fn: Callable[P, R] | None = None, *, state: list[str] | None =
140
155
 
141
156
  @wraps(fn)
142
157
  def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
143
- def fmt(v):
158
+ def fmt(v: Any) -> Any:
144
159
  if np.isscalar(v):
145
160
  return v
146
161
  if hasattr(v, "shape"):
dataeval/outputs/_bias.py CHANGED
@@ -128,33 +128,30 @@ class CoverageOutput(Output):
128
128
 
129
129
  import matplotlib.pyplot as plt
130
130
 
131
+ images = Images(images) if isinstance(images, Dataset) else images
132
+ if np.max(self.uncovered_indices) > len(images):
133
+ raise ValueError(
134
+ f"Uncovered indices {self.uncovered_indices} specify images "
135
+ f"unavailable in the provided number of images {len(images)}."
136
+ )
137
+
131
138
  # Determine which images to plot
132
139
  selected_indices = self.uncovered_indices[:top_k]
133
140
 
134
- images = Images(images) if isinstance(images, Dataset) else images
135
-
136
141
  # Plot the images
137
142
  num_images = min(top_k, len(selected_indices))
138
143
 
139
144
  rows = int(np.ceil(num_images / 3))
140
- fig, axs = plt.subplots(rows, 3, figsize=(9, 3 * rows))
141
-
142
- if rows == 1:
143
- for j in range(3):
144
- if j >= len(selected_indices):
145
- continue
146
- image = channels_first_to_last(as_numpy(images[selected_indices[j]]))
147
- axs[j].imshow(image)
148
- axs[j].axis("off")
149
- else:
150
- for i in range(rows):
151
- for j in range(3):
152
- i_j = i * 3 + j
153
- if i_j >= len(selected_indices):
154
- continue
155
- image = channels_first_to_last(as_numpy(images[selected_indices[i_j]]))
156
- axs[i, j].imshow(image)
157
- axs[i, j].axis("off")
145
+ cols = min(3, num_images)
146
+ fig, axs = plt.subplots(rows, cols, figsize=(3 * cols, 3 * rows))
147
+
148
+ for image, ax in zip(images[:num_images], axs.flat):
149
+ image = channels_first_to_last(as_numpy(image))
150
+ ax.imshow(image)
151
+ ax.axis("off")
152
+
153
+ for ax in axs.flat[num_images:]:
154
+ ax.axis("off")
158
155
 
159
156
  fig.tight_layout()
160
157
  return fig
@@ -233,14 +230,15 @@ class BalanceOutput(Output):
233
230
  # return the masked attribute
234
231
  if attr == "factor_names":
235
232
  return [x.replace(f"-{factor_type}", "") for x in self.factor_names if mask_lambda(x)]
236
- else:
237
- factor_type_mask = np.asarray([mask_lambda(x) for x in self.factor_names])
238
- if attr == "factors":
239
- return self.factors[factor_type_mask[1:]][:, factor_type_mask[1:]]
240
- elif attr == "balance":
241
- return self.balance[factor_type_mask]
242
- elif attr == "classwise":
243
- return self.classwise[:, factor_type_mask]
233
+ factor_type_mask = np.asarray([mask_lambda(x) for x in self.factor_names])
234
+ if attr == "factors":
235
+ return self.factors[factor_type_mask[1:]][:, factor_type_mask[1:]]
236
+ if attr == "balance":
237
+ return self.balance[factor_type_mask]
238
+ if attr == "classwise":
239
+ return self.classwise[:, factor_type_mask]
240
+
241
+ raise ValueError(f"Unknown attr {attr} specified.")
244
242
 
245
243
  def plot(
246
244
  self,
@@ -103,19 +103,13 @@ class DriftMVDCOutput(PerMetricResult):
103
103
  metric = Metric(display_name="Domain Classifier", column_name="domain_classifier_auroc")
104
104
  super().__init__(results_data, [metric])
105
105
 
106
- def plot(self, showme: bool = True) -> Figure:
106
+ def plot(self) -> Figure:
107
107
  """
108
108
  Render the roc_auc metric over the train/test data in relation to the threshold.
109
109
 
110
- Parameters
111
- ----------
112
- showme : bool, default True
113
- Option to display the figure.
114
-
115
110
  Returns
116
111
  -------
117
112
  matplotlib.figure.Figure
118
-
119
113
  """
120
114
  import matplotlib.pyplot as plt
121
115
 
@@ -146,6 +140,4 @@ class DriftMVDCOutput(PerMetricResult):
146
140
  ax.set_ylabel("ROC AUC", fontsize=7)
147
141
  ax.set_xlabel("Chunk Index", fontsize=7)
148
142
  ax.set_ylim((0.0, 1.1))
149
- if showme:
150
- plt.show()
151
143
  return fig
@@ -43,10 +43,12 @@ class DuplicatesOutput(Output, Generic[TIndexCollection]):
43
43
  near: list[TIndexCollection]
44
44
 
45
45
 
46
- def _reorganize_by_class_and_metric(result: IndexIssueMap, lstats: LabelStatsOutput):
46
+ def _reorganize_by_class_and_metric(
47
+ result: IndexIssueMap, lstats: LabelStatsOutput
48
+ ) -> tuple[dict[str, list[int]], dict[str, dict[str, int]]]:
47
49
  """Flip result from grouping by image to grouping by class and metric"""
48
- metrics = {}
49
- class_wise = {label: {} for label in lstats.class_names}
50
+ metrics: dict[str, list[int]] = {}
51
+ class_wise: dict[str, dict[str, int]] = {label: {} for label in lstats.class_names}
50
52
 
51
53
  # Group metrics and calculate class-wise counts
52
54
  for img, group in result.items():
@@ -59,7 +61,7 @@ def _reorganize_by_class_and_metric(result: IndexIssueMap, lstats: LabelStatsOut
59
61
  return metrics, class_wise
60
62
 
61
63
 
62
- def _create_table(metrics, class_wise):
64
+ def _create_table(metrics: dict[str, list[int]], class_wise: dict[str, dict[str, int]]) -> list[str]:
63
65
  """Create table for displaying the results"""
64
66
  max_class_length = max(len(str(label)) for label in class_wise) + 2
65
67
  max_total = max(len(metrics[group]) for group in metrics) + 2
@@ -69,7 +71,7 @@ def _create_table(metrics, class_wise):
69
71
  + [f"{group:^{max(5, len(str(group))) + 2}}" for group in sorted(metrics.keys())]
70
72
  + [f"{'Total':<{max_total}}"]
71
73
  )
72
- table_rows = []
74
+ table_rows: list[str] = []
73
75
 
74
76
  for class_cat, results in class_wise.items():
75
77
  table_value = [f"{class_cat:>{max_class_length}}"]
@@ -81,15 +83,14 @@ def _create_table(metrics, class_wise):
81
83
  table_value.append(f"{total:^{max_total}}")
82
84
  table_rows.append(" | ".join(table_value))
83
85
 
84
- table = [table_header] + table_rows
85
- return table
86
+ return [table_header] + table_rows
86
87
 
87
88
 
88
- def _create_pandas_dataframe(class_wise):
89
+ def _create_pandas_dataframe(class_wise: dict[str, dict[str, int]]) -> list[dict[str, str | int]]:
89
90
  """Create data for pandas dataframe"""
90
91
  data = []
91
92
  for label, metrics_dict in class_wise.items():
92
- row = {"Class": label}
93
+ row: dict[str, str | int] = {"Class": label}
93
94
  total = sum(metrics_dict.values())
94
95
  row.update(metrics_dict) # Add metric counts
95
96
  row["Total"] = total
@@ -118,8 +119,7 @@ class OutliersOutput(Output, Generic[TIndexIssueMap]):
118
119
  def __len__(self) -> int:
119
120
  if isinstance(self.issues, dict):
120
121
  return len(self.issues)
121
- else:
122
- return sum(len(d) for d in self.issues)
122
+ return sum(len(d) for d in self.issues)
123
123
 
124
124
  def to_table(self, labelstats: LabelStatsOutput) -> str:
125
125
  """
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
  __all__ = []
4
4
 
5
5
  from dataclasses import dataclass
6
- from typing import Any, Iterable, NamedTuple, Optional, Union
6
+ from typing import TYPE_CHECKING, Any, Iterable, NamedTuple, Optional, Sequence, Union
7
7
 
8
8
  import numpy as np
9
9
  import pandas as pd
@@ -13,10 +13,16 @@ from typing_extensions import TypeAlias
13
13
  from dataeval.outputs._base import Output
14
14
  from dataeval.utils._plot import channel_histogram_plot, histogram_plot
15
15
 
16
+ if TYPE_CHECKING:
17
+ from matplotlib.figure import Figure
18
+
16
19
  OptionalRange: TypeAlias = Optional[Union[int, Iterable[int]]]
17
20
 
18
21
  SOURCE_INDEX = "source_index"
19
- BOX_COUNT = "box_count"
22
+ OBJECT_COUNT = "object_count"
23
+ IMAGE_COUNT = "image_count"
24
+
25
+ BASE_ATTRS = (SOURCE_INDEX, OBJECT_COUNT, IMAGE_COUNT)
20
26
 
21
27
 
22
28
  class SourceIndex(NamedTuple):
@@ -51,17 +57,24 @@ class BaseStatsOutput(Output):
51
57
  ----------
52
58
  source_index : List[SourceIndex]
53
59
  Mapping from statistic to source image, box and channel index
54
- box_count : NDArray[np.uint16]
60
+ object_count : NDArray[np.uint16]
61
+ The number of detected objects in each image
55
62
  """
56
63
 
57
64
  source_index: list[SourceIndex]
58
- box_count: NDArray[np.uint16]
65
+ object_count: NDArray[np.uint16]
66
+ image_count: int
59
67
 
60
68
  def __post_init__(self) -> None:
61
- length = len(self.source_index)
62
- bad = {k: len(v) for k, v in self.data().items() if k not in [SOURCE_INDEX, BOX_COUNT] and len(v) != length}
63
- if bad:
64
- raise ValueError(f"All values must have the same length as source_index. Bad values: {str(bad)}.")
69
+ si_length = len(self.source_index)
70
+ mismatch = {k: len(v) for k, v in self.data().items() if k not in BASE_ATTRS and len(v) != si_length}
71
+ if mismatch:
72
+ raise ValueError(f"All values must have the same length as source_index. Bad values: {str(mismatch)}.")
73
+ oc_length = len(self.object_count)
74
+ if oc_length != self.image_count:
75
+ raise ValueError(
76
+ f"Total object counts per image does not match image count. {oc_length} != {self.image_count}."
77
+ )
65
78
 
66
79
  def get_channel_mask(
67
80
  self,
@@ -123,21 +136,64 @@ class BaseStatsOutput(Output):
123
136
 
124
137
  return max_channels, ch_mask
125
138
 
126
- def factors(self) -> dict[str, NDArray[Any]]:
139
+ def factors(
140
+ self,
141
+ filter: str | Sequence[str] | None = None, # noqa: A002
142
+ exclude_constant: bool = False,
143
+ ) -> dict[str, NDArray[Any]]:
144
+ """
145
+ Returns all 1-dimensional data as a dictionary of numpy arrays.
146
+
147
+ Parameters
148
+ ----------
149
+ filter : str, Sequence[str] or None, default None:
150
+ If provided, only returns keys that match the filter.
151
+ exclude_constant : bool, default False
152
+ If True, exclude arrays that contain only a single unique value.
153
+
154
+ Returns
155
+ -------
156
+ dict[str, NDArray[Any]]
157
+ """
158
+ filter_ = [filter] if isinstance(filter, str) else filter
127
159
  return {
128
160
  k: v
129
161
  for k, v in self.data().items()
130
- if k not in (SOURCE_INDEX, BOX_COUNT) and isinstance(v, np.ndarray) and v[v != 0].size > 0 and v.ndim == 1
162
+ if k not in BASE_ATTRS
163
+ and (filter_ is None or k in filter_)
164
+ and isinstance(v, np.ndarray)
165
+ and v.ndim == 1
166
+ and (not exclude_constant or len(np.unique(v)) > 1)
131
167
  }
132
168
 
133
169
  def plot(
134
170
  self, log: bool, channel_limit: int | None = None, channel_index: int | Iterable[int] | None = None
135
- ) -> None:
171
+ ) -> Figure:
172
+ """
173
+ Plots the statistics as a set of histograms.
174
+
175
+ Parameters
176
+ ----------
177
+ log : bool
178
+ If True, plots the histograms on a logarithmic scale.
179
+ channel_limit : int or None
180
+ The maximum number of channels to plot. If None, all channels are plotted.
181
+ channel_index : int, Iterable[int] or None
182
+ The index or indices of the channels to plot. If None, all channels are plotted.
183
+
184
+ Returns
185
+ -------
186
+ matplotlib.Figure
187
+ """
188
+ from matplotlib.figure import Figure
189
+
136
190
  max_channels, ch_mask = self._get_channels(channel_limit, channel_index)
191
+ factors = self.factors(exclude_constant=True)
192
+ if not factors:
193
+ return Figure()
137
194
  if max_channels == 1:
138
- histogram_plot(self.factors(), log)
139
- else:
140
- channel_histogram_plot(self.factors(), log, max_channels, ch_mask)
195
+ return histogram_plot(factors, log)
196
+ return channel_histogram_plot(factors, log, max_channels, ch_mask)
141
197
 
142
198
 
143
199
  @dataclass(frozen=True)
@@ -147,9 +203,9 @@ class DimensionStatsOutput(BaseStatsOutput):
147
203
 
148
204
  Attributes
149
205
  ----------
150
- left : NDArray[np.int32]
206
+ offset_x : NDArray[np.int32]
151
207
  Offsets from the left edge of images in pixels
152
- top : NDArray[np.int32]
208
+ offset_y : NDArray[np.int32]
153
209
  Offsets from the top edge of images in pixels
154
210
  width : NDArray[np.uint32]
155
211
  Width of the images in pixels
@@ -160,25 +216,28 @@ class DimensionStatsOutput(BaseStatsOutput):
160
216
  size : NDArray[np.uint32]
161
217
  Size of the images in pixels
162
218
  aspect_ratio : NDArray[np.float16]
163
- :term:`ASspect Ratio<Aspect Ratio>` of the images (width/height)
219
+ :term:`Aspect Ratio<Aspect Ratio>` of the images (width/height)
164
220
  depth : NDArray[np.uint8]
165
221
  Color depth of the images in bits
166
- center : NDArray[np.uint16]
222
+ center : NDArray[np.uint32]
167
223
  Offset from center in [x,y] coordinates of the images in pixels
168
- distance : NDArray[np.float16]
224
+ distance_center : NDArray[np.float32]
169
225
  Distance in pixels from center
226
+ distance_edge : NDArray[np.uint32]
227
+ Distance in pixels from nearest edge
170
228
  """
171
229
 
172
- left: NDArray[np.int32]
173
- top: NDArray[np.int32]
230
+ offset_x: NDArray[np.int32]
231
+ offset_y: NDArray[np.int32]
174
232
  width: NDArray[np.uint32]
175
233
  height: NDArray[np.uint32]
176
234
  channels: NDArray[np.uint8]
177
235
  size: NDArray[np.uint32]
178
236
  aspect_ratio: NDArray[np.float16]
179
237
  depth: NDArray[np.uint8]
180
- center: NDArray[np.int16]
181
- distance: NDArray[np.float16]
238
+ center: NDArray[np.int32]
239
+ distance_center: NDArray[np.float32]
240
+ distance_edge: NDArray[np.uint32]
182
241
 
183
242
 
184
243
  @dataclass(frozen=True)
@@ -154,10 +154,10 @@ def calc_params(p_i: NDArray[Any], n_i: NDArray[Any], niter: int) -> NDArray[Any
154
154
  Array of parameters to recreate line of best fit
155
155
  """
156
156
 
157
- def is_valid(f_new, x_new, f_old, x_old):
157
+ def is_valid(f_new, x_new, f_old, x_old) -> bool: # noqa: ANN001
158
158
  return f_new != np.nan
159
159
 
160
- def f(x):
160
+ def f(x) -> float: # noqa: ANN001
161
161
  try:
162
162
  return np.sum(np.square(p_i - f_out(n_i, x)))
163
163
  except RuntimeWarning:
dataeval/utils/_array.py CHANGED
@@ -23,7 +23,7 @@ T = TypeVar("T", ArrayLike, np.ndarray, torch.Tensor)
23
23
  _np_dtype = TypeVar("_np_dtype", bound=np.generic)
24
24
 
25
25
 
26
- def _try_import(module_name) -> ModuleType | None:
26
+ def _try_import(module_name: str) -> ModuleType | None:
27
27
  if module_name in _MODULE_CACHE:
28
28
  return _MODULE_CACHE[module_name]
29
29
 
@@ -148,8 +148,7 @@ def ensure_embeddings(
148
148
 
149
149
  if dtype is None:
150
150
  return embeddings
151
- else:
152
- return arr
151
+ return arr
153
152
 
154
153
 
155
154
  @overload
@@ -174,10 +173,9 @@ def flatten(array: ArrayLike) -> NDArray[Any] | torch.Tensor:
174
173
  if isinstance(array, np.ndarray):
175
174
  nparr = as_numpy(array)
176
175
  return nparr.reshape((nparr.shape[0], -1))
177
- elif isinstance(array, torch.Tensor):
176
+ if isinstance(array, torch.Tensor):
178
177
  return torch.flatten(array, start_dim=1)
179
- else:
180
- raise TypeError(f"Unsupported array type {type(array)}.")
178
+ raise TypeError(f"Unsupported array type {type(array)}.")
181
179
 
182
180
 
183
181
  _TArray = TypeVar("_TArray", bound=Array)
@@ -199,7 +197,6 @@ def channels_first_to_last(array: _TArray) -> _TArray:
199
197
  """
200
198
  if isinstance(array, np.ndarray):
201
199
  return np.transpose(array, (1, 2, 0))
202
- elif isinstance(array, torch.Tensor):
200
+ if isinstance(array, torch.Tensor):
203
201
  return torch.permute(array, (1, 2, 0))
204
- else:
205
- raise TypeError(f"Unsupported array type {type(array)}.")
202
+ raise TypeError(f"Unsupported array type {type(array)}.")
dataeval/utils/_bin.py CHANGED
@@ -195,5 +195,4 @@ def bin_by_clusters(data: NDArray[np.number[Any]]) -> NDArray[np.float64]:
195
195
  if extend_bins:
196
196
  bin_edges = np.concatenate([bin_edges, extend_bins])
197
197
 
198
- bin_edges = np.sort(bin_edges)
199
- return bin_edges
198
+ return np.sort(bin_edges)