dataeval 0.86.0__py3-none-any.whl → 0.86.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +1 -1
- dataeval/_log.py +1 -1
- dataeval/config.py +21 -4
- dataeval/data/_embeddings.py +2 -2
- dataeval/data/_images.py +2 -3
- dataeval/data/_metadata.py +188 -178
- dataeval/data/_selection.py +1 -2
- dataeval/data/_split.py +4 -5
- dataeval/data/_targets.py +17 -13
- dataeval/data/selections/_classfilter.py +2 -5
- dataeval/data/selections/_prioritize.py +6 -9
- dataeval/data/selections/_shuffle.py +3 -1
- dataeval/detectors/drift/_base.py +4 -5
- dataeval/detectors/drift/_mmd.py +3 -6
- dataeval/detectors/drift/_nml/_base.py +4 -2
- dataeval/detectors/drift/_nml/_chunk.py +11 -19
- dataeval/detectors/drift/_nml/_domainclassifier.py +8 -19
- dataeval/detectors/drift/_nml/_result.py +8 -9
- dataeval/detectors/drift/_nml/_thresholds.py +66 -77
- dataeval/detectors/linters/outliers.py +7 -7
- dataeval/metadata/_distance.py +10 -7
- dataeval/metadata/_ood.py +11 -103
- dataeval/metrics/bias/_balance.py +23 -33
- dataeval/metrics/bias/_diversity.py +16 -14
- dataeval/metrics/bias/_parity.py +18 -18
- dataeval/metrics/estimators/_divergence.py +2 -4
- dataeval/metrics/stats/_base.py +103 -42
- dataeval/metrics/stats/_boxratiostats.py +21 -19
- dataeval/metrics/stats/_dimensionstats.py +14 -10
- dataeval/metrics/stats/_hashstats.py +1 -1
- dataeval/metrics/stats/_pixelstats.py +6 -6
- dataeval/metrics/stats/_visualstats.py +3 -3
- dataeval/outputs/_base.py +22 -7
- dataeval/outputs/_bias.py +24 -70
- dataeval/outputs/_drift.py +1 -9
- dataeval/outputs/_linters.py +11 -11
- dataeval/outputs/_stats.py +82 -23
- dataeval/outputs/_workflows.py +2 -2
- dataeval/utils/_array.py +6 -9
- dataeval/utils/_bin.py +1 -2
- dataeval/utils/_clusterer.py +7 -4
- dataeval/utils/_fast_mst.py +27 -13
- dataeval/utils/_image.py +65 -11
- dataeval/utils/_mst.py +1 -3
- dataeval/utils/_plot.py +15 -10
- dataeval/utils/data/_dataset.py +54 -28
- dataeval/utils/data/metadata.py +104 -82
- dataeval/utils/datasets/__init__.py +2 -0
- dataeval/utils/datasets/_antiuav.py +189 -0
- dataeval/utils/datasets/_base.py +11 -8
- dataeval/utils/datasets/_cifar10.py +104 -45
- dataeval/utils/datasets/_fileio.py +21 -47
- dataeval/utils/datasets/_milco.py +22 -12
- dataeval/utils/datasets/_mixin.py +2 -4
- dataeval/utils/datasets/_mnist.py +3 -4
- dataeval/utils/datasets/_ships.py +14 -7
- dataeval/utils/datasets/_voc.py +229 -42
- dataeval/utils/torch/models.py +5 -10
- dataeval/utils/torch/trainer.py +3 -3
- dataeval/workflows/sufficiency.py +2 -2
- {dataeval-0.86.0.dist-info → dataeval-0.86.2.dist-info}/METADATA +2 -1
- dataeval-0.86.2.dist-info/RECORD +114 -0
- dataeval/detectors/ood/vae.py +0 -74
- dataeval-0.86.0.dist-info/RECORD +0 -114
- {dataeval-0.86.0.dist-info → dataeval-0.86.2.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.86.0.dist-info → dataeval-0.86.2.dist-info}/WHEEL +0 -0
dataeval/outputs/_bias.py
CHANGED
@@ -4,7 +4,7 @@ __all__ = []
|
|
4
4
|
|
5
5
|
import contextlib
|
6
6
|
from dataclasses import asdict, dataclass
|
7
|
-
from typing import Any,
|
7
|
+
from typing import Any, TypeVar
|
8
8
|
|
9
9
|
import numpy as np
|
10
10
|
import pandas as pd
|
@@ -128,33 +128,30 @@ class CoverageOutput(Output):
|
|
128
128
|
|
129
129
|
import matplotlib.pyplot as plt
|
130
130
|
|
131
|
+
images = Images(images) if isinstance(images, Dataset) else images
|
132
|
+
if np.max(self.uncovered_indices) > len(images):
|
133
|
+
raise ValueError(
|
134
|
+
f"Uncovered indices {self.uncovered_indices} specify images "
|
135
|
+
f"unavailable in the provided number of images {len(images)}."
|
136
|
+
)
|
137
|
+
|
131
138
|
# Determine which images to plot
|
132
139
|
selected_indices = self.uncovered_indices[:top_k]
|
133
140
|
|
134
|
-
images = Images(images) if isinstance(images, Dataset) else images
|
135
|
-
|
136
141
|
# Plot the images
|
137
142
|
num_images = min(top_k, len(selected_indices))
|
138
143
|
|
139
144
|
rows = int(np.ceil(num_images / 3))
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
for i in range(rows):
|
151
|
-
for j in range(3):
|
152
|
-
i_j = i * 3 + j
|
153
|
-
if i_j >= len(selected_indices):
|
154
|
-
continue
|
155
|
-
image = channels_first_to_last(as_numpy(images[selected_indices[i_j]]))
|
156
|
-
axs[i, j].imshow(image)
|
157
|
-
axs[i, j].axis("off")
|
145
|
+
cols = min(3, num_images)
|
146
|
+
fig, axs = plt.subplots(rows, cols, figsize=(3 * cols, 3 * rows))
|
147
|
+
|
148
|
+
for image, ax in zip(images[:num_images], axs.flat):
|
149
|
+
image = channels_first_to_last(as_numpy(image))
|
150
|
+
ax.imshow(image)
|
151
|
+
ax.axis("off")
|
152
|
+
|
153
|
+
for ax in axs.flat[num_images:]:
|
154
|
+
ax.axis("off")
|
158
155
|
|
159
156
|
fig.tight_layout()
|
160
157
|
return fig
|
@@ -202,52 +199,11 @@ class BalanceOutput(Output):
|
|
202
199
|
factor_names: list[str]
|
203
200
|
class_names: list[str]
|
204
201
|
|
205
|
-
@overload
|
206
|
-
def _by_factor_type(
|
207
|
-
self,
|
208
|
-
attr: Literal["factor_names"],
|
209
|
-
factor_type: Literal["discrete", "continuous", "both"],
|
210
|
-
) -> list[str]: ...
|
211
|
-
|
212
|
-
@overload
|
213
|
-
def _by_factor_type(
|
214
|
-
self,
|
215
|
-
attr: Literal["balance", "factors", "classwise"],
|
216
|
-
factor_type: Literal["discrete", "continuous", "both"],
|
217
|
-
) -> NDArray[np.float64]: ...
|
218
|
-
|
219
|
-
def _by_factor_type(
|
220
|
-
self,
|
221
|
-
attr: Literal["balance", "factors", "classwise", "factor_names"],
|
222
|
-
factor_type: Literal["discrete", "continuous", "both"],
|
223
|
-
) -> NDArray[np.float64] | list[str]:
|
224
|
-
# if not filtering by factor_type then just return the requested attribute without mask
|
225
|
-
if factor_type == "both":
|
226
|
-
return getattr(self, attr)
|
227
|
-
|
228
|
-
# create the mask for the selected factor_type
|
229
|
-
mask_lambda = (
|
230
|
-
(lambda x: "-continuous" not in x) if factor_type == "discrete" else (lambda x: "-discrete" not in x)
|
231
|
-
)
|
232
|
-
|
233
|
-
# return the masked attribute
|
234
|
-
if attr == "factor_names":
|
235
|
-
return [x.replace(f"-{factor_type}", "") for x in self.factor_names if mask_lambda(x)]
|
236
|
-
else:
|
237
|
-
factor_type_mask = np.asarray([mask_lambda(x) for x in self.factor_names])
|
238
|
-
if attr == "factors":
|
239
|
-
return self.factors[factor_type_mask[1:]][:, factor_type_mask[1:]]
|
240
|
-
elif attr == "balance":
|
241
|
-
return self.balance[factor_type_mask]
|
242
|
-
elif attr == "classwise":
|
243
|
-
return self.classwise[:, factor_type_mask]
|
244
|
-
|
245
202
|
def plot(
|
246
203
|
self,
|
247
204
|
row_labels: list[Any] | NDArray[Any] | None = None,
|
248
205
|
col_labels: list[Any] | NDArray[Any] | None = None,
|
249
206
|
plot_classwise: bool = False,
|
250
|
-
factor_type: Literal["discrete", "continuous", "both"] = "discrete",
|
251
207
|
) -> Figure:
|
252
208
|
"""
|
253
209
|
Plot a heatmap of balance information.
|
@@ -260,8 +216,6 @@ class BalanceOutput(Output):
|
|
260
216
|
List/Array containing the labels for columns in the histogram
|
261
217
|
plot_classwise : bool, default False
|
262
218
|
Whether to plot per-class balance instead of global balance
|
263
|
-
factor_type : "discrete", "continuous", or "both", default "discrete"
|
264
|
-
Whether to plot discretized values, continuous values, or to include both
|
265
219
|
|
266
220
|
Returns
|
267
221
|
-------
|
@@ -275,10 +229,10 @@ class BalanceOutput(Output):
|
|
275
229
|
if row_labels is None:
|
276
230
|
row_labels = self.class_names
|
277
231
|
if col_labels is None:
|
278
|
-
col_labels = self.
|
232
|
+
col_labels = self.factor_names
|
279
233
|
|
280
234
|
fig = heatmap(
|
281
|
-
self.
|
235
|
+
self.classwise,
|
282
236
|
row_labels,
|
283
237
|
col_labels,
|
284
238
|
xlabel="Factors",
|
@@ -289,8 +243,8 @@ class BalanceOutput(Output):
|
|
289
243
|
# Combine balance and factors results
|
290
244
|
data = np.concatenate(
|
291
245
|
[
|
292
|
-
self.
|
293
|
-
self.
|
246
|
+
self.balance[np.newaxis, 1:],
|
247
|
+
self.factors,
|
294
248
|
],
|
295
249
|
axis=0,
|
296
250
|
)
|
@@ -299,7 +253,7 @@ class BalanceOutput(Output):
|
|
299
253
|
# Finalize the data for the plot, last row is last factor x last factor so it gets dropped
|
300
254
|
heat_data = np.where(mask, np.nan, data)[:-1]
|
301
255
|
# Creating label array for heat map axes
|
302
|
-
heat_labels = self.
|
256
|
+
heat_labels = self.factor_names
|
303
257
|
|
304
258
|
if row_labels is None:
|
305
259
|
row_labels = heat_labels[:-1]
|
@@ -379,7 +333,7 @@ class DiversityOutput(Output):
|
|
379
333
|
import matplotlib.pyplot as plt
|
380
334
|
|
381
335
|
fig, ax = plt.subplots(figsize=(8, 8))
|
382
|
-
heat_labels =
|
336
|
+
heat_labels = ["class_labels"] + self.factor_names
|
383
337
|
ax.bar(heat_labels, self.diversity_index)
|
384
338
|
ax.set_xlabel("Factors")
|
385
339
|
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
|
dataeval/outputs/_drift.py
CHANGED
@@ -103,19 +103,13 @@ class DriftMVDCOutput(PerMetricResult):
|
|
103
103
|
metric = Metric(display_name="Domain Classifier", column_name="domain_classifier_auroc")
|
104
104
|
super().__init__(results_data, [metric])
|
105
105
|
|
106
|
-
def plot(self
|
106
|
+
def plot(self) -> Figure:
|
107
107
|
"""
|
108
108
|
Render the roc_auc metric over the train/test data in relation to the threshold.
|
109
109
|
|
110
|
-
Parameters
|
111
|
-
----------
|
112
|
-
showme : bool, default True
|
113
|
-
Option to display the figure.
|
114
|
-
|
115
110
|
Returns
|
116
111
|
-------
|
117
112
|
matplotlib.figure.Figure
|
118
|
-
|
119
113
|
"""
|
120
114
|
import matplotlib.pyplot as plt
|
121
115
|
|
@@ -146,6 +140,4 @@ class DriftMVDCOutput(PerMetricResult):
|
|
146
140
|
ax.set_ylabel("ROC AUC", fontsize=7)
|
147
141
|
ax.set_xlabel("Chunk Index", fontsize=7)
|
148
142
|
ax.set_ylim((0.0, 1.1))
|
149
|
-
if showme:
|
150
|
-
plt.show()
|
151
143
|
return fig
|
dataeval/outputs/_linters.py
CHANGED
@@ -43,10 +43,12 @@ class DuplicatesOutput(Output, Generic[TIndexCollection]):
|
|
43
43
|
near: list[TIndexCollection]
|
44
44
|
|
45
45
|
|
46
|
-
def _reorganize_by_class_and_metric(
|
46
|
+
def _reorganize_by_class_and_metric(
|
47
|
+
result: IndexIssueMap, lstats: LabelStatsOutput
|
48
|
+
) -> tuple[dict[str, list[int]], dict[str, dict[str, int]]]:
|
47
49
|
"""Flip result from grouping by image to grouping by class and metric"""
|
48
|
-
metrics = {}
|
49
|
-
class_wise = {label: {} for label in lstats.class_names}
|
50
|
+
metrics: dict[str, list[int]] = {}
|
51
|
+
class_wise: dict[str, dict[str, int]] = {label: {} for label in lstats.class_names}
|
50
52
|
|
51
53
|
# Group metrics and calculate class-wise counts
|
52
54
|
for img, group in result.items():
|
@@ -59,7 +61,7 @@ def _reorganize_by_class_and_metric(result: IndexIssueMap, lstats: LabelStatsOut
|
|
59
61
|
return metrics, class_wise
|
60
62
|
|
61
63
|
|
62
|
-
def _create_table(metrics, class_wise):
|
64
|
+
def _create_table(metrics: dict[str, list[int]], class_wise: dict[str, dict[str, int]]) -> list[str]:
|
63
65
|
"""Create table for displaying the results"""
|
64
66
|
max_class_length = max(len(str(label)) for label in class_wise) + 2
|
65
67
|
max_total = max(len(metrics[group]) for group in metrics) + 2
|
@@ -69,7 +71,7 @@ def _create_table(metrics, class_wise):
|
|
69
71
|
+ [f"{group:^{max(5, len(str(group))) + 2}}" for group in sorted(metrics.keys())]
|
70
72
|
+ [f"{'Total':<{max_total}}"]
|
71
73
|
)
|
72
|
-
table_rows = []
|
74
|
+
table_rows: list[str] = []
|
73
75
|
|
74
76
|
for class_cat, results in class_wise.items():
|
75
77
|
table_value = [f"{class_cat:>{max_class_length}}"]
|
@@ -81,15 +83,14 @@ def _create_table(metrics, class_wise):
|
|
81
83
|
table_value.append(f"{total:^{max_total}}")
|
82
84
|
table_rows.append(" | ".join(table_value))
|
83
85
|
|
84
|
-
|
85
|
-
return table
|
86
|
+
return [table_header] + table_rows
|
86
87
|
|
87
88
|
|
88
|
-
def _create_pandas_dataframe(class_wise):
|
89
|
+
def _create_pandas_dataframe(class_wise: dict[str, dict[str, int]]) -> list[dict[str, str | int]]:
|
89
90
|
"""Create data for pandas dataframe"""
|
90
91
|
data = []
|
91
92
|
for label, metrics_dict in class_wise.items():
|
92
|
-
row = {"Class": label}
|
93
|
+
row: dict[str, str | int] = {"Class": label}
|
93
94
|
total = sum(metrics_dict.values())
|
94
95
|
row.update(metrics_dict) # Add metric counts
|
95
96
|
row["Total"] = total
|
@@ -118,8 +119,7 @@ class OutliersOutput(Output, Generic[TIndexIssueMap]):
|
|
118
119
|
def __len__(self) -> int:
|
119
120
|
if isinstance(self.issues, dict):
|
120
121
|
return len(self.issues)
|
121
|
-
|
122
|
-
return sum(len(d) for d in self.issues)
|
122
|
+
return sum(len(d) for d in self.issues)
|
123
123
|
|
124
124
|
def to_table(self, labelstats: LabelStatsOutput) -> str:
|
125
125
|
"""
|
dataeval/outputs/_stats.py
CHANGED
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
5
|
from dataclasses import dataclass
|
6
|
-
from typing import Any, Iterable, NamedTuple, Optional, Union
|
6
|
+
from typing import TYPE_CHECKING, Any, Iterable, NamedTuple, Optional, Sequence, Union
|
7
7
|
|
8
8
|
import numpy as np
|
9
9
|
import pandas as pd
|
@@ -13,10 +13,16 @@ from typing_extensions import TypeAlias
|
|
13
13
|
from dataeval.outputs._base import Output
|
14
14
|
from dataeval.utils._plot import channel_histogram_plot, histogram_plot
|
15
15
|
|
16
|
+
if TYPE_CHECKING:
|
17
|
+
from matplotlib.figure import Figure
|
18
|
+
|
16
19
|
OptionalRange: TypeAlias = Optional[Union[int, Iterable[int]]]
|
17
20
|
|
18
21
|
SOURCE_INDEX = "source_index"
|
19
|
-
|
22
|
+
OBJECT_COUNT = "object_count"
|
23
|
+
IMAGE_COUNT = "image_count"
|
24
|
+
|
25
|
+
BASE_ATTRS = (SOURCE_INDEX, OBJECT_COUNT, IMAGE_COUNT)
|
20
26
|
|
21
27
|
|
22
28
|
class SourceIndex(NamedTuple):
|
@@ -51,17 +57,24 @@ class BaseStatsOutput(Output):
|
|
51
57
|
----------
|
52
58
|
source_index : List[SourceIndex]
|
53
59
|
Mapping from statistic to source image, box and channel index
|
54
|
-
|
60
|
+
object_count : NDArray[np.uint16]
|
61
|
+
The number of detected objects in each image
|
55
62
|
"""
|
56
63
|
|
57
64
|
source_index: list[SourceIndex]
|
58
|
-
|
65
|
+
object_count: NDArray[np.uint16]
|
66
|
+
image_count: int
|
59
67
|
|
60
68
|
def __post_init__(self) -> None:
|
61
|
-
|
62
|
-
|
63
|
-
if
|
64
|
-
raise ValueError(f"All values must have the same length as source_index. Bad values: {str(
|
69
|
+
si_length = len(self.source_index)
|
70
|
+
mismatch = {k: len(v) for k, v in self.data().items() if k not in BASE_ATTRS and len(v) != si_length}
|
71
|
+
if mismatch:
|
72
|
+
raise ValueError(f"All values must have the same length as source_index. Bad values: {str(mismatch)}.")
|
73
|
+
oc_length = len(self.object_count)
|
74
|
+
if oc_length != self.image_count:
|
75
|
+
raise ValueError(
|
76
|
+
f"Total object counts per image does not match image count. {oc_length} != {self.image_count}."
|
77
|
+
)
|
65
78
|
|
66
79
|
def get_channel_mask(
|
67
80
|
self,
|
@@ -123,21 +136,64 @@ class BaseStatsOutput(Output):
|
|
123
136
|
|
124
137
|
return max_channels, ch_mask
|
125
138
|
|
126
|
-
def factors(
|
139
|
+
def factors(
|
140
|
+
self,
|
141
|
+
filter: str | Sequence[str] | None = None, # noqa: A002
|
142
|
+
exclude_constant: bool = False,
|
143
|
+
) -> dict[str, NDArray[Any]]:
|
144
|
+
"""
|
145
|
+
Returns all 1-dimensional data as a dictionary of numpy arrays.
|
146
|
+
|
147
|
+
Parameters
|
148
|
+
----------
|
149
|
+
filter : str, Sequence[str] or None, default None:
|
150
|
+
If provided, only returns keys that match the filter.
|
151
|
+
exclude_constant : bool, default False
|
152
|
+
If True, exclude arrays that contain only a single unique value.
|
153
|
+
|
154
|
+
Returns
|
155
|
+
-------
|
156
|
+
dict[str, NDArray[Any]]
|
157
|
+
"""
|
158
|
+
filter_ = [filter] if isinstance(filter, str) else filter
|
127
159
|
return {
|
128
160
|
k: v
|
129
161
|
for k, v in self.data().items()
|
130
|
-
if k not in
|
162
|
+
if k not in BASE_ATTRS
|
163
|
+
and (filter_ is None or k in filter_)
|
164
|
+
and isinstance(v, np.ndarray)
|
165
|
+
and v.ndim == 1
|
166
|
+
and (not exclude_constant or len(np.unique(v)) > 1)
|
131
167
|
}
|
132
168
|
|
133
169
|
def plot(
|
134
170
|
self, log: bool, channel_limit: int | None = None, channel_index: int | Iterable[int] | None = None
|
135
|
-
) ->
|
171
|
+
) -> Figure:
|
172
|
+
"""
|
173
|
+
Plots the statistics as a set of histograms.
|
174
|
+
|
175
|
+
Parameters
|
176
|
+
----------
|
177
|
+
log : bool
|
178
|
+
If True, plots the histograms on a logarithmic scale.
|
179
|
+
channel_limit : int or None
|
180
|
+
The maximum number of channels to plot. If None, all channels are plotted.
|
181
|
+
channel_index : int, Iterable[int] or None
|
182
|
+
The index or indices of the channels to plot. If None, all channels are plotted.
|
183
|
+
|
184
|
+
Returns
|
185
|
+
-------
|
186
|
+
matplotlib.Figure
|
187
|
+
"""
|
188
|
+
from matplotlib.figure import Figure
|
189
|
+
|
136
190
|
max_channels, ch_mask = self._get_channels(channel_limit, channel_index)
|
191
|
+
factors = self.factors(exclude_constant=True)
|
192
|
+
if not factors:
|
193
|
+
return Figure()
|
137
194
|
if max_channels == 1:
|
138
|
-
histogram_plot(
|
139
|
-
|
140
|
-
channel_histogram_plot(self.factors(), log, max_channels, ch_mask)
|
195
|
+
return histogram_plot(factors, log)
|
196
|
+
return channel_histogram_plot(factors, log, max_channels, ch_mask)
|
141
197
|
|
142
198
|
|
143
199
|
@dataclass(frozen=True)
|
@@ -147,9 +203,9 @@ class DimensionStatsOutput(BaseStatsOutput):
|
|
147
203
|
|
148
204
|
Attributes
|
149
205
|
----------
|
150
|
-
|
206
|
+
offset_x : NDArray[np.int32]
|
151
207
|
Offsets from the left edge of images in pixels
|
152
|
-
|
208
|
+
offset_y : NDArray[np.int32]
|
153
209
|
Offsets from the top edge of images in pixels
|
154
210
|
width : NDArray[np.uint32]
|
155
211
|
Width of the images in pixels
|
@@ -160,25 +216,28 @@ class DimensionStatsOutput(BaseStatsOutput):
|
|
160
216
|
size : NDArray[np.uint32]
|
161
217
|
Size of the images in pixels
|
162
218
|
aspect_ratio : NDArray[np.float16]
|
163
|
-
:term:`
|
219
|
+
:term:`Aspect Ratio<Aspect Ratio>` of the images (width/height)
|
164
220
|
depth : NDArray[np.uint8]
|
165
221
|
Color depth of the images in bits
|
166
|
-
center : NDArray[np.
|
222
|
+
center : NDArray[np.uint32]
|
167
223
|
Offset from center in [x,y] coordinates of the images in pixels
|
168
|
-
|
224
|
+
distance_center : NDArray[np.float32]
|
169
225
|
Distance in pixels from center
|
226
|
+
distance_edge : NDArray[np.uint32]
|
227
|
+
Distance in pixels from nearest edge
|
170
228
|
"""
|
171
229
|
|
172
|
-
|
173
|
-
|
230
|
+
offset_x: NDArray[np.int32]
|
231
|
+
offset_y: NDArray[np.int32]
|
174
232
|
width: NDArray[np.uint32]
|
175
233
|
height: NDArray[np.uint32]
|
176
234
|
channels: NDArray[np.uint8]
|
177
235
|
size: NDArray[np.uint32]
|
178
236
|
aspect_ratio: NDArray[np.float16]
|
179
237
|
depth: NDArray[np.uint8]
|
180
|
-
center: NDArray[np.
|
181
|
-
|
238
|
+
center: NDArray[np.int32]
|
239
|
+
distance_center: NDArray[np.float32]
|
240
|
+
distance_edge: NDArray[np.uint32]
|
182
241
|
|
183
242
|
|
184
243
|
@dataclass(frozen=True)
|
dataeval/outputs/_workflows.py
CHANGED
@@ -154,10 +154,10 @@ def calc_params(p_i: NDArray[Any], n_i: NDArray[Any], niter: int) -> NDArray[Any
|
|
154
154
|
Array of parameters to recreate line of best fit
|
155
155
|
"""
|
156
156
|
|
157
|
-
def is_valid(f_new, x_new, f_old, x_old):
|
157
|
+
def is_valid(f_new, x_new, f_old, x_old) -> bool: # noqa: ANN001
|
158
158
|
return f_new != np.nan
|
159
159
|
|
160
|
-
def f(x):
|
160
|
+
def f(x) -> float: # noqa: ANN001
|
161
161
|
try:
|
162
162
|
return np.sum(np.square(p_i - f_out(n_i, x)))
|
163
163
|
except RuntimeWarning:
|
dataeval/utils/_array.py
CHANGED
@@ -23,7 +23,7 @@ T = TypeVar("T", ArrayLike, np.ndarray, torch.Tensor)
|
|
23
23
|
_np_dtype = TypeVar("_np_dtype", bound=np.generic)
|
24
24
|
|
25
25
|
|
26
|
-
def _try_import(module_name) -> ModuleType | None:
|
26
|
+
def _try_import(module_name: str) -> ModuleType | None:
|
27
27
|
if module_name in _MODULE_CACHE:
|
28
28
|
return _MODULE_CACHE[module_name]
|
29
29
|
|
@@ -148,8 +148,7 @@ def ensure_embeddings(
|
|
148
148
|
|
149
149
|
if dtype is None:
|
150
150
|
return embeddings
|
151
|
-
|
152
|
-
return arr
|
151
|
+
return arr
|
153
152
|
|
154
153
|
|
155
154
|
@overload
|
@@ -174,10 +173,9 @@ def flatten(array: ArrayLike) -> NDArray[Any] | torch.Tensor:
|
|
174
173
|
if isinstance(array, np.ndarray):
|
175
174
|
nparr = as_numpy(array)
|
176
175
|
return nparr.reshape((nparr.shape[0], -1))
|
177
|
-
|
176
|
+
if isinstance(array, torch.Tensor):
|
178
177
|
return torch.flatten(array, start_dim=1)
|
179
|
-
|
180
|
-
raise TypeError(f"Unsupported array type {type(array)}.")
|
178
|
+
raise TypeError(f"Unsupported array type {type(array)}.")
|
181
179
|
|
182
180
|
|
183
181
|
_TArray = TypeVar("_TArray", bound=Array)
|
@@ -199,7 +197,6 @@ def channels_first_to_last(array: _TArray) -> _TArray:
|
|
199
197
|
"""
|
200
198
|
if isinstance(array, np.ndarray):
|
201
199
|
return np.transpose(array, (1, 2, 0))
|
202
|
-
|
200
|
+
if isinstance(array, torch.Tensor):
|
203
201
|
return torch.permute(array, (1, 2, 0))
|
204
|
-
|
205
|
-
raise TypeError(f"Unsupported array type {type(array)}.")
|
202
|
+
raise TypeError(f"Unsupported array type {type(array)}.")
|
dataeval/utils/_bin.py
CHANGED
dataeval/utils/_clusterer.py
CHANGED
@@ -4,6 +4,7 @@ __all__ = []
|
|
4
4
|
|
5
5
|
import warnings
|
6
6
|
from dataclasses import dataclass
|
7
|
+
from typing import Any
|
7
8
|
|
8
9
|
import numba
|
9
10
|
import numpy as np
|
@@ -30,7 +31,9 @@ from dataeval.utils._fast_mst import calculate_neighbor_distances, minimum_spann
|
|
30
31
|
|
31
32
|
|
32
33
|
@numba.njit(parallel=True, locals={"i": numba.types.int32})
|
33
|
-
def compare_links_to_cluster_std(
|
34
|
+
def compare_links_to_cluster_std(
|
35
|
+
mst: NDArray[np.float32], clusters: NDArray[np.intp]
|
36
|
+
) -> tuple[NDArray[np.int32], NDArray[np.int32]]:
|
34
37
|
cluster_ids = np.unique(clusters)
|
35
38
|
cluster_grouping = np.full(mst.shape[0], -1, dtype=np.int16)
|
36
39
|
|
@@ -79,7 +82,7 @@ def cluster(data: ArrayLike) -> ClusterData:
|
|
79
82
|
cluster_selection_epsilon = 0.0
|
80
83
|
# cluster_selection_method = "eom"
|
81
84
|
|
82
|
-
x = flatten(to_numpy(data))
|
85
|
+
x: NDArray[Any] = flatten(to_numpy(data))
|
83
86
|
samples, features = x.shape # Due to flatten(), we know shape has a length of 2
|
84
87
|
if samples < 2:
|
85
88
|
raise ValueError(f"Data should have at least 2 samples; got {samples}")
|
@@ -125,9 +128,9 @@ def cluster(data: ArrayLike) -> ClusterData:
|
|
125
128
|
return ClusterData(clusters, mst, linkage_tree, condensed_tree, membership_strengths, kneighbors, kdistances)
|
126
129
|
|
127
130
|
|
128
|
-
def sorted_union_find(index_groups):
|
131
|
+
def sorted_union_find(index_groups: NDArray[np.int32]) -> list[list[np.int32]]:
|
129
132
|
"""Merges and sorts groups of indices that share any common index"""
|
130
|
-
groups = [[np.int32(x) for x in range(0)] for y in range(0)]
|
133
|
+
groups: list[list[np.int32]] = [[np.int32(x) for x in range(0)] for y in range(0)]
|
131
134
|
uniques, inverse = np.unique(index_groups, return_inverse=True)
|
132
135
|
inverse = inverse.flatten()
|
133
136
|
disjoint_set = ds_rank_create(uniques.size)
|
dataeval/utils/_fast_mst.py
CHANGED
@@ -6,9 +6,11 @@
|
|
6
6
|
__all__ = []
|
7
7
|
|
8
8
|
import warnings
|
9
|
+
from typing import Any
|
9
10
|
|
10
11
|
import numba
|
11
12
|
import numpy as np
|
13
|
+
from numpy.typing import NDArray
|
12
14
|
from sklearn.neighbors import NearestNeighbors
|
13
15
|
|
14
16
|
with warnings.catch_warnings():
|
@@ -17,24 +19,26 @@ with warnings.catch_warnings():
|
|
17
19
|
|
18
20
|
|
19
21
|
@numba.njit()
|
20
|
-
def _ds_union_by_rank(disjoint_set, point, nbr):
|
22
|
+
def _ds_union_by_rank(disjoint_set: tuple[NDArray[np.int32], NDArray[np.int32]], point: int, nbr: int) -> int:
|
21
23
|
y = ds_find(disjoint_set, point)
|
22
24
|
x = ds_find(disjoint_set, nbr)
|
23
25
|
|
24
26
|
if x == y:
|
25
27
|
return 0
|
26
28
|
|
27
|
-
if disjoint_set
|
29
|
+
if disjoint_set[1][x] < disjoint_set[1][y]:
|
28
30
|
x, y = y, x
|
29
31
|
|
30
|
-
disjoint_set
|
31
|
-
if disjoint_set
|
32
|
-
disjoint_set
|
32
|
+
disjoint_set[0][y] = x
|
33
|
+
if disjoint_set[1][x] == disjoint_set[1][y]:
|
34
|
+
disjoint_set[1][x] += 1
|
33
35
|
return 1
|
34
36
|
|
35
37
|
|
36
38
|
@numba.njit(locals={"i": numba.types.uint32, "nbr": numba.types.uint32, "dist": numba.types.float32})
|
37
|
-
def _init_tree(
|
39
|
+
def _init_tree(
|
40
|
+
n_neighbors: NDArray[np.intp], n_distance: NDArray[np.float32]
|
41
|
+
) -> tuple[NDArray[np.float32], int, tuple[NDArray[np.int32], NDArray[np.int32]], NDArray[np.uint32]]:
|
38
42
|
# Initial graph to hold tree connections
|
39
43
|
tree = np.zeros((n_neighbors.size - 1, 3), dtype=np.float32)
|
40
44
|
disjoint_set = ds_rank_create(n_neighbors.size)
|
@@ -56,7 +60,13 @@ def _init_tree(n_neighbors, n_distance):
|
|
56
60
|
|
57
61
|
|
58
62
|
@numba.njit(locals={"i": numba.types.uint32, "nbr": numba.types.uint32})
|
59
|
-
def _update_tree_by_distance(
|
63
|
+
def _update_tree_by_distance(
|
64
|
+
tree: NDArray[np.float32],
|
65
|
+
int_tree: int,
|
66
|
+
disjoint_set: tuple[NDArray[np.int32], NDArray[np.int32]],
|
67
|
+
n_neighbors: NDArray[np.uint32],
|
68
|
+
n_distance: NDArray[np.float32],
|
69
|
+
) -> tuple[NDArray[np.float32], int, tuple[NDArray[np.int32], NDArray[np.int32]], NDArray[np.uint32]]:
|
60
70
|
cluster_points = np.empty(n_neighbors.size, dtype=np.uint32)
|
61
71
|
sort_dist = np.argsort(n_distance)
|
62
72
|
dist_sorted = n_distance[sort_dist]
|
@@ -80,9 +90,9 @@ def _update_tree_by_distance(tree, int_tree, disjoint_set, n_neighbors, n_distan
|
|
80
90
|
|
81
91
|
|
82
92
|
@numba.njit(locals={"i": numba.types.uint32})
|
83
|
-
def _cluster_edges(tracker, last_idx, cluster_distances):
|
93
|
+
def _cluster_edges(tracker: NDArray[Any], last_idx: int, cluster_distances: NDArray[Any]) -> list[NDArray[np.intp]]:
|
84
94
|
cluster_ids = np.unique(tracker)
|
85
|
-
edge_points = []
|
95
|
+
edge_points: list[NDArray[np.intp]] = []
|
86
96
|
for idx in range(cluster_ids.size):
|
87
97
|
cluster_points = np.nonzero(tracker == cluster_ids[idx])[0]
|
88
98
|
cluster_size = cluster_points.size
|
@@ -102,14 +112,16 @@ def _cluster_edges(tracker, last_idx, cluster_distances):
|
|
102
112
|
return edge_points
|
103
113
|
|
104
114
|
|
105
|
-
def _compute_nn(dataA, dataB, k):
|
115
|
+
def _compute_nn(dataA: NDArray[Any], dataB: NDArray[Any], k: int) -> tuple[NDArray[np.int32], NDArray[np.float32]]:
|
106
116
|
distances, neighbors = NearestNeighbors(n_neighbors=k + 1, algorithm="brute").fit(dataA).kneighbors(dataB)
|
107
117
|
neighbors = np.array(neighbors[:, 1 : k + 1], dtype=np.int32)
|
108
118
|
distances = np.array(distances[:, 1 : k + 1], dtype=np.float32)
|
109
119
|
return neighbors, distances
|
110
120
|
|
111
121
|
|
112
|
-
def _calculate_cluster_neighbors(
|
122
|
+
def _calculate_cluster_neighbors(
|
123
|
+
data: NDArray[Any], groups: list[NDArray[np.intp]], point_array: NDArray[Any]
|
124
|
+
) -> tuple[NDArray[np.uint32], NDArray[np.float32]]:
|
113
125
|
"""Rerun nearest neighbor based on clusters"""
|
114
126
|
cluster_neighbors = np.zeros(point_array.size, dtype=np.uint32)
|
115
127
|
cluster_nbr_distances = np.full(point_array.size, np.inf, dtype=np.float32)
|
@@ -126,7 +138,9 @@ def _calculate_cluster_neighbors(data, groups, point_array):
|
|
126
138
|
return cluster_neighbors, cluster_nbr_distances
|
127
139
|
|
128
140
|
|
129
|
-
def minimum_spanning_tree(
|
141
|
+
def minimum_spanning_tree(
|
142
|
+
data: NDArray[Any], neighbors: NDArray[np.int32], distances: NDArray[np.float32]
|
143
|
+
) -> NDArray[np.float32]:
|
130
144
|
# Transpose arrays to get number of samples along a row
|
131
145
|
k_neighbors = neighbors.T.astype(np.uint32).copy()
|
132
146
|
k_distances = distances.T.astype(np.float32).copy()
|
@@ -168,7 +182,7 @@ def minimum_spanning_tree(data, neighbors, distances):
|
|
168
182
|
return tree
|
169
183
|
|
170
184
|
|
171
|
-
def calculate_neighbor_distances(data: np.ndarray, k: int = 10):
|
185
|
+
def calculate_neighbor_distances(data: np.ndarray, k: int = 10) -> tuple[NDArray[np.int32], NDArray[np.float32]]:
|
172
186
|
# Have the potential to add in other distance calculations - supported calculations:
|
173
187
|
# https://github.com/lmcinnes/pynndescent/blob/master/pynndescent/pynndescent_.py#L524
|
174
188
|
try:
|