dataeval 0.86.0__py3-none-any.whl → 0.86.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. dataeval/__init__.py +1 -1
  2. dataeval/_log.py +1 -1
  3. dataeval/config.py +21 -4
  4. dataeval/data/_embeddings.py +2 -2
  5. dataeval/data/_images.py +2 -3
  6. dataeval/data/_metadata.py +188 -178
  7. dataeval/data/_selection.py +1 -2
  8. dataeval/data/_split.py +4 -5
  9. dataeval/data/_targets.py +17 -13
  10. dataeval/data/selections/_classfilter.py +2 -5
  11. dataeval/data/selections/_prioritize.py +6 -9
  12. dataeval/data/selections/_shuffle.py +3 -1
  13. dataeval/detectors/drift/_base.py +4 -5
  14. dataeval/detectors/drift/_mmd.py +3 -6
  15. dataeval/detectors/drift/_nml/_base.py +4 -2
  16. dataeval/detectors/drift/_nml/_chunk.py +11 -19
  17. dataeval/detectors/drift/_nml/_domainclassifier.py +8 -19
  18. dataeval/detectors/drift/_nml/_result.py +8 -9
  19. dataeval/detectors/drift/_nml/_thresholds.py +66 -77
  20. dataeval/detectors/linters/outliers.py +7 -7
  21. dataeval/metadata/_distance.py +10 -7
  22. dataeval/metadata/_ood.py +11 -103
  23. dataeval/metrics/bias/_balance.py +23 -33
  24. dataeval/metrics/bias/_diversity.py +16 -14
  25. dataeval/metrics/bias/_parity.py +18 -18
  26. dataeval/metrics/estimators/_divergence.py +2 -4
  27. dataeval/metrics/stats/_base.py +103 -42
  28. dataeval/metrics/stats/_boxratiostats.py +21 -19
  29. dataeval/metrics/stats/_dimensionstats.py +14 -10
  30. dataeval/metrics/stats/_hashstats.py +1 -1
  31. dataeval/metrics/stats/_pixelstats.py +6 -6
  32. dataeval/metrics/stats/_visualstats.py +3 -3
  33. dataeval/outputs/_base.py +22 -7
  34. dataeval/outputs/_bias.py +24 -70
  35. dataeval/outputs/_drift.py +1 -9
  36. dataeval/outputs/_linters.py +11 -11
  37. dataeval/outputs/_stats.py +82 -23
  38. dataeval/outputs/_workflows.py +2 -2
  39. dataeval/utils/_array.py +6 -9
  40. dataeval/utils/_bin.py +1 -2
  41. dataeval/utils/_clusterer.py +7 -4
  42. dataeval/utils/_fast_mst.py +27 -13
  43. dataeval/utils/_image.py +65 -11
  44. dataeval/utils/_mst.py +1 -3
  45. dataeval/utils/_plot.py +15 -10
  46. dataeval/utils/data/_dataset.py +54 -28
  47. dataeval/utils/data/metadata.py +104 -82
  48. dataeval/utils/datasets/__init__.py +2 -0
  49. dataeval/utils/datasets/_antiuav.py +189 -0
  50. dataeval/utils/datasets/_base.py +11 -8
  51. dataeval/utils/datasets/_cifar10.py +104 -45
  52. dataeval/utils/datasets/_fileio.py +21 -47
  53. dataeval/utils/datasets/_milco.py +22 -12
  54. dataeval/utils/datasets/_mixin.py +2 -4
  55. dataeval/utils/datasets/_mnist.py +3 -4
  56. dataeval/utils/datasets/_ships.py +14 -7
  57. dataeval/utils/datasets/_voc.py +229 -42
  58. dataeval/utils/torch/models.py +5 -10
  59. dataeval/utils/torch/trainer.py +3 -3
  60. dataeval/workflows/sufficiency.py +2 -2
  61. {dataeval-0.86.0.dist-info → dataeval-0.86.2.dist-info}/METADATA +2 -1
  62. dataeval-0.86.2.dist-info/RECORD +114 -0
  63. dataeval/detectors/ood/vae.py +0 -74
  64. dataeval-0.86.0.dist-info/RECORD +0 -114
  65. {dataeval-0.86.0.dist-info → dataeval-0.86.2.dist-info}/LICENSE.txt +0 -0
  66. {dataeval-0.86.0.dist-info → dataeval-0.86.2.dist-info}/WHEEL +0 -0
dataeval/outputs/_bias.py CHANGED
@@ -4,7 +4,7 @@ __all__ = []
4
4
 
5
5
  import contextlib
6
6
  from dataclasses import asdict, dataclass
7
- from typing import Any, Literal, TypeVar, overload
7
+ from typing import Any, TypeVar
8
8
 
9
9
  import numpy as np
10
10
  import pandas as pd
@@ -128,33 +128,30 @@ class CoverageOutput(Output):
128
128
 
129
129
  import matplotlib.pyplot as plt
130
130
 
131
+ images = Images(images) if isinstance(images, Dataset) else images
132
+ if np.max(self.uncovered_indices) > len(images):
133
+ raise ValueError(
134
+ f"Uncovered indices {self.uncovered_indices} specify images "
135
+ f"unavailable in the provided number of images {len(images)}."
136
+ )
137
+
131
138
  # Determine which images to plot
132
139
  selected_indices = self.uncovered_indices[:top_k]
133
140
 
134
- images = Images(images) if isinstance(images, Dataset) else images
135
-
136
141
  # Plot the images
137
142
  num_images = min(top_k, len(selected_indices))
138
143
 
139
144
  rows = int(np.ceil(num_images / 3))
140
- fig, axs = plt.subplots(rows, 3, figsize=(9, 3 * rows))
141
-
142
- if rows == 1:
143
- for j in range(3):
144
- if j >= len(selected_indices):
145
- continue
146
- image = channels_first_to_last(as_numpy(images[selected_indices[j]]))
147
- axs[j].imshow(image)
148
- axs[j].axis("off")
149
- else:
150
- for i in range(rows):
151
- for j in range(3):
152
- i_j = i * 3 + j
153
- if i_j >= len(selected_indices):
154
- continue
155
- image = channels_first_to_last(as_numpy(images[selected_indices[i_j]]))
156
- axs[i, j].imshow(image)
157
- axs[i, j].axis("off")
145
+ cols = min(3, num_images)
146
+ fig, axs = plt.subplots(rows, cols, figsize=(3 * cols, 3 * rows))
147
+
148
+ for image, ax in zip(images[:num_images], axs.flat):
149
+ image = channels_first_to_last(as_numpy(image))
150
+ ax.imshow(image)
151
+ ax.axis("off")
152
+
153
+ for ax in axs.flat[num_images:]:
154
+ ax.axis("off")
158
155
 
159
156
  fig.tight_layout()
160
157
  return fig
@@ -202,52 +199,11 @@ class BalanceOutput(Output):
202
199
  factor_names: list[str]
203
200
  class_names: list[str]
204
201
 
205
- @overload
206
- def _by_factor_type(
207
- self,
208
- attr: Literal["factor_names"],
209
- factor_type: Literal["discrete", "continuous", "both"],
210
- ) -> list[str]: ...
211
-
212
- @overload
213
- def _by_factor_type(
214
- self,
215
- attr: Literal["balance", "factors", "classwise"],
216
- factor_type: Literal["discrete", "continuous", "both"],
217
- ) -> NDArray[np.float64]: ...
218
-
219
- def _by_factor_type(
220
- self,
221
- attr: Literal["balance", "factors", "classwise", "factor_names"],
222
- factor_type: Literal["discrete", "continuous", "both"],
223
- ) -> NDArray[np.float64] | list[str]:
224
- # if not filtering by factor_type then just return the requested attribute without mask
225
- if factor_type == "both":
226
- return getattr(self, attr)
227
-
228
- # create the mask for the selected factor_type
229
- mask_lambda = (
230
- (lambda x: "-continuous" not in x) if factor_type == "discrete" else (lambda x: "-discrete" not in x)
231
- )
232
-
233
- # return the masked attribute
234
- if attr == "factor_names":
235
- return [x.replace(f"-{factor_type}", "") for x in self.factor_names if mask_lambda(x)]
236
- else:
237
- factor_type_mask = np.asarray([mask_lambda(x) for x in self.factor_names])
238
- if attr == "factors":
239
- return self.factors[factor_type_mask[1:]][:, factor_type_mask[1:]]
240
- elif attr == "balance":
241
- return self.balance[factor_type_mask]
242
- elif attr == "classwise":
243
- return self.classwise[:, factor_type_mask]
244
-
245
202
  def plot(
246
203
  self,
247
204
  row_labels: list[Any] | NDArray[Any] | None = None,
248
205
  col_labels: list[Any] | NDArray[Any] | None = None,
249
206
  plot_classwise: bool = False,
250
- factor_type: Literal["discrete", "continuous", "both"] = "discrete",
251
207
  ) -> Figure:
252
208
  """
253
209
  Plot a heatmap of balance information.
@@ -260,8 +216,6 @@ class BalanceOutput(Output):
260
216
  List/Array containing the labels for columns in the histogram
261
217
  plot_classwise : bool, default False
262
218
  Whether to plot per-class balance instead of global balance
263
- factor_type : "discrete", "continuous", or "both", default "discrete"
264
- Whether to plot discretized values, continuous values, or to include both
265
219
 
266
220
  Returns
267
221
  -------
@@ -275,10 +229,10 @@ class BalanceOutput(Output):
275
229
  if row_labels is None:
276
230
  row_labels = self.class_names
277
231
  if col_labels is None:
278
- col_labels = self._by_factor_type("factor_names", factor_type)
232
+ col_labels = self.factor_names
279
233
 
280
234
  fig = heatmap(
281
- self._by_factor_type("classwise", factor_type),
235
+ self.classwise,
282
236
  row_labels,
283
237
  col_labels,
284
238
  xlabel="Factors",
@@ -289,8 +243,8 @@ class BalanceOutput(Output):
289
243
  # Combine balance and factors results
290
244
  data = np.concatenate(
291
245
  [
292
- self._by_factor_type("balance", factor_type)[np.newaxis, 1:],
293
- self._by_factor_type("factors", factor_type),
246
+ self.balance[np.newaxis, 1:],
247
+ self.factors,
294
248
  ],
295
249
  axis=0,
296
250
  )
@@ -299,7 +253,7 @@ class BalanceOutput(Output):
299
253
  # Finalize the data for the plot, last row is last factor x last factor so it gets dropped
300
254
  heat_data = np.where(mask, np.nan, data)[:-1]
301
255
  # Creating label array for heat map axes
302
- heat_labels = self._by_factor_type("factor_names", factor_type)
256
+ heat_labels = self.factor_names
303
257
 
304
258
  if row_labels is None:
305
259
  row_labels = heat_labels[:-1]
@@ -379,7 +333,7 @@ class DiversityOutput(Output):
379
333
  import matplotlib.pyplot as plt
380
334
 
381
335
  fig, ax = plt.subplots(figsize=(8, 8))
382
- heat_labels = np.concatenate((["class"], self.factor_names))
336
+ heat_labels = ["class_labels"] + self.factor_names
383
337
  ax.bar(heat_labels, self.diversity_index)
384
338
  ax.set_xlabel("Factors")
385
339
  plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
@@ -103,19 +103,13 @@ class DriftMVDCOutput(PerMetricResult):
103
103
  metric = Metric(display_name="Domain Classifier", column_name="domain_classifier_auroc")
104
104
  super().__init__(results_data, [metric])
105
105
 
106
- def plot(self, showme: bool = True) -> Figure:
106
+ def plot(self) -> Figure:
107
107
  """
108
108
  Render the roc_auc metric over the train/test data in relation to the threshold.
109
109
 
110
- Parameters
111
- ----------
112
- showme : bool, default True
113
- Option to display the figure.
114
-
115
110
  Returns
116
111
  -------
117
112
  matplotlib.figure.Figure
118
-
119
113
  """
120
114
  import matplotlib.pyplot as plt
121
115
 
@@ -146,6 +140,4 @@ class DriftMVDCOutput(PerMetricResult):
146
140
  ax.set_ylabel("ROC AUC", fontsize=7)
147
141
  ax.set_xlabel("Chunk Index", fontsize=7)
148
142
  ax.set_ylim((0.0, 1.1))
149
- if showme:
150
- plt.show()
151
143
  return fig
@@ -43,10 +43,12 @@ class DuplicatesOutput(Output, Generic[TIndexCollection]):
43
43
  near: list[TIndexCollection]
44
44
 
45
45
 
46
- def _reorganize_by_class_and_metric(result: IndexIssueMap, lstats: LabelStatsOutput):
46
+ def _reorganize_by_class_and_metric(
47
+ result: IndexIssueMap, lstats: LabelStatsOutput
48
+ ) -> tuple[dict[str, list[int]], dict[str, dict[str, int]]]:
47
49
  """Flip result from grouping by image to grouping by class and metric"""
48
- metrics = {}
49
- class_wise = {label: {} for label in lstats.class_names}
50
+ metrics: dict[str, list[int]] = {}
51
+ class_wise: dict[str, dict[str, int]] = {label: {} for label in lstats.class_names}
50
52
 
51
53
  # Group metrics and calculate class-wise counts
52
54
  for img, group in result.items():
@@ -59,7 +61,7 @@ def _reorganize_by_class_and_metric(result: IndexIssueMap, lstats: LabelStatsOut
59
61
  return metrics, class_wise
60
62
 
61
63
 
62
- def _create_table(metrics, class_wise):
64
+ def _create_table(metrics: dict[str, list[int]], class_wise: dict[str, dict[str, int]]) -> list[str]:
63
65
  """Create table for displaying the results"""
64
66
  max_class_length = max(len(str(label)) for label in class_wise) + 2
65
67
  max_total = max(len(metrics[group]) for group in metrics) + 2
@@ -69,7 +71,7 @@ def _create_table(metrics, class_wise):
69
71
  + [f"{group:^{max(5, len(str(group))) + 2}}" for group in sorted(metrics.keys())]
70
72
  + [f"{'Total':<{max_total}}"]
71
73
  )
72
- table_rows = []
74
+ table_rows: list[str] = []
73
75
 
74
76
  for class_cat, results in class_wise.items():
75
77
  table_value = [f"{class_cat:>{max_class_length}}"]
@@ -81,15 +83,14 @@ def _create_table(metrics, class_wise):
81
83
  table_value.append(f"{total:^{max_total}}")
82
84
  table_rows.append(" | ".join(table_value))
83
85
 
84
- table = [table_header] + table_rows
85
- return table
86
+ return [table_header] + table_rows
86
87
 
87
88
 
88
- def _create_pandas_dataframe(class_wise):
89
+ def _create_pandas_dataframe(class_wise: dict[str, dict[str, int]]) -> list[dict[str, str | int]]:
89
90
  """Create data for pandas dataframe"""
90
91
  data = []
91
92
  for label, metrics_dict in class_wise.items():
92
- row = {"Class": label}
93
+ row: dict[str, str | int] = {"Class": label}
93
94
  total = sum(metrics_dict.values())
94
95
  row.update(metrics_dict) # Add metric counts
95
96
  row["Total"] = total
@@ -118,8 +119,7 @@ class OutliersOutput(Output, Generic[TIndexIssueMap]):
118
119
  def __len__(self) -> int:
119
120
  if isinstance(self.issues, dict):
120
121
  return len(self.issues)
121
- else:
122
- return sum(len(d) for d in self.issues)
122
+ return sum(len(d) for d in self.issues)
123
123
 
124
124
  def to_table(self, labelstats: LabelStatsOutput) -> str:
125
125
  """
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
  __all__ = []
4
4
 
5
5
  from dataclasses import dataclass
6
- from typing import Any, Iterable, NamedTuple, Optional, Union
6
+ from typing import TYPE_CHECKING, Any, Iterable, NamedTuple, Optional, Sequence, Union
7
7
 
8
8
  import numpy as np
9
9
  import pandas as pd
@@ -13,10 +13,16 @@ from typing_extensions import TypeAlias
13
13
  from dataeval.outputs._base import Output
14
14
  from dataeval.utils._plot import channel_histogram_plot, histogram_plot
15
15
 
16
+ if TYPE_CHECKING:
17
+ from matplotlib.figure import Figure
18
+
16
19
  OptionalRange: TypeAlias = Optional[Union[int, Iterable[int]]]
17
20
 
18
21
  SOURCE_INDEX = "source_index"
19
- BOX_COUNT = "box_count"
22
+ OBJECT_COUNT = "object_count"
23
+ IMAGE_COUNT = "image_count"
24
+
25
+ BASE_ATTRS = (SOURCE_INDEX, OBJECT_COUNT, IMAGE_COUNT)
20
26
 
21
27
 
22
28
  class SourceIndex(NamedTuple):
@@ -51,17 +57,24 @@ class BaseStatsOutput(Output):
51
57
  ----------
52
58
  source_index : List[SourceIndex]
53
59
  Mapping from statistic to source image, box and channel index
54
- box_count : NDArray[np.uint16]
60
+ object_count : NDArray[np.uint16]
61
+ The number of detected objects in each image
55
62
  """
56
63
 
57
64
  source_index: list[SourceIndex]
58
- box_count: NDArray[np.uint16]
65
+ object_count: NDArray[np.uint16]
66
+ image_count: int
59
67
 
60
68
  def __post_init__(self) -> None:
61
- length = len(self.source_index)
62
- bad = {k: len(v) for k, v in self.data().items() if k not in [SOURCE_INDEX, BOX_COUNT] and len(v) != length}
63
- if bad:
64
- raise ValueError(f"All values must have the same length as source_index. Bad values: {str(bad)}.")
69
+ si_length = len(self.source_index)
70
+ mismatch = {k: len(v) for k, v in self.data().items() if k not in BASE_ATTRS and len(v) != si_length}
71
+ if mismatch:
72
+ raise ValueError(f"All values must have the same length as source_index. Bad values: {str(mismatch)}.")
73
+ oc_length = len(self.object_count)
74
+ if oc_length != self.image_count:
75
+ raise ValueError(
76
+ f"Total object counts per image does not match image count. {oc_length} != {self.image_count}."
77
+ )
65
78
 
66
79
  def get_channel_mask(
67
80
  self,
@@ -123,21 +136,64 @@ class BaseStatsOutput(Output):
123
136
 
124
137
  return max_channels, ch_mask
125
138
 
126
- def factors(self) -> dict[str, NDArray[Any]]:
139
+ def factors(
140
+ self,
141
+ filter: str | Sequence[str] | None = None, # noqa: A002
142
+ exclude_constant: bool = False,
143
+ ) -> dict[str, NDArray[Any]]:
144
+ """
145
+ Returns all 1-dimensional data as a dictionary of numpy arrays.
146
+
147
+ Parameters
148
+ ----------
149
+ filter : str, Sequence[str] or None, default None:
150
+ If provided, only returns keys that match the filter.
151
+ exclude_constant : bool, default False
152
+ If True, exclude arrays that contain only a single unique value.
153
+
154
+ Returns
155
+ -------
156
+ dict[str, NDArray[Any]]
157
+ """
158
+ filter_ = [filter] if isinstance(filter, str) else filter
127
159
  return {
128
160
  k: v
129
161
  for k, v in self.data().items()
130
- if k not in (SOURCE_INDEX, BOX_COUNT) and isinstance(v, np.ndarray) and v[v != 0].size > 0 and v.ndim == 1
162
+ if k not in BASE_ATTRS
163
+ and (filter_ is None or k in filter_)
164
+ and isinstance(v, np.ndarray)
165
+ and v.ndim == 1
166
+ and (not exclude_constant or len(np.unique(v)) > 1)
131
167
  }
132
168
 
133
169
  def plot(
134
170
  self, log: bool, channel_limit: int | None = None, channel_index: int | Iterable[int] | None = None
135
- ) -> None:
171
+ ) -> Figure:
172
+ """
173
+ Plots the statistics as a set of histograms.
174
+
175
+ Parameters
176
+ ----------
177
+ log : bool
178
+ If True, plots the histograms on a logarithmic scale.
179
+ channel_limit : int or None
180
+ The maximum number of channels to plot. If None, all channels are plotted.
181
+ channel_index : int, Iterable[int] or None
182
+ The index or indices of the channels to plot. If None, all channels are plotted.
183
+
184
+ Returns
185
+ -------
186
+ matplotlib.Figure
187
+ """
188
+ from matplotlib.figure import Figure
189
+
136
190
  max_channels, ch_mask = self._get_channels(channel_limit, channel_index)
191
+ factors = self.factors(exclude_constant=True)
192
+ if not factors:
193
+ return Figure()
137
194
  if max_channels == 1:
138
- histogram_plot(self.factors(), log)
139
- else:
140
- channel_histogram_plot(self.factors(), log, max_channels, ch_mask)
195
+ return histogram_plot(factors, log)
196
+ return channel_histogram_plot(factors, log, max_channels, ch_mask)
141
197
 
142
198
 
143
199
  @dataclass(frozen=True)
@@ -147,9 +203,9 @@ class DimensionStatsOutput(BaseStatsOutput):
147
203
 
148
204
  Attributes
149
205
  ----------
150
- left : NDArray[np.int32]
206
+ offset_x : NDArray[np.int32]
151
207
  Offsets from the left edge of images in pixels
152
- top : NDArray[np.int32]
208
+ offset_y : NDArray[np.int32]
153
209
  Offsets from the top edge of images in pixels
154
210
  width : NDArray[np.uint32]
155
211
  Width of the images in pixels
@@ -160,25 +216,28 @@ class DimensionStatsOutput(BaseStatsOutput):
160
216
  size : NDArray[np.uint32]
161
217
  Size of the images in pixels
162
218
  aspect_ratio : NDArray[np.float16]
163
- :term:`ASspect Ratio<Aspect Ratio>` of the images (width/height)
219
+ :term:`Aspect Ratio<Aspect Ratio>` of the images (width/height)
164
220
  depth : NDArray[np.uint8]
165
221
  Color depth of the images in bits
166
- center : NDArray[np.uint16]
222
+ center : NDArray[np.uint32]
167
223
  Offset from center in [x,y] coordinates of the images in pixels
168
- distance : NDArray[np.float16]
224
+ distance_center : NDArray[np.float32]
169
225
  Distance in pixels from center
226
+ distance_edge : NDArray[np.uint32]
227
+ Distance in pixels from nearest edge
170
228
  """
171
229
 
172
- left: NDArray[np.int32]
173
- top: NDArray[np.int32]
230
+ offset_x: NDArray[np.int32]
231
+ offset_y: NDArray[np.int32]
174
232
  width: NDArray[np.uint32]
175
233
  height: NDArray[np.uint32]
176
234
  channels: NDArray[np.uint8]
177
235
  size: NDArray[np.uint32]
178
236
  aspect_ratio: NDArray[np.float16]
179
237
  depth: NDArray[np.uint8]
180
- center: NDArray[np.int16]
181
- distance: NDArray[np.float16]
238
+ center: NDArray[np.int32]
239
+ distance_center: NDArray[np.float32]
240
+ distance_edge: NDArray[np.uint32]
182
241
 
183
242
 
184
243
  @dataclass(frozen=True)
@@ -154,10 +154,10 @@ def calc_params(p_i: NDArray[Any], n_i: NDArray[Any], niter: int) -> NDArray[Any
154
154
  Array of parameters to recreate line of best fit
155
155
  """
156
156
 
157
- def is_valid(f_new, x_new, f_old, x_old):
157
+ def is_valid(f_new, x_new, f_old, x_old) -> bool: # noqa: ANN001
158
158
  return f_new != np.nan
159
159
 
160
- def f(x):
160
+ def f(x) -> float: # noqa: ANN001
161
161
  try:
162
162
  return np.sum(np.square(p_i - f_out(n_i, x)))
163
163
  except RuntimeWarning:
dataeval/utils/_array.py CHANGED
@@ -23,7 +23,7 @@ T = TypeVar("T", ArrayLike, np.ndarray, torch.Tensor)
23
23
  _np_dtype = TypeVar("_np_dtype", bound=np.generic)
24
24
 
25
25
 
26
- def _try_import(module_name) -> ModuleType | None:
26
+ def _try_import(module_name: str) -> ModuleType | None:
27
27
  if module_name in _MODULE_CACHE:
28
28
  return _MODULE_CACHE[module_name]
29
29
 
@@ -148,8 +148,7 @@ def ensure_embeddings(
148
148
 
149
149
  if dtype is None:
150
150
  return embeddings
151
- else:
152
- return arr
151
+ return arr
153
152
 
154
153
 
155
154
  @overload
@@ -174,10 +173,9 @@ def flatten(array: ArrayLike) -> NDArray[Any] | torch.Tensor:
174
173
  if isinstance(array, np.ndarray):
175
174
  nparr = as_numpy(array)
176
175
  return nparr.reshape((nparr.shape[0], -1))
177
- elif isinstance(array, torch.Tensor):
176
+ if isinstance(array, torch.Tensor):
178
177
  return torch.flatten(array, start_dim=1)
179
- else:
180
- raise TypeError(f"Unsupported array type {type(array)}.")
178
+ raise TypeError(f"Unsupported array type {type(array)}.")
181
179
 
182
180
 
183
181
  _TArray = TypeVar("_TArray", bound=Array)
@@ -199,7 +197,6 @@ def channels_first_to_last(array: _TArray) -> _TArray:
199
197
  """
200
198
  if isinstance(array, np.ndarray):
201
199
  return np.transpose(array, (1, 2, 0))
202
- elif isinstance(array, torch.Tensor):
200
+ if isinstance(array, torch.Tensor):
203
201
  return torch.permute(array, (1, 2, 0))
204
- else:
205
- raise TypeError(f"Unsupported array type {type(array)}.")
202
+ raise TypeError(f"Unsupported array type {type(array)}.")
dataeval/utils/_bin.py CHANGED
@@ -195,5 +195,4 @@ def bin_by_clusters(data: NDArray[np.number[Any]]) -> NDArray[np.float64]:
195
195
  if extend_bins:
196
196
  bin_edges = np.concatenate([bin_edges, extend_bins])
197
197
 
198
- bin_edges = np.sort(bin_edges)
199
- return bin_edges
198
+ return np.sort(bin_edges)
@@ -4,6 +4,7 @@ __all__ = []
4
4
 
5
5
  import warnings
6
6
  from dataclasses import dataclass
7
+ from typing import Any
7
8
 
8
9
  import numba
9
10
  import numpy as np
@@ -30,7 +31,9 @@ from dataeval.utils._fast_mst import calculate_neighbor_distances, minimum_spann
30
31
 
31
32
 
32
33
  @numba.njit(parallel=True, locals={"i": numba.types.int32})
33
- def compare_links_to_cluster_std(mst, clusters):
34
+ def compare_links_to_cluster_std(
35
+ mst: NDArray[np.float32], clusters: NDArray[np.intp]
36
+ ) -> tuple[NDArray[np.int32], NDArray[np.int32]]:
34
37
  cluster_ids = np.unique(clusters)
35
38
  cluster_grouping = np.full(mst.shape[0], -1, dtype=np.int16)
36
39
 
@@ -79,7 +82,7 @@ def cluster(data: ArrayLike) -> ClusterData:
79
82
  cluster_selection_epsilon = 0.0
80
83
  # cluster_selection_method = "eom"
81
84
 
82
- x = flatten(to_numpy(data))
85
+ x: NDArray[Any] = flatten(to_numpy(data))
83
86
  samples, features = x.shape # Due to flatten(), we know shape has a length of 2
84
87
  if samples < 2:
85
88
  raise ValueError(f"Data should have at least 2 samples; got {samples}")
@@ -125,9 +128,9 @@ def cluster(data: ArrayLike) -> ClusterData:
125
128
  return ClusterData(clusters, mst, linkage_tree, condensed_tree, membership_strengths, kneighbors, kdistances)
126
129
 
127
130
 
128
- def sorted_union_find(index_groups):
131
+ def sorted_union_find(index_groups: NDArray[np.int32]) -> list[list[np.int32]]:
129
132
  """Merges and sorts groups of indices that share any common index"""
130
- groups = [[np.int32(x) for x in range(0)] for y in range(0)]
133
+ groups: list[list[np.int32]] = [[np.int32(x) for x in range(0)] for y in range(0)]
131
134
  uniques, inverse = np.unique(index_groups, return_inverse=True)
132
135
  inverse = inverse.flatten()
133
136
  disjoint_set = ds_rank_create(uniques.size)
@@ -6,9 +6,11 @@
6
6
  __all__ = []
7
7
 
8
8
  import warnings
9
+ from typing import Any
9
10
 
10
11
  import numba
11
12
  import numpy as np
13
+ from numpy.typing import NDArray
12
14
  from sklearn.neighbors import NearestNeighbors
13
15
 
14
16
  with warnings.catch_warnings():
@@ -17,24 +19,26 @@ with warnings.catch_warnings():
17
19
 
18
20
 
19
21
  @numba.njit()
20
- def _ds_union_by_rank(disjoint_set, point, nbr):
22
+ def _ds_union_by_rank(disjoint_set: tuple[NDArray[np.int32], NDArray[np.int32]], point: int, nbr: int) -> int:
21
23
  y = ds_find(disjoint_set, point)
22
24
  x = ds_find(disjoint_set, nbr)
23
25
 
24
26
  if x == y:
25
27
  return 0
26
28
 
27
- if disjoint_set.rank[x] < disjoint_set.rank[y]:
29
+ if disjoint_set[1][x] < disjoint_set[1][y]:
28
30
  x, y = y, x
29
31
 
30
- disjoint_set.parent[y] = x
31
- if disjoint_set.rank[x] == disjoint_set.rank[y]:
32
- disjoint_set.rank[x] += 1
32
+ disjoint_set[0][y] = x
33
+ if disjoint_set[1][x] == disjoint_set[1][y]:
34
+ disjoint_set[1][x] += 1
33
35
  return 1
34
36
 
35
37
 
36
38
  @numba.njit(locals={"i": numba.types.uint32, "nbr": numba.types.uint32, "dist": numba.types.float32})
37
- def _init_tree(n_neighbors, n_distance):
39
+ def _init_tree(
40
+ n_neighbors: NDArray[np.intp], n_distance: NDArray[np.float32]
41
+ ) -> tuple[NDArray[np.float32], int, tuple[NDArray[np.int32], NDArray[np.int32]], NDArray[np.uint32]]:
38
42
  # Initial graph to hold tree connections
39
43
  tree = np.zeros((n_neighbors.size - 1, 3), dtype=np.float32)
40
44
  disjoint_set = ds_rank_create(n_neighbors.size)
@@ -56,7 +60,13 @@ def _init_tree(n_neighbors, n_distance):
56
60
 
57
61
 
58
62
  @numba.njit(locals={"i": numba.types.uint32, "nbr": numba.types.uint32})
59
- def _update_tree_by_distance(tree, int_tree, disjoint_set, n_neighbors, n_distance):
63
+ def _update_tree_by_distance(
64
+ tree: NDArray[np.float32],
65
+ int_tree: int,
66
+ disjoint_set: tuple[NDArray[np.int32], NDArray[np.int32]],
67
+ n_neighbors: NDArray[np.uint32],
68
+ n_distance: NDArray[np.float32],
69
+ ) -> tuple[NDArray[np.float32], int, tuple[NDArray[np.int32], NDArray[np.int32]], NDArray[np.uint32]]:
60
70
  cluster_points = np.empty(n_neighbors.size, dtype=np.uint32)
61
71
  sort_dist = np.argsort(n_distance)
62
72
  dist_sorted = n_distance[sort_dist]
@@ -80,9 +90,9 @@ def _update_tree_by_distance(tree, int_tree, disjoint_set, n_neighbors, n_distan
80
90
 
81
91
 
82
92
  @numba.njit(locals={"i": numba.types.uint32})
83
- def _cluster_edges(tracker, last_idx, cluster_distances):
93
+ def _cluster_edges(tracker: NDArray[Any], last_idx: int, cluster_distances: NDArray[Any]) -> list[NDArray[np.intp]]:
84
94
  cluster_ids = np.unique(tracker)
85
- edge_points = []
95
+ edge_points: list[NDArray[np.intp]] = []
86
96
  for idx in range(cluster_ids.size):
87
97
  cluster_points = np.nonzero(tracker == cluster_ids[idx])[0]
88
98
  cluster_size = cluster_points.size
@@ -102,14 +112,16 @@ def _cluster_edges(tracker, last_idx, cluster_distances):
102
112
  return edge_points
103
113
 
104
114
 
105
- def _compute_nn(dataA, dataB, k):
115
+ def _compute_nn(dataA: NDArray[Any], dataB: NDArray[Any], k: int) -> tuple[NDArray[np.int32], NDArray[np.float32]]:
106
116
  distances, neighbors = NearestNeighbors(n_neighbors=k + 1, algorithm="brute").fit(dataA).kneighbors(dataB)
107
117
  neighbors = np.array(neighbors[:, 1 : k + 1], dtype=np.int32)
108
118
  distances = np.array(distances[:, 1 : k + 1], dtype=np.float32)
109
119
  return neighbors, distances
110
120
 
111
121
 
112
- def _calculate_cluster_neighbors(data, groups, point_array):
122
+ def _calculate_cluster_neighbors(
123
+ data: NDArray[Any], groups: list[NDArray[np.intp]], point_array: NDArray[Any]
124
+ ) -> tuple[NDArray[np.uint32], NDArray[np.float32]]:
113
125
  """Rerun nearest neighbor based on clusters"""
114
126
  cluster_neighbors = np.zeros(point_array.size, dtype=np.uint32)
115
127
  cluster_nbr_distances = np.full(point_array.size, np.inf, dtype=np.float32)
@@ -126,7 +138,9 @@ def _calculate_cluster_neighbors(data, groups, point_array):
126
138
  return cluster_neighbors, cluster_nbr_distances
127
139
 
128
140
 
129
- def minimum_spanning_tree(data, neighbors, distances):
141
+ def minimum_spanning_tree(
142
+ data: NDArray[Any], neighbors: NDArray[np.int32], distances: NDArray[np.float32]
143
+ ) -> NDArray[np.float32]:
130
144
  # Transpose arrays to get number of samples along a row
131
145
  k_neighbors = neighbors.T.astype(np.uint32).copy()
132
146
  k_distances = distances.T.astype(np.float32).copy()
@@ -168,7 +182,7 @@ def minimum_spanning_tree(data, neighbors, distances):
168
182
  return tree
169
183
 
170
184
 
171
- def calculate_neighbor_distances(data: np.ndarray, k: int = 10):
185
+ def calculate_neighbor_distances(data: np.ndarray, k: int = 10) -> tuple[NDArray[np.int32], NDArray[np.float32]]:
172
186
  # Have the potential to add in other distance calculations - supported calculations:
173
187
  # https://github.com/lmcinnes/pynndescent/blob/master/pynndescent/pynndescent_.py#L524
174
188
  try: