dataeval 0.85.0__py3-none-any.whl → 0.86.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. dataeval/__init__.py +1 -1
  2. dataeval/_log.py +1 -1
  3. dataeval/config.py +21 -4
  4. dataeval/data/_embeddings.py +2 -2
  5. dataeval/data/_images.py +2 -3
  6. dataeval/data/_metadata.py +65 -42
  7. dataeval/data/_selection.py +2 -3
  8. dataeval/data/_split.py +2 -3
  9. dataeval/data/_targets.py +17 -13
  10. dataeval/data/selections/_classfilter.py +6 -8
  11. dataeval/data/selections/_prioritize.py +6 -9
  12. dataeval/data/selections/_shuffle.py +3 -1
  13. dataeval/detectors/drift/__init__.py +4 -1
  14. dataeval/detectors/drift/_base.py +4 -5
  15. dataeval/detectors/drift/_mmd.py +3 -6
  16. dataeval/detectors/drift/_mvdc.py +92 -0
  17. dataeval/detectors/drift/_nml/__init__.py +6 -0
  18. dataeval/detectors/drift/_nml/_base.py +70 -0
  19. dataeval/detectors/drift/_nml/_chunk.py +396 -0
  20. dataeval/detectors/drift/_nml/_domainclassifier.py +181 -0
  21. dataeval/detectors/drift/_nml/_result.py +97 -0
  22. dataeval/detectors/drift/_nml/_thresholds.py +269 -0
  23. dataeval/detectors/linters/outliers.py +7 -7
  24. dataeval/metrics/bias/_parity.py +10 -13
  25. dataeval/metrics/estimators/_divergence.py +2 -4
  26. dataeval/metrics/stats/_base.py +103 -42
  27. dataeval/metrics/stats/_boxratiostats.py +21 -19
  28. dataeval/metrics/stats/_dimensionstats.py +14 -10
  29. dataeval/metrics/stats/_hashstats.py +1 -1
  30. dataeval/metrics/stats/_pixelstats.py +6 -6
  31. dataeval/metrics/stats/_visualstats.py +3 -3
  32. dataeval/outputs/__init__.py +2 -1
  33. dataeval/outputs/_base.py +22 -7
  34. dataeval/outputs/_bias.py +27 -31
  35. dataeval/outputs/_drift.py +60 -0
  36. dataeval/outputs/_linters.py +12 -17
  37. dataeval/outputs/_stats.py +83 -29
  38. dataeval/outputs/_workflows.py +2 -2
  39. dataeval/utils/_array.py +6 -9
  40. dataeval/utils/_bin.py +1 -2
  41. dataeval/utils/_clusterer.py +7 -4
  42. dataeval/utils/_fast_mst.py +27 -13
  43. dataeval/utils/_image.py +65 -11
  44. dataeval/utils/_mst.py +1 -3
  45. dataeval/utils/_plot.py +15 -10
  46. dataeval/utils/data/_dataset.py +32 -20
  47. dataeval/utils/data/metadata.py +104 -82
  48. dataeval/utils/datasets/__init__.py +2 -0
  49. dataeval/utils/datasets/_antiuav.py +189 -0
  50. dataeval/utils/datasets/_base.py +11 -8
  51. dataeval/utils/datasets/_cifar10.py +104 -45
  52. dataeval/utils/datasets/_fileio.py +21 -47
  53. dataeval/utils/datasets/_milco.py +19 -11
  54. dataeval/utils/datasets/_mixin.py +2 -4
  55. dataeval/utils/datasets/_mnist.py +3 -4
  56. dataeval/utils/datasets/_ships.py +14 -7
  57. dataeval/utils/datasets/_voc.py +229 -42
  58. dataeval/utils/torch/models.py +5 -10
  59. dataeval/utils/torch/trainer.py +3 -3
  60. dataeval/workflows/sufficiency.py +2 -2
  61. {dataeval-0.85.0.dist-info → dataeval-0.86.1.dist-info}/METADATA +3 -2
  62. dataeval-0.86.1.dist-info/RECORD +114 -0
  63. dataeval/detectors/ood/vae.py +0 -74
  64. dataeval-0.85.0.dist-info/RECORD +0 -107
  65. {dataeval-0.85.0.dist-info → dataeval-0.86.1.dist-info}/LICENSE.txt +0 -0
  66. {dataeval-0.85.0.dist-info → dataeval-0.86.1.dist-info}/WHEEL +0 -0
@@ -2,11 +2,17 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
+ import contextlib
5
6
  from dataclasses import dataclass
6
7
 
7
8
  import numpy as np
9
+ import pandas as pd
8
10
  from numpy.typing import NDArray
9
11
 
12
+ with contextlib.suppress(ImportError):
13
+ from matplotlib.figure import Figure
14
+
15
+ from dataeval.detectors.drift._nml._result import Metric, PerMetricResult
10
16
  from dataeval.outputs._base import Output
11
17
 
12
18
 
@@ -81,3 +87,57 @@ class DriftOutput(DriftBaseOutput):
81
87
  feature_threshold: float
82
88
  p_vals: NDArray[np.float32]
83
89
  distances: NDArray[np.float32]
90
+
91
+
92
+ class DriftMVDCOutput(PerMetricResult):
93
+ """Class wrapping the results of the classifier for drift detection and providing plotting functionality."""
94
+
95
+ def __init__(self, results_data: pd.DataFrame) -> None:
96
+ """Initialize a DomainClassifierCalculator results object.
97
+
98
+ Parameters
99
+ ----------
100
+ results_data : pd.DataFrame
101
+ Results data returned by a DomainClassifierCalculator.
102
+ """
103
+ metric = Metric(display_name="Domain Classifier", column_name="domain_classifier_auroc")
104
+ super().__init__(results_data, [metric])
105
+
106
+ def plot(self) -> Figure:
107
+ """
108
+ Render the roc_auc metric over the train/test data in relation to the threshold.
109
+
110
+ Returns
111
+ -------
112
+ matplotlib.figure.Figure
113
+ """
114
+ import matplotlib.pyplot as plt
115
+
116
+ fig, ax = plt.subplots(dpi=300)
117
+ resdf = self.to_df()
118
+ xticks = np.arange(resdf.shape[0])
119
+ trndf = resdf[resdf["chunk"]["period"] == "reference"]
120
+ tstdf = resdf[resdf["chunk"]["period"] == "analysis"]
121
+ # Get local indices for drift markers
122
+ driftx = np.where(resdf["domain_classifier_auroc"]["alert"].values) # type: ignore | dataframe
123
+ if np.size(driftx) > 2:
124
+ ax.plot(resdf.index, resdf["domain_classifier_auroc"]["upper_threshold"], "r--", label="thr_up")
125
+ ax.plot(resdf.index, resdf["domain_classifier_auroc"]["lower_threshold"], "r--", label="thr_low")
126
+ ax.plot(trndf.index, trndf["domain_classifier_auroc"]["value"], "b", label="train")
127
+ ax.plot(tstdf.index, tstdf["domain_classifier_auroc"]["value"], "g", label="test")
128
+ ax.plot(
129
+ resdf.index.values[driftx], # type: ignore | dataframe
130
+ resdf["domain_classifier_auroc"]["value"].values[driftx], # type: ignore | dataframe
131
+ "dm",
132
+ markersize=3,
133
+ label="drift",
134
+ )
135
+ ax.set_xticks(xticks)
136
+ ax.tick_params(axis="x", labelsize=6)
137
+ ax.tick_params(axis="y", labelsize=6)
138
+ ax.legend(loc="lower left", fontsize=6)
139
+ ax.set_title("Domain Classifier, Drift Detection", fontsize=8)
140
+ ax.set_ylabel("ROC AUC", fontsize=7)
141
+ ax.set_xlabel("Chunk Index", fontsize=7)
142
+ ax.set_ylim((0.0, 1.1))
143
+ return fig
@@ -2,15 +2,12 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
- import contextlib
6
5
  from dataclasses import dataclass
7
6
  from typing import Generic, TypeVar, Union
8
7
 
8
+ import pandas as pd
9
9
  from typing_extensions import TypeAlias
10
10
 
11
- with contextlib.suppress(ImportError):
12
- import pandas as pd
13
-
14
11
  from dataeval.outputs._base import Output
15
12
  from dataeval.outputs._stats import DimensionStatsOutput, LabelStatsOutput, PixelStatsOutput, VisualStatsOutput
16
13
 
@@ -46,10 +43,12 @@ class DuplicatesOutput(Output, Generic[TIndexCollection]):
46
43
  near: list[TIndexCollection]
47
44
 
48
45
 
49
- def _reorganize_by_class_and_metric(result: IndexIssueMap, lstats: LabelStatsOutput):
46
+ def _reorganize_by_class_and_metric(
47
+ result: IndexIssueMap, lstats: LabelStatsOutput
48
+ ) -> tuple[dict[str, list[int]], dict[str, dict[str, int]]]:
50
49
  """Flip result from grouping by image to grouping by class and metric"""
51
- metrics = {}
52
- class_wise = {label: {} for label in lstats.class_names}
50
+ metrics: dict[str, list[int]] = {}
51
+ class_wise: dict[str, dict[str, int]] = {label: {} for label in lstats.class_names}
53
52
 
54
53
  # Group metrics and calculate class-wise counts
55
54
  for img, group in result.items():
@@ -62,7 +61,7 @@ def _reorganize_by_class_and_metric(result: IndexIssueMap, lstats: LabelStatsOut
62
61
  return metrics, class_wise
63
62
 
64
63
 
65
- def _create_table(metrics, class_wise):
64
+ def _create_table(metrics: dict[str, list[int]], class_wise: dict[str, dict[str, int]]) -> list[str]:
66
65
  """Create table for displaying the results"""
67
66
  max_class_length = max(len(str(label)) for label in class_wise) + 2
68
67
  max_total = max(len(metrics[group]) for group in metrics) + 2
@@ -72,7 +71,7 @@ def _create_table(metrics, class_wise):
72
71
  + [f"{group:^{max(5, len(str(group))) + 2}}" for group in sorted(metrics.keys())]
73
72
  + [f"{'Total':<{max_total}}"]
74
73
  )
75
- table_rows = []
74
+ table_rows: list[str] = []
76
75
 
77
76
  for class_cat, results in class_wise.items():
78
77
  table_value = [f"{class_cat:>{max_class_length}}"]
@@ -84,15 +83,14 @@ def _create_table(metrics, class_wise):
84
83
  table_value.append(f"{total:^{max_total}}")
85
84
  table_rows.append(" | ".join(table_value))
86
85
 
87
- table = [table_header] + table_rows
88
- return table
86
+ return [table_header] + table_rows
89
87
 
90
88
 
91
- def _create_pandas_dataframe(class_wise):
89
+ def _create_pandas_dataframe(class_wise: dict[str, dict[str, int]]) -> list[dict[str, str | int]]:
92
90
  """Create data for pandas dataframe"""
93
91
  data = []
94
92
  for label, metrics_dict in class_wise.items():
95
- row = {"Class": label}
93
+ row: dict[str, str | int] = {"Class": label}
96
94
  total = sum(metrics_dict.values())
97
95
  row.update(metrics_dict) # Add metric counts
98
96
  row["Total"] = total
@@ -121,8 +119,7 @@ class OutliersOutput(Output, Generic[TIndexIssueMap]):
121
119
  def __len__(self) -> int:
122
120
  if isinstance(self.issues, dict):
123
121
  return len(self.issues)
124
- else:
125
- return sum(len(d) for d in self.issues)
122
+ return sum(len(d) for d in self.issues)
126
123
 
127
124
  def to_table(self, labelstats: LabelStatsOutput) -> str:
128
125
  """
@@ -168,8 +165,6 @@ class OutliersOutput(Output, Generic[TIndexIssueMap]):
168
165
  -----
169
166
  This method requires `pandas <https://pandas.pydata.org/>`_ to be installed.
170
167
  """
171
- import pandas as pd
172
-
173
168
  if isinstance(self.issues, dict):
174
169
  _, classwise = _reorganize_by_class_and_metric(self.issues, labelstats)
175
170
  data = _create_pandas_dataframe(classwise)
@@ -2,24 +2,27 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
- import contextlib
6
5
  from dataclasses import dataclass
7
- from typing import Any, Iterable, NamedTuple, Optional, Union
6
+ from typing import TYPE_CHECKING, Any, Iterable, NamedTuple, Optional, Sequence, Union
8
7
 
9
8
  import numpy as np
9
+ import pandas as pd
10
10
  from numpy.typing import NDArray
11
11
  from typing_extensions import TypeAlias
12
12
 
13
- with contextlib.suppress(ImportError):
14
- import pandas as pd
15
-
16
13
  from dataeval.outputs._base import Output
17
14
  from dataeval.utils._plot import channel_histogram_plot, histogram_plot
18
15
 
16
+ if TYPE_CHECKING:
17
+ from matplotlib.figure import Figure
18
+
19
19
  OptionalRange: TypeAlias = Optional[Union[int, Iterable[int]]]
20
20
 
21
21
  SOURCE_INDEX = "source_index"
22
- BOX_COUNT = "box_count"
22
+ OBJECT_COUNT = "object_count"
23
+ IMAGE_COUNT = "image_count"
24
+
25
+ BASE_ATTRS = (SOURCE_INDEX, OBJECT_COUNT, IMAGE_COUNT)
23
26
 
24
27
 
25
28
  class SourceIndex(NamedTuple):
@@ -54,17 +57,24 @@ class BaseStatsOutput(Output):
54
57
  ----------
55
58
  source_index : List[SourceIndex]
56
59
  Mapping from statistic to source image, box and channel index
57
- box_count : NDArray[np.uint16]
60
+ object_count : NDArray[np.uint16]
61
+ The number of detected objects in each image
58
62
  """
59
63
 
60
64
  source_index: list[SourceIndex]
61
- box_count: NDArray[np.uint16]
65
+ object_count: NDArray[np.uint16]
66
+ image_count: int
62
67
 
63
68
  def __post_init__(self) -> None:
64
- length = len(self.source_index)
65
- bad = {k: len(v) for k, v in self.data().items() if k not in [SOURCE_INDEX, BOX_COUNT] and len(v) != length}
66
- if bad:
67
- raise ValueError(f"All values must have the same length as source_index. Bad values: {str(bad)}.")
69
+ si_length = len(self.source_index)
70
+ mismatch = {k: len(v) for k, v in self.data().items() if k not in BASE_ATTRS and len(v) != si_length}
71
+ if mismatch:
72
+ raise ValueError(f"All values must have the same length as source_index. Bad values: {str(mismatch)}.")
73
+ oc_length = len(self.object_count)
74
+ if oc_length != self.image_count:
75
+ raise ValueError(
76
+ f"Total object counts per image does not match image count. {oc_length} != {self.image_count}."
77
+ )
68
78
 
69
79
  def get_channel_mask(
70
80
  self,
@@ -126,21 +136,64 @@ class BaseStatsOutput(Output):
126
136
 
127
137
  return max_channels, ch_mask
128
138
 
129
- def factors(self) -> dict[str, NDArray[Any]]:
139
+ def factors(
140
+ self,
141
+ filter: str | Sequence[str] | None = None, # noqa: A002
142
+ exclude_constant: bool = False,
143
+ ) -> dict[str, NDArray[Any]]:
144
+ """
145
+ Returns all 1-dimensional data as a dictionary of numpy arrays.
146
+
147
+ Parameters
148
+ ----------
149
+ filter : str, Sequence[str] or None, default None:
150
+ If provided, only returns keys that match the filter.
151
+ exclude_constant : bool, default False
152
+ If True, exclude arrays that contain only a single unique value.
153
+
154
+ Returns
155
+ -------
156
+ dict[str, NDArray[Any]]
157
+ """
158
+ filter_ = [filter] if isinstance(filter, str) else filter
130
159
  return {
131
160
  k: v
132
161
  for k, v in self.data().items()
133
- if k not in (SOURCE_INDEX, BOX_COUNT) and isinstance(v, np.ndarray) and v[v != 0].size > 0 and v.ndim == 1
162
+ if k not in BASE_ATTRS
163
+ and (filter_ is None or k in filter_)
164
+ and isinstance(v, np.ndarray)
165
+ and v.ndim == 1
166
+ and (not exclude_constant or len(np.unique(v)) > 1)
134
167
  }
135
168
 
136
169
  def plot(
137
170
  self, log: bool, channel_limit: int | None = None, channel_index: int | Iterable[int] | None = None
138
- ) -> None:
171
+ ) -> Figure:
172
+ """
173
+ Plots the statistics as a set of histograms.
174
+
175
+ Parameters
176
+ ----------
177
+ log : bool
178
+ If True, plots the histograms on a logarithmic scale.
179
+ channel_limit : int or None
180
+ The maximum number of channels to plot. If None, all channels are plotted.
181
+ channel_index : int, Iterable[int] or None
182
+ The index or indices of the channels to plot. If None, all channels are plotted.
183
+
184
+ Returns
185
+ -------
186
+ matplotlib.Figure
187
+ """
188
+ from matplotlib.figure import Figure
189
+
139
190
  max_channels, ch_mask = self._get_channels(channel_limit, channel_index)
191
+ factors = self.factors(exclude_constant=True)
192
+ if not factors:
193
+ return Figure()
140
194
  if max_channels == 1:
141
- histogram_plot(self.factors(), log)
142
- else:
143
- channel_histogram_plot(self.factors(), log, max_channels, ch_mask)
195
+ return histogram_plot(factors, log)
196
+ return channel_histogram_plot(factors, log, max_channels, ch_mask)
144
197
 
145
198
 
146
199
  @dataclass(frozen=True)
@@ -150,9 +203,9 @@ class DimensionStatsOutput(BaseStatsOutput):
150
203
 
151
204
  Attributes
152
205
  ----------
153
- left : NDArray[np.int32]
206
+ offset_x : NDArray[np.int32]
154
207
  Offsets from the left edge of images in pixels
155
- top : NDArray[np.int32]
208
+ offset_y : NDArray[np.int32]
156
209
  Offsets from the top edge of images in pixels
157
210
  width : NDArray[np.uint32]
158
211
  Width of the images in pixels
@@ -163,25 +216,28 @@ class DimensionStatsOutput(BaseStatsOutput):
163
216
  size : NDArray[np.uint32]
164
217
  Size of the images in pixels
165
218
  aspect_ratio : NDArray[np.float16]
166
- :term:`ASspect Ratio<Aspect Ratio>` of the images (width/height)
219
+ :term:`Aspect Ratio<Aspect Ratio>` of the images (width/height)
167
220
  depth : NDArray[np.uint8]
168
221
  Color depth of the images in bits
169
- center : NDArray[np.uint16]
222
+ center : NDArray[np.uint32]
170
223
  Offset from center in [x,y] coordinates of the images in pixels
171
- distance : NDArray[np.float16]
224
+ distance_center : NDArray[np.float32]
172
225
  Distance in pixels from center
226
+ distance_edge : NDArray[np.uint32]
227
+ Distance in pixels from nearest edge
173
228
  """
174
229
 
175
- left: NDArray[np.int32]
176
- top: NDArray[np.int32]
230
+ offset_x: NDArray[np.int32]
231
+ offset_y: NDArray[np.int32]
177
232
  width: NDArray[np.uint32]
178
233
  height: NDArray[np.uint32]
179
234
  channels: NDArray[np.uint8]
180
235
  size: NDArray[np.uint32]
181
236
  aspect_ratio: NDArray[np.float16]
182
237
  depth: NDArray[np.uint8]
183
- center: NDArray[np.int16]
184
- distance: NDArray[np.float16]
238
+ center: NDArray[np.int32]
239
+ distance_center: NDArray[np.float32]
240
+ distance_edge: NDArray[np.uint32]
185
241
 
186
242
 
187
243
  @dataclass(frozen=True)
@@ -281,8 +337,6 @@ class LabelStatsOutput(Output):
281
337
  -------
282
338
  pd.DataFrame
283
339
  """
284
- import pandas as pd
285
-
286
340
  total_count = []
287
341
  image_count = []
288
342
  for cls in range(len(self.class_names)):
@@ -154,10 +154,10 @@ def calc_params(p_i: NDArray[Any], n_i: NDArray[Any], niter: int) -> NDArray[Any
154
154
  Array of parameters to recreate line of best fit
155
155
  """
156
156
 
157
- def is_valid(f_new, x_new, f_old, x_old):
157
+ def is_valid(f_new, x_new, f_old, x_old) -> bool: # noqa: ANN001
158
158
  return f_new != np.nan
159
159
 
160
- def f(x):
160
+ def f(x) -> float: # noqa: ANN001
161
161
  try:
162
162
  return np.sum(np.square(p_i - f_out(n_i, x)))
163
163
  except RuntimeWarning:
dataeval/utils/_array.py CHANGED
@@ -23,7 +23,7 @@ T = TypeVar("T", ArrayLike, np.ndarray, torch.Tensor)
23
23
  _np_dtype = TypeVar("_np_dtype", bound=np.generic)
24
24
 
25
25
 
26
- def _try_import(module_name) -> ModuleType | None:
26
+ def _try_import(module_name: str) -> ModuleType | None:
27
27
  if module_name in _MODULE_CACHE:
28
28
  return _MODULE_CACHE[module_name]
29
29
 
@@ -148,8 +148,7 @@ def ensure_embeddings(
148
148
 
149
149
  if dtype is None:
150
150
  return embeddings
151
- else:
152
- return arr
151
+ return arr
153
152
 
154
153
 
155
154
  @overload
@@ -174,10 +173,9 @@ def flatten(array: ArrayLike) -> NDArray[Any] | torch.Tensor:
174
173
  if isinstance(array, np.ndarray):
175
174
  nparr = as_numpy(array)
176
175
  return nparr.reshape((nparr.shape[0], -1))
177
- elif isinstance(array, torch.Tensor):
176
+ if isinstance(array, torch.Tensor):
178
177
  return torch.flatten(array, start_dim=1)
179
- else:
180
- raise TypeError(f"Unsupported array type {type(array)}.")
178
+ raise TypeError(f"Unsupported array type {type(array)}.")
181
179
 
182
180
 
183
181
  _TArray = TypeVar("_TArray", bound=Array)
@@ -199,7 +197,6 @@ def channels_first_to_last(array: _TArray) -> _TArray:
199
197
  """
200
198
  if isinstance(array, np.ndarray):
201
199
  return np.transpose(array, (1, 2, 0))
202
- elif isinstance(array, torch.Tensor):
200
+ if isinstance(array, torch.Tensor):
203
201
  return torch.permute(array, (1, 2, 0))
204
- else:
205
- raise TypeError(f"Unsupported array type {type(array)}.")
202
+ raise TypeError(f"Unsupported array type {type(array)}.")
dataeval/utils/_bin.py CHANGED
@@ -195,5 +195,4 @@ def bin_by_clusters(data: NDArray[np.number[Any]]) -> NDArray[np.float64]:
195
195
  if extend_bins:
196
196
  bin_edges = np.concatenate([bin_edges, extend_bins])
197
197
 
198
- bin_edges = np.sort(bin_edges)
199
- return bin_edges
198
+ return np.sort(bin_edges)
@@ -4,6 +4,7 @@ __all__ = []
4
4
 
5
5
  import warnings
6
6
  from dataclasses import dataclass
7
+ from typing import Any
7
8
 
8
9
  import numba
9
10
  import numpy as np
@@ -30,7 +31,9 @@ from dataeval.utils._fast_mst import calculate_neighbor_distances, minimum_spann
30
31
 
31
32
 
32
33
  @numba.njit(parallel=True, locals={"i": numba.types.int32})
33
- def compare_links_to_cluster_std(mst, clusters):
34
+ def compare_links_to_cluster_std(
35
+ mst: NDArray[np.float32], clusters: NDArray[np.intp]
36
+ ) -> tuple[NDArray[np.int32], NDArray[np.int32]]:
34
37
  cluster_ids = np.unique(clusters)
35
38
  cluster_grouping = np.full(mst.shape[0], -1, dtype=np.int16)
36
39
 
@@ -79,7 +82,7 @@ def cluster(data: ArrayLike) -> ClusterData:
79
82
  cluster_selection_epsilon = 0.0
80
83
  # cluster_selection_method = "eom"
81
84
 
82
- x = flatten(to_numpy(data))
85
+ x: NDArray[Any] = flatten(to_numpy(data))
83
86
  samples, features = x.shape # Due to flatten(), we know shape has a length of 2
84
87
  if samples < 2:
85
88
  raise ValueError(f"Data should have at least 2 samples; got {samples}")
@@ -125,9 +128,9 @@ def cluster(data: ArrayLike) -> ClusterData:
125
128
  return ClusterData(clusters, mst, linkage_tree, condensed_tree, membership_strengths, kneighbors, kdistances)
126
129
 
127
130
 
128
- def sorted_union_find(index_groups):
131
+ def sorted_union_find(index_groups: NDArray[np.int32]) -> list[list[np.int32]]:
129
132
  """Merges and sorts groups of indices that share any common index"""
130
- groups = [[np.int32(x) for x in range(0)] for y in range(0)]
133
+ groups: list[list[np.int32]] = [[np.int32(x) for x in range(0)] for y in range(0)]
131
134
  uniques, inverse = np.unique(index_groups, return_inverse=True)
132
135
  inverse = inverse.flatten()
133
136
  disjoint_set = ds_rank_create(uniques.size)
@@ -6,9 +6,11 @@
6
6
  __all__ = []
7
7
 
8
8
  import warnings
9
+ from typing import Any
9
10
 
10
11
  import numba
11
12
  import numpy as np
13
+ from numpy.typing import NDArray
12
14
  from sklearn.neighbors import NearestNeighbors
13
15
 
14
16
  with warnings.catch_warnings():
@@ -17,24 +19,26 @@ with warnings.catch_warnings():
17
19
 
18
20
 
19
21
  @numba.njit()
20
- def _ds_union_by_rank(disjoint_set, point, nbr):
22
+ def _ds_union_by_rank(disjoint_set: tuple[NDArray[np.int32], NDArray[np.int32]], point: int, nbr: int) -> int:
21
23
  y = ds_find(disjoint_set, point)
22
24
  x = ds_find(disjoint_set, nbr)
23
25
 
24
26
  if x == y:
25
27
  return 0
26
28
 
27
- if disjoint_set.rank[x] < disjoint_set.rank[y]:
29
+ if disjoint_set[1][x] < disjoint_set[1][y]:
28
30
  x, y = y, x
29
31
 
30
- disjoint_set.parent[y] = x
31
- if disjoint_set.rank[x] == disjoint_set.rank[y]:
32
- disjoint_set.rank[x] += 1
32
+ disjoint_set[0][y] = x
33
+ if disjoint_set[1][x] == disjoint_set[1][y]:
34
+ disjoint_set[1][x] += 1
33
35
  return 1
34
36
 
35
37
 
36
38
  @numba.njit(locals={"i": numba.types.uint32, "nbr": numba.types.uint32, "dist": numba.types.float32})
37
- def _init_tree(n_neighbors, n_distance):
39
+ def _init_tree(
40
+ n_neighbors: NDArray[np.intp], n_distance: NDArray[np.float32]
41
+ ) -> tuple[NDArray[np.float32], int, tuple[NDArray[np.int32], NDArray[np.int32]], NDArray[np.uint32]]:
38
42
  # Initial graph to hold tree connections
39
43
  tree = np.zeros((n_neighbors.size - 1, 3), dtype=np.float32)
40
44
  disjoint_set = ds_rank_create(n_neighbors.size)
@@ -56,7 +60,13 @@ def _init_tree(n_neighbors, n_distance):
56
60
 
57
61
 
58
62
  @numba.njit(locals={"i": numba.types.uint32, "nbr": numba.types.uint32})
59
- def _update_tree_by_distance(tree, int_tree, disjoint_set, n_neighbors, n_distance):
63
+ def _update_tree_by_distance(
64
+ tree: NDArray[np.float32],
65
+ int_tree: int,
66
+ disjoint_set: tuple[NDArray[np.int32], NDArray[np.int32]],
67
+ n_neighbors: NDArray[np.uint32],
68
+ n_distance: NDArray[np.float32],
69
+ ) -> tuple[NDArray[np.float32], int, tuple[NDArray[np.int32], NDArray[np.int32]], NDArray[np.uint32]]:
60
70
  cluster_points = np.empty(n_neighbors.size, dtype=np.uint32)
61
71
  sort_dist = np.argsort(n_distance)
62
72
  dist_sorted = n_distance[sort_dist]
@@ -80,9 +90,9 @@ def _update_tree_by_distance(tree, int_tree, disjoint_set, n_neighbors, n_distan
80
90
 
81
91
 
82
92
  @numba.njit(locals={"i": numba.types.uint32})
83
- def _cluster_edges(tracker, last_idx, cluster_distances):
93
+ def _cluster_edges(tracker: NDArray[Any], last_idx: int, cluster_distances: NDArray[Any]) -> list[NDArray[np.intp]]:
84
94
  cluster_ids = np.unique(tracker)
85
- edge_points = []
95
+ edge_points: list[NDArray[np.intp]] = []
86
96
  for idx in range(cluster_ids.size):
87
97
  cluster_points = np.nonzero(tracker == cluster_ids[idx])[0]
88
98
  cluster_size = cluster_points.size
@@ -102,14 +112,16 @@ def _cluster_edges(tracker, last_idx, cluster_distances):
102
112
  return edge_points
103
113
 
104
114
 
105
- def _compute_nn(dataA, dataB, k):
115
+ def _compute_nn(dataA: NDArray[Any], dataB: NDArray[Any], k: int) -> tuple[NDArray[np.int32], NDArray[np.float32]]:
106
116
  distances, neighbors = NearestNeighbors(n_neighbors=k + 1, algorithm="brute").fit(dataA).kneighbors(dataB)
107
117
  neighbors = np.array(neighbors[:, 1 : k + 1], dtype=np.int32)
108
118
  distances = np.array(distances[:, 1 : k + 1], dtype=np.float32)
109
119
  return neighbors, distances
110
120
 
111
121
 
112
- def _calculate_cluster_neighbors(data, groups, point_array):
122
+ def _calculate_cluster_neighbors(
123
+ data: NDArray[Any], groups: list[NDArray[np.intp]], point_array: NDArray[Any]
124
+ ) -> tuple[NDArray[np.uint32], NDArray[np.float32]]:
113
125
  """Rerun nearest neighbor based on clusters"""
114
126
  cluster_neighbors = np.zeros(point_array.size, dtype=np.uint32)
115
127
  cluster_nbr_distances = np.full(point_array.size, np.inf, dtype=np.float32)
@@ -126,7 +138,9 @@ def _calculate_cluster_neighbors(data, groups, point_array):
126
138
  return cluster_neighbors, cluster_nbr_distances
127
139
 
128
140
 
129
- def minimum_spanning_tree(data, neighbors, distances):
141
+ def minimum_spanning_tree(
142
+ data: NDArray[Any], neighbors: NDArray[np.int32], distances: NDArray[np.float32]
143
+ ) -> NDArray[np.float32]:
130
144
  # Transpose arrays to get number of samples along a row
131
145
  k_neighbors = neighbors.T.astype(np.uint32).copy()
132
146
  k_distances = distances.T.astype(np.float32).copy()
@@ -168,7 +182,7 @@ def minimum_spanning_tree(data, neighbors, distances):
168
182
  return tree
169
183
 
170
184
 
171
- def calculate_neighbor_distances(data: np.ndarray, k: int = 10):
185
+ def calculate_neighbor_distances(data: np.ndarray, k: int = 10) -> tuple[NDArray[np.int32], NDArray[np.float32]]:
172
186
  # Have the potential to add in other distance calculations - supported calculations:
173
187
  # https://github.com/lmcinnes/pynndescent/blob/master/pynndescent/pynndescent_.py#L524
174
188
  try:
dataeval/utils/_image.py CHANGED
@@ -12,6 +12,9 @@ from scipy.signal import convolve2d
12
12
  EDGE_KERNEL = np.array([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]], dtype=np.int8)
13
13
  BIT_DEPTH = (1, 8, 12, 16, 32)
14
14
 
15
+ Box = tuple[int, int, int, int]
16
+ """Bounding box as tuple of integers in x0, y0, x1, y1 format."""
17
+
15
18
 
16
19
  @dataclass
17
20
  class BitDepth:
@@ -25,12 +28,11 @@ def get_bitdepth(image: NDArray[Any]) -> BitDepth:
25
28
  Approximates the bit depth of the image using the
26
29
  min and max pixel values.
27
30
  """
28
- pmin, pmax = np.min(image), np.max(image)
31
+ pmin, pmax = np.nanmin(image), np.nanmax(image)
29
32
  if pmin < 0:
30
33
  return BitDepth(0, pmin, pmax)
31
- else:
32
- depth = ([x for x in BIT_DEPTH if 2**x > pmax] or [max(BIT_DEPTH)])[0]
33
- return BitDepth(depth, 0, 2**depth - 1)
34
+ depth = ([x for x in BIT_DEPTH if 2**x > pmax] or [max(BIT_DEPTH)])[0]
35
+ return BitDepth(depth, 0, 2**depth - 1)
34
36
 
35
37
 
36
38
  def rescale(image: NDArray[Any], depth: int = 1) -> NDArray[Any]:
@@ -40,9 +42,8 @@ def rescale(image: NDArray[Any], depth: int = 1) -> NDArray[Any]:
40
42
  bitdepth = get_bitdepth(image)
41
43
  if bitdepth.depth == depth:
42
44
  return image
43
- else:
44
- normalized = (image + bitdepth.pmin) / (bitdepth.pmax - bitdepth.pmin)
45
- return normalized * (2**depth - 1)
45
+ normalized = (image + bitdepth.pmin) / (bitdepth.pmax - bitdepth.pmin)
46
+ return normalized * (2**depth - 1)
46
47
 
47
48
 
48
49
  def normalize_image_shape(image: NDArray[Any]) -> NDArray[Any]:
@@ -52,13 +53,12 @@ def normalize_image_shape(image: NDArray[Any]) -> NDArray[Any]:
52
53
  ndim = image.ndim
53
54
  if ndim == 2:
54
55
  return np.expand_dims(image, axis=0)
55
- elif ndim == 3:
56
+ if ndim == 3:
56
57
  return image
57
- elif ndim > 3:
58
+ if ndim > 3:
58
59
  # Slice all but the last 3 dimensions
59
60
  return image[(0,) * (ndim - 3)]
60
- else:
61
- raise ValueError("Images must have 2 or more dimensions.")
61
+ raise ValueError("Images must have 2 or more dimensions.")
62
62
 
63
63
 
64
64
  def edge_filter(image: NDArray[Any], offset: float = 0.5) -> NDArray[np.uint8]:
@@ -71,3 +71,57 @@ def edge_filter(image: NDArray[Any], offset: float = 0.5) -> NDArray[np.uint8]:
71
71
  edges = convolve2d(image, EDGE_KERNEL, mode="same", boundary="symm") + offset
72
72
  np.clip(edges, 0, 255, edges)
73
73
  return edges
74
+
75
+
76
+ def clip_box(image: NDArray[Any], box: Box) -> Box:
77
+ """
78
+ Clip the box to inside the provided image dimensions.
79
+ """
80
+ x0, y0, x1, y1 = box
81
+ h, w = image.shape[-2:]
82
+
83
+ return max(0, x0), max(0, y0), min(w, x1), min(h, y1)
84
+
85
+
86
+ def is_valid_box(box: Box) -> bool:
87
+ """
88
+ Check if the box dimensions provided are a valid image.
89
+ """
90
+ return box[2] > box[0] and box[3] > box[1]
91
+
92
+
93
+ def clip_and_pad(image: NDArray[Any], box: Box) -> NDArray[Any]:
94
+ """
95
+ Extract a region from an image based on a bounding box, clipping to image boundaries
96
+ and padding out-of-bounds areas with np.nan.
97
+
98
+ Parameters:
99
+ -----------
100
+ image : NDArray[Any]
101
+ Input image array in format C, H, W (channels first)
102
+ box : Box
103
+ Bounding box coordinates as (x0, y0, x1, y1) where (x0, y0) is top-left and (x1, y1) is bottom-right
104
+
105
+ Returns:
106
+ --------
107
+ NDArray[Any]
108
+ The extracted region with out-of-bounds areas padded with np.nan
109
+ """
110
+
111
+ # Create output array filled with NaN with a minimum size of 1x1
112
+ bw, bh = max(1, box[2] - box[0]), max(1, box[3] - box[1])
113
+
114
+ output = np.full((image.shape[-3] if image.ndim > 2 else 1, bh, bw), np.nan)
115
+
116
+ # Calculate source box
117
+ sbox = clip_box(image, box)
118
+
119
+ # Calculate destination box
120
+ x0, y0 = sbox[0] - box[0], sbox[1] - box[1]
121
+ x1, y1 = x0 + (sbox[2] - sbox[0]), y0 + (sbox[3] - sbox[1])
122
+
123
+ # Copy the source if valid from the image to the output
124
+ if is_valid_box(sbox):
125
+ output[:, y0:y1, x0:x1] = image[:, sbox[1] : sbox[3], sbox[0] : sbox[2]]
126
+
127
+ return output